File size: 7,255 Bytes
8247a04
 
 
0443b19
8247a04
 
 
 
 
 
0443b19
 
 
aa0c79b
8247a04
fc21604
 
8247a04
 
aa0c79b
 
 
 
 
 
 
 
 
a57efb7
 
 
 
 
aa0c79b
 
a57efb7
 
 
 
 
 
 
 
8247a04
7ddc847
 
 
8247a04
 
7ddc847
 
 
8247a04
0443b19
8247a04
 
 
7ddc847
 
 
fc21604
 
 
7ddc847
8247a04
0443b19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8247a04
7ddc847
 
 
 
 
 
8247a04
 
 
 
 
 
 
 
fc21604
 
 
8247a04
 
aa0c79b
8247a04
 
 
 
 
 
 
 
aa0c79b
8247a04
 
 
 
 
 
 
fc21604
 
0443b19
 
 
 
 
 
 
8247a04
 
 
 
 
 
 
 
 
 
0443b19
8247a04
 
0443b19
 
 
 
 
 
 
 
 
 
8247a04
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import gradio as gr
import config
from inference import DiffusionInference
from controlnet_pipeline import ControlNetPipeline
from PIL import Image
import io

# Initialize the inference class
inference = DiffusionInference()

# Initialize the ControlNet pipeline
controlnet = ControlNetPipeline()

def text_to_image_fn(prompt, model, negative_prompt=None, guidance_scale=7.5, num_inference_steps=50, seed=None):
    try:
        # Model validation - fallback to default if empty
        if not model or model.strip() == '':
            model = config.DEFAULT_TEXT2IMG_MODEL
            
        # Prepare seed parameter
        seed_value = None
        if seed and seed.strip() != '':
            try:
                seed_value = int(seed)
            except (ValueError, TypeError):
                # Let inference handle invalid seed
                pass
            
        # Create kwargs dictionary for parameters
        kwargs = {
            "prompt": prompt,
            "model_name": model,
            "guidance_scale": guidance_scale,
            "num_inference_steps": num_inference_steps,
            "seed": seed_value
        }
        
        # Only add negative_prompt if it's not None
        if negative_prompt is not None:
            kwargs["negative_prompt"] = negative_prompt
            
        # Call the inference module with unpacked kwargs
        image = inference.text_to_image(**kwargs)
        
        if image is None:
            return None, "No image was generated. Check the model and parameters."
        
        return image, None
    except Exception as e:
        error_msg = f"Error: {str(e)}"
        print(error_msg)
        return None, error_msg

def image_to_image_fn(image, prompt, model, use_controlnet=False, negative_prompt=None, guidance_scale=7.5, num_inference_steps=50):
    """
    Handle image to image transformation request
    """
    if image is None:
        return None, "No input image provided."
    
    # Handle empty prompt - use default if completely empty
    if prompt is None or prompt.strip() == "":
        prompt = config.DEFAULT_IMG2IMG_PROMPT
    
    try:
        if use_controlnet:
            # Use ControlNet pipeline directly on the device
            result = controlnet.generate(
                prompt=prompt,
                image=image,
                negative_prompt=negative_prompt,
                guidance_scale=float(guidance_scale),
                num_inference_steps=int(num_inference_steps)
            )
            return result, None
        else:
            # Model validation - fallback to default if empty
            if not model or model.strip() == '':
                model = config.DEFAULT_IMG2IMG_MODEL
                
            # Use regular inference API
            result = inference.image_to_image(
                image=image,
                prompt=prompt,
                model_name=model,
                negative_prompt=negative_prompt,
                guidance_scale=float(guidance_scale) if guidance_scale is not None else None,
                num_inference_steps=int(num_inference_steps) if num_inference_steps is not None else None
            )
            
            if result is None:
                return None, "No image was generated. Check the model and parameters."
            
            return result, None
    except Exception as e:
        error_msg = f"Error: {str(e)}"
        print(error_msg)
        print(f"Input image type: {type(image)}")
        print(f"Prompt: {prompt}")
        print(f"Model: {model}")
        return None, error_msg

# Create Gradio UI
with gr.Blocks(title="Diffusion Models") as app:
    gr.Markdown("# Hugging Face Diffusion Models")
    
    with gr.Tab("Text to Image"):
        with gr.Row():
            with gr.Column():
                txt2img_prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...", value=config.DEFAULT_TEXT2IMG_PROMPT)
                txt2img_negative = gr.Textbox(label="Negative Prompt (Optional)", placeholder="What to exclude from the image", value=config.DEFAULT_NEGATIVE_PROMPT)
                txt2img_model = gr.Textbox(label="Model", placeholder=f"Enter model name", value=config.DEFAULT_TEXT2IMG_MODEL)
                txt2img_guidance = gr.Slider(minimum=1.0, maximum=20.0, value=7.5, step=0.5, label="Guidance Scale")
                txt2img_steps = gr.Slider(minimum=10, maximum=100, value=50, step=1, label="Inference Steps")
                txt2img_seed = gr.Textbox(label="Seed (Optional)", placeholder="Leave empty for random seed", value="")
                txt2img_button = gr.Button("Generate Image")
            
            with gr.Column():
                txt2img_output = gr.Image(type="pil", label="Generated Image")
                txt2img_error = gr.Textbox(label="Error", visible=True)
        
        txt2img_button.click(
            fn=text_to_image_fn,
            inputs=[txt2img_prompt, txt2img_model, txt2img_negative, txt2img_guidance, txt2img_steps, txt2img_seed],
            outputs=[txt2img_output, txt2img_error]
        )
    
    with gr.Tab("Image to Image"):
        with gr.Row():
            with gr.Column():
                img2img_input = gr.Image(type="pil", label="Input Image")
                img2img_prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...", value=config.DEFAULT_IMG2IMG_PROMPT)
                img2img_negative = gr.Textbox(label="Negative Prompt (Optional)", placeholder="What to exclude from the image", value=config.DEFAULT_NEGATIVE_PROMPT)
                
                with gr.Row():
                    with gr.Column(scale=1):
                        img2img_controlnet = gr.Checkbox(label="Use ControlNet (Depth)", value=False)
                    with gr.Column(scale=2):
                        img2img_model = gr.Textbox(label="Model (used only if ControlNet is disabled)", placeholder=f"Enter model name", value=config.DEFAULT_IMG2IMG_MODEL, visible=True)
                
                img2img_guidance = gr.Slider(minimum=1.0, maximum=20.0, value=7.5, step=0.5, label="Guidance Scale")
                img2img_steps = gr.Slider(minimum=10, maximum=100, value=50, step=1, label="Inference Steps")
                img2img_button = gr.Button("Transform Image")
            
            with gr.Column():
                img2img_output = gr.Image(type="pil", label="Generated Image")
                img2img_error = gr.Textbox(label="Error", visible=True)
        
        img2img_button.click(
            fn=image_to_image_fn,
            inputs=[img2img_input, img2img_prompt, img2img_model, img2img_controlnet, img2img_negative, img2img_guidance, img2img_steps],
            outputs=[img2img_output, img2img_error]
        )
        
        # Add visibility toggle for the model textbox based on ControlNet checkbox
        def toggle_model_visibility(use_controlnet):
            return not use_controlnet
            
        img2img_controlnet.change(
            fn=toggle_model_visibility,
            inputs=[img2img_controlnet],
            outputs=[img2img_model]
        )

# Launch the Gradio app
if __name__ == "__main__":
    app.launch(server_name=config.GRADIO_HOST, server_port=config.GRADIO_PORT)