lionelgarnier commited on
Commit
47d8bfc
·
1 Parent(s): ac3fb1c

Improve image generation with better processing and model configuration

Browse files
Files changed (1) hide show
  1. app.py +18 -14
app.py CHANGED
@@ -7,6 +7,7 @@ import torch
7
  from diffusers import DiffusionPipeline
8
  from transformers import pipeline, AutoTokenizer
9
  from huggingface_hub import login
 
10
 
11
  hf_token = os.getenv("hf_token")
12
  login(token=hf_token)
@@ -30,6 +31,8 @@ def get_image_gen_pipeline():
30
  "black-forest-labs/FLUX.1-dev",
31
  torch_dtype=dtype,
32
  ).to(device)
 
 
33
  except Exception as e:
34
  print(f"Error loading image generation model: {e}")
35
  return None
@@ -69,7 +72,7 @@ def refine_prompt(prompt):
69
  return "Text generation model is unavailable."
70
  try:
71
  messages = [
72
- {"role": "system", "content": "Vous êtes un designer produit avec de solides connaissances dans la génération de texte en image. Vous recevrez une demande de produit sous forme de description succincte, et votre mission sera d'imaginer un nouveau design de produit répondant à ce besoin.\n\nLe livrable (réponse générée) sera exclusivement un texte de prompt pour l'IA de texte to image FLUX.1-schnell.\n\nCe prompt devra inclure une description visuelle de l'objet mentionnant explicitement les aspects indispensables de sa fonction.\nA coté de ça vous devez aussi explicitement mentionner dans ce prompt les caractéristiques esthétiques/photo du rendu image (ex : photoréaliste, haute qualité, focale, grain, etc.), sachant que l'image sera l'image principale de cet objet dans le catalogue produit. Le fond de l'image générée doit être entièrement blanc.\nLe prompt doit être sans narration, peut être long mais ne doit pas dépasser 512 jetons."}, {"role": "user", "content": prompt},
73
  ]
74
  refined_prompt = text_gen(messages)
75
 
@@ -93,20 +96,17 @@ def validate_dimensions(width, height):
93
  return True, None
94
 
95
  @spaces.GPU()
96
- def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4): # , progress=gr.Progress(track_tqdm=True)
97
  try:
98
- # progress(0, desc="Starting generation...")
99
 
100
  # Validate that prompt is not empty
101
  if not prompt or prompt.strip() == "":
102
  return None, "Please provide a valid prompt."
103
 
104
- # progress(0.1, desc="Loading image generation model...")
105
  pipe = get_image_gen_pipeline()
106
  if pipe is None:
107
  return None, "Image generation model is unavailable."
108
 
109
- # progress(0.2, desc="Validating dimensions...")
110
  is_valid, error_msg = validate_dimensions(width, height)
111
  if not is_valid:
112
  return None, error_msg
@@ -114,26 +114,30 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_in
114
  if randomize_seed:
115
  seed = random.randint(0, MAX_SEED)
116
 
117
- # progress(0.3, desc="Setting up generator...")
118
  generator = torch.Generator("cuda").manual_seed(seed) # Explicitly use CUDA generator
119
 
120
- # progress(0.4, desc="Generating image...")
121
  with torch.autocast('cuda'):
122
- image = pipe(
123
  prompt=prompt,
124
  width=width,
125
  height=height,
126
  num_inference_steps=num_inference_steps,
127
  generator=generator,
128
- guidance_scale=0.0, # Increased guidance scale
129
- # max_sequence_length=512
130
- ).images[0]
 
 
 
 
 
 
 
131
 
132
- torch.cuda.empty_cache() # Clean up GPU memory after generation
133
- # progress(1.0, desc="Done!")
134
  return image, seed
135
  except Exception as e:
136
- print(f"Error in infer: {str(e)}") # Add detailed error logging
137
  return None, f"Error generating image: {str(e)}"
138
 
139
  examples = [
 
7
  from diffusers import DiffusionPipeline
8
  from transformers import pipeline, AutoTokenizer
9
  from huggingface_hub import login
10
+ from PIL import Image
11
 
12
  hf_token = os.getenv("hf_token")
13
  login(token=hf_token)
 
31
  "black-forest-labs/FLUX.1-dev",
32
  torch_dtype=dtype,
33
  ).to(device)
34
+ _image_gen_pipeline.enable_model_cpu_offload()
35
+ _image_gen_pipeline.enable_vae_slicing()
36
  except Exception as e:
37
  print(f"Error loading image generation model: {e}")
38
  return None
 
72
  return "Text generation model is unavailable."
73
  try:
74
  messages = [
75
+ {"role": "system", "content": "Vous êtes un designer produit avec de solides connaissances dans la génération de texte en image. Vous recevrez une demande de produit sous forme de description succincte, et votre mission sera d'imaginer un nouveau design de produit répondant à ce besoin.\n\nLe livrable (réponse générée) sera exclusivement un texte de prompt pour l'IA de texte to image FLUX.1-schnell.\n\nCe prompt devra inclure une description visuelle de l'objet mentionnant explicitement les aspects indispensables de sa fonction.\nA coté de ça vous devez aussi explicitement mentionner dans ce prompt les caractéristiques esthétiques/photo du rendu image (ex : photoréaliste, haute qualité, focale, grain, etc.), sachant que l'image sera l'image principale de cet objet dans le catalogue produit. Le fond de l'image générée doit être entièrement blanc.\nLe prompt doit être sans narration, peut être long mais ne doit pas dépasser 77 jetons."}, {"role": "user", "content": prompt},
76
  ]
77
  refined_prompt = text_gen(messages)
78
 
 
96
  return True, None
97
 
98
  @spaces.GPU()
99
+ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4):
100
  try:
 
101
 
102
  # Validate that prompt is not empty
103
  if not prompt or prompt.strip() == "":
104
  return None, "Please provide a valid prompt."
105
 
 
106
  pipe = get_image_gen_pipeline()
107
  if pipe is None:
108
  return None, "Image generation model is unavailable."
109
 
 
110
  is_valid, error_msg = validate_dimensions(width, height)
111
  if not is_valid:
112
  return None, error_msg
 
114
  if randomize_seed:
115
  seed = random.randint(0, MAX_SEED)
116
 
 
117
  generator = torch.Generator("cuda").manual_seed(seed) # Explicitly use CUDA generator
118
 
 
119
  with torch.autocast('cuda'):
120
+ output = pipe(
121
  prompt=prompt,
122
  width=width,
123
  height=height,
124
  num_inference_steps=num_inference_steps,
125
  generator=generator,
126
+ guidance_scale=7.5,
127
+ max_sequence_length=512
128
+ )
129
+
130
+ # Ensure the image is properly normalized and converted
131
+ image = output.images[0]
132
+ if isinstance(image, torch.Tensor):
133
+ image = (image.clamp(-1, 1) + 1) / 2
134
+ image = (image * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
135
+ image = Image.fromarray(image)
136
 
137
+ torch.cuda.empty_cache()
 
138
  return image, seed
139
  except Exception as e:
140
+ print(f"Error in infer: {str(e)}")
141
  return None, f"Error generating image: {str(e)}"
142
 
143
  examples = [