Setup SageMaker#
Let load environment variable which containt ROLE for sagemaker.
import osfrom dotenv import load_dotenvfrom sagemaker.utils import name_from_baseload_dotenv('.env.example')
Create an endpoint name
# append a timestamp to stringendpoint_name = name_from_base('sdxl-1-0-jumstart')endpoint_name
Model Data#
First let get default bucket
from sagemaker import Sessionbucket = Session().default_bucket()
Then, download a pre-trained model from s3 and upload to s3
model_filename = "sdxlv1-sgm0.1.0.tar.gz"model_source_uri = f"s3://stabilityai-public-packages/model-packages/sdxl-v1-0-dlc/sgm0.1.0/{model_filename}"model_uri = f's3://{bucket}/stabilityai/sdxl-v1-0-dlc/sgm0.1.0/{model_filename}'
Upload to s3
!aws s3 cp {model_source_uri} {model_filename}!aws s3 cp {model_filename} {model_uri}
Double check the size of the model
!aws s3 ls s3://{bucket}/stabilityai/sdxl-v1-0-dlc/sgm0.1.0/ \--human-readable \--recursive
Async Endpoint#
Let create a model object
from sagemaker.pytorch.model import PyTorchModelfrom sagemaker.predictor import Predictorfrom sagemaker.serializers import JSONSerializerfrom sagemaker.deserializers import BytesDeserializerfrom sagemaker.async_inference import AsyncInferenceConfig
We need to specify model_data and iamge_uri, and instance type
inference_image_uri_region = "us-east-1"inference_image_uri_region_acct = "763104351884"inference_image_uri = f"{inference_image_uri_region_acct}.dkr.ecr.{inference_image_uri_region}.amazonaws.com/stabilityai-pytorch-inference:2.0.1-sgm0.1.0-gpu-py310-cu118-ubuntu20.04-sagemaker"
Let create a model
pytorch_model = PyTorchModel(name=endpoint_name,model_data=model_uri,image_uri=inference_image_uri,role=os.environ['ROLE'])
Let create an async endpoint configuration
pytorch_model = PyTorchModel(name=endpoint_name,model_data=model_uri,image_uri=inference_image_uri,role=os.environ['ROLE'])
Finally, deploy the model
deployed_model = pytorch_model.deploy(endpoint_name=endpoint_name,initial_instance_count=1,instance_type="ml.g5.4xlarge",serializer=JSONSerializer(),deserializer=BytesDeserializer(accept="image/png"),async_inference_config=async_config)
Autoscaling#
- Register the endpoing to the autoscaling
- Create an autoscaling policy
- Apply the policy to the scaling target
- Scale based on custom metrics monitored in cloudwatch HERE
It is possible to scale to zero. First, let register the endpoint with the autoscaling
import boto3client = boto3.client("application-autoscaling")resource_id = "endpoint/" + endpoint_name + "/variant/" + "AllTraffic"response = client.register_scalable_target(ServiceNamespace="sagemaker",ResourceId=resource_id,ScalableDimension="sagemaker:variant:DilsesiredInstanceCount",MinCapacity=0,MaxCapacity=1,)
Then, create a scaling policy and apply it to the target endpoint
response = client.put_scaling_policy(PolicyName="Invocations-ScalingPolicy",ServiceNamespace="sagemaker", # The namespace of the AWS service that provides the resource.ResourceId=resource_id, # Endpoint nameScalableDimension="sagemaker:variant:DesiredInstanceCount", # SageMaker supports only Instance CountPolicyType="TargetTrackingScaling", # 'StepScaling'|'TargetTrackingScaling'TargetTrackingScalingPolicyConfiguration={"TargetValue": 5.0, # The target value for the metric. - here the metric is - SageMakerVariantInvocationsPerInstance"CustomizedMetricSpecification": {"MetricName": "ApproximateBacklogSizePerInstance","Namespace": "AWS/SageMaker","Dimensions": [{"Name": "EndpointName", "Value": endpoint_name}],"Statistic": "Average",},"ScaleInCooldown": 1800, # The amount of time, in seconds, after a scale in activity completes before another scale in activity can start."ScaleOutCooldown": 300, # ScaleOutCooldown - The amount of time, in seconds, after a scale out activity completes before another scale out activity can start.# 'DisableScaleIn': True|False - indicates whether scale in by the target tracking policy is disabled.# If the value is true, scale in is disabled and the target tracking policy won't remove capacity from the scalable resource.},)
ApproximateBacklogSize
The number of items in the queue for an endpoint that are currently being processed or yet to be processed.
ApproximateBacklogSizePerInstance
Number of items in the queue divided by the number of instances behind an endpoint. This metric is primarily used for setting up application autoscaling for an async-enabled endpoint.
Under the hood, two cloudwatch alarms are created to monitor the metric and trigger the scaling policy action<LinkedImageheight={400}alt="alarm-auto-scaling-sagemaker-endpoint"src="/thumbnail/alarm-autoscaling-sm-endpoint.png"/>## PredictorLet use the predictor to invoke the async endpoint```pyoutput = deployed_model.predict({"text_prompts":[{"text": "tiger and wife in anime style", "weight": 1.0}],"width": 512,"height": 512,# "style_preset": "anime","cfg_scale": 15,"samples": 1,"seed": 3,"num_inference_steps": 30,# "sampler": "DDIM",})
Save and open the image
filename = 'some_image.png'with open(filename, 'wb') as f:f.write(output)
SageMaker Runtime#
Let create sagemaker runtime client
import boto3import jsonsm_client = boto3.client('sagemaker-runtime')
Then create payload request
# payload requestpayload = json.dumps({"text_prompts": [{"text": "tiger and wife in anime style","weight": 1.0}],"cfg_scale": 15,"samples": 1,"seed": 3,"style_preset": "anime","num_inference_steps": 30,"height": 1024,"width": 1024}, indent=4)
Then upload the input (sample.json) to s3 input path for async invocation
with open ('sample.json', "w") as file:file.write(payload)session = Session()session.upload_data("sample.json",bucket=bucket,key_prefix="stable-diffusion-async-input",extra_args={"ContentType": "application/json"})
Invoke the async endpoint
# invoke endpointresponse = sm_client.invoke_endpoint_async(EndpointName=endpoint_name,InputLocation=f"s3://{bucket}/stable-diffusion-async-input/sample.json",ContentType="application/json",InvocationTimeoutSeconds=3600,Accept="image/png")
Print the response, download error or output image
!aws s3 cp {response['ResponseMetadata']['HTTPHeaders']['x-amzn-sagemaker-failurelocation']} .
Download output image
!aws s3 cp {response['ResponseMetadata']['HTTPHeaders']['x-amzn-sagemaker-outputlocation']} .