gaur3009 commited on
Commit
7a99816
·
verified ·
1 Parent(s): 0b9bc1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -31
app.py CHANGED
@@ -6,73 +6,71 @@ import gradio as gr
6
  from src.image_prep import canny_from_pil
7
  from src.pix2pix_turbo import Pix2Pix_Turbo
8
 
 
9
  model = Pix2Pix_Turbo("edge_to_image")
10
 
11
-
12
  def process(input_image, prompt, low_threshold, high_threshold):
13
- # resize to be a multiple of 8
14
  new_width = input_image.width - input_image.width % 8
15
  new_height = input_image.height - input_image.height % 8
16
  input_image = input_image.resize((new_width, new_height))
 
 
17
  canny = canny_from_pil(input_image, low_threshold, high_threshold)
 
 
18
  with torch.no_grad():
19
  c_t = transforms.ToTensor()(canny).unsqueeze(0).cuda()
20
  output_image = model(c_t, prompt)
21
  output_pil = transforms.ToPILImage()(output_image[0].cpu() * 0.5 + 0.5)
22
- # flippy canny values, map all 0s to 1s and 1s to 0s
 
23
  canny_viz = 1 - (np.array(canny) / 255)
24
  canny_viz = Image.fromarray((canny_viz * 255).astype(np.uint8))
 
25
  return canny_viz, output_pil
26
 
27
-
28
  if __name__ == "__main__":
29
- # load the model
30
  with gr.Blocks() as demo:
31
  gr.Markdown("# Pix2pix-Turbo: **Canny Edge -> Image**")
 
32
  with gr.Row():
33
  with gr.Column():
34
- input_image = gr.Image(sources="upload", type="pil")
35
  prompt = gr.Textbox(label="Prompt")
36
  low_threshold = gr.Slider(
37
  label="Canny low threshold",
38
  minimum=1,
39
  maximum=255,
40
  value=100,
41
- step=10,
42
  )
43
  high_threshold = gr.Slider(
44
  label="Canny high threshold",
45
  minimum=1,
46
  maximum=255,
47
  value=200,
48
- step=10,
49
  )
50
  run_button = gr.Button(value="Run")
 
51
  with gr.Column():
52
  result_canny = gr.Image(type="pil")
 
53
  with gr.Column():
54
  result_output = gr.Image(type="pil")
55
-
56
- prompt.submit(
57
- fn=process,
58
- inputs=[input_image, prompt, low_threshold, high_threshold],
59
- outputs=[result_canny, result_output],
60
- )
61
- low_threshold.change(
62
- fn=process,
63
- inputs=[input_image, prompt, low_threshold, high_threshold],
64
- outputs=[result_canny, result_output],
65
- )
66
- high_threshold.change(
67
- fn=process,
68
- inputs=[input_image, prompt, low_threshold, high_threshold],
69
- outputs=[result_canny, result_output],
70
- )
71
- run_button.click(
72
- fn=process,
73
- inputs=[input_image, prompt, low_threshold, high_threshold],
74
- outputs=[result_canny, result_output],
75
- )
76
-
77
  demo.queue()
78
- demo.launch(debug=True, share=False)
 
6
  from src.image_prep import canny_from_pil
7
  from src.pix2pix_turbo import Pix2Pix_Turbo
8
 
9
+ # Initialize the model
10
  model = Pix2Pix_Turbo("edge_to_image")
11
 
12
+ # Define the processing function
13
  def process(input_image, prompt, low_threshold, high_threshold):
14
+ # Resize to be a multiple of 8
15
  new_width = input_image.width - input_image.width % 8
16
  new_height = input_image.height - input_image.height % 8
17
  input_image = input_image.resize((new_width, new_height))
18
+
19
+ # Generate canny edge image
20
  canny = canny_from_pil(input_image, low_threshold, high_threshold)
21
+
22
+ # Convert to tensor and process with model
23
  with torch.no_grad():
24
  c_t = transforms.ToTensor()(canny).unsqueeze(0).cuda()
25
  output_image = model(c_t, prompt)
26
  output_pil = transforms.ToPILImage()(output_image[0].cpu() * 0.5 + 0.5)
27
+
28
+ # Visualize canny edges (invert colors)
29
  canny_viz = 1 - (np.array(canny) / 255)
30
  canny_viz = Image.fromarray((canny_viz * 255).astype(np.uint8))
31
+
32
  return canny_viz, output_pil
33
 
 
34
  if __name__ == "__main__":
35
+ # Create the Gradio interface
36
  with gr.Blocks() as demo:
37
  gr.Markdown("# Pix2pix-Turbo: **Canny Edge -> Image**")
38
+
39
  with gr.Row():
40
  with gr.Column():
41
+ input_image = gr.Image(source="upload", type="pil")
42
  prompt = gr.Textbox(label="Prompt")
43
  low_threshold = gr.Slider(
44
  label="Canny low threshold",
45
  minimum=1,
46
  maximum=255,
47
  value=100,
48
+ step=10
49
  )
50
  high_threshold = gr.Slider(
51
  label="Canny high threshold",
52
  minimum=1,
53
  maximum=255,
54
  value=200,
55
+ step=10
56
  )
57
  run_button = gr.Button(value="Run")
58
+
59
  with gr.Column():
60
  result_canny = gr.Image(type="pil")
61
+
62
  with gr.Column():
63
  result_output = gr.Image(type="pil")
64
+
65
+ # Set up event handlers
66
+ inputs = [input_image, prompt, low_threshold, high_threshold]
67
+ outputs = [result_canny, result_output]
68
+
69
+ prompt.submit(fn=process, inputs=inputs, outputs=outputs)
70
+ low_threshold.change(fn=process, inputs=inputs, outputs=outputs)
71
+ high_threshold.change(fn=process, inputs=inputs, outputs=outputs)
72
+ run_button.click(fn=process, inputs=inputs, outputs=outputs)
73
+
74
+ # Launch the Gradio interface
 
 
 
 
 
 
 
 
 
 
 
75
  demo.queue()
76
+ demo.launch(debug=True, share=False)