safinal commited on
Commit
a263f63
Β·
verified Β·
1 Parent(s): 3dad86d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -2
app.py CHANGED
@@ -4,12 +4,38 @@ import numpy as np
4
  from PIL import Image
5
  import pandas as pd
6
  from sklearn.metrics.pairwise import cosine_similarity
 
7
 
8
 
9
  from token_classifier import load_token_classifier, predict
10
  from model import Model
11
  from dataset import RetrievalDataset
12
- from generate_embeds import encode_database
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
 
15
  # Load model and configurations
@@ -21,7 +47,6 @@ def load_model():
21
 
22
 
23
  def process_single_query(model, query_image_path, query_text, database_embeddings, database_df):
24
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
 
26
  # Process query image
27
  query_img = model.processor(Image.open(query_image_path)).unsqueeze(0).to(device)
 
4
  from PIL import Image
5
  import pandas as pd
6
  from sklearn.metrics.pairwise import cosine_similarity
7
+ from tqdm import tqdm
8
 
9
 
10
  from token_classifier import load_token_classifier, predict
11
  from model import Model
12
  from dataset import RetrievalDataset
13
+
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ batch_size = 512
16
+
17
+
18
+ def encode_database(model, df: pd.DataFrame) -> np.ndarray :
19
+ """
20
+ Process database images and generate embeddings.
21
+
22
+ Args:
23
+ df (pd. DataFrame ): DataFrame with column:
24
+ - target_image: str, paths to database images
25
+
26
+ Returns:
27
+ np.ndarray: Embeddings array (num_images, embedding_dim)
28
+ """
29
+ model.eval()
30
+ all_embeddings = []
31
+ for i in tqdm(range(0, len(df), batch_size)):
32
+ target_imgs = torch.stack([model.processor(Image.open(target_image_path)) for target_image_path in df['target_image'][i:i+batch_size]]).to(device)
33
+ with torch.no_grad():
34
+ # target_imgs_embedding = model.encode_database_image(target_imgs)
35
+ target_imgs_embedding = model.feature_extractor.encode_image(target_imgs)
36
+ target_imgs_embedding = torch.nn.functional.normalize(target_imgs_embedding, dim=1, p=2)
37
+ all_embeddings.append(target_imgs_embedding.detach().cpu().numpy())
38
+ return np.concatenate(all_embeddings)
39
 
40
 
41
  # Load model and configurations
 
47
 
48
 
49
  def process_single_query(model, query_image_path, query_text, database_embeddings, database_df):
 
50
 
51
  # Process query image
52
  query_img = model.processor(Image.open(query_image_path)).unsqueeze(0).to(device)