|
from fastapi import FastAPI, File, UploadFile, Form, HTTPException |
|
from fastapi.responses import Response |
|
from PIL import Image |
|
import io |
|
import uvicorn |
|
import config |
|
from inference import DiffusionInference |
|
from controlnet_pipeline import ControlNetPipeline |
|
|
|
app = FastAPI(title="Diffusion Models API") |
|
|
|
|
|
inference = DiffusionInference() |
|
|
|
|
|
controlnet = ControlNetPipeline() |
|
|
|
@app.get("/") |
|
async def root(): |
|
return {"message": "Diffusion Models API is running"} |
|
|
|
@app.post("/text-to-image") |
|
async def text_to_image( |
|
prompt: str = Form(config.DEFAULT_TEXT2IMG_PROMPT), |
|
model: str = Form(config.DEFAULT_TEXT2IMG_MODEL), |
|
negative_prompt: str = Form(config.DEFAULT_NEGATIVE_PROMPT), |
|
guidance_scale: float = Form(7.5), |
|
num_inference_steps: int = Form(50), |
|
seed: str = Form(None) |
|
): |
|
""" |
|
Generate an image from a text prompt |
|
""" |
|
try: |
|
|
|
if not model or model.strip() == '': |
|
model = config.DEFAULT_TEXT2IMG_MODEL |
|
|
|
|
|
if not negative_prompt or negative_prompt.strip() == '': |
|
negative_prompt = config.DEFAULT_NEGATIVE_PROMPT |
|
|
|
|
|
|
|
|
|
seed_value = None |
|
if seed is not None and seed.strip() != '': |
|
try: |
|
seed_value = int(seed) |
|
except (ValueError, TypeError): |
|
pass |
|
|
|
|
|
|
|
image = inference.text_to_image( |
|
prompt=prompt, |
|
model_name=model, |
|
negative_prompt=negative_prompt, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=num_inference_steps, |
|
seed=seed_value |
|
) |
|
|
|
|
|
img_byte_arr = io.BytesIO() |
|
image.save(img_byte_arr, format='PNG') |
|
img_byte_arr = img_byte_arr.getvalue() |
|
|
|
return Response(content=img_byte_arr, media_type="image/png") |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.post("/image-to-image") |
|
async def image_to_image( |
|
image: UploadFile = File(...), |
|
prompt: str = Form(config.DEFAULT_IMG2IMG_PROMPT), |
|
model: str = Form(config.DEFAULT_IMG2IMG_MODEL), |
|
use_controlnet: bool = Form(False), |
|
negative_prompt: str = Form(config.DEFAULT_NEGATIVE_PROMPT), |
|
guidance_scale: float = Form(7.5), |
|
num_inference_steps: int = Form(50) |
|
): |
|
""" |
|
Generate a new image from an input image and optional prompt |
|
""" |
|
try: |
|
|
|
contents = await image.read() |
|
input_image = Image.open(io.BytesIO(contents)) |
|
|
|
|
|
if use_controlnet and config.USE_CONTROLNET: |
|
|
|
result = controlnet.generate( |
|
prompt=prompt, |
|
image=input_image, |
|
negative_prompt=negative_prompt, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=num_inference_steps |
|
) |
|
else: |
|
|
|
if not model or model.strip() == '': |
|
model = config.DEFAULT_IMG2IMG_MODEL |
|
|
|
|
|
if not prompt or prompt.strip() == '': |
|
prompt = config.DEFAULT_IMG2IMG_PROMPT |
|
|
|
|
|
if not negative_prompt or negative_prompt.strip() == '': |
|
negative_prompt = config.DEFAULT_NEGATIVE_PROMPT |
|
|
|
|
|
result = inference.image_to_image( |
|
image=input_image, |
|
prompt=prompt, |
|
model_name=model, |
|
negative_prompt=negative_prompt, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=num_inference_steps |
|
) |
|
|
|
|
|
img_byte_arr = io.BytesIO() |
|
result.save(img_byte_arr, format='PNG') |
|
img_byte_arr = img_byte_arr.getvalue() |
|
|
|
return Response(content=img_byte_arr, media_type="image/png") |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
if __name__ == "__main__": |
|
uvicorn.run( |
|
"api:app", |
|
host=config.API_HOST, |
|
port=config.API_PORT, |
|
reload=True |
|
) |
|
|