pattern / model /recipe_search.py
sakshamlakhera
Initial commit
733fcd8
raw
history blame
5.38 kB
import os
import csv
import ast
import pickle
import gdown
import torch
import torch.nn.functional as F
import streamlit as st
from transformers import BertTokenizer, BertModel
from config import GOOGLE_DRIVE_FILES
def download_file_from_drive(file_id: str, destination: str, file_name: str) -> bool:
try:
with st.spinner(f"Downloading {file_name}..."):
url = f"https://drive.google.com/uc?id={file_id}"
gdown.download(url, destination, quiet=False)
return True
except Exception as e:
st.error(f"Failed to download {file_name}: {e}")
return False
def ensure_files_downloaded():
for filename, file_id in GOOGLE_DRIVE_FILES.items():
if not os.path.exists(filename):
success = download_file_from_drive(file_id, filename, filename)
if not success:
return False
return True
class GoogleDriveRecipeSearch:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not ensure_files_downloaded():
self.is_ready = False
return
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
self.model = BertModel.from_pretrained("bert-base-uncased")
if os.path.exists("assets/nlp/tag_based_bert_model.pth"):
self.model.load_state_dict(
torch.load("assets/nlp/tag_based_bert_model.pth", map_location=self.device)
)
st.success("Trained model loaded successfully!")
else:
st.warning("Using untrained model")
self.model.to(self.device)
self.model.eval()
self.load_data()
self.is_ready = True
def load_data(self):
self.recipe_embeddings = torch.load("assets/nlp/torch_recipe_embeddings_231630.pt", map_location=self.device)
self.recipes = self._load_recipes("assets/nlp/RAW_recipes.csv")
self.recipe_stats = pickle.load(open("assets/nlp/recipe_statistics_231630.pkl", "rb"))
self.recipe_scores = pickle.load(open("assets/nlp/recipe_scores_231630.pkl", "rb"))
def _load_recipes(self, path):
recipes = []
with open(path, "r", encoding="utf-8") as file:
reader = csv.DictReader(file)
for idx, row in enumerate(reader):
name = row.get("name", "").strip()
if not name or name.lower() in ["nan", "unknown recipe"]:
continue
try:
recipe = {
"id": int(row.get("id", idx)),
"name": name,
"ingredients": ast.literal_eval(row.get("ingredients", "[]")),
"tags": ast.literal_eval(row.get("tags", "[]")),
"minutes": int(float(row.get("minutes", 0))),
"n_steps": int(float(row.get("n_steps", 0))),
"description": row.get("description", ""),
"steps": ast.literal_eval(row.get("steps", "[]"))
}
recipes.append(recipe)
except:
continue
return recipes
def search_recipes(self, query, num_results=5, min_rating=3.0):
if not query.strip():
return []
print('im here')
tokens = self.tokenizer(query, return_tensors="pt", truncation=True, padding=True)
tokens = {k: v.to(self.device) for k, v in tokens.items()}
with torch.no_grad():
outputs = self.model(**tokens)
query_embedding = outputs.last_hidden_state[:, 0, :]
query_embedding = F.normalize(query_embedding, dim=1)
recipe_embeddings = F.normalize(self.recipe_embeddings, dim=1)
similarity_scores = torch.matmul(recipe_embeddings, query_embedding.T).squeeze()
final_scores = []
for i in range(len(self.recipe_embeddings)):
recipe = self.recipes[i]
avg_rating, num_ratings, *_ = self.recipe_stats.get(recipe["id"], (0.0, 0, 0))
if avg_rating < min_rating or num_ratings < 2:
continue
combined_score = (
0.6 * similarity_scores[i].item() +
0.4 * self.recipe_scores.get(recipe["id"], 0)
)
final_scores.append((combined_score, i))
top_matches = sorted(final_scores, key=lambda x: x[0], reverse=True)[:num_results]
results = []
for score, idx in top_matches:
recipe = self.recipes[idx]
avg_rating, num_ratings, *_ = self.recipe_stats.get(recipe["id"], (0.0, 0, 0))
results.append({
"name": recipe["name"],
"tags": recipe.get("tags", []),
"ingredients": recipe.get("ingredients", []),
"minutes": recipe.get("minutes", 0),
"n_steps": recipe.get("n_steps", 0),
"avg_rating": avg_rating,
"num_ratings": num_ratings,
"similarity_score": similarity_scores[idx].item(),
"combined_score": score,
"steps": recipe.get("steps", []),
"description": recipe.get("description", "")
})
return results
@st.cache_resource
def load_search_system():
return GoogleDriveRecipeSearch()