Introduction#
GitHub this note shows how to build a simple image generating app with nextjs and stable diffusion hosted on Amazon SageMaker.
Lambda Function#
- Deploy via ECR
- Configure provisioned concurrency
Due to heavy dependencies 303MB of sagemaker and stable difussion client, I have to deploy the Lambda function using ECR
const draw = new aws_lambda.Function(this, 'DiffusionLambdaPublic', {functionName: 'DiffusionLambdaPublic',code: aws_lambda.EcrImageCode.fromAssetImage(path.join(__dirname, './../lambda-diffusion')),handler: aws_lambda.Handler.FROM_IMAGE,runtime: aws_lambda.Runtime.FROM_IMAGE,memorySize: 1024,timeout: Duration.seconds(25),role: roleForLambda,environment: {BUCKET_NAME: props.bucketName,ENDPOINT_NAME: props.endpointName}})
The total latency of generating image by stable diffusion (25 seconds) and lambda cold start might be greater than 29 seconds. To hot fix for a demo, I configure Lambda provisioned concurrency
const alias = new aws_lambda.Alias(this, 'LiveAliasConcurrencyProvisioned', {aliasName: 'live',version: draw.currentVersion,provisionedConcurrentExecutions: 10})
Project structure for lambda ecr
|--bin|--lib|--lambda-diffusion|--Dockerfile|--.dockerignore|--requirements.txt|--index.py|--package
We need to install dependencies in target directory package
python -m pip install -r requirements.txt --target package
Content of the Dockerfile
FROM public.ecr.aws/lambda/python:3.9# create code dir inside containerRUN mkdir ${LAMBDA_TASK_ROOT}/source# copy code to containerCOPY "requirements.txt" ${LAMBDA_TASK_ROOT}/source# copy handler function to containerCOPY ./index.py ${LAMBDA_TASK_ROOT}# install dependencies for running time environmentCOPY ./package/ ${LAMBDA_TASK_ROOT}# set the CMD to your handlerCMD [ "index.handler" ]
API Gateway#
Let integrate the lambda alias with API Gateway. Please ensure that the API Gateway has permissions to invoke the lambda alias
const roleForApiGw = new aws_iam.Role(this, 'RoleForApiGwInvokeDrawPublic', {roleName: 'RoleForApiGwInvokeDrawPublic',assumedBy: new aws_iam.ServicePrincipal('apigateway.amazonaws.com')})roleForApiGw.addToPolicy(new aws_iam.PolicyStatement({effect: aws_iam.Effect.ALLOW,actions: ['lambda:InvokeFunction'],resources: [draw.functionArn, `${draw.functionArn}:*`]}))roleForApiGw.addToPolicy(new aws_iam.PolicyStatement({effect: aws_iam.Effect.ALLOW,actions: ['logs:CreateLogGroup','logs:CreateLogStream','logs:DescribeLogGroups','logs:DescribeLogStreams','logs:PutLogEvents','logs:GetLogEvents','logs:FilterLogEvents'],resources: ['*']}))
Enable CORS and logging
const apigw = new aws_apigateway.RestApi(this, 'ApiForDiffusionModelPublic', {restApiName: 'ApiForDiffusionModelPublic',deploy: false,cloudWatchRole: true})const image = apigw.root.addResource('image')const getImageMethod = image.addMethod('GET',new aws_apigateway.LambdaIntegration(alias, {credentialsRole: roleForApiGw}))image.addCorsPreflight({allowOrigins: ['*'],allowMethods: ['GET', 'POST', 'OPTIONS'],allowHeaders: ['*']})const logGroup = new aws_logs.LogGroup(this, 'AccessLogForDiffusionPublic', {logGroupName: 'AccessLogForDiffusionPublic',removalPolicy: RemovalPolicy.DESTROY,retention: RetentionDays.ONE_WEEK})
Create deployment stage
const deployment = new aws_apigateway.Deployment(this,'DeployDiffusionApiPublic',{api: apigw})const prodStage = new aws_apigateway.Stage(this, 'DiffusionProdStagePublic', {stageName: 'prod',deployment,dataTraceEnabled: true,accessLogDestination: new aws_apigateway.LogGroupLogDestination(logGroup),accessLogFormat: aws_apigateway.AccessLogFormat.jsonWithStandardFields()})
Finally, consider API Key and usage plan
new aws_apigateway.RateLimitedApiKey(this, 'RateLimitForDiffusionPublic', {apiKeyName: 'RateLimitForDiffusionPublic',customerId: 'DiffusionPublic',apiStages: [{api: apigw,stage: prodStage,throttle: [{method: getImageMethod,throttle: {burstLimit: 20,rateLimit: 10}}]}],quota: {limit: 300,period: aws_apigateway.Period.DAY},throttle: {burstLimit: 20,rateLimit: 10},enabled: true,generateDistinctId: true,description: 'rate limit for customer a by api key'})
Front End#
Just create a simple form, submit and call an API to get the generated image
FrontEnd
'use client'import axios from 'axios'import { useEffect, useState } from 'react'import { config } from '@/config'const HomePage = () => {const [url, setUrl] = useState<string | null>(null)const [modal, setModal] = useState<Boolean>(false)const [counter, setCounter] = useState<Number>(25)const generateImage = async (prompt: string) => {// const token = localStorage.getItem("IdToken");try {const { data, status } = await axios.get(config.API_DIFFUSION, {// headers: {// Authorization: `Bearer ${token}`,// "Content-Type": "application/json",// },params: {prompt: prompt}})console.log(data)setUrl(data.url)} catch (error) {console.log(error)}setModal(false)}const timer = () => {var timeleft = 25var downloadTimer = setInterval(function () {if (timeleft <= 0) {clearInterval(downloadTimer)} else {}setCounter(timeleft)timeleft -= 1}, 1000)}useEffect(() => {}, [url, modal])useEffect(() => {}, [counter])return (<div className="min-h-screen dark:bg-slate-800"><div className="mx-auto max-w-3xl dark:bg-slate-800 dark:text-white px-10"><div className="mb-5"><textareaid="prompt"name="prompt"rows={2}placeholder="describe an image you want..."className="p-2.5 w-full text-gray-900 bg-slate-200 rounded-lg border border-gray-300 focus:border-2 focus:ring-blue-500 focus:border-blue-500 dark:bg-gray-700 dark:border-gray-600 dark:placeholder-gray-400 dark:text-white dark:focus:ring-blue-500 dark:focus:border-blue-500 my-5 outline-none focus:outline-none"></textarea><buttonclassName="bg-orange-400 px-10 py-3 rounded-sm"onClick={async () => {let prompt = (document.getElementById('prompt') as HTMLInputElement).valueif (prompt === '') {prompt = 'a big house'}setUrl(null)setCounter(25)setModal(true)timer()await generateImage(prompt)}}>Submit</button></div>{url ? (<div><img src={url} alt="test"></img></div>) : ('')}{modal === true ? (<divclassName="fixed top-0 bottom-0 left-0 right-0 bg-slate-500 bg-opacity-70"id="modal"><div className="mx-auto max-w-3xl sm:p-10 p-5"><div className="justify-center items-center flex bg-white py-20 px-10 rounded-lg relative"><h1 className="text-black" id="countdown">Please wait {String(counter)} for generating your image</h1></div></div></div>) : ('')}</div></div>)}export default HomePage
Amplify Hosting#
[!IMPORTANT]
Please setup TOKEN on your GitHub account setting and then store the TOKEN in Amazon Secret Manager
Let create a stack to deploy the app on Amplify
import { SecretValue, Stack, StackProps, aws_codebuild } from 'aws-cdk-lib'import { Construct } from 'constructs'import * as Amplify from '@aws-cdk/aws-amplify-alpha'interface AmplifyHostingProps extends StackProps {owner: stringrepository: stringtoken: stringenvVariables: anycommands: any}export class AmplifyHosting extends Stack {constructor(scope: Construct, id: string, props: AmplifyHostingProps) {super(scope, id, props)const amplify = new Amplify.App(this, 'NextStableDiffusionDemo', {sourceCodeProvider: new Amplify.GitHubSourceCodeProvider({owner: props.owner,repository: props.repository,oauthToken: SecretValue.secretsManager(props.token)// oauthToken: SecretValue.unsafePlainText(props.token),}),buildSpec: aws_codebuild.BuildSpec.fromObjectToYaml({version: '1.0',frontend: {phases: {preBuild: {commands: ['npm ci']},build: {commands: props.commands}},artifacts: {baseDirectory: '.next',files: ['**/*']},cache: {path: ['node_modules/**/*']}}}),platform: Amplify.Platform.WEB_COMPUTE,environmentVariables: props.envVariables})amplify.addBranch('main', { stage: 'PRODUCTION' })}}
Logic Code#
Model version and instance type
instance_type="ml.g5.2xlarge"sdxl-1-0-jumpstart
For lambda python implementation, we need to install dependencies. This is requirements.txt
sagemakerstability-sdk[sagemaker] @ git+https://github.com/Stability-AI/stability-sdk.git@sagemaker
Then write a simple handler
import sagemakerfrom sagemaker import ModelPackage, get_execution_rolefrom stability_sdk_sagemaker.predictor import StabilityPredictorfrom stability_sdk_sagemaker.models import get_model_package_arnfrom stability_sdk.api import GenerationRequest, GenerationResponse, TextPromptfrom PIL import Imagefrom typing import Union, Tupleimport ioimport osimport base64import boto3sagemaker_session = sagemaker.Session()endpoint_name = "sdxl-1-0-jumpstart-2024-01-24-09-54-52-906"s3_client = boto3.client("s3", region_name="us-east-1")def decode_and_show(model_response: GenerationResponse) -> None:"""Decodes and displays an image from SDXL outputArgs:model_response (GenerationResponse): The response object from the deployed SDXL model.Returns:None"""image = model_response.artifacts[0].base64image_data = base64.b64decode(image.encode())image = Image.open(io.BytesIO(image_data))#key = "diffusion/cat.png"s3_client.upload_fileobj(io.BytesIO(image_data), "cdk-entest-videos", key)# signed url s3response = s3_client.generate_presigned_url("get_object",Params={"Bucket": "cdk-entest-videos", "Key": key},ExpiresIn=3600,)print(response)image.save("hehe.png")# display(image)deployed_model = StabilityPredictor(endpoint_name=endpoint_name, sagemaker_session=sagemaker_session)output = deployed_model.predict(GenerationRequest(text_prompts=[TextPrompt(text="tiger and wife")],style_preset="anime",# style_preset="origami",seed=3,height=1024,width=1024,))decode_and_show(output)
Additionally, here is Dockerfile
FROM public.ecr.aws/lambda/python:3.9# create code dir inside containerRUN mkdir ${LAMBDA_TASK_ROOT}/source# copy code to containerCOPY "requirements.txt" ${LAMBDA_TASK_ROOT}/source# copy handler function to containerCOPY ./index.py ${LAMBDA_TASK_ROOT}# install dependencies for running time environmentCOPY ./package/ ${LAMBDA_TASK_ROOT}# set the CMD to your handlerCMD [ "index.handler" ]
The deployed handler
index.py
"""lambda call sagemaker diffusion model"""import sagemakerfrom stability_sdk_sagemaker.predictor import StabilityPredictorfrom stability_sdk.api import GenerationRequest, GenerationResponse, TextPromptfrom PIL import Imageimport ioimport base64import boto3import datetimeimport jsonimport osSTYLE = "anime"s3_client = boto3.client("s3")sagemaker_session = sagemaker.Session()deployed_model = StabilityPredictor(endpoint_name=os.environ["ENDPOINT_NAME"], sagemaker_session=sagemaker_session)def decode_and_show(model_response: GenerationResponse) -> None:"""Decodes and displays an image from SDXL outputArgs:model_response (GenerationResponse): The response object from the deployed SDXL model.Returns:None"""# file namename = datetime.datetime.now().strftime("%m-%d-%Y-%H-%M-%S")# keykey = f"diffuision/{name}.png"# imageimage = model_response.artifacts[0].base64image_data = base64.b64decode(image.encode())s3_client.upload_fileobj(io.BytesIO(image_data), os.environ["BUCKET_NAME"], key)# signed urlsign_url = s3_client.generate_presigned_url("get_object",Params={"Bucket": os.environ["BUCKET_NAME"], "Key": key},ExpiresIn=3600,)# image = Image.open(io.BytesIO(image_data))# image.save("hehe.png")return sign_urldef handler(event, context):"""handler"""# parse prompttry:promt = event["queryStringParameters"]["prompt"]except:promt = "fish"# call modeloutput = deployed_model.predict(GenerationRequest(text_prompts=[TextPrompt(text=promt)],style_preset=STYLE,seed=3,height=1024,width=1024,))# save image to s3try:url = decode_and_show(output)except:url = "ERROR"# returnreturn {"statusCode": 200,"headers": {"Access-Control-Allow-Origin": "*","Access-Control-Allow-Headers": "Content-Type","Access-Control-Allow-Methods": "OPTIONS,GET",},"body": json.dumps({"url": url,}),}