File size: 4,893 Bytes
8247a04 0443b19 8247a04 0443b19 8247a04 fc21604 8247a04 aa0c79b 8247a04 fc21604 8247a04 fc21604 aa0c79b 5ddf7bf aa0c79b 8247a04 aa0c79b 8247a04 fc21604 0443b19 fc21604 8247a04 0443b19 8247a04 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import Response
from PIL import Image
import io
import uvicorn
import config
from inference import DiffusionInference
from controlnet_pipeline import ControlNetPipeline
app = FastAPI(title="Diffusion Models API")
# Initialize the inference class
inference = DiffusionInference()
# Initialize the ControlNet pipeline
controlnet = ControlNetPipeline()
@app.get("/")
async def root():
return {"message": "Diffusion Models API is running"}
@app.post("/text-to-image")
async def text_to_image(
prompt: str = Form(config.DEFAULT_TEXT2IMG_PROMPT),
model: str = Form(config.DEFAULT_TEXT2IMG_MODEL),
negative_prompt: str = Form(config.DEFAULT_NEGATIVE_PROMPT),
guidance_scale: float = Form(7.5),
num_inference_steps: int = Form(50),
seed: str = Form(None)
):
"""
Generate an image from a text prompt
"""
try:
# Use default model if not specified or empty
if not model or model.strip() == '':
model = config.DEFAULT_TEXT2IMG_MODEL
# Use default negative prompt if not specified or empty
if not negative_prompt or negative_prompt.strip() == '':
negative_prompt = config.DEFAULT_NEGATIVE_PROMPT
# Process seed parameter
# We'll pass seed=None to inference.text_to_image if no valid seed is provided
# The random seed will be generated in the inference module
seed_value = None
if seed is not None and seed.strip() != '':
try:
seed_value = int(seed)
except (ValueError, TypeError):
pass
# Let the inference module handle invalid seed
# Call the inference module
image = inference.text_to_image(
prompt=prompt,
model_name=model,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
seed=seed_value
)
# Convert PIL image to bytes
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
return Response(content=img_byte_arr, media_type="image/png")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/image-to-image")
async def image_to_image(
image: UploadFile = File(...),
prompt: str = Form(config.DEFAULT_IMG2IMG_PROMPT),
model: str = Form(config.DEFAULT_IMG2IMG_MODEL),
use_controlnet: bool = Form(False),
negative_prompt: str = Form(config.DEFAULT_NEGATIVE_PROMPT),
guidance_scale: float = Form(7.5),
num_inference_steps: int = Form(50)
):
"""
Generate a new image from an input image and optional prompt
"""
try:
# Read and convert input image
contents = await image.read()
input_image = Image.open(io.BytesIO(contents))
# Use ControlNet if specified
if use_controlnet and config.USE_CONTROLNET:
# Process with ControlNet pipeline
result = controlnet.generate(
prompt=prompt,
image=input_image,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps
)
else:
# Use default model if not specified or empty
if not model or model.strip() == '':
model = config.DEFAULT_IMG2IMG_MODEL
# Use default prompt if not specified or empty
if not prompt or prompt.strip() == '':
prompt = config.DEFAULT_IMG2IMG_PROMPT
# Use default negative prompt if not specified or empty
if not negative_prompt or negative_prompt.strip() == '':
negative_prompt = config.DEFAULT_NEGATIVE_PROMPT
# Call the inference module
result = inference.image_to_image(
image=input_image,
prompt=prompt,
model_name=model,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps
)
# Convert PIL image to bytes
img_byte_arr = io.BytesIO()
result.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
return Response(content=img_byte_arr, media_type="image/png")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(
"api:app",
host=config.API_HOST,
port=config.API_PORT,
reload=True
)
|