File size: 4,893 Bytes
8247a04
 
 
 
 
 
 
0443b19
8247a04
 
 
 
 
 
0443b19
 
 
8247a04
 
 
 
 
 
fc21604
 
 
8247a04
aa0c79b
 
8247a04
 
 
 
 
fc21604
 
8247a04
 
fc21604
 
 
 
aa0c79b
 
 
 
 
 
 
 
5ddf7bf
aa0c79b
 
8247a04
 
 
 
 
 
aa0c79b
 
8247a04
 
 
 
 
 
 
 
 
 
 
 
 
 
fc21604
 
0443b19
fc21604
8247a04
 
 
 
 
 
 
 
 
 
 
0443b19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8247a04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import Response
from PIL import Image
import io
import uvicorn
import config
from inference import DiffusionInference
from controlnet_pipeline import ControlNetPipeline

app = FastAPI(title="Diffusion Models API")

# Initialize the inference class
inference = DiffusionInference()

# Initialize the ControlNet pipeline
controlnet = ControlNetPipeline()

@app.get("/")
async def root():
    return {"message": "Diffusion Models API is running"}

@app.post("/text-to-image")
async def text_to_image(
    prompt: str = Form(config.DEFAULT_TEXT2IMG_PROMPT),
    model: str = Form(config.DEFAULT_TEXT2IMG_MODEL),
    negative_prompt: str = Form(config.DEFAULT_NEGATIVE_PROMPT),
    guidance_scale: float = Form(7.5),
    num_inference_steps: int = Form(50),
    seed: str = Form(None)
):
    """
    Generate an image from a text prompt
    """
    try:
        # Use default model if not specified or empty
        if not model or model.strip() == '':
            model = config.DEFAULT_TEXT2IMG_MODEL
            
        # Use default negative prompt if not specified or empty
        if not negative_prompt or negative_prompt.strip() == '':
            negative_prompt = config.DEFAULT_NEGATIVE_PROMPT
            
        # Process seed parameter
        # We'll pass seed=None to inference.text_to_image if no valid seed is provided
        # The random seed will be generated in the inference module
        seed_value = None
        if seed is not None and seed.strip() != '':
            try:
                seed_value = int(seed)
            except (ValueError, TypeError):
                pass
                # Let the inference module handle invalid seed
                
        # Call the inference module
        image = inference.text_to_image(
            prompt=prompt,
            model_name=model,
            negative_prompt=negative_prompt,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            seed=seed_value
        )
        
        # Convert PIL image to bytes
        img_byte_arr = io.BytesIO()
        image.save(img_byte_arr, format='PNG')
        img_byte_arr = img_byte_arr.getvalue()
        
        return Response(content=img_byte_arr, media_type="image/png")
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/image-to-image")
async def image_to_image(
    image: UploadFile = File(...),
    prompt: str = Form(config.DEFAULT_IMG2IMG_PROMPT),
    model: str = Form(config.DEFAULT_IMG2IMG_MODEL),
    use_controlnet: bool = Form(False),
    negative_prompt: str = Form(config.DEFAULT_NEGATIVE_PROMPT),
    guidance_scale: float = Form(7.5),
    num_inference_steps: int = Form(50)
):
    """
    Generate a new image from an input image and optional prompt
    """
    try:
        # Read and convert input image
        contents = await image.read()
        input_image = Image.open(io.BytesIO(contents))
        
        # Use ControlNet if specified
        if use_controlnet and config.USE_CONTROLNET:
            # Process with ControlNet pipeline
            result = controlnet.generate(
                prompt=prompt,
                image=input_image,
                negative_prompt=negative_prompt,
                guidance_scale=guidance_scale,
                num_inference_steps=num_inference_steps
            )
        else:
            # Use default model if not specified or empty
            if not model or model.strip() == '':
                model = config.DEFAULT_IMG2IMG_MODEL
                
            # Use default prompt if not specified or empty
            if not prompt or prompt.strip() == '':
                prompt = config.DEFAULT_IMG2IMG_PROMPT
                
            # Use default negative prompt if not specified or empty
            if not negative_prompt or negative_prompt.strip() == '':
                negative_prompt = config.DEFAULT_NEGATIVE_PROMPT
                
            # Call the inference module
            result = inference.image_to_image(
                image=input_image,
                prompt=prompt,
                model_name=model,
                negative_prompt=negative_prompt,
                guidance_scale=guidance_scale,
                num_inference_steps=num_inference_steps
            )
        
        # Convert PIL image to bytes
        img_byte_arr = io.BytesIO()
        result.save(img_byte_arr, format='PNG')
        img_byte_arr = img_byte_arr.getvalue()
        
        return Response(content=img_byte_arr, media_type="image/png")
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    uvicorn.run(
        "api:app",
        host=config.API_HOST,
        port=config.API_PORT,
        reload=True
    )