HaohuaLv commited on
Commit
734dc0e
·
1 Parent(s): d28444d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -11
app.py CHANGED
@@ -15,23 +15,22 @@ cosinesimilarity = CosineSimilarity()
15
 
16
 
17
 
18
- def load_candidates(candidate_dir):
19
  def preprocess(examples):
20
- images = [image.convert("RGB") for image in examples["image"]]
21
  examples["image_embedding"] = image_encoder(image_processor(images, return_tensors="pt")["pixel_values"])["pooler_output"]
22
  return examples
23
-
24
- dataset = [dict(image=Image.open(tempfile.name).convert("RGB").resize((224, 224))) for tempfile in candidate_dir]
25
  dataset = Dataset.from_list(dataset)
26
  with torch.no_grad():
27
- dataset = dataset.map(preprocess, batched=True, batch_size=64)
28
-
29
  return dataset
30
 
31
 
32
  def load_candidates_in_cache(candidate_files):
33
  global candidates
34
  candidates = load_candidates(candidate_files)
 
35
 
36
 
37
  def scribble_matching(input_img: Image):
@@ -42,12 +41,11 @@ def scribble_matching(input_img: Image):
42
  image_embeddings = torch.tensor(candidates["image_embedding"], dtype=torch.float32)
43
 
44
 
45
-
46
  sim = cosinesimilarity(scribble_embedding, image_embeddings)
47
 
48
  predicts = torch.topk(sim, k=15)
49
 
50
- output_imgs = candidates[predicts.indices.tolist()]["image"]
51
  labels = predicts.values.tolist()
52
  labels = [f"{label:.3f}" for label in labels]
53
 
@@ -58,16 +56,16 @@ def main():
58
  with gr.Blocks() as demo:
59
  with gr.Row():
60
  input_img = gr.Image(type="pil", label="scribble", height=512, width=512, source="canvas", tool="color-sketch", brush_radius=10)
61
- prediction_gallery = gr.Gallery(min_width=512, columns=4, show_label=True, )
62
 
63
  with gr.Row():
64
  candidate_dir = gr.File(file_count="directory", min_width=300, height=300)
65
  load_candidates_btn = gr.Button("Load", variant="secondary", size="sm")
66
  btn = gr.Button("Scribble Matching", variant="primary")
67
- load_candidates_btn.click(fn=load_candidates_in_cache, inputs=[candidate_dir])
68
  btn.click(fn=scribble_matching, inputs=[input_img], outputs=[prediction_gallery])
69
 
70
- demo.launch(debug=True)
71
 
72
  if __name__ == "__main__":
73
  main()
 
15
 
16
 
17
 
18
+ def load_candidates(candidate_dir, progress=gr.Progress()):
19
  def preprocess(examples):
20
+ images = [image for image in examples["image"]]
21
  examples["image_embedding"] = image_encoder(image_processor(images, return_tensors="pt")["pixel_values"])["pooler_output"]
22
  return examples
23
+ dataset = [dict(image=Image.open(tempfile.name).convert("RGB").resize((224, 224))) for tempfile in progress.tqdm(candidate_dir)]
 
24
  dataset = Dataset.from_list(dataset)
25
  with torch.no_grad():
26
+ dataset = dataset.map(preprocess, batched=True, batch_size=1024)
 
27
  return dataset
28
 
29
 
30
  def load_candidates_in_cache(candidate_files):
31
  global candidates
32
  candidates = load_candidates(candidate_files)
33
+ return [f.name for f in candidate_files]
34
 
35
 
36
  def scribble_matching(input_img: Image):
 
41
  image_embeddings = torch.tensor(candidates["image_embedding"], dtype=torch.float32)
42
 
43
 
 
44
  sim = cosinesimilarity(scribble_embedding, image_embeddings)
45
 
46
  predicts = torch.topk(sim, k=15)
47
 
48
+ output_imgs = candidates["image"][predicts.indices.tolist()]
49
  labels = predicts.values.tolist()
50
  labels = [f"{label:.3f}" for label in labels]
51
 
 
56
  with gr.Blocks() as demo:
57
  with gr.Row():
58
  input_img = gr.Image(type="pil", label="scribble", height=512, width=512, source="canvas", tool="color-sketch", brush_radius=10)
59
+ prediction_gallery = gr.Gallery(min_width=512, columns=4, show_label=True)
60
 
61
  with gr.Row():
62
  candidate_dir = gr.File(file_count="directory", min_width=300, height=300)
63
  load_candidates_btn = gr.Button("Load", variant="secondary", size="sm")
64
  btn = gr.Button("Scribble Matching", variant="primary")
65
+ load_candidates_btn.click(fn=load_candidates_in_cache, inputs=[candidate_dir], outputs=candidate_dir)
66
  btn.click(fn=scribble_matching, inputs=[input_img], outputs=[prediction_gallery])
67
 
68
+ demo.queue().launch()
69
 
70
  if __name__ == "__main__":
71
  main()