safinal commited on
Commit
ffd2453
Β·
verified Β·
1 Parent(s): 20858a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -15
app.py CHANGED
@@ -3,34 +3,32 @@ import torch
3
  import numpy as np
4
  from PIL import Image
5
  import pandas as pd
6
- from pathlib import Path
7
  from sklearn.metrics.pairwise import cosine_similarity
8
 
9
- # Import your model and necessary functions
10
- from src.config import ConfigManager
11
- from src.token_classifier import load_token_classifier, predict
12
- from your_model_file import YourModel # Replace with your actual model import
 
 
13
 
14
  # Load model and configurations
15
  def load_model():
16
- model = YourModel() # Initialize your model
 
17
  model.eval()
18
  return model
19
 
20
- def load_dataset():
21
- # Load your default dataset
22
- database_df = pd.read_csv('database.csv') # Adjust path as needed
23
- return database_df
24
 
25
  def process_single_query(model, query_image_path, query_text, database_embeddings, database_df):
26
- device = ConfigManager().get("training")["device"]
27
 
28
  # Process query image
29
  query_img = model.processor(Image.open(query_image_path)).unsqueeze(0).to(device)
30
 
31
  # Get token classifier
32
  token_classifier, token_classifier_tokenizer = load_token_classifier(
33
- ConfigManager().get("paths")["pretrained_token_classifier_path"],
34
  device
35
  )
36
 
@@ -87,8 +85,16 @@ def process_single_query(model, query_image_path, query_text, database_embedding
87
 
88
  # Initialize model and database
89
  model = load_model()
90
- database_df = load_dataset()
91
- database_embeddings = encode_database(model, database_df) # Using your existing function
 
 
 
 
 
 
 
 
92
 
93
  def interface_fn(selected_image, query_text):
94
  result_image_path = process_single_query(
@@ -96,7 +102,7 @@ def interface_fn(selected_image, query_text):
96
  selected_image,
97
  query_text,
98
  database_embeddings,
99
- database_df
100
  )
101
  return Image.open(result_image_path)
102
 
 
3
  import numpy as np
4
  from PIL import Image
5
  import pandas as pd
 
6
  from sklearn.metrics.pairwise import cosine_similarity
7
 
8
+
9
+ from token_classifier import load_token_classifier, predict
10
+ from model import Model
11
+ from dataset import RetrievalDataset
12
+ from generate_embeds import encode_database
13
+
14
 
15
  # Load model and configurations
16
  def load_model():
17
+ model = Model(model_name="ViTamin-L-384", pretrained=None)
18
+ model.load("weights.pth")
19
  model.eval()
20
  return model
21
 
 
 
 
 
22
 
23
  def process_single_query(model, query_image_path, query_text, database_embeddings, database_df):
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
 
26
  # Process query image
27
  query_img = model.processor(Image.open(query_image_path)).unsqueeze(0).to(device)
28
 
29
  # Get token classifier
30
  token_classifier, token_classifier_tokenizer = load_token_classifier(
31
+ "trained_distil_bert_base",
32
  device
33
  )
34
 
 
85
 
86
  # Initialize model and database
87
  model = load_model()
88
+
89
+ test_dataset = RetrievalDataset(
90
+ img_dir_path="sample_evaluation/images",
91
+ annotations_file_path="sample_evaluation/data.csv",
92
+ split='test',
93
+ transform=model.processor,
94
+ tokenizer=model.tokenizer
95
+ )
96
+
97
+ database_embeddings = encode_database(model, test_dataset.load_database()) # Using your existing function
98
 
99
  def interface_fn(selected_image, query_text):
100
  result_image_path = process_single_query(
 
102
  selected_image,
103
  query_text,
104
  database_embeddings,
105
+ test_dataset.load_database()
106
  )
107
  return Image.open(result_image_path)
108