import torch import torchvision.transforms as transforms from PIL import Image from sklearn.metrics.pairwise import cosine_similarity import timm import numpy as np import gradio as gr class ImageEmbedder: def __init__(self, model_name='vit_base_patch16_224'): self.model = timm.create_model(model_name, pretrained=True) self.model.head = torch.nn.Identity() # Remove classification head self.model.eval() self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def get_embedding(self, image): image = image.convert('RGB') image_tensor = self.transform(image).unsqueeze(0) with torch.no_grad(): embedding = self.model(image_tensor) return embedding.squeeze().numpy() def compare_images(image1, image2, similarity_threshold=0.85): embedder = ImageEmbedder() # Get embeddings embedding1 = embedder.get_embedding(image1) embedding2 = embedder.get_embedding(image2) # Calculate similarity similarity = cosine_similarity(embedding1.reshape(1, -1), embedding2.reshape(1, -1))[0][0] # Determine if images are similar if similarity > similarity_threshold: return f"The images are similar. Similarity score: {similarity:.4f}" else: return f"The images are not similar. Similarity score: {similarity:.4f}" def main(image1, image2): return compare_images(image1, image2) iface = gr.Interface( fn=main, inputs=[gr.Image(type="pil"), gr.Image(type="pil")], outputs="text", title="Image Similarity Checker", description="Upload two images to check their similarity based on embeddings." ) iface.launch()