HaohuaLv commited on
Commit
9db3d14
·
1 Parent(s): 3639f29

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import ViTModel, ViTImageProcessor
2
+ from PIL import Image, ImageOps
3
+ import gradio as gr
4
+ import torch
5
+ from datasets import Dataset
6
+ from torch.nn import CosineSimilarity
7
+
8
+ image_processor = ViTImageProcessor.from_pretrained("vit-base-patch16-224")
9
+ image_encoder = ViTModel.from_pretrained("output/image_encoder/epoch_29").eval().to("cuda")
10
+ scribble_encoder = ViTModel.from_pretrained("output/scibble_encoder/epoch_29").eval().to("cuda")
11
+
12
+ candidates: Dataset = None
13
+
14
+ cosinesimilarity = CosineSimilarity()
15
+
16
+
17
+
18
+ def load_candidates(candidate_dir):
19
+ def preprocess(examples):
20
+ images = [image.convert("RGB") for image in examples["image"]]
21
+ examples["image_embedding"] = image_encoder(image_processor(images, return_tensors="pt")["pixel_values"].to("cuda"))["pooler_output"]
22
+ return examples
23
+
24
+ dataset = [dict(image=Image.open(tempfile.name).convert("RGB").resize((224, 224))) for tempfile in candidate_dir]
25
+ dataset = Dataset.from_list(dataset)
26
+ with torch.no_grad():
27
+ dataset = dataset.map(preprocess, batched=True, batch_size=1024)
28
+
29
+ return dataset
30
+
31
+
32
+ def load_candidates_in_cache(candidate_files):
33
+ global candidates
34
+ candidates = load_candidates(candidate_files)
35
+
36
+
37
+ def scribble_matching(input_img: Image):
38
+ input_img = ImageOps.invert(input_img)
39
+
40
+ scribble = input_img
41
+ scribble_embedding = scribble_encoder(image_processor(scribble, return_tensors="pt")["pixel_values"].to("cuda"))["pooler_output"].to("cpu")
42
+ image_embeddings = torch.tensor(candidates["image_embedding"], dtype=torch.float32)
43
+
44
+
45
+
46
+ sim = cosinesimilarity(scribble_embedding, image_embeddings)
47
+
48
+ predicts = torch.topk(sim, k=15)
49
+
50
+ output_imgs = candidates[predicts.indices.tolist()]["image"]
51
+ labels = predicts.values.tolist()
52
+ labels = [f"{label:.3f}" for label in labels]
53
+
54
+ return list(zip([input_img] + output_imgs, ["preview"] + labels))
55
+
56
+
57
+ def main():
58
+ with gr.Blocks() as demo:
59
+ with gr.Row():
60
+ input_img = gr.Image(type="pil", label="scribble", height=512, width=512, source="canvas", tool="color-sketch", brush_radius=10)
61
+ prediction_gallery = gr.Gallery(min_width=512, columns=4, show_label=True, )
62
+
63
+ with gr.Row():
64
+ candidate_dir = gr.File(file_count="directory", min_width=300, height=300)
65
+ load_candidates_btn = gr.Button("Load", variant="secondary", size="sm")
66
+ btn = gr.Button("Scribble Matching", variant="primary")
67
+ load_candidates_btn.click(fn=load_candidates_in_cache, inputs=[candidate_dir])
68
+ btn.click(fn=scribble_matching, inputs=[input_img], outputs=[prediction_gallery])
69
+
70
+ demo.launch(debug=True)
71
+
72
+ if __name__ == "__main__":
73
+ main()