naonauno commited on
Commit
96bb4b8
·
verified ·
1 Parent(s): 097fdcb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -25
app.py CHANGED
@@ -45,10 +45,13 @@ pipe.load_lora_weights(
45
  weight_name="40kHalf.safetensors"
46
  )
47
 
 
 
 
48
  def get_random_condition_image():
49
  conditions_dir = Path("conditions")
50
  if conditions_dir.exists():
51
- image_files = list(conditions_dir.glob("*.[jp][pn][g]")) # matches .jpg, .png, .jpeg
52
  if image_files:
53
  random_image = random.choice(image_files)
54
  return str(random_image)
@@ -57,38 +60,65 @@ def get_random_condition_image():
57
  def get_canny_image(image, low_threshold=100, high_threshold=200):
58
  if isinstance(image, Image.Image):
59
  image = np.array(image)
 
 
60
 
61
- if image.shape[2] == 4:
 
 
62
  image = image[..., :3]
63
 
64
  canny_image = cv2.Canny(image, low_threshold, high_threshold)
65
  canny_image = np.stack([canny_image] * 3, axis=-1)
66
  return Image.fromarray(canny_image)
67
 
68
- @spaces.GPU(duration=120)
69
- def generate_image(input_image, prompt, negative_prompt, guidance_scale, steps, low_threshold, high_threshold, seed):
70
- if seed is not None and seed != "":
71
- try:
72
- generator = torch.Generator().manual_seed(int(seed))
73
- except ValueError:
 
 
 
 
 
 
74
  generator = torch.Generator()
75
- else:
76
- generator = torch.Generator()
77
 
78
- canny_image = get_canny_image(input_image, low_threshold, high_threshold)
79
-
80
- with torch.no_grad():
81
- image = pipe(
82
- prompt=prompt,
83
- negative_prompt=negative_prompt,
84
- num_inference_steps=steps,
85
- guidance_scale=guidance_scale,
86
- image=canny_image,
87
- extra_condition_scale=1.0,
88
- generator=generator
89
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- return canny_image, image
 
 
 
 
92
 
93
  def random_image_click():
94
  image_path = get_random_condition_image()
@@ -99,7 +129,7 @@ def random_image_click():
99
  # Example data
100
  examples = [
101
  [
102
- "conditions/example1.jpg", # Replace with actual paths
103
  "a futuristic cyberpunk city",
104
  "blurry, bad quality",
105
  7.5,
@@ -109,7 +139,7 @@ examples = [
109
  42
110
  ],
111
  [
112
- "conditions/example2.jpg", # Replace with actual paths
113
  "a serene mountain landscape",
114
  "dark, gloomy",
115
  7.0,
@@ -134,6 +164,7 @@ with gr.Blocks() as demo:
134
  with gr.Column():
135
  input_image = gr.Image(label="Input Image", type="numpy")
136
  random_image_btn = gr.Button("Load Random Reference Image")
 
137
 
138
  prompt = gr.Textbox(
139
  label="Prompt",
@@ -195,4 +226,5 @@ with gr.Blocks() as demo:
195
  outputs=[canny_output, result]
196
  )
197
 
 
198
  demo.launch()
 
45
  weight_name="40kHalf.safetensors"
46
  )
47
 
48
+ # Move to CPU initially
49
+ pipe = pipe.to("cpu")
50
+
51
  def get_random_condition_image():
52
  conditions_dir = Path("conditions")
53
  if conditions_dir.exists():
54
+ image_files = list(conditions_dir.glob("*.[jp][pn][g]"))
55
  if image_files:
56
  random_image = random.choice(image_files)
57
  return str(random_image)
 
60
  def get_canny_image(image, low_threshold=100, high_threshold=200):
61
  if isinstance(image, Image.Image):
62
  image = np.array(image)
63
+ elif isinstance(image, str):
64
+ image = np.array(Image.open(image))
65
 
66
+ if len(image.shape) == 2: # If grayscale
67
+ image = np.stack([image] * 3, axis=-1)
68
+ elif image.shape[2] == 4: # If RGBA
69
  image = image[..., :3]
70
 
71
  canny_image = cv2.Canny(image, low_threshold, high_threshold)
72
  canny_image = np.stack([canny_image] * 3, axis=-1)
73
  return Image.fromarray(canny_image)
74
 
75
+ @spaces.GPU(duration=300) # Increased duration to 5 minutes
76
+ def generate_image(input_image, prompt, negative_prompt, guidance_scale, steps, low_threshold, high_threshold, seed, progress=gr.Progress()):
77
+ if input_image is None:
78
+ raise gr.Error("Please provide an input image!")
79
+
80
+ try:
81
+ if seed is not None and seed != "":
82
+ try:
83
+ generator = torch.Generator().manual_seed(int(seed))
84
+ except ValueError:
85
+ generator = torch.Generator()
86
+ else:
87
  generator = torch.Generator()
 
 
88
 
89
+ progress(0.1, desc="Processing input image...")
90
+ canny_image = get_canny_image(input_image, low_threshold, high_threshold)
91
+
92
+ progress(0.2, desc="Moving model to device...")
93
+ # Move pipeline to GPU
94
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
95
+ pipe.to(device)
96
+
97
+ progress(0.3, desc="Generating image...")
98
+ with torch.no_grad():
99
+ image = pipe(
100
+ prompt=prompt,
101
+ negative_prompt=negative_prompt,
102
+ num_inference_steps=steps,
103
+ guidance_scale=guidance_scale,
104
+ image=canny_image,
105
+ extra_condition_scale=1.0,
106
+ generator=generator
107
+ ).images[0]
108
+
109
+ progress(0.9, desc="Moving model back to CPU...")
110
+ # Move back to CPU to free up GPU memory
111
+ pipe.to("cpu")
112
+ torch.cuda.empty_cache()
113
+
114
+ progress(1.0, desc="Done!")
115
+ return canny_image, image
116
 
117
+ except Exception as e:
118
+ # Move back to CPU in case of error
119
+ pipe.to("cpu")
120
+ torch.cuda.empty_cache()
121
+ raise gr.Error(f"An error occurred: {str(e)}")
122
 
123
  def random_image_click():
124
  image_path = get_random_condition_image()
 
129
  # Example data
130
  examples = [
131
  [
132
+ "conditions/example1.jpg",
133
  "a futuristic cyberpunk city",
134
  "blurry, bad quality",
135
  7.5,
 
139
  42
140
  ],
141
  [
142
+ "conditions/example2.jpg",
143
  "a serene mountain landscape",
144
  "dark, gloomy",
145
  7.0,
 
164
  with gr.Column():
165
  input_image = gr.Image(label="Input Image", type="numpy")
166
  random_image_btn = gr.Button("Load Random Reference Image")
167
+ status_text = gr.Textbox(label="Status", value="Ready", interactive=False)
168
 
169
  prompt = gr.Textbox(
170
  label="Prompt",
 
226
  outputs=[canny_output, result]
227
  )
228
 
229
+ demo.queue() # Enable queuing for better handling of concurrent requests
230
  demo.launch()