lionelgarnier commited on
Commit
98c7793
·
1 Parent(s): 5428aaf

cursor changes 2

Browse files
Files changed (1) hide show
  1. app.py +30 -18
app.py CHANGED
@@ -15,19 +15,24 @@ login(token=hf_token)
15
  dtype = torch.bfloat16
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
19
-
20
  MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 2048
22
 
 
 
 
 
 
 
23
  def refine_prompt(prompt):
 
 
24
  try:
25
- chatbot = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.3", max_new_tokens=2048, device=device)
26
  messages = [
27
- {"role": "system", "content": "You are a product designer. You will get a basic prompt of product request and you need to imagine a new product design to satisfy that need. Produce an extended description of product front view that will be use by Flux to generate a visual"},
28
  {"role": "user", "content": prompt},
29
  ]
30
- refined_prompt = chatbot(messages)
31
  return refined_prompt
32
  except Exception as e:
33
  return f"Error refining prompt: {str(e)}"
@@ -42,6 +47,15 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_in
42
  try:
43
  progress(0, desc="Starting generation...")
44
 
 
 
 
 
 
 
 
 
 
45
  if randomize_seed:
46
  seed = random.randint(0, MAX_SEED)
47
 
@@ -51,15 +65,16 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_in
51
  progress(0.4, desc="Generating image...")
52
  with torch.cuda.amp.autocast():
53
  image = pipe(
54
- prompt = prompt,
55
- width = width,
56
- height = height,
57
- num_inference_steps = num_inference_steps,
58
- generator = generator,
59
  guidance_scale=0.0,
60
  max_sequence_length=2048
61
  ).images[0]
62
-
 
63
  progress(1.0, desc="Done!")
64
  return image, seed
65
  except Exception as e:
@@ -80,13 +95,10 @@ css="""
80
 
81
  with gr.Blocks(css=css) as demo:
82
 
83
- info = gr.Info("Loading models... Please wait.")
84
-
85
- try:
86
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
87
- info.value = "Models loaded successfully!"
88
- except Exception as e:
89
- info.value = f"Error loading models: {str(e)}"
90
 
91
  with gr.Column(elem_id="col-container"):
92
  gr.Markdown(f"""# Text to Product
 
15
  dtype = torch.bfloat16
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
 
 
18
  MAX_SEED = np.iinfo(np.int32).max
19
  MAX_IMAGE_SIZE = 2048
20
 
21
+ try:
22
+ text_gen_pipeline = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.3", max_new_tokens=2048, device=device)
23
+ except Exception as e:
24
+ text_gen_pipeline = None
25
+ print(f"Error loading text generation model: {e}")
26
+
27
  def refine_prompt(prompt):
28
+ if text_gen_pipeline is None:
29
+ return "Text generation model is unavailable."
30
  try:
 
31
  messages = [
32
+ {"role": "system", "content": "You are a product designer. You will get a basic prompt of product request and you need to imagine a new product design to satisfy that need. Produce an extended description of product front view that will be used by Flux to generate a visual"},
33
  {"role": "user", "content": prompt},
34
  ]
35
+ refined_prompt = text_gen_pipeline(messages)
36
  return refined_prompt
37
  except Exception as e:
38
  return f"Error refining prompt: {str(e)}"
 
47
  try:
48
  progress(0, desc="Starting generation...")
49
 
50
+ # Validate that prompt is not empty
51
+ if not prompt or prompt.strip() == "":
52
+ return None, "Please provide a valid prompt."
53
+
54
+ # Validate width/height dimensions
55
+ is_valid, error_msg = validate_dimensions(width, height)
56
+ if not is_valid:
57
+ return None, error_msg
58
+
59
  if randomize_seed:
60
  seed = random.randint(0, MAX_SEED)
61
 
 
65
  progress(0.4, desc="Generating image...")
66
  with torch.cuda.amp.autocast():
67
  image = pipe(
68
+ prompt=prompt,
69
+ width=width,
70
+ height=height,
71
+ num_inference_steps=num_inference_steps,
72
+ generator=generator,
73
  guidance_scale=0.0,
74
  max_sequence_length=2048
75
  ).images[0]
76
+
77
+ torch.cuda.empty_cache() # Clean up GPU memory after generation
78
  progress(1.0, desc="Done!")
79
  return image, seed
80
  except Exception as e:
 
95
 
96
  with gr.Blocks(css=css) as demo:
97
 
98
+ # Compute the model loading status message ahead of creating the Info component.
99
+ model_status = "Models loaded successfully!"
100
+
101
+ info = gr.Info(model_status)
 
 
 
102
 
103
  with gr.Column(elem_id="col-container"):
104
  gr.Markdown(f"""# Text to Product