netrosec commited on
Commit
d00dc43
·
verified ·
1 Parent(s): 182e985

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ from diffusers import StableDiffusionInpaintPipeline
5
+ from PIL import Image
6
+ from segment_anything import SamPredictor, sam_model_registry
7
+
8
+ device = "cuda"
9
+ sam_checkpoint = "/home/jupyter/diffusers/examples/sam_vit_h_4b8939.pth" # Added missing forward slash at the beginning
10
+ model_type = "vit_h"
11
+ # Load the model using the function from the registry and pass the checkpoint path
12
+ model_fn = sam_model_registry[model_type]
13
+ model = model_fn(checkpoint=sam_checkpoint)
14
+
15
+ # Move the model to the desired device (GPU)
16
+ model.to(device)
17
+
18
+ predictor = SamPredictor(model)
19
+
20
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
21
+ "stabilityai/stable-diffusion-2-inpainting",
22
+ torch_dtype=torch.float16,
23
+ ) # Removed space
24
+
25
+ pipe = pipe.to(device)
26
+
27
+ selected_pixels = []
28
+
29
+
30
+ with gr.Blocks() as demo:
31
+ with gr.Row():
32
+ input_img = gr.Image(label="Input") # Removed space
33
+ mask_img = gr.Image(label="Mask") # Corrected "Mas" to "Mask"
34
+ output_img = gr.Image(label="Output") # Removed space
35
+
36
+ with gr.Row():
37
+ prompt_text = gr.Textbox(lines=1, label="Prompt") # Removed space
38
+
39
+ with gr.Row():
40
+ submit = gr.Button("Submit")
41
+
42
+ def generate_mask(image, evt: gr.SelectData):
43
+ selected_pixels.append(evt.index) # Removed space
44
+
45
+ predictor.set_image(image) # Removed space
46
+ input_points = np.array(selected_pixels)
47
+ input_labels = np.ones(input_points.shape[0])
48
+ mask, _, _ = predictor.predict(
49
+ point_coords=input_points,
50
+ point_labels=input_labels,
51
+ multimask_output=False
52
+ )
53
+ # (n, sz, sz)
54
+ mask = Image.fromarray(mask[0, :, :]) # Removed space
55
+ return mask
56
+
57
+ def inpaint(image, mask, prompt):
58
+ image = Image.fromarray(image) # Removed space
59
+ mask = Image.fromarray(mask) # Removed space
60
+
61
+ image = image.resize((512, 512))
62
+ mask = mask.resize((512, 512))
63
+
64
+ output = pipe(
65
+ prompt=prompt,
66
+ image=image,
67
+ mask_image=mask,
68
+ ).images[0]
69
+
70
+ return output
71
+
72
+ input_img.select(generate_mask, [input_img], [mask_img])
73
+ submit.click(inpaint, inputs=[input_img, mask_img, prompt_text], outputs=[output_img])
74
+
75
+ if __name__ == "__main__":
76
+ demo.launch()