AkinyemiAra commited on
Commit
c0e2011
·
verified ·
1 Parent(s): 30bbdee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -10
app.py CHANGED
@@ -15,8 +15,19 @@ model.eval()
15
  DATASET_DIR = Path("dataset")
16
  CACHE_FILE = "cache.pkl"
17
 
 
 
 
 
 
 
 
 
 
 
 
18
  def get_embedding(image: Image.Image, device="cpu"):
19
- # Use CLIP's built-in preprocessing instead of custom resize
20
  inputs = processor(images=image, return_tensors="pt").to(device)
21
  model_device = model.to(device)
22
  with torch.no_grad():
@@ -25,9 +36,11 @@ def get_embedding(image: Image.Image, device="cpu"):
25
  emb = emb / emb.norm(p=2, dim=-1, keepdim=True)
26
  return emb
27
 
 
28
  def get_reference_embeddings():
29
  # Get all current image files
30
- current_images = set(img_path.name for img_path in DATASET_DIR.glob("*.jpg"))
 
31
 
32
  # Load existing cache if it exists
33
  cached_embeddings = {}
@@ -44,11 +57,15 @@ def get_reference_embeddings():
44
  embeddings = {}
45
  device = "cuda" if torch.cuda.is_available() else "cpu"
46
 
47
- for img_path in DATASET_DIR.glob("*.jpg"):
48
  print(f"Processing {img_path.name}...")
49
- img = Image.open(img_path).convert("RGB")
50
- emb = get_embedding(img, device=device)
51
- embeddings[img_path.name] = emb.cpu()
 
 
 
 
52
 
53
  # Save updated cache
54
  with open(CACHE_FILE, "wb") as f:
@@ -94,15 +111,16 @@ def add_image(name: str, image):
94
  if not name.strip():
95
  return "Please provide a valid image name."
96
 
97
- path = DATASET_DIR / f"{name}.jpg"
98
- image.save(path)
 
99
 
100
  # Use GPU for consistency if available
101
  device = "cuda" if torch.cuda.is_available() else "cpu"
102
  emb = get_embedding(image, device=device)
103
 
104
  # Add to current embeddings and save cache
105
- reference_embeddings[f"{name}.jpg"] = emb.cpu()
106
 
107
  with open(CACHE_FILE, "wb") as f:
108
  pickle.dump(reference_embeddings, f)
@@ -120,4 +138,4 @@ add_interface = gr.Interface(fn=add_image,
120
  allow_flagging="never")
121
 
122
  demo = gr.TabbedInterface([search_interface, add_interface], tab_names=["Search", "Add Product"])
123
- demo.launch()
 
15
  DATASET_DIR = Path("dataset")
16
  CACHE_FILE = "cache.pkl"
17
 
18
+ # Define supported image formats
19
+ IMAGE_EXTENSIONS = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif", "*.webp", "*.tiff", "*.tif"]
20
+
21
+ def get_all_image_files():
22
+ """Get all image files from dataset directory"""
23
+ image_files = []
24
+ for ext in IMAGE_EXTENSIONS:
25
+ image_files.extend(DATASET_DIR.glob(ext))
26
+ image_files.extend(DATASET_DIR.glob(ext.upper())) # Also check uppercase
27
+ return image_files
28
+
29
  def get_embedding(image: Image.Image, device="cpu"):
30
+ # Use CLIP's built-in preprocessing
31
  inputs = processor(images=image, return_tensors="pt").to(device)
32
  model_device = model.to(device)
33
  with torch.no_grad():
 
36
  emb = emb / emb.norm(p=2, dim=-1, keepdim=True)
37
  return emb
38
 
39
+ @spaces.GPU
40
  def get_reference_embeddings():
41
  # Get all current image files
42
+ current_image_files = get_all_image_files()
43
+ current_images = set(img_path.name for img_path in current_image_files)
44
 
45
  # Load existing cache if it exists
46
  cached_embeddings = {}
 
57
  embeddings = {}
58
  device = "cuda" if torch.cuda.is_available() else "cpu"
59
 
60
+ for img_path in current_image_files:
61
  print(f"Processing {img_path.name}...")
62
+ try:
63
+ img = Image.open(img_path).convert("RGB")
64
+ emb = get_embedding(img, device=device)
65
+ embeddings[img_path.name] = emb.cpu()
66
+ except Exception as e:
67
+ print(f"Error processing {img_path.name}: {e}")
68
+ continue
69
 
70
  # Save updated cache
71
  with open(CACHE_FILE, "wb") as f:
 
111
  if not name.strip():
112
  return "Please provide a valid image name."
113
 
114
+ # Save as PNG to preserve quality for all input formats
115
+ path = DATASET_DIR / f"{name}.png"
116
+ image.save(path, "PNG")
117
 
118
  # Use GPU for consistency if available
119
  device = "cuda" if torch.cuda.is_available() else "cpu"
120
  emb = get_embedding(image, device=device)
121
 
122
  # Add to current embeddings and save cache
123
+ reference_embeddings[f"{name}.png"] = emb.cpu()
124
 
125
  with open(CACHE_FILE, "wb") as f:
126
  pickle.dump(reference_embeddings, f)
 
138
  allow_flagging="never")
139
 
140
  demo = gr.TabbedInterface([search_interface, add_interface], tab_names=["Search", "Add Product"])
141
+ demo.launch(mcp_server=True)