File size: 3,696 Bytes
8247a04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc21604
 
 
8247a04
 
 
 
 
 
 
fc21604
 
8247a04
 
fc21604
 
 
 
8247a04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc21604
 
 
8247a04
 
 
 
 
 
 
 
 
 
 
fc21604
 
8247a04
 
fc21604
 
 
 
 
 
 
 
8247a04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import Response
from PIL import Image
import io
import uvicorn
import config
from inference import DiffusionInference

app = FastAPI(title="Diffusion Models API")

# Initialize the inference class
inference = DiffusionInference()

@app.get("/")
async def root():
    return {"message": "Diffusion Models API is running"}

@app.post("/text-to-image")
async def text_to_image(
    prompt: str = Form(config.DEFAULT_TEXT2IMG_PROMPT),
    model: str = Form(config.DEFAULT_TEXT2IMG_MODEL),
    negative_prompt: str = Form(config.DEFAULT_NEGATIVE_PROMPT),
    guidance_scale: float = Form(7.5),
    num_inference_steps: int = Form(50)
):
    """
    Generate an image from a text prompt
    """
    try:
        # Use default model if not specified or empty
        if not model or model.strip() == '':
            model = config.DEFAULT_TEXT2IMG_MODEL
            
        # Use default negative prompt if not specified or empty
        if not negative_prompt or negative_prompt.strip() == '':
            negative_prompt = config.DEFAULT_NEGATIVE_PROMPT
            
        # Call the inference module
        image = inference.text_to_image(
            prompt=prompt,
            model_name=model,
            negative_prompt=negative_prompt,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps
        )
        
        # Convert PIL image to bytes
        img_byte_arr = io.BytesIO()
        image.save(img_byte_arr, format='PNG')
        img_byte_arr = img_byte_arr.getvalue()
        
        return Response(content=img_byte_arr, media_type="image/png")
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/image-to-image")
async def image_to_image(
    image: UploadFile = File(...),
    prompt: str = Form(config.DEFAULT_IMG2IMG_PROMPT),
    model: str = Form(config.DEFAULT_IMG2IMG_MODEL),
    negative_prompt: str = Form(config.DEFAULT_NEGATIVE_PROMPT),
    guidance_scale: float = Form(7.5),
    num_inference_steps: int = Form(50)
):
    """
    Generate a new image from an input image and optional prompt
    """
    try:
        # Read and convert input image
        contents = await image.read()
        input_image = Image.open(io.BytesIO(contents))
        
        # Use default model if not specified or empty
        if not model or model.strip() == '':
            model = config.DEFAULT_IMG2IMG_MODEL
            
        # Use default prompt if not specified or empty
        if not prompt or prompt.strip() == '':
            prompt = config.DEFAULT_IMG2IMG_PROMPT
            
        # Use default negative prompt if not specified or empty
        if not negative_prompt or negative_prompt.strip() == '':
            negative_prompt = config.DEFAULT_NEGATIVE_PROMPT
            
        # Call the inference module
        result = inference.image_to_image(
            image=input_image,
            prompt=prompt,
            model_name=model,
            negative_prompt=negative_prompt,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps
        )
        
        # Convert PIL image to bytes
        img_byte_arr = io.BytesIO()
        result.save(img_byte_arr, format='PNG')
        img_byte_arr = img_byte_arr.getvalue()
        
        return Response(content=img_byte_arr, media_type="image/png")
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    uvicorn.run(
        "api:app",
        host=config.API_HOST,
        port=config.API_PORT,
        reload=True
    )