HaohuaLv commited on
Commit
d28444d
·
1 Parent(s): d708f3f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -6,8 +6,8 @@ from datasets import Dataset
6
  from torch.nn import CosineSimilarity
7
 
8
  image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
9
- image_encoder = ViTModel.from_pretrained("model/image_encoder/epoch_29").eval().to("cuda")
10
- scribble_encoder = ViTModel.from_pretrained("model/scibble_encoder/epoch_29").eval().to("cuda")
11
 
12
  candidates: Dataset = None
13
 
@@ -18,7 +18,7 @@ cosinesimilarity = CosineSimilarity()
18
  def load_candidates(candidate_dir):
19
  def preprocess(examples):
20
  images = [image.convert("RGB") for image in examples["image"]]
21
- examples["image_embedding"] = image_encoder(image_processor(images, return_tensors="pt")["pixel_values"].to("cuda"))["pooler_output"]
22
  return examples
23
 
24
  dataset = [dict(image=Image.open(tempfile.name).convert("RGB").resize((224, 224))) for tempfile in candidate_dir]
@@ -38,7 +38,7 @@ def scribble_matching(input_img: Image):
38
  input_img = ImageOps.invert(input_img)
39
 
40
  scribble = input_img
41
- scribble_embedding = scribble_encoder(image_processor(scribble, return_tensors="pt")["pixel_values"].to("cuda"))["pooler_output"].to("cpu")
42
  image_embeddings = torch.tensor(candidates["image_embedding"], dtype=torch.float32)
43
 
44
 
 
6
  from torch.nn import CosineSimilarity
7
 
8
  image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
9
+ image_encoder = ViTModel.from_pretrained("model/image_encoder/epoch_29").eval()
10
+ scribble_encoder = ViTModel.from_pretrained("model/scibble_encoder/epoch_29").eval()
11
 
12
  candidates: Dataset = None
13
 
 
18
  def load_candidates(candidate_dir):
19
  def preprocess(examples):
20
  images = [image.convert("RGB") for image in examples["image"]]
21
+ examples["image_embedding"] = image_encoder(image_processor(images, return_tensors="pt")["pixel_values"])["pooler_output"]
22
  return examples
23
 
24
  dataset = [dict(image=Image.open(tempfile.name).convert("RGB").resize((224, 224))) for tempfile in candidate_dir]
 
38
  input_img = ImageOps.invert(input_img)
39
 
40
  scribble = input_img
41
+ scribble_embedding = scribble_encoder(image_processor(scribble, return_tensors="pt")["pixel_values"])["pooler_output"].to("cpu")
42
  image_embeddings = torch.tensor(candidates["image_embedding"], dtype=torch.float32)
43
 
44