lionelgarnier commited on
Commit
827b490
·
1 Parent(s): b6b421e

debug cursor image

Browse files
Files changed (1) hide show
  1. app.py +32 -8
app.py CHANGED
@@ -15,13 +15,37 @@ MAX_SEED = np.iinfo(np.int32).max
15
  MAX_IMAGE_SIZE = 2048
16
 
17
  _text_gen_pipeline = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  @spaces.GPU()
19
  def get_text_gen_pipeline():
20
  global _text_gen_pipeline
21
  if _text_gen_pipeline is None:
22
  try:
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
24
- _text_gen_pipeline = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.3", max_new_tokens=2048, device=device)
 
 
 
 
 
 
25
  except Exception as e:
26
  print(f"Error loading text generation model: {e}")
27
  return None
@@ -52,9 +76,9 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_in
52
  try:
53
  progress(0, desc="Starting generation...")
54
 
55
- device = "cuda" if torch.cuda.is_available() else "cpu"
56
- dtype = torch.bfloat16
57
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
58
 
59
  # Validate that prompt is not empty
60
  if not prompt or prompt.strip() == "":
@@ -79,7 +103,7 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_in
79
  height=height,
80
  num_inference_steps=num_inference_steps,
81
  generator=generator,
82
- guidance_scale=0.0,
83
  max_sequence_length=2048
84
  ).images[0]
85
 
@@ -138,7 +162,7 @@ with gr.Blocks(css=css) as demo:
138
 
139
  run_button = gr.Button("Create visual", scale=0)
140
 
141
- result = gr.Image(label="Result", show_label=False)
142
 
143
  with gr.Accordion("Advanced Settings Mistral", open=False):
144
  gr.Slider(
@@ -223,7 +247,7 @@ with gr.Blocks(css=css) as demo:
223
  examples = examples,
224
  fn = infer,
225
  inputs = [prompt],
226
- outputs = [result, seed],
227
  cache_examples="lazy"
228
  )
229
 
@@ -239,7 +263,7 @@ with gr.Blocks(css=css) as demo:
239
  triggers=[run_button.click],
240
  fn = infer,
241
  inputs = [refined_prompt, seed, randomize_seed, width, height, num_inference_steps],
242
- outputs = [result, seed]
243
  )
244
 
245
  demo.launch()
 
15
  MAX_IMAGE_SIZE = 2048
16
 
17
  _text_gen_pipeline = None
18
+ _image_gen_pipeline = None
19
+
20
+ @spaces.GPU()
21
+ def get_image_gen_pipeline():
22
+ global _image_gen_pipeline
23
+ if _image_gen_pipeline is None:
24
+ try:
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
+ dtype = torch.bfloat16
27
+ _image_gen_pipeline = DiffusionPipeline.from_pretrained(
28
+ "black-forest-labs/FLUX.1-schnell",
29
+ torch_dtype=dtype
30
+ ).to(device)
31
+ except Exception as e:
32
+ print(f"Error loading image generation model: {e}")
33
+ return None
34
+ return _image_gen_pipeline
35
+
36
  @spaces.GPU()
37
  def get_text_gen_pipeline():
38
  global _text_gen_pipeline
39
  if _text_gen_pipeline is None:
40
  try:
41
  device = "cuda" if torch.cuda.is_available() else "cpu"
42
+ _text_gen_pipeline = pipeline(
43
+ "text-generation",
44
+ model="mistralai/Mistral-7B-Instruct-v0.3",
45
+ max_new_tokens=2048,
46
+ device=device,
47
+ tokenizer_kwargs={"add_prefix_space": False}
48
+ )
49
  except Exception as e:
50
  print(f"Error loading text generation model: {e}")
51
  return None
 
76
  try:
77
  progress(0, desc="Starting generation...")
78
 
79
+ pipe = get_image_gen_pipeline()
80
+ if pipe is None:
81
+ return None, "Image generation model is unavailable."
82
 
83
  # Validate that prompt is not empty
84
  if not prompt or prompt.strip() == "":
 
103
  height=height,
104
  num_inference_steps=num_inference_steps,
105
  generator=generator,
106
+ guidance_scale=5.0,
107
  max_sequence_length=2048
108
  ).images[0]
109
 
 
162
 
163
  run_button = gr.Button("Create visual", scale=0)
164
 
165
+ generated_image = gr.Image(label="Generated Image", show_label=False)
166
 
167
  with gr.Accordion("Advanced Settings Mistral", open=False):
168
  gr.Slider(
 
247
  examples = examples,
248
  fn = infer,
249
  inputs = [prompt],
250
+ outputs = [generated_image, seed],
251
  cache_examples="lazy"
252
  )
253
 
 
263
  triggers=[run_button.click],
264
  fn = infer,
265
  inputs = [refined_prompt, seed, randomize_seed, width, height, num_inference_steps],
266
+ outputs = [generated_image, seed]
267
  )
268
 
269
  demo.launch()