Spaces:
Runtime error
Runtime error
lionelgarnier
commited on
Commit
·
47d8bfc
1
Parent(s):
ac3fb1c
Improve image generation with better processing and model configuration
Browse files
app.py
CHANGED
@@ -7,6 +7,7 @@ import torch
|
|
7 |
from diffusers import DiffusionPipeline
|
8 |
from transformers import pipeline, AutoTokenizer
|
9 |
from huggingface_hub import login
|
|
|
10 |
|
11 |
hf_token = os.getenv("hf_token")
|
12 |
login(token=hf_token)
|
@@ -30,6 +31,8 @@ def get_image_gen_pipeline():
|
|
30 |
"black-forest-labs/FLUX.1-dev",
|
31 |
torch_dtype=dtype,
|
32 |
).to(device)
|
|
|
|
|
33 |
except Exception as e:
|
34 |
print(f"Error loading image generation model: {e}")
|
35 |
return None
|
@@ -69,7 +72,7 @@ def refine_prompt(prompt):
|
|
69 |
return "Text generation model is unavailable."
|
70 |
try:
|
71 |
messages = [
|
72 |
-
{"role": "system", "content": "Vous êtes un designer produit avec de solides connaissances dans la génération de texte en image. Vous recevrez une demande de produit sous forme de description succincte, et votre mission sera d'imaginer un nouveau design de produit répondant à ce besoin.\n\nLe livrable (réponse générée) sera exclusivement un texte de prompt pour l'IA de texte to image FLUX.1-schnell.\n\nCe prompt devra inclure une description visuelle de l'objet mentionnant explicitement les aspects indispensables de sa fonction.\nA coté de ça vous devez aussi explicitement mentionner dans ce prompt les caractéristiques esthétiques/photo du rendu image (ex : photoréaliste, haute qualité, focale, grain, etc.), sachant que l'image sera l'image principale de cet objet dans le catalogue produit. Le fond de l'image générée doit être entièrement blanc.\nLe prompt doit être sans narration, peut être long mais ne doit pas dépasser
|
73 |
]
|
74 |
refined_prompt = text_gen(messages)
|
75 |
|
@@ -93,20 +96,17 @@ def validate_dimensions(width, height):
|
|
93 |
return True, None
|
94 |
|
95 |
@spaces.GPU()
|
96 |
-
def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4):
|
97 |
try:
|
98 |
-
# progress(0, desc="Starting generation...")
|
99 |
|
100 |
# Validate that prompt is not empty
|
101 |
if not prompt or prompt.strip() == "":
|
102 |
return None, "Please provide a valid prompt."
|
103 |
|
104 |
-
# progress(0.1, desc="Loading image generation model...")
|
105 |
pipe = get_image_gen_pipeline()
|
106 |
if pipe is None:
|
107 |
return None, "Image generation model is unavailable."
|
108 |
|
109 |
-
# progress(0.2, desc="Validating dimensions...")
|
110 |
is_valid, error_msg = validate_dimensions(width, height)
|
111 |
if not is_valid:
|
112 |
return None, error_msg
|
@@ -114,26 +114,30 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_in
|
|
114 |
if randomize_seed:
|
115 |
seed = random.randint(0, MAX_SEED)
|
116 |
|
117 |
-
# progress(0.3, desc="Setting up generator...")
|
118 |
generator = torch.Generator("cuda").manual_seed(seed) # Explicitly use CUDA generator
|
119 |
|
120 |
-
# progress(0.4, desc="Generating image...")
|
121 |
with torch.autocast('cuda'):
|
122 |
-
|
123 |
prompt=prompt,
|
124 |
width=width,
|
125 |
height=height,
|
126 |
num_inference_steps=num_inference_steps,
|
127 |
generator=generator,
|
128 |
-
guidance_scale=
|
129 |
-
|
130 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
-
torch.cuda.empty_cache()
|
133 |
-
# progress(1.0, desc="Done!")
|
134 |
return image, seed
|
135 |
except Exception as e:
|
136 |
-
print(f"Error in infer: {str(e)}")
|
137 |
return None, f"Error generating image: {str(e)}"
|
138 |
|
139 |
examples = [
|
|
|
7 |
from diffusers import DiffusionPipeline
|
8 |
from transformers import pipeline, AutoTokenizer
|
9 |
from huggingface_hub import login
|
10 |
+
from PIL import Image
|
11 |
|
12 |
hf_token = os.getenv("hf_token")
|
13 |
login(token=hf_token)
|
|
|
31 |
"black-forest-labs/FLUX.1-dev",
|
32 |
torch_dtype=dtype,
|
33 |
).to(device)
|
34 |
+
_image_gen_pipeline.enable_model_cpu_offload()
|
35 |
+
_image_gen_pipeline.enable_vae_slicing()
|
36 |
except Exception as e:
|
37 |
print(f"Error loading image generation model: {e}")
|
38 |
return None
|
|
|
72 |
return "Text generation model is unavailable."
|
73 |
try:
|
74 |
messages = [
|
75 |
+
{"role": "system", "content": "Vous êtes un designer produit avec de solides connaissances dans la génération de texte en image. Vous recevrez une demande de produit sous forme de description succincte, et votre mission sera d'imaginer un nouveau design de produit répondant à ce besoin.\n\nLe livrable (réponse générée) sera exclusivement un texte de prompt pour l'IA de texte to image FLUX.1-schnell.\n\nCe prompt devra inclure une description visuelle de l'objet mentionnant explicitement les aspects indispensables de sa fonction.\nA coté de ça vous devez aussi explicitement mentionner dans ce prompt les caractéristiques esthétiques/photo du rendu image (ex : photoréaliste, haute qualité, focale, grain, etc.), sachant que l'image sera l'image principale de cet objet dans le catalogue produit. Le fond de l'image générée doit être entièrement blanc.\nLe prompt doit être sans narration, peut être long mais ne doit pas dépasser 77 jetons."}, {"role": "user", "content": prompt},
|
76 |
]
|
77 |
refined_prompt = text_gen(messages)
|
78 |
|
|
|
96 |
return True, None
|
97 |
|
98 |
@spaces.GPU()
|
99 |
+
def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4):
|
100 |
try:
|
|
|
101 |
|
102 |
# Validate that prompt is not empty
|
103 |
if not prompt or prompt.strip() == "":
|
104 |
return None, "Please provide a valid prompt."
|
105 |
|
|
|
106 |
pipe = get_image_gen_pipeline()
|
107 |
if pipe is None:
|
108 |
return None, "Image generation model is unavailable."
|
109 |
|
|
|
110 |
is_valid, error_msg = validate_dimensions(width, height)
|
111 |
if not is_valid:
|
112 |
return None, error_msg
|
|
|
114 |
if randomize_seed:
|
115 |
seed = random.randint(0, MAX_SEED)
|
116 |
|
|
|
117 |
generator = torch.Generator("cuda").manual_seed(seed) # Explicitly use CUDA generator
|
118 |
|
|
|
119 |
with torch.autocast('cuda'):
|
120 |
+
output = pipe(
|
121 |
prompt=prompt,
|
122 |
width=width,
|
123 |
height=height,
|
124 |
num_inference_steps=num_inference_steps,
|
125 |
generator=generator,
|
126 |
+
guidance_scale=7.5,
|
127 |
+
max_sequence_length=512
|
128 |
+
)
|
129 |
+
|
130 |
+
# Ensure the image is properly normalized and converted
|
131 |
+
image = output.images[0]
|
132 |
+
if isinstance(image, torch.Tensor):
|
133 |
+
image = (image.clamp(-1, 1) + 1) / 2
|
134 |
+
image = (image * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
|
135 |
+
image = Image.fromarray(image)
|
136 |
|
137 |
+
torch.cuda.empty_cache()
|
|
|
138 |
return image, seed
|
139 |
except Exception as e:
|
140 |
+
print(f"Error in infer: {str(e)}")
|
141 |
return None, f"Error generating image: {str(e)}"
|
142 |
|
143 |
examples = [
|