akhaliq HF Staff commited on
Commit
1bafe30
·
verified ·
1 Parent(s): 7d6d309

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -0
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import spaces
4
+ import torch
5
+ import random
6
+ from PIL import Image
7
+
8
+ # Import the pipeline from diffusers
9
+ from diffusers import FluxKontextPipeline
10
+
11
+ # --- Constants and Model Loading ---
12
+ MAX_SEED = np.iinfo(np.int32).max
13
+
14
+ # Load the pretrained model
15
+ # Note: This requires a CUDA-enabled GPU. Error handling is added for environments without it.
16
+ try:
17
+ pipe = FluxKontextPipeline.from_pretrained(
18
+ "black-forest-labs/FLUX.1-Kontext-dev",
19
+ torch_dtype=torch.bfloat16
20
+ ).to("cuda")
21
+ except Exception as e:
22
+ pipe = None
23
+ print(f"Warning: Could not load the model on CUDA. GPU is required. Error: {e}")
24
+
25
+ # --- Core Inference Function for ChatInterface ---
26
+
27
+ @spaces.GPU
28
+ def chat_fn(message, chat_history, seed, randomize_seed, guidance_scale, steps, progress=gr.Progress(track_tqdm=True)):
29
+ """
30
+ Performs image generation or editing based on user input from the chat interface.
31
+
32
+ Args:
33
+ message (dict): A dictionary from gr.MultimodalTextbox, containing:
34
+ - "text" (str): The user's text prompt.
35
+ - "files" (list): A list of paths to uploaded files.
36
+ chat_history (list): The history of the conversation (managed by ChatInterface).
37
+ seed (int): The random seed for generation.
38
+ randomize_seed (bool): If True, a random seed is used.
39
+ guidance_scale (float): Controls adherence to the prompt.
40
+ steps (int): Number of inference steps.
41
+ progress (gr.Progress): Gradio progress tracker.
42
+
43
+ Returns:
44
+ PIL.Image.Image: The generated or edited image to be displayed in the chat.
45
+ """
46
+ if pipe is None:
47
+ raise gr.Error("Model could not be loaded. A CUDA-enabled GPU is required to run this application.")
48
+
49
+ prompt = message["text"]
50
+ files = message["files"]
51
+
52
+ # Input validation
53
+ if not prompt and not files:
54
+ raise gr.Error("Please provide a prompt and/or upload an image.")
55
+
56
+ if randomize_seed:
57
+ seed = random.randint(0, MAX_SEED)
58
+
59
+ # Set up a PyTorch generator for reproducible results
60
+ generator = torch.Generator(device="cuda").manual_seed(seed)
61
+
62
+ input_image = None
63
+ if files:
64
+ # User has uploaded an image for editing (image-to-image)
65
+ print(f"Received image: {files[0]}")
66
+ input_image = Image.open(files[0]).convert("RGB")
67
+ image = pipe(
68
+ image=input_image,
69
+ prompt=prompt,
70
+ guidance_scale=guidance_scale,
71
+ num_inference_steps=steps,
72
+ generator=generator,
73
+ ).images[0]
74
+ else:
75
+ # No image uploaded, perform text-to-image generation
76
+ print(f"Received prompt for text-to-image: {prompt}")
77
+ image = pipe(
78
+ prompt=prompt,
79
+ guidance_scale=guidance_scale,
80
+ num_inference_steps=steps,
81
+ generator=generator,
82
+ ).images[0]
83
+
84
+ # To also inform the user of the seed, you could optionally return a tuple,
85
+ # but for a clean image output, we just return the image.
86
+ # For example: return (image, f"Seed: {seed}")
87
+ return image
88
+
89
+ # --- UI Definition using gr.ChatInterface ---
90
+
91
+ # Define the components for "Advanced Settings" that will be passed to `additional_inputs`
92
+ seed_slider = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
93
+ randomize_checkbox = gr.Checkbox(label="Randomize seed", value=False)
94
+ guidance_slider = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=2.5)
95
+ steps_slider = gr.Slider(label="Steps", minimum=1, maximum=30, value=28, step=1)
96
+
97
+ # Create the ChatInterface
98
+ demo = gr.ChatInterface(
99
+ fn=chat_fn,
100
+ title="FLUX.1 Kontext [dev]",
101
+ description="""<p style='text-align: center;'>
102
+ A simple chat UI for the <b>FLUX.1 Kontext</b> model.
103
+ <br>
104
+ To edit an image, upload it and type your instructions (e.g., "Add a hat").
105
+ <br>
106
+ To generate an image, just type a prompt (e.g., "A photo of an astronaut on a horse").
107
+ <br>
108
+ Find the model on <a href='https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev' target='_blank'>Hugging Face</a>.
109
+ </p>""",
110
+ # Use a multimodal textbox to allow both text and image uploads
111
+ textbox=gr.MultimodalTextbox(
112
+ file_types=["image"],
113
+ placeholder="Type a prompt and/or upload an image...",
114
+ render=False # Important: Let ChatInterface render the textbox
115
+ ),
116
+ additional_inputs=[
117
+ seed_slider,
118
+ randomize_checkbox,
119
+ guidance_slider,
120
+ steps_slider
121
+ ],
122
+ examples=[
123
+ {"text": "A cute robot reading a book", "files": []},
124
+ {"text": "change his shirt to a hawaiian shirt", "files": ["https://gradio-builds.s3.amazonaws.com/demo-files/chewbacca.png"]},
125
+ {"text": "make it a wooden house", "files": ["https://gradio-builds.s3.amazonaws.com/demo-files/house.png"]},
126
+ ],
127
+ theme="soft"
128
+ )
129
+
130
+ # Launch the application
131
+ if __name__ == "__main__":
132
+ demo.launch()