Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -15,23 +15,22 @@ cosinesimilarity = CosineSimilarity()
|
|
15 |
|
16 |
|
17 |
|
18 |
-
def load_candidates(candidate_dir):
|
19 |
def preprocess(examples):
|
20 |
-
images = [image
|
21 |
examples["image_embedding"] = image_encoder(image_processor(images, return_tensors="pt")["pixel_values"])["pooler_output"]
|
22 |
return examples
|
23 |
-
|
24 |
-
dataset = [dict(image=Image.open(tempfile.name).convert("RGB").resize((224, 224))) for tempfile in candidate_dir]
|
25 |
dataset = Dataset.from_list(dataset)
|
26 |
with torch.no_grad():
|
27 |
-
dataset = dataset.map(preprocess, batched=True, batch_size=
|
28 |
-
|
29 |
return dataset
|
30 |
|
31 |
|
32 |
def load_candidates_in_cache(candidate_files):
|
33 |
global candidates
|
34 |
candidates = load_candidates(candidate_files)
|
|
|
35 |
|
36 |
|
37 |
def scribble_matching(input_img: Image):
|
@@ -42,12 +41,11 @@ def scribble_matching(input_img: Image):
|
|
42 |
image_embeddings = torch.tensor(candidates["image_embedding"], dtype=torch.float32)
|
43 |
|
44 |
|
45 |
-
|
46 |
sim = cosinesimilarity(scribble_embedding, image_embeddings)
|
47 |
|
48 |
predicts = torch.topk(sim, k=15)
|
49 |
|
50 |
-
output_imgs = candidates[predicts.indices.tolist()]
|
51 |
labels = predicts.values.tolist()
|
52 |
labels = [f"{label:.3f}" for label in labels]
|
53 |
|
@@ -58,16 +56,16 @@ def main():
|
|
58 |
with gr.Blocks() as demo:
|
59 |
with gr.Row():
|
60 |
input_img = gr.Image(type="pil", label="scribble", height=512, width=512, source="canvas", tool="color-sketch", brush_radius=10)
|
61 |
-
prediction_gallery = gr.Gallery(min_width=512, columns=4, show_label=True
|
62 |
|
63 |
with gr.Row():
|
64 |
candidate_dir = gr.File(file_count="directory", min_width=300, height=300)
|
65 |
load_candidates_btn = gr.Button("Load", variant="secondary", size="sm")
|
66 |
btn = gr.Button("Scribble Matching", variant="primary")
|
67 |
-
load_candidates_btn.click(fn=load_candidates_in_cache, inputs=[candidate_dir])
|
68 |
btn.click(fn=scribble_matching, inputs=[input_img], outputs=[prediction_gallery])
|
69 |
|
70 |
-
demo.launch(
|
71 |
|
72 |
if __name__ == "__main__":
|
73 |
main()
|
|
|
15 |
|
16 |
|
17 |
|
18 |
+
def load_candidates(candidate_dir, progress=gr.Progress()):
|
19 |
def preprocess(examples):
|
20 |
+
images = [image for image in examples["image"]]
|
21 |
examples["image_embedding"] = image_encoder(image_processor(images, return_tensors="pt")["pixel_values"])["pooler_output"]
|
22 |
return examples
|
23 |
+
dataset = [dict(image=Image.open(tempfile.name).convert("RGB").resize((224, 224))) for tempfile in progress.tqdm(candidate_dir)]
|
|
|
24 |
dataset = Dataset.from_list(dataset)
|
25 |
with torch.no_grad():
|
26 |
+
dataset = dataset.map(preprocess, batched=True, batch_size=1024)
|
|
|
27 |
return dataset
|
28 |
|
29 |
|
30 |
def load_candidates_in_cache(candidate_files):
|
31 |
global candidates
|
32 |
candidates = load_candidates(candidate_files)
|
33 |
+
return [f.name for f in candidate_files]
|
34 |
|
35 |
|
36 |
def scribble_matching(input_img: Image):
|
|
|
41 |
image_embeddings = torch.tensor(candidates["image_embedding"], dtype=torch.float32)
|
42 |
|
43 |
|
|
|
44 |
sim = cosinesimilarity(scribble_embedding, image_embeddings)
|
45 |
|
46 |
predicts = torch.topk(sim, k=15)
|
47 |
|
48 |
+
output_imgs = candidates["image"][predicts.indices.tolist()]
|
49 |
labels = predicts.values.tolist()
|
50 |
labels = [f"{label:.3f}" for label in labels]
|
51 |
|
|
|
56 |
with gr.Blocks() as demo:
|
57 |
with gr.Row():
|
58 |
input_img = gr.Image(type="pil", label="scribble", height=512, width=512, source="canvas", tool="color-sketch", brush_radius=10)
|
59 |
+
prediction_gallery = gr.Gallery(min_width=512, columns=4, show_label=True)
|
60 |
|
61 |
with gr.Row():
|
62 |
candidate_dir = gr.File(file_count="directory", min_width=300, height=300)
|
63 |
load_candidates_btn = gr.Button("Load", variant="secondary", size="sm")
|
64 |
btn = gr.Button("Scribble Matching", variant="primary")
|
65 |
+
load_candidates_btn.click(fn=load_candidates_in_cache, inputs=[candidate_dir], outputs=candidate_dir)
|
66 |
btn.click(fn=scribble_matching, inputs=[input_img], outputs=[prediction_gallery])
|
67 |
|
68 |
+
demo.queue().launch()
|
69 |
|
70 |
if __name__ == "__main__":
|
71 |
main()
|