lionelgarnier commited on
Commit
cd82b78
·
1 Parent(s): fd3ce93

Update progress tracking method in image generation pipeline

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -100,20 +100,21 @@ def validate_dimensions(width, height):
100
  return True, None
101
 
102
  @spaces.GPU()
103
- def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
104
  try:
105
- progress(0, desc="Starting generation...")
 
106
 
107
  # Validate that prompt is not empty
108
  if not prompt or prompt.strip() == "":
109
  return None, "Please provide a valid prompt."
110
 
111
- progress(0.1, desc="Loading image generation model...")
112
  pipe = get_image_gen_pipeline()
113
  if pipe is None:
114
  return None, "Image generation model is unavailable."
115
 
116
- progress(0.2, desc="Validating dimensions...")
117
  is_valid, error_msg = validate_dimensions(width, height)
118
  if not is_valid:
119
  return None, error_msg
@@ -121,10 +122,9 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_in
121
  if randomize_seed:
122
  seed = random.randint(0, MAX_SEED)
123
 
124
- progress(0.3, desc="Setting up generator...")
125
  generator = torch.Generator().manual_seed(seed)
126
 
127
- progress(0.4, desc="Generating image...")
128
  with torch.autocast('cuda'):
129
  image = pipe(
130
  prompt=prompt,
@@ -137,7 +137,7 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_in
137
  ).images[0]
138
 
139
  torch.cuda.empty_cache() # Clean up GPU memory after generation
140
- progress(1.0, desc="Done!")
141
  return image, seed
142
  except Exception as e:
143
  return None, f"Error generating image: {str(e)}"
 
100
  return True, None
101
 
102
  @spaces.GPU()
103
+ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress()):
104
  try:
105
+ # Use progress.update() instead of directly calling progress()
106
+ progress.update(0)
107
 
108
  # Validate that prompt is not empty
109
  if not prompt or prompt.strip() == "":
110
  return None, "Please provide a valid prompt."
111
 
112
+ progress.update(0.2, desc="Loading model...")
113
  pipe = get_image_gen_pipeline()
114
  if pipe is None:
115
  return None, "Image generation model is unavailable."
116
 
117
+ progress.update(0.4, desc="Validating dimensions...")
118
  is_valid, error_msg = validate_dimensions(width, height)
119
  if not is_valid:
120
  return None, error_msg
 
122
  if randomize_seed:
123
  seed = random.randint(0, MAX_SEED)
124
 
125
+ progress.update(0.6, desc="Generating image...")
126
  generator = torch.Generator().manual_seed(seed)
127
 
 
128
  with torch.autocast('cuda'):
129
  image = pipe(
130
  prompt=prompt,
 
137
  ).images[0]
138
 
139
  torch.cuda.empty_cache() # Clean up GPU memory after generation
140
+ progress.update(1.0, desc="Done!")
141
  return image, seed
142
  except Exception as e:
143
  return None, f"Error generating image: {str(e)}"