lionelgarnier commited on
Commit
a196f30
·
1 Parent(s): 067e31b

add default system prompt and refactor parameters for text generation

Browse files
Files changed (1) hide show
  1. app.py +45 -42
app.py CHANGED
@@ -12,9 +12,28 @@ from PIL import Image
12
  hf_token = os.getenv("hf_token")
13
  login(token=hf_token)
14
 
 
15
  MAX_SEED = np.iinfo(np.int32).max
16
  MAX_IMAGE_SIZE = 2048
17
- PRELOAD_MODELS = False # Easy switch for preloading
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  _text_gen_pipeline = None
20
  _image_gen_pipeline = None
@@ -64,15 +83,6 @@ def get_text_gen_pipeline():
64
  return None
65
  return _text_gen_pipeline
66
 
67
- # Default system prompt for text generation
68
- DEFAULT_SYSTEM_PROMPT = """Vous êtes un designer produit avec de solides connaissances dans la génération de texte en image. Vous recevrez une demande de produit sous forme de description succincte, et votre mission sera d'imaginer un nouveau design de produit répondant à ce besoin.
69
-
70
- Le livrable (réponse générée) sera exclusivement un texte de prompt pour l'IA de texte to image FLUX.1-schnell.
71
-
72
- Ce prompt devra inclure une description visuelle de l'objet mentionnant explicitement les aspects indispensables de sa fonction.
73
- A coté de ça vous devez aussi explicitement mentionner dans ce prompt les caractéristiques esthétiques/photo du rendu image (ex : photoréaliste, haute qualité, focale, grain, etc.), sachant que l'image sera l'image principale de cet objet dans le catalogue produit. Le fond de l'image générée doit être entièrement blanc.
74
- Le prompt doit être sans narration, peut être long mais ne doit pas dépasser 77 jetons."""
75
-
76
  @spaces.GPU()
77
  def refine_prompt(prompt, system_prompt=DEFAULT_SYSTEM_PROMPT, progress=gr.Progress()):
78
  text_gen = get_text_gen_pipeline()
@@ -114,12 +124,18 @@ def validate_dimensions(width, height):
114
  return True, None
115
 
116
  @spaces.GPU()
117
- def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
118
  try:
119
  # Validate that prompt is not empty
120
  if not prompt or prompt.strip() == "":
121
  return None, "Please provide a valid prompt."
122
 
 
123
  pipe = get_image_gen_pipeline()
124
  if pipe is None:
125
  return None, "Image generation model is unavailable."
@@ -134,6 +150,7 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_in
134
  # Use default torch generator instead of cuda-specific generator
135
  generator = torch.Generator().manual_seed(seed)
136
 
 
137
  # Match the working example's parameters
138
  output = pipe(
139
  prompt=prompt,
@@ -141,20 +158,23 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_in
141
  height=height,
142
  num_inference_steps=num_inference_steps,
143
  generator=generator,
144
- guidance_scale=0.0, # Changed from 7.5 to 0.0
145
  )
146
 
 
147
  image = output.images[0]
 
148
  return image, f"Image generated successfully with seed {seed}"
149
  except Exception as e:
150
  print(f"Error in infer: {str(e)}")
151
  return None, f"Error generating image: {str(e)}"
152
 
153
- # Update examples to be a list of prompts only, not including other parameters
 
154
  examples = [
155
- "a backpack for kids, flower style",
156
- "medieval flip flops",
157
- "cat shaped cake mold",
158
  ]
159
 
160
  css="""
@@ -165,26 +185,10 @@ css="""
165
  """
166
 
167
  def preload_models():
168
- global _text_gen_pipeline, _image_gen_pipeline
169
-
170
  print("Preloading models...")
171
- success = True
172
-
173
- try:
174
- _text_gen_pipeline = get_text_gen_pipeline()
175
- if _text_gen_pipeline is None:
176
- success = False
177
- except Exception as e:
178
- print(f"Error preloading text generation model: {str(e)}")
179
- success = False
180
-
181
- try:
182
- _image_gen_pipeline = get_image_gen_pipeline()
183
- if _image_gen_pipeline is None:
184
- success = False
185
- except Exception as e:
186
- print(f"Error preloading image generation model: {str(e)}")
187
- success = False
188
 
189
  status = "Models preloaded successfully!" if success else "Error preloading models"
190
  print(status)
@@ -196,7 +200,6 @@ def preload_models():
196
  def process_example_pipeline(example_prompt, system_prompt=DEFAULT_SYSTEM_PROMPT, progress=gr.Progress()):
197
  # Step 1: Update status
198
  progress(0, desc="Starting example processing")
199
- progress_status = "Selected example: " + example_prompt
200
 
201
  # Step 2: Refine the prompt
202
  progress(0.1, desc="Refining prompt with Mistral")
@@ -254,7 +257,7 @@ def create_interface():
254
  # Mistral settings
255
  temperature = gr.Slider(
256
  label="Temperature",
257
- value=0.9,
258
  minimum=0.0,
259
  maximum=1.0,
260
  step=0.05,
@@ -270,19 +273,19 @@ def create_interface():
270
 
271
  with gr.Tab("Flux"):
272
  # Flux settings
273
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
274
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
275
 
276
  with gr.Row():
277
- width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512)
278
- height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512)
279
 
280
  num_inference_steps = gr.Slider(
281
  label="Number of inference steps",
282
  minimum=1,
283
  maximum=50,
284
  step=1,
285
- value=6,
286
  )
287
 
288
  # Examples section - simplified version that only updates the prompt fields
 
12
  hf_token = os.getenv("hf_token")
13
  login(token=hf_token)
14
 
15
+ # Global constants and default values
16
  MAX_SEED = np.iinfo(np.int32).max
17
  MAX_IMAGE_SIZE = 2048
18
+ PRELOAD_MODELS = False
19
+
20
+ # Default system prompt for text generation
21
+ DEFAULT_SYSTEM_PROMPT = """Vous êtes un designer produit avec de solides connaissances dans la génération de texte en image. Vous recevrez une demande de produit sous forme de description succincte, et votre mission sera d'imaginer un nouveau design de produit répondant à ce besoin.
22
+
23
+ Le livrable (réponse générée) sera exclusivement un texte de prompt pour l'IA de texte to image FLUX.1-schnell.
24
+
25
+ Ce prompt devra inclure une description visuelle de l'objet mentionnant explicitement les aspects indispensables de sa fonction.
26
+ A coté de ça vous devez aussi explicitement mentionner dans ce prompt les caractéristiques esthétiques/photo du rendu image (ex : photoréaliste, haute qualité, focale, grain, etc.), sachant que l'image sera l'image principale de cet objet dans le catalogue produit. Le fond de l'image générée doit être entièrement blanc.
27
+ Le prompt doit être sans narration, peut être long mais ne doit pas dépasser 77 jetons."""
28
+
29
+ # Default Flux parameters
30
+ DEFAULT_SEED = 42
31
+ DEFAULT_RANDOMIZE_SEED = True
32
+ DEFAULT_WIDTH = 512
33
+ DEFAULT_HEIGHT = 512
34
+ DEFAULT_NUM_INFERENCE_STEPS = 6
35
+ DEFAULT_GUIDANCE_SCALE = 0.0
36
+ DEFAULT_TEMPERATURE = 0.9
37
 
38
  _text_gen_pipeline = None
39
  _image_gen_pipeline = None
 
83
  return None
84
  return _text_gen_pipeline
85
 
 
 
 
 
 
 
 
 
 
86
  @spaces.GPU()
87
  def refine_prompt(prompt, system_prompt=DEFAULT_SYSTEM_PROMPT, progress=gr.Progress()):
88
  text_gen = get_text_gen_pipeline()
 
124
  return True, None
125
 
126
  @spaces.GPU()
127
+ def infer(prompt, seed=DEFAULT_SEED,
128
+ randomize_seed=DEFAULT_RANDOMIZE_SEED,
129
+ width=DEFAULT_WIDTH,
130
+ height=DEFAULT_HEIGHT,
131
+ num_inference_steps=DEFAULT_NUM_INFERENCE_STEPS,
132
+ progress=gr.Progress(track_tqdm=True)):
133
  try:
134
  # Validate that prompt is not empty
135
  if not prompt or prompt.strip() == "":
136
  return None, "Please provide a valid prompt."
137
 
138
+ progress(0.1, desc="Loading model")
139
  pipe = get_image_gen_pipeline()
140
  if pipe is None:
141
  return None, "Image generation model is unavailable."
 
150
  # Use default torch generator instead of cuda-specific generator
151
  generator = torch.Generator().manual_seed(seed)
152
 
153
+ progress(0.3, desc="Running inference")
154
  # Match the working example's parameters
155
  output = pipe(
156
  prompt=prompt,
 
158
  height=height,
159
  num_inference_steps=num_inference_steps,
160
  generator=generator,
161
+ guidance_scale=DEFAULT_GUIDANCE_SCALE,
162
  )
163
 
164
+ progress(0.8, desc="Processing output")
165
  image = output.images[0]
166
+ progress(1.0, desc="Complete")
167
  return image, f"Image generated successfully with seed {seed}"
168
  except Exception as e:
169
  print(f"Error in infer: {str(e)}")
170
  return None, f"Error generating image: {str(e)}"
171
 
172
+
173
+ # Format: [prompt, system_prompt]
174
  examples = [
175
+ ["a backpack for kids, flower style", DEFAULT_SYSTEM_PROMPT],
176
+ ["medieval flip flops", DEFAULT_SYSTEM_PROMPT],
177
+ ["cat shaped cake mold", DEFAULT_SYSTEM_PROMPT],
178
  ]
179
 
180
  css="""
 
185
  """
186
 
187
  def preload_models():
 
 
188
  print("Preloading models...")
189
+ text_success = get_text_gen_pipeline() is not None
190
+ image_success = get_image_gen_pipeline() is not None
191
+ success = text_success and image_success
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
  status = "Models preloaded successfully!" if success else "Error preloading models"
194
  print(status)
 
200
  def process_example_pipeline(example_prompt, system_prompt=DEFAULT_SYSTEM_PROMPT, progress=gr.Progress()):
201
  # Step 1: Update status
202
  progress(0, desc="Starting example processing")
 
203
 
204
  # Step 2: Refine the prompt
205
  progress(0.1, desc="Refining prompt with Mistral")
 
257
  # Mistral settings
258
  temperature = gr.Slider(
259
  label="Temperature",
260
+ value=DEFAULT_TEMPERATURE,
261
  minimum=0.0,
262
  maximum=1.0,
263
  step=0.05,
 
273
 
274
  with gr.Tab("Flux"):
275
  # Flux settings
276
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=DEFAULT_SEED)
277
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=DEFAULT_RANDOMIZE_SEED)
278
 
279
  with gr.Row():
280
+ width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=DEFAULT_WIDTH)
281
+ height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=DEFAULT_HEIGHT)
282
 
283
  num_inference_steps = gr.Slider(
284
  label="Number of inference steps",
285
  minimum=1,
286
  maximum=50,
287
  step=1,
288
+ value=DEFAULT_NUM_INFERENCE_STEPS,
289
  )
290
 
291
  # Examples section - simplified version that only updates the prompt fields