Develope#
Let create a simple ts projec to test first.
npm init
Install aws sdk for s3 client, credentials provider and sagemaker runtime
npm install @aws-sdk/client-s3 @aws-sdk/client-sagemaker-runtime @aws-sdk/credential-providers @aws-sdk/s3-request-presigner
Let load temporary credential from ~/.aws/credentials using fromIni and create clients.
const credentials = fromIni({profile: "demo",});const s3Client = new S3Client({region: "us-east-1",credentials: credentials,});const smClient = new SageMakerRuntimeClient({region: "us-east-1",credentials: credentials,});
Get signed url download from s3
const get_signed_url = async ({ key }: { key: string }) => {const command = new GetObjectCommand({Bucket: process.env["BUCKET"],Key: key,});const url = await getSignedUrl(s3Client as any, command as any, {expiresIn: 3600,});console.log(url);};
Call sagemaker async inference endpoint
const callSdxEndpoint = async ({ prompt }: { prompt: string }) => {// s3 input locationconst input = "stable-diffusion-async-input/sample.json";// create promptconst payload = JSON.stringify({text_prompts: [{text: prompt,weight: 1.0,},],cfg_scale: 15,samples: 1,seed: 3,style_preset: "anime",num_inference_steps: 30,height: 1024,width: 1024,});// upload to s3 input locationawait s3Client.send(new PutObjectCommand({Bucket: process.env["BUCKET"],Key: input,Body: payload,}));console.log("async invoke sagemaker endpoint ...");const command = new InvokeEndpointAsyncCommand({EndpointName: process.env["SM_ENDPOINT_NAME"],InputLocation: `s3://${process.env["BUCKET"]}/${input}`,ContentType: "application/json",Accept: "image/png",InvocationTimeoutSeconds: 3600,});const response = await smClient.send(command);console.log(response);return response;};
Server Function#
Let create a function named callSdxEndpoint which perform three steps
- Create a prompt in JSON
- Upload the prompt to s3 input path
- Sagemaker invoke the async endpoint
- Return s3 output path (key)
const callSdxEndpoint = async ({ prompt }: { prompt: string }) => {console.log("prompt: ", prompt);// s3 input locationconst input = "stable-diffusion-async-input/sample.json";// create promptconst payload = JSON.stringify({text_prompts: [{text: prompt,weight: 1.0,},],cfg_scale: 15,samples: 1,seed: 3,style_preset: "anime",num_inference_steps: 30,height: 1024,width: 1024,});// upload to s3 input locationawait s3Client.send(new PutObjectCommand({Bucket: process.env["BUCKET"],Key: input,Body: payload,}));console.log("async invoke sagemaker endpoint ...");const command = new InvokeEndpointAsyncCommand({EndpointName: process.env["SM_ENDPOINT_NAME"],InputLocation: `s3://${process.env["BUCKET"]}/${input}`,ContentType: "application/json",Accept: "image/png",InvocationTimeoutSeconds: 3600,});const response = await smClient.send(command);console.log(response);return response;};
This logic will be put in a server function in actions.ts as the following project structure.
|--app|--page.tsx|--actions.ts|--global.css|--package.json|--next.config.mjs|--tailwind.config.js|--tsconfig.json
Here is full detail actions.ts
actions.ts
"use server";import {GetObjectCommand,S3Client,PutObjectCommand,} from "@aws-sdk/client-s3";import { fromIni } from "@aws-sdk/credential-providers";import { getSignedUrl } from "@aws-sdk/s3-request-presigner";import {InvokeEndpointAsyncCommand,SageMakerRuntimeClient,} from "@aws-sdk/client-sagemaker-runtime";// no need in nextjs// import * as dotevn from "dotenv";// dotevn.config();const credentials = fromIni({profile: "demo",});const s3Client = new S3Client({region: "us-east-1",credentials: credentials,});const smClient = new SageMakerRuntimeClient({region: "us-east-1",credentials: credentials,});const getImage = async ({ key }: { key: string }) => {console.log(process.env["BUCKET"], key);const command = new GetObjectCommand({Bucket: process.env["BUCKET"],Key: key,});const url = await getSignedUrl(s3Client as any, command as any, {expiresIn: 3600,});console.log(url);return url;};const callSdxEndpoint = async ({ prompt }: { prompt: string }) => {console.log("prompt: ", prompt);// s3 input locationconst input = "stable-diffusion-async-input/sample.json";// create promptconst payload = JSON.stringify({text_prompts: [{text: prompt,weight: 1.0,},],cfg_scale: 15,samples: 1,seed: 3,style_preset: "anime",num_inference_steps: 30,height: 1024,width: 1024,});// upload to s3 input locationawait s3Client.send(new PutObjectCommand({Bucket: process.env["BUCKET"],Key: input,Body: payload,}));console.log("async invoke sagemaker endpoint ...");const command = new InvokeEndpointAsyncCommand({EndpointName: process.env["SM_ENDPOINT_NAME"],InputLocation: `s3://${process.env["BUCKET"]}/${input}`,ContentType: "application/json",Accept: "image/png",InvocationTimeoutSeconds: 3600,});const response = await smClient.send(command);console.log(response);return response;};export { getImage, callSdxEndpoint };
Frontend#
Let create a simple form to capture user's prompt and return generated image.
"use client";<form className="mt-10 px-10"><div className="relative"><inputtype="text"id="prompt"name="prompt"className="w-[100%] bg-gray-300 px-5 py-5 rounded-sm"// value={prompt}// onChange={(event) => {// setPrompt(event.target.value);// }}></input><buttonclassName="bg-green-400 px-5 py-3 w-[150px] rounded-sm absolute right-1 translate-y-[-50%] top-[50%]"onClick={(event) => {event.preventDefault();genImage(prompt);}}>Gen Image</button></div><div className="relative mt-5"><inputtype="text"id="imageurl"name="imageurl"className="w-[100%] bg-gray-300 px-5 py-5 rounded-sm"placeholder="s3://generated-image-url"disabled// value={output}// onChange={(event) => {}}></input><buttonclassName="bg-orange-400 px-5 py-3 w-[150px] rounded-sm absolute right-1 top-[50%] translate-y-[-50%]"onClick={(event) => {event.preventDefault();getImageUrl();}}>Get Image</button></div><img src={url}></img></form>;
Here is the full detail of page.tsx with which calls genImage and getImageUrl
page.tsx
"use client";import { useEffect, useState } from "react";import { callSdxEndpoint, getImage } from "./actions";const HomePage = () => {const [url, setUrl] = useState<string>("");const [output, setOuput] = useState<string>("s3://generated-image-output");const [prompt, setPrompt] = useState<string>("");useEffect(() => {}, [url]);useEffect(() => {}, [output]);const getImageUrl = async () => {const key = output.split("/").pop();try {const url = await getImage({key: `stable-diffusion-async/${key}`,});setUrl(url);} catch (error) {console.log(error);}};const genImage = async (prompt: string) => {const response = await callSdxEndpoint({// prompt: prompt,prompt: (document.getElementById("prompt") as HTMLInputElement).value,});const key = response.OutputLocation!.split("/").pop();setOuput(key as string);(document.getElementById("imageurl") as HTMLInputElement).value =key as string;};return (<main><div className="max-w-3xl mx-auto"><form className="mt-10 px-10"><div className="relative"><inputtype="text"id="prompt"name="prompt"className="w-[100%] bg-gray-300 px-5 py-5 rounded-sm"// value={prompt}// onChange={(event) => {// setPrompt(event.target.value);// }}></input><buttonclassName="bg-green-400 px-5 py-3 w-[150px] rounded-sm absolute right-1 translate-y-[-50%] top-[50%]"onClick={(event) => {event.preventDefault();genImage(prompt);}}>Gen Image</button></div><div className="relative mt-5"><inputtype="text"id="imageurl"name="imageurl"className="w-[100%] bg-gray-300 px-5 py-5 rounded-sm"placeholder="s3://generated-image-url"disabled// value={output}// onChange={(event) => {}}></input><buttonclassName="bg-orange-400 px-5 py-3 w-[150px] rounded-sm absolute right-1 top-[50%] translate-y-[-50%]"onClick={(event) => {event.preventDefault();getImageUrl();}}>Get Image</button></div><img src={url}></img></form></div></main>);};export default HomePage;