Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -6,8 +6,8 @@ from datasets import Dataset
|
|
6 |
from torch.nn import CosineSimilarity
|
7 |
|
8 |
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
|
9 |
-
image_encoder = ViTModel.from_pretrained("model/image_encoder/epoch_29").eval()
|
10 |
-
scribble_encoder = ViTModel.from_pretrained("model/scibble_encoder/epoch_29").eval()
|
11 |
|
12 |
candidates: Dataset = None
|
13 |
|
@@ -18,7 +18,7 @@ cosinesimilarity = CosineSimilarity()
|
|
18 |
def load_candidates(candidate_dir):
|
19 |
def preprocess(examples):
|
20 |
images = [image.convert("RGB") for image in examples["image"]]
|
21 |
-
examples["image_embedding"] = image_encoder(image_processor(images, return_tensors="pt")["pixel_values"]
|
22 |
return examples
|
23 |
|
24 |
dataset = [dict(image=Image.open(tempfile.name).convert("RGB").resize((224, 224))) for tempfile in candidate_dir]
|
@@ -38,7 +38,7 @@ def scribble_matching(input_img: Image):
|
|
38 |
input_img = ImageOps.invert(input_img)
|
39 |
|
40 |
scribble = input_img
|
41 |
-
scribble_embedding = scribble_encoder(image_processor(scribble, return_tensors="pt")["pixel_values"]
|
42 |
image_embeddings = torch.tensor(candidates["image_embedding"], dtype=torch.float32)
|
43 |
|
44 |
|
|
|
6 |
from torch.nn import CosineSimilarity
|
7 |
|
8 |
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
|
9 |
+
image_encoder = ViTModel.from_pretrained("model/image_encoder/epoch_29").eval()
|
10 |
+
scribble_encoder = ViTModel.from_pretrained("model/scibble_encoder/epoch_29").eval()
|
11 |
|
12 |
candidates: Dataset = None
|
13 |
|
|
|
18 |
def load_candidates(candidate_dir):
|
19 |
def preprocess(examples):
|
20 |
images = [image.convert("RGB") for image in examples["image"]]
|
21 |
+
examples["image_embedding"] = image_encoder(image_processor(images, return_tensors="pt")["pixel_values"])["pooler_output"]
|
22 |
return examples
|
23 |
|
24 |
dataset = [dict(image=Image.open(tempfile.name).convert("RGB").resize((224, 224))) for tempfile in candidate_dir]
|
|
|
38 |
input_img = ImageOps.invert(input_img)
|
39 |
|
40 |
scribble = input_img
|
41 |
+
scribble_embedding = scribble_encoder(image_processor(scribble, return_tensors="pt")["pixel_values"])["pooler_output"].to("cpu")
|
42 |
image_embeddings = torch.tensor(candidates["image_embedding"], dtype=torch.float32)
|
43 |
|
44 |
|