lionelgarnier commited on
Commit
640d399
·
1 Parent(s): b39baa9

simplify cide

Browse files
Files changed (1) hide show
  1. app.py +88 -152
app.py CHANGED
@@ -42,27 +42,25 @@ def get_image_gen_pipeline():
42
  def get_text_gen_pipeline():
43
  global _text_gen_pipeline
44
  if _text_gen_pipeline is None:
45
- try:
46
- device = "cuda" if torch.cuda.is_available() else "cpu"
47
- tokenizer = AutoTokenizer.from_pretrained(
48
- "mistralai/Mistral-7B-Instruct-v0.3",
49
- use_fast=True
50
- )
51
- # Set pad_token_id to eos_token_id if pad_token is not set
52
- if tokenizer.pad_token is None:
53
- tokenizer.pad_token = tokenizer.eos_token
54
-
55
- _text_gen_pipeline = pipeline(
56
- "text-generation",
57
- model="mistralai/Mistral-7B-Instruct-v0.3",
58
- tokenizer=tokenizer,
59
- max_new_tokens=2048,
60
- device=device,
61
- pad_token_id=tokenizer.pad_token_id # Explicitly set pad_token_id
62
- )
63
- except Exception as e:
64
- print(f"Error loading text generation model: {e}")
65
- return None
66
  return _text_gen_pipeline
67
 
68
  @spaces.GPU()
@@ -127,15 +125,10 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_in
127
  max_sequence_length=512
128
  )
129
 
130
- # Ensure the image is properly normalized and converted
131
  image = output.images[0]
132
- # if isinstance(image, torch.Tensor):
133
- # image = (image.clamp(-1, 1) + 1) / 2
134
- # image = (image * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
135
- # image = Image.fromarray(image)
136
 
137
  #torch.cuda.empty_cache()
138
- return image, seed
139
  except Exception as e:
140
  print(f"Error in infer: {str(e)}")
141
  return None, f"Error generating image: {str(e)}"
@@ -154,150 +147,93 @@ css="""
154
  """
155
 
156
  def preload_models():
157
- print("Préchargement des modèles...")
 
 
 
 
158
  try:
159
- # Préchargement du modèle de génération de texte
160
- device = "cuda" if torch.cuda.is_available() else "cpu"
161
- # Explicitly load the fast tokenizer LGR
162
- tokenizer = AutoTokenizer.from_pretrained(
163
- "mistralai/Mistral-7B-Instruct-v0.3",
164
- use_fast=True # Ensures a fast tokenizer is used
165
- )
166
- _text_gen_pipeline = pipeline(
167
- "text-generation",
168
- model="mistralai/Mistral-7B-Instruct-v0.3",
169
- tokenizer=tokenizer, # Pass the fast tokenizer in LGR
170
- max_new_tokens=2048,
171
- device=device,
172
- )
173
-
174
- # Préchargement du modèle de génération d'images
175
- dtype = torch.bfloat16
176
- _image_gen_pipeline = DiffusionPipeline.from_pretrained(
177
- "black-forest-labs/FLUX.1-schnell",
178
- # "black-forest-labs/FLUX.1-dev",
179
- torch_dtype=dtype
180
- ).to(device)
181
-
182
- print("Modèles préchargés avec succès!")
183
- return True
184
  except Exception as e:
185
- print(f"Erreur lors du préchargement des modèles: {str(e)}")
186
- return False
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  def create_interface():
189
- # Modify the preloading logic
190
  if PRELOAD_MODELS:
191
  models_loaded = preload_models()
192
- model_status = "✅ Modèles chargés avec succès!" if models_loaded else "⚠️ Erreur lors du chargement des modèles"
193
  else:
194
- model_status = "ℹ️ Modèles seront chargés à la demande"
195
 
196
  with gr.Blocks(css=css) as demo:
197
- info = gr.Info(model_status)
198
 
199
  with gr.Column(elem_id="col-container"):
200
- gr.Markdown(f"""# Text to Product
201
- Using Mistral-7B-Instruct-v0.3 + FLUX.1-dev + Trellis
202
- """)
203
 
 
204
  with gr.Row():
205
-
206
  prompt = gr.Text(
207
- label="Prompt",
208
  show_label=False,
209
  max_lines=1,
210
  placeholder="Enter basic object prompt",
211
  container=False,
212
  )
213
-
214
- prompt_button = gr.Button("Refine prompt with Mistral", scale=0)
215
 
216
  refined_prompt = gr.Text(
217
- label="Refined Prompt",
218
  show_label=False,
219
  max_lines=10,
220
  placeholder="Detailed object prompt",
221
  container=False,
222
  max_length=2048,
223
- )
224
-
225
-
226
- run_button = gr.Button("Create visual with Flux", scale=0)
227
-
228
- generated_image = gr.Image(label="Generated Image", show_label=False)
229
 
230
- with gr.Accordion("Advanced Settings Mistral", open=False):
231
- gr.Slider(
232
- label="Temperature",
233
- value=0.9,
234
- minimum=0.0,
235
- maximum=1.0,
236
- step=0.05,
237
- interactive=True,
238
- info="Higher values produce more diverse outputs",
239
- ),
240
- gr.Slider(
241
- label="Max new tokens",
242
- value=256,
243
- minimum=0,
244
- maximum=1048,
245
- step=64,
246
- interactive=True,
247
- info="The maximum numbers of new tokens",
248
- ),
249
- gr.Slider(
250
- label="Top-p (nucleus sampling)",
251
- value=0.90,
252
- minimum=0.0,
253
- maximum=1,
254
- step=0.05,
255
- interactive=True,
256
- info="Higher values sample more low-probability tokens",
257
- ),
258
- gr.Slider(
259
- label="Repetition penalty",
260
- value=1.2,
261
- minimum=1.0,
262
- maximum=2.0,
263
- step=0.05,
264
- interactive=True,
265
- info="Penalize repeated tokens",
266
- )
267
 
268
- with gr.Accordion("Advanced Settings Flux", open=False):
269
-
270
- seed = gr.Slider(
271
- label="Seed",
272
- minimum=0,
273
- maximum=MAX_SEED,
274
- step=1,
275
- value=0,
276
- )
277
-
278
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
279
-
280
- with gr.Row():
281
-
282
- width = gr.Slider(
283
- label="Width",
284
- minimum=256,
285
- maximum=MAX_IMAGE_SIZE,
286
- step=32,
287
- value=512,
288
  )
289
 
290
- height = gr.Slider(
291
- label="Height",
292
- minimum=256,
293
- maximum=MAX_IMAGE_SIZE,
294
- step=32,
295
- value=512,
296
- )
297
-
298
- with gr.Row():
299
 
300
-
301
  num_inference_steps = gr.Slider(
302
  label="Number of inference steps",
303
  minimum=1,
@@ -305,29 +241,29 @@ def create_interface():
305
  step=1,
306
  value=10,
307
  )
308
-
 
309
  gr.Examples(
310
  examples=examples,
311
  fn=refine_prompt,
312
- inputs = [prompt],
313
- outputs = [refined_prompt],
314
  cache_examples=True,
315
- cache_mode='lazy'
316
  )
317
 
318
-
319
  gr.on(
320
  triggers=[prompt_button.click, prompt.submit],
321
- fn = refine_prompt,
322
- inputs = [prompt],
323
- outputs = [refined_prompt]
324
  )
325
 
326
  gr.on(
327
- triggers=[run_button.click],
328
- fn = infer,
329
- inputs = [refined_prompt, seed, randomize_seed, width, height, num_inference_steps],
330
- outputs = [generated_image, prompt]
331
  )
332
 
333
  return demo
 
42
  def get_text_gen_pipeline():
43
  global _text_gen_pipeline
44
  if _text_gen_pipeline is None:
45
+ try:
46
+ device = "cuda" if torch.cuda.is_available() else "cpu"
47
+ tokenizer = AutoTokenizer.from_pretrained(
48
+ "mistralai/Mistral-7B-Instruct-v0.3",
49
+ use_fast=True
50
+ )
51
+ tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
52
+
53
+ _text_gen_pipeline = pipeline(
54
+ "text-generation",
55
+ model="mistralai/Mistral-7B-Instruct-v0.3",
56
+ tokenizer=tokenizer,
57
+ max_new_tokens=2048,
58
+ device=device,
59
+ pad_token_id=tokenizer.pad_token_id
60
+ )
61
+ except Exception as e:
62
+ print(f"Error loading text generation model: {e}")
63
+ return None
 
 
64
  return _text_gen_pipeline
65
 
66
  @spaces.GPU()
 
125
  max_sequence_length=512
126
  )
127
 
 
128
  image = output.images[0]
 
 
 
 
129
 
130
  #torch.cuda.empty_cache()
131
+ return image, f"Image generated successfully with seed {seed}"
132
  except Exception as e:
133
  print(f"Error in infer: {str(e)}")
134
  return None, f"Error generating image: {str(e)}"
 
147
  """
148
 
149
  def preload_models():
150
+ global _text_gen_pipeline, _image_gen_pipeline
151
+
152
+ print("Preloading models...")
153
+ success = True
154
+
155
  try:
156
+ _text_gen_pipeline = get_text_gen_pipeline()
157
+ if _text_gen_pipeline is None:
158
+ success = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  except Exception as e:
160
+ print(f"Error preloading text generation model: {str(e)}")
161
+ success = False
162
+
163
+ try:
164
+ _image_gen_pipeline = get_image_gen_pipeline()
165
+ if _image_gen_pipeline is None:
166
+ success = False
167
+ except Exception as e:
168
+ print(f"Error preloading image generation model: {str(e)}")
169
+ success = False
170
+
171
+ status = "Models preloaded successfully!" if success else "Error preloading models"
172
+ print(status)
173
+ return success
174
 
175
  def create_interface():
176
+ # Preload models if needed
177
  if PRELOAD_MODELS:
178
  models_loaded = preload_models()
179
+ model_status = "✅ Models loaded successfully!" if models_loaded else "⚠️ Error loading models"
180
  else:
181
+ model_status = "ℹ️ Models will be loaded on demand"
182
 
183
  with gr.Blocks(css=css) as demo:
184
+ gr.Info(model_status)
185
 
186
  with gr.Column(elem_id="col-container"):
187
+ gr.Markdown("# Text to Product\nUsing Mistral-7B-Instruct-v0.3 + FLUX.1-dev + Trellis")
 
 
188
 
189
+ # Basic inputs
190
  with gr.Row():
 
191
  prompt = gr.Text(
 
192
  show_label=False,
193
  max_lines=1,
194
  placeholder="Enter basic object prompt",
195
  container=False,
196
  )
197
+ prompt_button = gr.Button("Refine prompt with Mistral")
 
198
 
199
  refined_prompt = gr.Text(
 
200
  show_label=False,
201
  max_lines=10,
202
  placeholder="Detailed object prompt",
203
  container=False,
204
  max_length=2048,
205
+ )
 
 
 
 
 
206
 
207
+ visual_button = gr.Button("Create visual with Flux")
208
+ generated_image = gr.Image(show_label=False)
209
+ error_box = gr.Textbox(
210
+ label="Status Messages",
211
+ interactive=False,
212
+ placeholder="Status messages will appear here",
213
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
+ # Accordion sections for advanced settings
216
+ with gr.Accordion("Advanced Settings", open=False):
217
+ with gr.Tab("Mistral"):
218
+ # Mistral settings
219
+ temperature = gr.Slider(
220
+ label="Temperature",
221
+ value=0.9,
222
+ minimum=0.0,
223
+ maximum=1.0,
224
+ step=0.05,
225
+ info="Higher values produce more diverse outputs",
 
 
 
 
 
 
 
 
 
226
  )
227
 
228
+ with gr.Tab("Flux"):
229
+ # Flux settings
230
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
231
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
232
+
233
+ with gr.Row():
234
+ width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512)
235
+ height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512)
 
236
 
 
237
  num_inference_steps = gr.Slider(
238
  label="Number of inference steps",
239
  minimum=1,
 
241
  step=1,
242
  value=10,
243
  )
244
+
245
+ # Examples section
246
  gr.Examples(
247
  examples=examples,
248
  fn=refine_prompt,
249
+ inputs=[prompt],
250
+ outputs=[refined_prompt],
251
  cache_examples=True,
 
252
  )
253
 
254
+ # Event handlers
255
  gr.on(
256
  triggers=[prompt_button.click, prompt.submit],
257
+ fn=refine_prompt,
258
+ inputs=[prompt],
259
+ outputs=[refined_prompt]
260
  )
261
 
262
  gr.on(
263
+ triggers=[visual_button.click],
264
+ fn=infer,
265
+ inputs=[refined_prompt, seed, randomize_seed, width, height, num_inference_steps],
266
+ outputs=[generated_image, error_box]
267
  )
268
 
269
  return demo