Setup SageMaker#

Let load environment variable which containt ROLE for sagemaker.

import os
from dotenv import load_dotenv
from sagemaker.utils import name_from_base
load_dotenv('.env.example')

Create an endpoint name

# append a timestamp to string
endpoint_name = name_from_base('sdxl-1-0-jumstart')
endpoint_name

Model Data#

First let get default bucket

from sagemaker import Session
bucket = Session().default_bucket()

Then, download a pre-trained model from s3 and upload to s3

model_filename = "sdxlv1-sgm0.1.0.tar.gz"
model_source_uri = f"s3://stabilityai-public-packages/model-packages/sdxl-v1-0-dlc/sgm0.1.0/{model_filename}"
model_uri = f's3://{bucket}/stabilityai/sdxl-v1-0-dlc/sgm0.1.0/{model_filename}'

Upload to s3

!aws s3 cp {model_source_uri} {model_filename}
!aws s3 cp {model_filename} {model_uri}

Double check the size of the model

!aws s3 ls s3://{bucket}/stabilityai/sdxl-v1-0-dlc/sgm0.1.0/ \
--human-readable \
--recursive

Async Endpoint#

Let create a model object

from sagemaker.pytorch.model import PyTorchModel
from sagemaker.predictor import Predictor
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import BytesDeserializer
from sagemaker.async_inference import AsyncInferenceConfig

We need to specify model_data and iamge_uri, and instance type

inference_image_uri_region = "us-east-1"
inference_image_uri_region_acct = "763104351884"
inference_image_uri = f"{inference_image_uri_region_acct}.dkr.ecr.{inference_image_uri_region}.amazonaws.com/stabilityai-pytorch-inference:2.0.1-sgm0.1.0-gpu-py310-cu118-ubuntu20.04-sagemaker"

Let create a model

pytorch_model = PyTorchModel(
name=endpoint_name,
model_data=model_uri,
image_uri=inference_image_uri,
role=os.environ['ROLE']
)

Let create an async endpoint configuration

pytorch_model = PyTorchModel(
name=endpoint_name,
model_data=model_uri,
image_uri=inference_image_uri,
role=os.environ['ROLE']
)

Finally, deploy the model

deployed_model = pytorch_model.deploy(
endpoint_name=endpoint_name,
initial_instance_count=1,
instance_type="ml.g5.4xlarge",
serializer=JSONSerializer(),
deserializer=BytesDeserializer(accept="image/png"),
async_inference_config=async_config
)

Autoscaling#

  • Register the endpoing to the autoscaling
  • Create an autoscaling policy
  • Apply the policy to the scaling target
  • Scale based on custom metrics monitored in cloudwatch HERE

It is possible to scale to zero. First, let register the endpoint with the autoscaling

import boto3
client = boto3.client("application-autoscaling")
resource_id = "endpoint/" + endpoint_name + "/variant/" + "AllTraffic"
response = client.register_scalable_target(
ServiceNamespace="sagemaker",
ResourceId=resource_id,
ScalableDimension="sagemaker:variant:DilsesiredInstanceCount",
MinCapacity=0,
MaxCapacity=1,
)

Then, create a scaling policy and apply it to the target endpoint

response = client.put_scaling_policy(
PolicyName="Invocations-ScalingPolicy",
ServiceNamespace="sagemaker", # The namespace of the AWS service that provides the resource.
ResourceId=resource_id, # Endpoint name
ScalableDimension="sagemaker:variant:DesiredInstanceCount", # SageMaker supports only Instance Count
PolicyType="TargetTrackingScaling", # 'StepScaling'|'TargetTrackingScaling'
TargetTrackingScalingPolicyConfiguration={
"TargetValue": 5.0, # The target value for the metric. - here the metric is - SageMakerVariantInvocationsPerInstance
"CustomizedMetricSpecification": {
"MetricName": "ApproximateBacklogSizePerInstance",
"Namespace": "AWS/SageMaker",
"Dimensions": [{"Name": "EndpointName", "Value": endpoint_name}],
"Statistic": "Average",
},
"ScaleInCooldown": 1800, # The amount of time, in seconds, after a scale in activity completes before another scale in activity can start.
"ScaleOutCooldown": 300, # ScaleOutCooldown - The amount of time, in seconds, after a scale out activity completes before another scale out activity can start.
# 'DisableScaleIn': True|False - indicates whether scale in by the target tracking policy is disabled.
# If the value is true, scale in is disabled and the target tracking policy won't remove capacity from the scalable resource.
},
)

ApproximateBacklogSize

The number of items in the queue for an endpoint that are currently being processed or yet to be processed.

ApproximateBacklogSizePerInstance

Number of items in the queue divided by the number of instances behind an endpoint. This metric is primarily used for setting up application autoscaling for an async-enabled endpoint.

Under the hood, two cloudwatch alarms are created to monitor the metric and trigger the scaling policy action
<LinkedImage
height={400}
alt="alarm-auto-scaling-sagemaker-endpoint"
src="/thumbnail/alarm-autoscaling-sm-endpoint.png"
/>
## Predictor
Let use the predictor to invoke the async endpoint
```py
output = deployed_model.predict(
{"text_prompts":[{"text": "tiger and wife in anime style", "weight": 1.0}],
"width": 512,
"height": 512,
# "style_preset": "anime",
"cfg_scale": 15,
"samples": 1,
"seed": 3,
"num_inference_steps": 30,
# "sampler": "DDIM",
}
)

Save and open the image

filename = 'some_image.png'
with open(filename, 'wb') as f:
f.write(output)

SageMaker Runtime#

Let create sagemaker runtime client

import boto3
import json
sm_client = boto3.client('sagemaker-runtime')

Then create payload request

# payload request
payload = json.dumps({
"text_prompts": [
{
"text": "tiger and wife in anime style",
"weight": 1.0
}
],
"cfg_scale": 15,
"samples": 1,
"seed": 3,
"style_preset": "anime",
"num_inference_steps": 30,
"height": 1024,
"width": 1024
}, indent=4)

Then upload the input (sample.json) to s3 input path for async invocation

with open ('sample.json', "w") as file:
file.write(payload)
session = Session()
session.upload_data(
"sample.json",
bucket=bucket,
key_prefix="stable-diffusion-async-input",
extra_args={"ContentType": "application/json"}
)

Invoke the async endpoint

# invoke endpoint
response = sm_client.invoke_endpoint_async(
EndpointName=endpoint_name,
InputLocation=f"s3://{bucket}/stable-diffusion-async-input/sample.json",
ContentType="application/json",
InvocationTimeoutSeconds=3600,
Accept="image/png"
)

Print the response, download error or output image

!aws s3 cp {response['ResponseMetadata']['HTTPHeaders']['x-amzn-sagemaker-failurelocation']} .

Download output image

!aws s3 cp {response['ResponseMetadata']['HTTPHeaders']['x-amzn-sagemaker-outputlocation']} .

Reference#