Harshithtd commited on
Commit
995307a
·
verified ·
1 Parent(s): 2824ccb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ from PIL import Image
4
+ from sklearn.metrics.pairwise import cosine_similarity
5
+ import timm
6
+ import numpy as np
7
+ import gradio as gr
8
+
9
+ class ImageEmbedder:
10
+ def __init__(self, model_name='vit_base_patch16_224'):
11
+ self.model = timm.create_model(model_name, pretrained=True)
12
+ self.model.head = torch.nn.Identity() # Remove classification head
13
+ self.model.eval()
14
+
15
+ self.transform = transforms.Compose([
16
+ transforms.Resize((224, 224)),
17
+ transforms.ToTensor(),
18
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
19
+ ])
20
+
21
+ def get_embedding(self, image):
22
+ image = image.convert('RGB')
23
+ image_tensor = self.transform(image).unsqueeze(0)
24
+
25
+ with torch.no_grad():
26
+ embedding = self.model(image_tensor)
27
+
28
+ return embedding.squeeze().numpy()
29
+
30
+ def compare_images(image1, image2, similarity_threshold=0.85):
31
+ embedder = ImageEmbedder()
32
+
33
+ # Get embeddings
34
+ embedding1 = embedder.get_embedding(image1)
35
+ embedding2 = embedder.get_embedding(image2)
36
+
37
+ # Calculate similarity
38
+ similarity = cosine_similarity(embedding1.reshape(1, -1), embedding2.reshape(1, -1))[0][0]
39
+
40
+ # Determine if images are similar
41
+ if similarity > similarity_threshold:
42
+ return f"The images are similar. Similarity score: {similarity:.4f}"
43
+ else:
44
+ return f"The images are not similar. Similarity score: {similarity:.4f}"
45
+
46
+ def main(image1, image2):
47
+ return compare_images(image1, image2)
48
+
49
+ iface = gr.Interface(
50
+ fn=main,
51
+ inputs=[gr.inputs.Image(type="pil"), gr.inputs.Image(type="pil")],
52
+ outputs="text",
53
+ title="Image Similarity Checker",
54
+ description="Upload two images to check their similarity based on embeddings."
55
+ )
56
+
57
+ iface.launch()