Spaces:
Runtime error
Runtime error
lionelgarnier
commited on
Commit
·
a196f30
1
Parent(s):
067e31b
add default system prompt and refactor parameters for text generation
Browse files
app.py
CHANGED
@@ -12,9 +12,28 @@ from PIL import Image
|
|
12 |
hf_token = os.getenv("hf_token")
|
13 |
login(token=hf_token)
|
14 |
|
|
|
15 |
MAX_SEED = np.iinfo(np.int32).max
|
16 |
MAX_IMAGE_SIZE = 2048
|
17 |
-
PRELOAD_MODELS = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
_text_gen_pipeline = None
|
20 |
_image_gen_pipeline = None
|
@@ -64,15 +83,6 @@ def get_text_gen_pipeline():
|
|
64 |
return None
|
65 |
return _text_gen_pipeline
|
66 |
|
67 |
-
# Default system prompt for text generation
|
68 |
-
DEFAULT_SYSTEM_PROMPT = """Vous êtes un designer produit avec de solides connaissances dans la génération de texte en image. Vous recevrez une demande de produit sous forme de description succincte, et votre mission sera d'imaginer un nouveau design de produit répondant à ce besoin.
|
69 |
-
|
70 |
-
Le livrable (réponse générée) sera exclusivement un texte de prompt pour l'IA de texte to image FLUX.1-schnell.
|
71 |
-
|
72 |
-
Ce prompt devra inclure une description visuelle de l'objet mentionnant explicitement les aspects indispensables de sa fonction.
|
73 |
-
A coté de ça vous devez aussi explicitement mentionner dans ce prompt les caractéristiques esthétiques/photo du rendu image (ex : photoréaliste, haute qualité, focale, grain, etc.), sachant que l'image sera l'image principale de cet objet dans le catalogue produit. Le fond de l'image générée doit être entièrement blanc.
|
74 |
-
Le prompt doit être sans narration, peut être long mais ne doit pas dépasser 77 jetons."""
|
75 |
-
|
76 |
@spaces.GPU()
|
77 |
def refine_prompt(prompt, system_prompt=DEFAULT_SYSTEM_PROMPT, progress=gr.Progress()):
|
78 |
text_gen = get_text_gen_pipeline()
|
@@ -114,12 +124,18 @@ def validate_dimensions(width, height):
|
|
114 |
return True, None
|
115 |
|
116 |
@spaces.GPU()
|
117 |
-
def infer(prompt, seed=
|
|
|
|
|
|
|
|
|
|
|
118 |
try:
|
119 |
# Validate that prompt is not empty
|
120 |
if not prompt or prompt.strip() == "":
|
121 |
return None, "Please provide a valid prompt."
|
122 |
|
|
|
123 |
pipe = get_image_gen_pipeline()
|
124 |
if pipe is None:
|
125 |
return None, "Image generation model is unavailable."
|
@@ -134,6 +150,7 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_in
|
|
134 |
# Use default torch generator instead of cuda-specific generator
|
135 |
generator = torch.Generator().manual_seed(seed)
|
136 |
|
|
|
137 |
# Match the working example's parameters
|
138 |
output = pipe(
|
139 |
prompt=prompt,
|
@@ -141,20 +158,23 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_in
|
|
141 |
height=height,
|
142 |
num_inference_steps=num_inference_steps,
|
143 |
generator=generator,
|
144 |
-
guidance_scale=
|
145 |
)
|
146 |
|
|
|
147 |
image = output.images[0]
|
|
|
148 |
return image, f"Image generated successfully with seed {seed}"
|
149 |
except Exception as e:
|
150 |
print(f"Error in infer: {str(e)}")
|
151 |
return None, f"Error generating image: {str(e)}"
|
152 |
|
153 |
-
|
|
|
154 |
examples = [
|
155 |
-
"a backpack for kids, flower style",
|
156 |
-
"medieval flip flops",
|
157 |
-
"cat shaped cake mold",
|
158 |
]
|
159 |
|
160 |
css="""
|
@@ -165,26 +185,10 @@ css="""
|
|
165 |
"""
|
166 |
|
167 |
def preload_models():
|
168 |
-
global _text_gen_pipeline, _image_gen_pipeline
|
169 |
-
|
170 |
print("Preloading models...")
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
_text_gen_pipeline = get_text_gen_pipeline()
|
175 |
-
if _text_gen_pipeline is None:
|
176 |
-
success = False
|
177 |
-
except Exception as e:
|
178 |
-
print(f"Error preloading text generation model: {str(e)}")
|
179 |
-
success = False
|
180 |
-
|
181 |
-
try:
|
182 |
-
_image_gen_pipeline = get_image_gen_pipeline()
|
183 |
-
if _image_gen_pipeline is None:
|
184 |
-
success = False
|
185 |
-
except Exception as e:
|
186 |
-
print(f"Error preloading image generation model: {str(e)}")
|
187 |
-
success = False
|
188 |
|
189 |
status = "Models preloaded successfully!" if success else "Error preloading models"
|
190 |
print(status)
|
@@ -196,7 +200,6 @@ def preload_models():
|
|
196 |
def process_example_pipeline(example_prompt, system_prompt=DEFAULT_SYSTEM_PROMPT, progress=gr.Progress()):
|
197 |
# Step 1: Update status
|
198 |
progress(0, desc="Starting example processing")
|
199 |
-
progress_status = "Selected example: " + example_prompt
|
200 |
|
201 |
# Step 2: Refine the prompt
|
202 |
progress(0.1, desc="Refining prompt with Mistral")
|
@@ -254,7 +257,7 @@ def create_interface():
|
|
254 |
# Mistral settings
|
255 |
temperature = gr.Slider(
|
256 |
label="Temperature",
|
257 |
-
value=
|
258 |
minimum=0.0,
|
259 |
maximum=1.0,
|
260 |
step=0.05,
|
@@ -270,19 +273,19 @@ def create_interface():
|
|
270 |
|
271 |
with gr.Tab("Flux"):
|
272 |
# Flux settings
|
273 |
-
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=
|
274 |
-
randomize_seed = gr.Checkbox(label="Randomize seed", value=
|
275 |
|
276 |
with gr.Row():
|
277 |
-
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=
|
278 |
-
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=
|
279 |
|
280 |
num_inference_steps = gr.Slider(
|
281 |
label="Number of inference steps",
|
282 |
minimum=1,
|
283 |
maximum=50,
|
284 |
step=1,
|
285 |
-
value=
|
286 |
)
|
287 |
|
288 |
# Examples section - simplified version that only updates the prompt fields
|
|
|
12 |
hf_token = os.getenv("hf_token")
|
13 |
login(token=hf_token)
|
14 |
|
15 |
+
# Global constants and default values
|
16 |
MAX_SEED = np.iinfo(np.int32).max
|
17 |
MAX_IMAGE_SIZE = 2048
|
18 |
+
PRELOAD_MODELS = False
|
19 |
+
|
20 |
+
# Default system prompt for text generation
|
21 |
+
DEFAULT_SYSTEM_PROMPT = """Vous êtes un designer produit avec de solides connaissances dans la génération de texte en image. Vous recevrez une demande de produit sous forme de description succincte, et votre mission sera d'imaginer un nouveau design de produit répondant à ce besoin.
|
22 |
+
|
23 |
+
Le livrable (réponse générée) sera exclusivement un texte de prompt pour l'IA de texte to image FLUX.1-schnell.
|
24 |
+
|
25 |
+
Ce prompt devra inclure une description visuelle de l'objet mentionnant explicitement les aspects indispensables de sa fonction.
|
26 |
+
A coté de ça vous devez aussi explicitement mentionner dans ce prompt les caractéristiques esthétiques/photo du rendu image (ex : photoréaliste, haute qualité, focale, grain, etc.), sachant que l'image sera l'image principale de cet objet dans le catalogue produit. Le fond de l'image générée doit être entièrement blanc.
|
27 |
+
Le prompt doit être sans narration, peut être long mais ne doit pas dépasser 77 jetons."""
|
28 |
+
|
29 |
+
# Default Flux parameters
|
30 |
+
DEFAULT_SEED = 42
|
31 |
+
DEFAULT_RANDOMIZE_SEED = True
|
32 |
+
DEFAULT_WIDTH = 512
|
33 |
+
DEFAULT_HEIGHT = 512
|
34 |
+
DEFAULT_NUM_INFERENCE_STEPS = 6
|
35 |
+
DEFAULT_GUIDANCE_SCALE = 0.0
|
36 |
+
DEFAULT_TEMPERATURE = 0.9
|
37 |
|
38 |
_text_gen_pipeline = None
|
39 |
_image_gen_pipeline = None
|
|
|
83 |
return None
|
84 |
return _text_gen_pipeline
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
@spaces.GPU()
|
87 |
def refine_prompt(prompt, system_prompt=DEFAULT_SYSTEM_PROMPT, progress=gr.Progress()):
|
88 |
text_gen = get_text_gen_pipeline()
|
|
|
124 |
return True, None
|
125 |
|
126 |
@spaces.GPU()
|
127 |
+
def infer(prompt, seed=DEFAULT_SEED,
|
128 |
+
randomize_seed=DEFAULT_RANDOMIZE_SEED,
|
129 |
+
width=DEFAULT_WIDTH,
|
130 |
+
height=DEFAULT_HEIGHT,
|
131 |
+
num_inference_steps=DEFAULT_NUM_INFERENCE_STEPS,
|
132 |
+
progress=gr.Progress(track_tqdm=True)):
|
133 |
try:
|
134 |
# Validate that prompt is not empty
|
135 |
if not prompt or prompt.strip() == "":
|
136 |
return None, "Please provide a valid prompt."
|
137 |
|
138 |
+
progress(0.1, desc="Loading model")
|
139 |
pipe = get_image_gen_pipeline()
|
140 |
if pipe is None:
|
141 |
return None, "Image generation model is unavailable."
|
|
|
150 |
# Use default torch generator instead of cuda-specific generator
|
151 |
generator = torch.Generator().manual_seed(seed)
|
152 |
|
153 |
+
progress(0.3, desc="Running inference")
|
154 |
# Match the working example's parameters
|
155 |
output = pipe(
|
156 |
prompt=prompt,
|
|
|
158 |
height=height,
|
159 |
num_inference_steps=num_inference_steps,
|
160 |
generator=generator,
|
161 |
+
guidance_scale=DEFAULT_GUIDANCE_SCALE,
|
162 |
)
|
163 |
|
164 |
+
progress(0.8, desc="Processing output")
|
165 |
image = output.images[0]
|
166 |
+
progress(1.0, desc="Complete")
|
167 |
return image, f"Image generated successfully with seed {seed}"
|
168 |
except Exception as e:
|
169 |
print(f"Error in infer: {str(e)}")
|
170 |
return None, f"Error generating image: {str(e)}"
|
171 |
|
172 |
+
|
173 |
+
# Format: [prompt, system_prompt]
|
174 |
examples = [
|
175 |
+
["a backpack for kids, flower style", DEFAULT_SYSTEM_PROMPT],
|
176 |
+
["medieval flip flops", DEFAULT_SYSTEM_PROMPT],
|
177 |
+
["cat shaped cake mold", DEFAULT_SYSTEM_PROMPT],
|
178 |
]
|
179 |
|
180 |
css="""
|
|
|
185 |
"""
|
186 |
|
187 |
def preload_models():
|
|
|
|
|
188 |
print("Preloading models...")
|
189 |
+
text_success = get_text_gen_pipeline() is not None
|
190 |
+
image_success = get_image_gen_pipeline() is not None
|
191 |
+
success = text_success and image_success
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
|
193 |
status = "Models preloaded successfully!" if success else "Error preloading models"
|
194 |
print(status)
|
|
|
200 |
def process_example_pipeline(example_prompt, system_prompt=DEFAULT_SYSTEM_PROMPT, progress=gr.Progress()):
|
201 |
# Step 1: Update status
|
202 |
progress(0, desc="Starting example processing")
|
|
|
203 |
|
204 |
# Step 2: Refine the prompt
|
205 |
progress(0.1, desc="Refining prompt with Mistral")
|
|
|
257 |
# Mistral settings
|
258 |
temperature = gr.Slider(
|
259 |
label="Temperature",
|
260 |
+
value=DEFAULT_TEMPERATURE,
|
261 |
minimum=0.0,
|
262 |
maximum=1.0,
|
263 |
step=0.05,
|
|
|
273 |
|
274 |
with gr.Tab("Flux"):
|
275 |
# Flux settings
|
276 |
+
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=DEFAULT_SEED)
|
277 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=DEFAULT_RANDOMIZE_SEED)
|
278 |
|
279 |
with gr.Row():
|
280 |
+
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=DEFAULT_WIDTH)
|
281 |
+
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=DEFAULT_HEIGHT)
|
282 |
|
283 |
num_inference_steps = gr.Slider(
|
284 |
label="Number of inference steps",
|
285 |
minimum=1,
|
286 |
maximum=50,
|
287 |
step=1,
|
288 |
+
value=DEFAULT_NUM_INFERENCE_STEPS,
|
289 |
)
|
290 |
|
291 |
# Examples section - simplified version that only updates the prompt fields
|