File size: 5,378 Bytes
733fcd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import csv
import ast
import pickle
import gdown
import torch
import torch.nn.functional as F
import streamlit as st
from transformers import BertTokenizer, BertModel
from config import GOOGLE_DRIVE_FILES


def download_file_from_drive(file_id: str, destination: str, file_name: str) -> bool:
    try:
        with st.spinner(f"Downloading {file_name}..."):
            url = f"https://drive.google.com/uc?id={file_id}"
            gdown.download(url, destination, quiet=False)
        return True
    except Exception as e:
        st.error(f"Failed to download {file_name}: {e}")
        return False

def ensure_files_downloaded():
    for filename, file_id in GOOGLE_DRIVE_FILES.items():
        if not os.path.exists(filename):
            success = download_file_from_drive(file_id, filename, filename)
            if not success:
                return False
    return True

class GoogleDriveRecipeSearch:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        if not ensure_files_downloaded():
            self.is_ready = False
            return

        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        self.model = BertModel.from_pretrained("bert-base-uncased")

        if os.path.exists("assets/nlp/tag_based_bert_model.pth"):
            self.model.load_state_dict(
                torch.load("assets/nlp/tag_based_bert_model.pth", map_location=self.device)
            )
            st.success("Trained model loaded successfully!")
        else:
            st.warning("Using untrained model")

        self.model.to(self.device)
        self.model.eval()

        self.load_data()
        self.is_ready = True

    def load_data(self):
        self.recipe_embeddings = torch.load("assets/nlp/torch_recipe_embeddings_231630.pt", map_location=self.device)
        self.recipes = self._load_recipes("assets/nlp/RAW_recipes.csv")
        self.recipe_stats = pickle.load(open("assets/nlp/recipe_statistics_231630.pkl", "rb"))
        self.recipe_scores = pickle.load(open("assets/nlp/recipe_scores_231630.pkl", "rb"))

    def _load_recipes(self, path):
        recipes = []
        with open(path, "r", encoding="utf-8") as file:
            reader = csv.DictReader(file)
            for idx, row in enumerate(reader):
                name = row.get("name", "").strip()
                if not name or name.lower() in ["nan", "unknown recipe"]:
                    continue
                try:
                    recipe = {
                        "id": int(row.get("id", idx)),
                        "name": name,
                        "ingredients": ast.literal_eval(row.get("ingredients", "[]")),
                        "tags": ast.literal_eval(row.get("tags", "[]")),
                        "minutes": int(float(row.get("minutes", 0))),
                        "n_steps": int(float(row.get("n_steps", 0))),
                        "description": row.get("description", ""),
                        "steps": ast.literal_eval(row.get("steps", "[]"))
                    }
                    recipes.append(recipe)
                except:
                    continue
        return recipes

    def search_recipes(self, query, num_results=5, min_rating=3.0):
        if not query.strip():
            return []
        print('im here')

        tokens = self.tokenizer(query, return_tensors="pt", truncation=True, padding=True)
        tokens = {k: v.to(self.device) for k, v in tokens.items()}

        with torch.no_grad():
            outputs = self.model(**tokens)
            query_embedding = outputs.last_hidden_state[:, 0, :]

        query_embedding = F.normalize(query_embedding, dim=1)
        recipe_embeddings = F.normalize(self.recipe_embeddings, dim=1)

        similarity_scores = torch.matmul(recipe_embeddings, query_embedding.T).squeeze()

        final_scores = []
        for i in range(len(self.recipe_embeddings)):
            recipe = self.recipes[i]
            avg_rating, num_ratings, *_ = self.recipe_stats.get(recipe["id"], (0.0, 0, 0))
            if avg_rating < min_rating or num_ratings < 2:
                continue
            combined_score = (
                0.6 * similarity_scores[i].item() +
                0.4 * self.recipe_scores.get(recipe["id"], 0)
            )
            final_scores.append((combined_score, i))

        top_matches = sorted(final_scores, key=lambda x: x[0], reverse=True)[:num_results]

        results = []
        for score, idx in top_matches:
            recipe = self.recipes[idx]
            avg_rating, num_ratings, *_ = self.recipe_stats.get(recipe["id"], (0.0, 0, 0))
            results.append({
                "name": recipe["name"],
                "tags": recipe.get("tags", []),
                "ingredients": recipe.get("ingredients", []),
                "minutes": recipe.get("minutes", 0),
                "n_steps": recipe.get("n_steps", 0),
                "avg_rating": avg_rating,
                "num_ratings": num_ratings,
                "similarity_score": similarity_scores[idx].item(),
                "combined_score": score,
                "steps": recipe.get("steps", []),
                "description": recipe.get("description", "")
            })

        return results

@st.cache_resource
def load_search_system():
    return GoogleDriveRecipeSearch()