dheena commited on
Commit
f6f293b
·
1 Parent(s): e8ee9c0
Files changed (2) hide show
  1. Dockerfile +1 -0
  2. src/model.py +6 -7
Dockerfile CHANGED
@@ -7,6 +7,7 @@ RUN apt-get update && apt-get install -y \
7
  curl \
8
  software-properties-common \
9
  git \
 
10
  && rm -rf /var/lib/apt/lists/*
11
 
12
  COPY requirements.txt ./
 
7
  curl \
8
  software-properties-common \
9
  git \
10
+ libgl1-mesa-glx \
11
  && rm -rf /var/lib/apt/lists/*
12
 
13
  COPY requirements.txt ./
src/model.py CHANGED
@@ -1,16 +1,14 @@
1
  import faiss
2
  import torch
3
  import clip
4
- from openai import OpenAI
5
  import numpy as np
6
  from PIL import Image
7
  from fastapi import FastAPI
8
  from typing import List
9
  import segmentation
10
 
11
- client = OpenAI()
12
  device = "cpu"
13
- model, preprocess = clip.load("ViT-B/32", device=device)
14
 
15
  def get_image_features(image: Image.Image) -> np.ndarray:
16
  """Extract CLIP features from an image."""
@@ -29,9 +27,9 @@ def save_image_in_index(image_features: np.ndarray, metadata: dict):
29
  index.add(image_features)
30
  meta_data_store.append(metadata)
31
 
32
- def process_image_embedding(image_url: str, labels=['clothes']) -> np.ndarray:
33
  """Get feature embedding for a query image."""
34
- search_image, search_detections = segmentation.grounded_segmentation(image=image_url, labels=labels)
35
  cropped_image = segmentation.cut_image(search_image, search_detections[0].mask, search_detections[0].box)
36
 
37
  # Convert to valid RGB
@@ -40,8 +38,7 @@ def process_image_embedding(image_url: str, labels=['clothes']) -> np.ndarray:
40
  if cropped_image.ndim == 2:
41
  cropped_image = np.stack([cropped_image] * 3, axis=-1)
42
 
43
- pil_image = Image.fromarray(cropped_image)
44
- return pil_image
45
 
46
  def get_top_k_results(image_url: str, k: int = 10) -> List[dict]:
47
  """Find top-k similar images from the index."""
@@ -63,3 +60,5 @@ def get_top_k_results(image_url: str, k: int = 10) -> List[dict]:
63
 
64
 
65
 
 
 
 
1
  import faiss
2
  import torch
3
  import clip
 
4
  import numpy as np
5
  from PIL import Image
6
  from fastapi import FastAPI
7
  from typing import List
8
  import segmentation
9
 
 
10
  device = "cpu"
11
+ model, preprocess = clip.load("ViT-B/32", device=device, download_root="./clip_cache")
12
 
13
  def get_image_features(image: Image.Image) -> np.ndarray:
14
  """Extract CLIP features from an image."""
 
27
  index.add(image_features)
28
  meta_data_store.append(metadata)
29
 
30
+ def process_image_embedding(image_or_url, labels=['clothes']) -> np.ndarray:
31
  """Get feature embedding for a query image."""
32
+ search_image, search_detections = segmentation.grounded_segmentation(image=image_or_url, labels=labels)
33
  cropped_image = segmentation.cut_image(search_image, search_detections[0].mask, search_detections[0].box)
34
 
35
  # Convert to valid RGB
 
38
  if cropped_image.ndim == 2:
39
  cropped_image = np.stack([cropped_image] * 3, axis=-1)
40
 
41
+ return Image.fromarray(cropped_image)
 
42
 
43
  def get_top_k_results(image_url: str, k: int = 10) -> List[dict]:
44
  """Find top-k similar images from the index."""
 
60
 
61
 
62
 
63
+
64
+