Gemini899 commited on
Commit
93adb8b
·
verified ·
1 Parent(s): 5416f76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -68
app.py CHANGED
@@ -2,76 +2,59 @@ import spaces
2
  import gradio as gr
3
  import re
4
  from PIL import Image
5
-
6
  import os
7
  import numpy as np
8
  import torch
9
- from diffusers import FluxImg2ImgPipeline
10
 
11
- # Use default float32 precision for CPU
12
- dtype = torch.float32
13
- device = "cpu"
14
 
15
- # Load the pipeline on CPU
16
- pipe = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
 
17
 
18
  def sanitize_prompt(prompt):
19
- # Allow only alphanumeric characters, spaces, and basic punctuation
20
  allowed_chars = re.compile(r"[^a-zA-Z0-9\s.,!?-]")
21
- sanitized_prompt = allowed_chars.sub("", prompt)
22
- return sanitized_prompt
23
 
24
  def convert_to_fit_size(original_width_and_height, maximum_size=2048):
25
  width, height = original_width_and_height
26
  if width <= maximum_size and height <= maximum_size:
27
  return width, height
28
-
29
- if width > height:
30
- scaling_factor = maximum_size / width
31
- else:
32
- scaling_factor = maximum_size / height
33
-
34
- new_width = int(width * scaling_factor)
35
- new_height = int(height * scaling_factor)
36
- return new_width, new_height
37
 
38
  def adjust_to_multiple_of_32(width: int, height: int):
39
- width = width - (width % 32)
40
- height = height - (height % 32)
41
- return width, height
42
-
43
- @spaces.CPU(duration=120)
44
- def process_images(image, prompt="a girl", strength=0.75, seed=0, inference_step=4, progress=gr.Progress(track_tqdm=True)):
45
- progress(0, desc="Starting")
46
 
47
- def process_img2img(image, prompt="a person", strength=0.75, seed=0, num_inference_steps=4):
 
 
 
 
48
  if image is None:
49
- print("empty input image returned")
50
  return None
51
- # Create a CPU generator
52
- generator = torch.Generator("cpu").manual_seed(seed)
53
  fit_width, fit_height = convert_to_fit_size(image.size)
54
  width, height = adjust_to_multiple_of_32(fit_width, fit_height)
55
  image = image.resize((width, height), Image.LANCZOS)
56
 
57
  output = pipe(
58
  prompt=prompt,
59
- image=image,
60
  generator=generator,
61
  strength=strength,
62
- width=width,
63
- height=height,
64
- guidance_scale=0,
65
  num_inference_steps=num_inference_steps,
66
- max_sequence_length=256
67
  )
68
 
69
  pil_image = output.images[0]
70
- new_width, new_height = pil_image.size
71
-
72
- if (new_width != fit_width) or (new_height != fit_height):
73
- resized_image = pil_image.resize((fit_width, fit_height), Image.LANCZOS)
74
- return resized_image
75
  return pil_image
76
 
77
  output = process_img2img(image, prompt, strength, seed, inference_step)
@@ -79,8 +62,7 @@ def process_images(image, prompt="a girl", strength=0.75, seed=0, inference_step
79
 
80
  def read_file(path: str) -> str:
81
  with open(path, 'r', encoding='utf-8') as f:
82
- content = f.read()
83
- return content
84
 
85
  css = """
86
  #col-left {
@@ -95,15 +77,13 @@ css = """
95
  display: flex;
96
  align-items: center;
97
  justify-content: center;
98
- gap:10px
99
  }
100
-
101
  .image {
102
  width: 128px;
103
  height: 128px;
104
  object-fit: cover;
105
  }
106
-
107
  .text {
108
  font-size: 16px;
109
  }
@@ -115,40 +95,26 @@ with gr.Blocks(css=css, elem_id="demo-container") as demo:
115
  gr.HTML(read_file("demo_tools.html"))
116
  with gr.Row():
117
  with gr.Column():
118
- image = gr.Image(
119
- height=800,
120
- sources=['upload','clipboard'],
121
- image_mode='RGB',
122
- elem_id="image_upload",
123
- type="pil",
124
- label="Upload"
125
- )
126
  with gr.Row(elem_id="prompt-container", equal_height=False):
127
  with gr.Row():
128
- prompt = gr.Textbox(
129
- label="Prompt",
130
- value="a women",
131
- placeholder="Your prompt (what you want in place of what is erased)",
132
- elem_id="prompt"
133
- )
134
  btn = gr.Button("Img2Img", elem_id="run_button", variant="primary")
135
-
136
  with gr.Accordion(label="Advanced Settings", open=False):
137
  with gr.Row(equal_height=True):
138
- strength = gr.Number(value=0.75, minimum=0, maximum=0.75, step=0.01, label="strength")
139
- seed = gr.Number(value=100, minimum=0, step=1, label="seed")
140
- inference_step = gr.Number(value=4, minimum=1, step=4, label="inference_step")
141
  id_input = gr.Text(label="Name", visible=False)
142
-
143
  with gr.Column():
144
  image_out = gr.Image(height=800, sources=[], label="Output", elem_id="output-img", format="jpg")
145
 
146
  gr.Examples(
147
  examples=[
148
- ["examples/draw_input.jpg", "examples/draw_output.jpg", "a women ,eyes closed,mouth opened"],
149
- ["examples/draw-gimp_input.jpg", "examples/draw-gimp_output.jpg", "a women ,eyes closed,mouth opened"],
150
- ["examples/gimp_input.jpg", "examples/gimp_output.jpg", "a women ,hand on neck"],
151
- ["examples/inpaint_input.jpg", "examples/inpaint_output.jpg", "a women ,hand on neck"]
152
  ],
153
  inputs=[image, image_out, prompt],
154
  )
 
2
  import gradio as gr
3
  import re
4
  from PIL import Image
 
5
  import os
6
  import numpy as np
7
  import torch
8
+ from diffusers import StableDiffusionImg2ImgPipeline
9
 
10
+ # Use float16 for lower VRAM usage
11
+ dtype = torch.float16
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
+ # Load the lighter model on the GPU
15
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=dtype)
16
+ pipe.to(device)
17
 
18
  def sanitize_prompt(prompt):
 
19
  allowed_chars = re.compile(r"[^a-zA-Z0-9\s.,!?-]")
20
+ return allowed_chars.sub("", prompt)
 
21
 
22
  def convert_to_fit_size(original_width_and_height, maximum_size=2048):
23
  width, height = original_width_and_height
24
  if width <= maximum_size and height <= maximum_size:
25
  return width, height
26
+ scaling_factor = maximum_size / max(width, height)
27
+ return int(width * scaling_factor), int(height * scaling_factor)
 
 
 
 
 
 
 
28
 
29
  def adjust_to_multiple_of_32(width: int, height: int):
30
+ return width - (width % 32), height - (height % 32)
 
 
 
 
 
 
31
 
32
+ @spaces.GPU(duration=120)
33
+ def process_images(image, prompt="a woman", strength=0.75, seed=0, inference_step=50, progress=gr.Progress(track_tqdm=True)):
34
+ progress(0, desc="Starting processing")
35
+
36
+ def process_img2img(image, prompt="a person", strength=0.75, seed=0, num_inference_steps=50):
37
  if image is None:
38
+ print("Empty input image returned")
39
  return None
40
+ generator = torch.Generator(device).manual_seed(seed)
 
41
  fit_width, fit_height = convert_to_fit_size(image.size)
42
  width, height = adjust_to_multiple_of_32(fit_width, fit_height)
43
  image = image.resize((width, height), Image.LANCZOS)
44
 
45
  output = pipe(
46
  prompt=prompt,
47
+ init_image=image,
48
  generator=generator,
49
  strength=strength,
50
+ guidance_scale=7.5,
 
 
51
  num_inference_steps=num_inference_steps,
 
52
  )
53
 
54
  pil_image = output.images[0]
55
+ # Optionally, resize back to original fitted dimensions if desired
56
+ if pil_image.size != (fit_width, fit_height):
57
+ pil_image = pil_image.resize((fit_width, fit_height), Image.LANCZOS)
 
 
58
  return pil_image
59
 
60
  output = process_img2img(image, prompt, strength, seed, inference_step)
 
62
 
63
  def read_file(path: str) -> str:
64
  with open(path, 'r', encoding='utf-8') as f:
65
+ return f.read()
 
66
 
67
  css = """
68
  #col-left {
 
77
  display: flex;
78
  align-items: center;
79
  justify-content: center;
80
+ gap: 10px;
81
  }
 
82
  .image {
83
  width: 128px;
84
  height: 128px;
85
  object-fit: cover;
86
  }
 
87
  .text {
88
  font-size: 16px;
89
  }
 
95
  gr.HTML(read_file("demo_tools.html"))
96
  with gr.Row():
97
  with gr.Column():
98
+ image = gr.Image(height=800, sources=['upload','clipboard'], image_mode='RGB', elem_id="image_upload", type="pil", label="Upload")
 
 
 
 
 
 
 
99
  with gr.Row(elem_id="prompt-container", equal_height=False):
100
  with gr.Row():
101
+ prompt = gr.Textbox(label="Prompt", value="a woman", placeholder="Your prompt", elem_id="prompt")
 
 
 
 
 
102
  btn = gr.Button("Img2Img", elem_id="run_button", variant="primary")
 
103
  with gr.Accordion(label="Advanced Settings", open=False):
104
  with gr.Row(equal_height=True):
105
+ strength = gr.Number(value=0.75, minimum=0, maximum=1.0, step=0.01, label="Strength")
106
+ seed = gr.Number(value=100, minimum=0, step=1, label="Seed")
107
+ inference_step = gr.Number(value=50, minimum=1, step=1, label="Inference Steps")
108
  id_input = gr.Text(label="Name", visible=False)
 
109
  with gr.Column():
110
  image_out = gr.Image(height=800, sources=[], label="Output", elem_id="output-img", format="jpg")
111
 
112
  gr.Examples(
113
  examples=[
114
+ ["examples/draw_input.jpg", "examples/draw_output.jpg", "a woman with blue eyes"],
115
+ ["examples/draw-gimp_input.jpg", "examples/draw-gimp_output.jpg", "a woman with a serene expression"],
116
+ ["examples/gimp_input.jpg", "examples/gimp_output.jpg", "a woman in a garden"],
117
+ ["examples/inpaint_input.jpg", "examples/inpaint_output.jpg", "a woman in a futuristic city"]
118
  ],
119
  inputs=[image, image_out, prompt],
120
  )