Spaces:
Sleeping
Sleeping
sakshamlakhera
commited on
Commit
·
733fcd8
1
Parent(s):
050283f
Initial commit
Browse files- Dockerfile +1 -1
- Home.py +24 -0
- LICENSE +21 -0
- assets/.DS_Store +0 -0
- assets/css/styles.css +57 -0
- assets/modelWeights/best_model_onion_v1.pth +3 -0
- assets/modelWeights/best_model_pear_v1.pth +3 -0
- assets/modelWeights/best_model_strawberry_v1.pth +3 -0
- assets/modelWeights/best_model_tomato_v1.pth +3 -0
- assets/modelWeights/best_model_v1.pth +3 -0
- assets/nlp/.DS_Store +0 -0
- assets/nlp/WEIGHTS.md +0 -0
- config.py +15 -0
- model/.DS_Store +0 -0
- model/__init__.py +0 -0
- model/classifier.py +43 -0
- model/recipe_search.py +139 -0
- pages/1_Image_Classification.py +34 -0
- pages/2_Variation_Detection.py +53 -0
- pages/3_Recipe_Recommendation.py +90 -0
- pages/4_Report.py +107 -0
- scripts/.DS_Store +0 -0
- scripts/CV/.DS_Store +0 -0
- scripts/CV/script.ipynb +0 -0
- scripts/NLP/nlp_colab.py +475 -0
- scripts/NLP/processing_files_for_app.py +393 -0
- scripts/NLP/search_script.py +216 -0
- utils/.DS_Store +0 -0
- utils/__init__.py +0 -0
- utils/layout.py +33 -0
Dockerfile
CHANGED
@@ -18,4 +18,4 @@ EXPOSE 8501
|
|
18 |
|
19 |
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
20 |
|
21 |
-
ENTRYPOINT ["streamlit", "run", "
|
|
|
18 |
|
19 |
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
20 |
|
21 |
+
ENTRYPOINT ["streamlit", "run", "Home.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
Home.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from utils.layout import set_custom_page_config, render_header
|
3 |
+
|
4 |
+
with open("assets/css/styles.css") as f:
|
5 |
+
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
|
6 |
+
|
7 |
+
set_custom_page_config()
|
8 |
+
render_header()
|
9 |
+
|
10 |
+
st.markdown("""
|
11 |
+
<div class="about-box">
|
12 |
+
Welcome to our Smart Kitchen Assistant — a CSE555 Final Project developed by Group 5 (Saksham & Ahmed).
|
13 |
+
<br><br>
|
14 |
+
🔍 This tool leverages AI to assist in:
|
15 |
+
- Classifying images of vegetables and fruits.
|
16 |
+
- Detecting their variations (cut, whole, sliced).
|
17 |
+
- Recommending recipes based on natural language input.
|
18 |
+
</div>
|
19 |
+
|
20 |
+
### 🔗 Use the left sidebar to navigate between:
|
21 |
+
- 🥦 Task A: Classification
|
22 |
+
- 🧊 Task B: Variation Detection
|
23 |
+
- 🧠 NLP Recipe Recommendation
|
24 |
+
""", unsafe_allow_html=True)
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2025 azaher1215
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
assets/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
assets/css/styles.css
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
body {
|
2 |
+
font-family: 'Segoe UI', sans-serif;
|
3 |
+
}
|
4 |
+
|
5 |
+
.block-container {
|
6 |
+
max-width: 900px;
|
7 |
+
margin: 0 auto;
|
8 |
+
padding: 2rem;
|
9 |
+
}
|
10 |
+
|
11 |
+
.project-header {
|
12 |
+
text-align: center;
|
13 |
+
margin-top: 1rem;
|
14 |
+
margin-bottom: 2rem;
|
15 |
+
}
|
16 |
+
|
17 |
+
.home-container {
|
18 |
+
display: flex;
|
19 |
+
justify-content: center;
|
20 |
+
align-items: center;
|
21 |
+
height: 70vh;
|
22 |
+
}
|
23 |
+
|
24 |
+
.home-card {
|
25 |
+
background: #ffffff;
|
26 |
+
border-radius: 12px;
|
27 |
+
padding: 2rem;
|
28 |
+
box-shadow: 0 8px 20px rgba(0, 0, 0, 0.1);
|
29 |
+
max-width: 600px;
|
30 |
+
text-align: center;
|
31 |
+
}
|
32 |
+
|
33 |
+
.about-box {
|
34 |
+
background-color: #f1f3f6;
|
35 |
+
border-left: 5px solid #4a90e2;
|
36 |
+
padding: 1rem;
|
37 |
+
margin-bottom: 1.5rem;
|
38 |
+
border-radius: 6px;
|
39 |
+
font-size: 0.95rem;
|
40 |
+
}
|
41 |
+
|
42 |
+
img {
|
43 |
+
border-radius: 10px;
|
44 |
+
}
|
45 |
+
|
46 |
+
/* Reduce sidebar width */
|
47 |
+
.css-1d391kg, .css-1d391kg > div {
|
48 |
+
width: 250px !important;
|
49 |
+
}
|
50 |
+
|
51 |
+
/* Standard text sizes */
|
52 |
+
h1 { font-size: 2.2rem; }
|
53 |
+
h2 { font-size: 1.5rem; }
|
54 |
+
p, li { font-size: 1rem; }
|
55 |
+
|
56 |
+
/* Sidebar tweaks */
|
57 |
+
.css-1lcbmhc { padding-top: 2rem; }
|
assets/modelWeights/best_model_onion_v1.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3ce6d74a4b1ccf494999e60addc2f8995072eca00837eb77eabd71ee859a0023
|
3 |
+
size 16343319
|
assets/modelWeights/best_model_pear_v1.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:07e5a67e49f46112e14f0e533c7df4edaf4562ebbffcf65393f0d8bd130a8a37
|
3 |
+
size 16342953
|
assets/modelWeights/best_model_strawberry_v1.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:deed87390d881b658d39db29ec6e1850bf6c09bbf47882bd611a3a1de821fe4e
|
3 |
+
size 16345405
|
assets/modelWeights/best_model_tomato_v1.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bb1db959f9732f49d95d174a6ba01da3271f57f5169b8af94a01abff7e78d329
|
3 |
+
size 16343685
|
assets/modelWeights/best_model_v1.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6aced5beaeea31c3cf030c250bbaf4c4c3f8d644b4dda6db5d21b4358d27b994
|
3 |
+
size 16346243
|
assets/nlp/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
assets/nlp/WEIGHTS.md
ADDED
File without changes
|
config.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CLASS_LABELS = ['onion', 'pear', 'strawberry', 'tomato']
|
2 |
+
|
3 |
+
MODEL_PATH = "assets/modelWeights/best_model_v1.pth"
|
4 |
+
MODEL_PATH_ONION = "assets/modelWeights/best_model_onion_v1.pth"
|
5 |
+
MODEL_PATH_PEAR = "assets/modelWeights/best_model_pear_v1.pth"
|
6 |
+
MODEL_PATH_TOMATO = "assets/modelWeights/best_model_tomato_v1.pth"
|
7 |
+
MODEL_PATH_STRAWBERRY = "assets/modelWeights/best_model_strawberry_v1.pth"
|
8 |
+
|
9 |
+
GOOGLE_DRIVE_FILES = {
|
10 |
+
'assets/nlp/torch_recipe_embeddings_231630.pt': '1PSidY1toSfgECXDxa4pGza56Jq6vOq6t',
|
11 |
+
'assets/nlp/tag_based_bert_model.pth': '1LBl7yFs5JFqOsgfn88BF9g83W9mxiBm6',
|
12 |
+
'assets/nlp/RAW_recipes.csv': '1rFJQzg_ErwEpN6WmhQ4jRyiXv6JCINyf',
|
13 |
+
'assets/nlp/recipe_statistics_231630.pkl': '1n8TNT-6EA_usv59CCCU1IXqtuM7i084E',
|
14 |
+
'assets/nlp/recipe_scores_231630.pkl': '1gfPBzghKHOZqgJu4VE9NkandAd6FGjrA'
|
15 |
+
}
|
model/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
model/__init__.py
ADDED
File without changes
|
model/classifier.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from typing import Tuple, List
|
4 |
+
from torchvision import models, transforms
|
5 |
+
from PIL import Image
|
6 |
+
from config import CLASS_LABELS, MODEL_PATH
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
|
10 |
+
def get_model():
|
11 |
+
model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
|
12 |
+
model.classifier[1] = nn.Linear(model.classifier[1].in_features, len(CLASS_LABELS))
|
13 |
+
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
|
14 |
+
model.eval()
|
15 |
+
return model
|
16 |
+
|
17 |
+
def get_model_by_name(model_path: str, num_classes: int):
|
18 |
+
model = models.efficientnet_b0(weights=None)
|
19 |
+
model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
|
20 |
+
model.load_state_dict(torch.load(model_path, map_location='cpu'))
|
21 |
+
|
22 |
+
model.eval()
|
23 |
+
return model
|
24 |
+
|
25 |
+
|
26 |
+
def predict(image: Image.Image, model, class_labels: List[str] = None) -> Tuple[str, float]:
|
27 |
+
transform = transforms.Compose([
|
28 |
+
transforms.Resize((224, 224)),
|
29 |
+
transforms.ToTensor()
|
30 |
+
])
|
31 |
+
image_tensor = transform(image).unsqueeze(0)
|
32 |
+
|
33 |
+
with torch.no_grad():
|
34 |
+
output = model(image_tensor)
|
35 |
+
probabilities = F.softmax(output, dim=1)
|
36 |
+
confidence, pred = torch.max(probabilities, dim=1)
|
37 |
+
print(pred)
|
38 |
+
|
39 |
+
if class_labels is None:
|
40 |
+
class_labels = CLASS_LABELS
|
41 |
+
|
42 |
+
return class_labels[pred.item()], confidence.item()
|
43 |
+
|
model/recipe_search.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import csv
|
3 |
+
import ast
|
4 |
+
import pickle
|
5 |
+
import gdown
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import streamlit as st
|
9 |
+
from transformers import BertTokenizer, BertModel
|
10 |
+
from config import GOOGLE_DRIVE_FILES
|
11 |
+
|
12 |
+
|
13 |
+
def download_file_from_drive(file_id: str, destination: str, file_name: str) -> bool:
|
14 |
+
try:
|
15 |
+
with st.spinner(f"Downloading {file_name}..."):
|
16 |
+
url = f"https://drive.google.com/uc?id={file_id}"
|
17 |
+
gdown.download(url, destination, quiet=False)
|
18 |
+
return True
|
19 |
+
except Exception as e:
|
20 |
+
st.error(f"Failed to download {file_name}: {e}")
|
21 |
+
return False
|
22 |
+
|
23 |
+
def ensure_files_downloaded():
|
24 |
+
for filename, file_id in GOOGLE_DRIVE_FILES.items():
|
25 |
+
if not os.path.exists(filename):
|
26 |
+
success = download_file_from_drive(file_id, filename, filename)
|
27 |
+
if not success:
|
28 |
+
return False
|
29 |
+
return True
|
30 |
+
|
31 |
+
class GoogleDriveRecipeSearch:
|
32 |
+
def __init__(self):
|
33 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
34 |
+
|
35 |
+
if not ensure_files_downloaded():
|
36 |
+
self.is_ready = False
|
37 |
+
return
|
38 |
+
|
39 |
+
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
40 |
+
self.model = BertModel.from_pretrained("bert-base-uncased")
|
41 |
+
|
42 |
+
if os.path.exists("assets/nlp/tag_based_bert_model.pth"):
|
43 |
+
self.model.load_state_dict(
|
44 |
+
torch.load("assets/nlp/tag_based_bert_model.pth", map_location=self.device)
|
45 |
+
)
|
46 |
+
st.success("Trained model loaded successfully!")
|
47 |
+
else:
|
48 |
+
st.warning("Using untrained model")
|
49 |
+
|
50 |
+
self.model.to(self.device)
|
51 |
+
self.model.eval()
|
52 |
+
|
53 |
+
self.load_data()
|
54 |
+
self.is_ready = True
|
55 |
+
|
56 |
+
def load_data(self):
|
57 |
+
self.recipe_embeddings = torch.load("assets/nlp/torch_recipe_embeddings_231630.pt", map_location=self.device)
|
58 |
+
self.recipes = self._load_recipes("assets/nlp/RAW_recipes.csv")
|
59 |
+
self.recipe_stats = pickle.load(open("assets/nlp/recipe_statistics_231630.pkl", "rb"))
|
60 |
+
self.recipe_scores = pickle.load(open("assets/nlp/recipe_scores_231630.pkl", "rb"))
|
61 |
+
|
62 |
+
def _load_recipes(self, path):
|
63 |
+
recipes = []
|
64 |
+
with open(path, "r", encoding="utf-8") as file:
|
65 |
+
reader = csv.DictReader(file)
|
66 |
+
for idx, row in enumerate(reader):
|
67 |
+
name = row.get("name", "").strip()
|
68 |
+
if not name or name.lower() in ["nan", "unknown recipe"]:
|
69 |
+
continue
|
70 |
+
try:
|
71 |
+
recipe = {
|
72 |
+
"id": int(row.get("id", idx)),
|
73 |
+
"name": name,
|
74 |
+
"ingredients": ast.literal_eval(row.get("ingredients", "[]")),
|
75 |
+
"tags": ast.literal_eval(row.get("tags", "[]")),
|
76 |
+
"minutes": int(float(row.get("minutes", 0))),
|
77 |
+
"n_steps": int(float(row.get("n_steps", 0))),
|
78 |
+
"description": row.get("description", ""),
|
79 |
+
"steps": ast.literal_eval(row.get("steps", "[]"))
|
80 |
+
}
|
81 |
+
recipes.append(recipe)
|
82 |
+
except:
|
83 |
+
continue
|
84 |
+
return recipes
|
85 |
+
|
86 |
+
def search_recipes(self, query, num_results=5, min_rating=3.0):
|
87 |
+
if not query.strip():
|
88 |
+
return []
|
89 |
+
print('im here')
|
90 |
+
|
91 |
+
tokens = self.tokenizer(query, return_tensors="pt", truncation=True, padding=True)
|
92 |
+
tokens = {k: v.to(self.device) for k, v in tokens.items()}
|
93 |
+
|
94 |
+
with torch.no_grad():
|
95 |
+
outputs = self.model(**tokens)
|
96 |
+
query_embedding = outputs.last_hidden_state[:, 0, :]
|
97 |
+
|
98 |
+
query_embedding = F.normalize(query_embedding, dim=1)
|
99 |
+
recipe_embeddings = F.normalize(self.recipe_embeddings, dim=1)
|
100 |
+
|
101 |
+
similarity_scores = torch.matmul(recipe_embeddings, query_embedding.T).squeeze()
|
102 |
+
|
103 |
+
final_scores = []
|
104 |
+
for i in range(len(self.recipe_embeddings)):
|
105 |
+
recipe = self.recipes[i]
|
106 |
+
avg_rating, num_ratings, *_ = self.recipe_stats.get(recipe["id"], (0.0, 0, 0))
|
107 |
+
if avg_rating < min_rating or num_ratings < 2:
|
108 |
+
continue
|
109 |
+
combined_score = (
|
110 |
+
0.6 * similarity_scores[i].item() +
|
111 |
+
0.4 * self.recipe_scores.get(recipe["id"], 0)
|
112 |
+
)
|
113 |
+
final_scores.append((combined_score, i))
|
114 |
+
|
115 |
+
top_matches = sorted(final_scores, key=lambda x: x[0], reverse=True)[:num_results]
|
116 |
+
|
117 |
+
results = []
|
118 |
+
for score, idx in top_matches:
|
119 |
+
recipe = self.recipes[idx]
|
120 |
+
avg_rating, num_ratings, *_ = self.recipe_stats.get(recipe["id"], (0.0, 0, 0))
|
121 |
+
results.append({
|
122 |
+
"name": recipe["name"],
|
123 |
+
"tags": recipe.get("tags", []),
|
124 |
+
"ingredients": recipe.get("ingredients", []),
|
125 |
+
"minutes": recipe.get("minutes", 0),
|
126 |
+
"n_steps": recipe.get("n_steps", 0),
|
127 |
+
"avg_rating": avg_rating,
|
128 |
+
"num_ratings": num_ratings,
|
129 |
+
"similarity_score": similarity_scores[idx].item(),
|
130 |
+
"combined_score": score,
|
131 |
+
"steps": recipe.get("steps", []),
|
132 |
+
"description": recipe.get("description", "")
|
133 |
+
})
|
134 |
+
|
135 |
+
return results
|
136 |
+
|
137 |
+
@st.cache_resource
|
138 |
+
def load_search_system():
|
139 |
+
return GoogleDriveRecipeSearch()
|
pages/1_Image_Classification.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.layout import render_layout
|
2 |
+
import streamlit as st
|
3 |
+
from PIL import Image
|
4 |
+
from model.classifier import get_model, predict
|
5 |
+
|
6 |
+
def classification_page():
|
7 |
+
st.markdown("## 🖼️ Task A: Image Classification")
|
8 |
+
|
9 |
+
st.markdown("""
|
10 |
+
<div class="about-box">
|
11 |
+
This module classifies images into <b>Onion, Pear, Strawberry, or Tomato</b>
|
12 |
+
using an EfficientNet-B0 model.
|
13 |
+
</div>
|
14 |
+
""", unsafe_allow_html=True)
|
15 |
+
|
16 |
+
model = load_model()
|
17 |
+
|
18 |
+
uploaded = st.file_uploader("📤 Upload an image (JPG/PNG)", type=["jpg", "jpeg", "png"])
|
19 |
+
if uploaded:
|
20 |
+
img = Image.open(uploaded).convert("RGB")
|
21 |
+
label, confidence = predict(img, model)
|
22 |
+
print(label)
|
23 |
+
|
24 |
+
st.success(f"🎯 Prediction: **{label.upper()}** ({confidence*100:.2f}% confidence)")
|
25 |
+
|
26 |
+
st.markdown("<div style='text-align: center;'>", unsafe_allow_html=True)
|
27 |
+
st.image(img, caption="Uploaded Image", width=300)
|
28 |
+
st.markdown("</div>", unsafe_allow_html=True)
|
29 |
+
|
30 |
+
@st.cache_resource
|
31 |
+
def load_model():
|
32 |
+
return get_model()
|
33 |
+
|
34 |
+
render_layout(classification_page)
|
pages/2_Variation_Detection.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.layout import render_layout
|
2 |
+
import streamlit as st
|
3 |
+
from PIL import Image
|
4 |
+
from model.classifier import predict, get_model_by_name
|
5 |
+
import utils.config as config
|
6 |
+
|
7 |
+
VARIATION_CLASS_MAP = {
|
8 |
+
"Onion": ['halved', 'sliced', 'whole'],
|
9 |
+
"Strawberry": ['Hulled', 'sliced', 'whole'],
|
10 |
+
"Tomato": ['diced', 'vines', 'whole'],
|
11 |
+
"Pear": ['halved', 'sliced', 'whole']
|
12 |
+
}
|
13 |
+
|
14 |
+
MODEL_PATH_MAP = {
|
15 |
+
"Onion": config.MODEL_PATH_ONION,
|
16 |
+
"Pear": config.MODEL_PATH_PEAR,
|
17 |
+
"Strawberry": config.MODEL_PATH_STRAWBERRY,
|
18 |
+
"Tomato": config.MODEL_PATH_TOMATO
|
19 |
+
}
|
20 |
+
|
21 |
+
@st.cache_resource
|
22 |
+
def load_model(product_name):
|
23 |
+
model_path = MODEL_PATH_MAP[product_name]
|
24 |
+
num_classes = len(VARIATION_CLASS_MAP[product_name])
|
25 |
+
return get_model_by_name(model_path, num_classes=num_classes)
|
26 |
+
|
27 |
+
def variation_detection_page():
|
28 |
+
st.markdown("## 🔍 Task B: Variation Detection")
|
29 |
+
|
30 |
+
st.markdown("""
|
31 |
+
<div class="about-box">
|
32 |
+
This module detects variations such as <code>Whole</code>, <code>Halved</code>, <code>Diced</code>, etc.
|
33 |
+
for Onion, Pear, Strawberry, and Tomato using individually fine-tuned models.
|
34 |
+
</div>
|
35 |
+
""", unsafe_allow_html=True)
|
36 |
+
|
37 |
+
product = st.selectbox("Select Product Type", list(MODEL_PATH_MAP.keys()))
|
38 |
+
|
39 |
+
model = load_model(product)
|
40 |
+
class_labels = VARIATION_CLASS_MAP[product]
|
41 |
+
|
42 |
+
uploaded = st.file_uploader("📤 Upload an image (JPG/PNG)", type=["jpg", "jpeg", "png"])
|
43 |
+
if uploaded:
|
44 |
+
img = Image.open(uploaded).convert("RGB")
|
45 |
+
label, confidence = predict(img, model, class_labels=class_labels)
|
46 |
+
|
47 |
+
st.success(f"🔍 Detected Variation: **{label}** ({confidence * 100:.2f}% confidence)")
|
48 |
+
|
49 |
+
st.markdown("<div style='text-align: center;'>", unsafe_allow_html=True)
|
50 |
+
st.image(img, caption=f"Uploaded Image - {product}", width=300)
|
51 |
+
st.markdown("</div>", unsafe_allow_html=True)
|
52 |
+
|
53 |
+
render_layout(variation_detection_page)
|
pages/3_Recipe_Recommendation.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.layout import render_layout
|
2 |
+
import streamlit as st
|
3 |
+
import time
|
4 |
+
from model.recipe_search import load_search_system # assumed you modularized this logic
|
5 |
+
import streamlit.components.v1 as components
|
6 |
+
|
7 |
+
def recipe_search_page():
|
8 |
+
st.markdown("""
|
9 |
+
## 🍽️ Advanced Recipe Recommendation
|
10 |
+
<div class="about-box">
|
11 |
+
This module uses a custom-trained BERT model to semantically search recipes
|
12 |
+
based on your query, ingredients, and tags.
|
13 |
+
</div>
|
14 |
+
""", unsafe_allow_html=True)
|
15 |
+
|
16 |
+
if 'search_system' not in st.session_state:
|
17 |
+
with st.spinner("🔄 Initializing recipe search system..."):
|
18 |
+
st.session_state.search_system = load_search_system()
|
19 |
+
|
20 |
+
search_system = st.session_state.search_system
|
21 |
+
|
22 |
+
if not search_system.is_ready:
|
23 |
+
st.error("❌ System not ready. Please check data files and try again.")
|
24 |
+
return
|
25 |
+
|
26 |
+
query = st.text_input(
|
27 |
+
"Search for recipes:",
|
28 |
+
placeholder="e.g., 'chicken pasta', 'vegetarian salad', 'chocolate dessert'"
|
29 |
+
)
|
30 |
+
|
31 |
+
col1, col2 = st.columns(2)
|
32 |
+
with col1:
|
33 |
+
num_results = st.slider("Number of results", 1, 15, 5)
|
34 |
+
with col2:
|
35 |
+
min_rating = st.slider("Minimum rating", 1.0, 5.0, 3.0, 0.1)
|
36 |
+
|
37 |
+
if st.button("🔍 Search Recipes") and query:
|
38 |
+
with st.spinner(f"Searching for '{query}'..."):
|
39 |
+
start = time.time()
|
40 |
+
print(query, num_results, min_rating)
|
41 |
+
results = search_system.search_recipes(query, num_results, min_rating)
|
42 |
+
elapsed = time.time() - start
|
43 |
+
|
44 |
+
if results:
|
45 |
+
st.markdown(f"### 🎯 Top {len(results)} recipe recommendations for: *'{query}'*")
|
46 |
+
st.markdown("<sub>📊 Sorted by best match using semantic search and popularity</sub>", unsafe_allow_html=True)
|
47 |
+
st.markdown("<hr>", unsafe_allow_html=True)
|
48 |
+
|
49 |
+
for i, recipe in enumerate(results, 1):
|
50 |
+
steps_html = "".join([f"<li>{step.strip().capitalize()}</li>" for step in recipe.get("steps", [])])
|
51 |
+
description = recipe.get("description", "").strip().capitalize()
|
52 |
+
|
53 |
+
html_code = f"""
|
54 |
+
<div style="margin-bottom: 24px; padding: 16px; border-radius: 12px; background-color: #fdfdfd; box-shadow: 0 2px 8px rgba(0,0,0,0.06); font-family: Arial, sans-serif;">
|
55 |
+
<div style="font-size: 18px; font-weight: bold; color: #333;">🔝 {i}. {recipe['name']}</div>
|
56 |
+
|
57 |
+
<div style="margin: 4px 0 8px 0; font-size: 14px; color: #555;">
|
58 |
+
⏱️ <b>{recipe['minutes']} min</b> | 🔥 <b>{recipe['n_steps']} steps</b> | ⭐ <b>{recipe['avg_rating']:.1f}/5.0</b>
|
59 |
+
<span style="font-size: 12px; color: #999;">({recipe['num_ratings']} ratings)</span>
|
60 |
+
</div>
|
61 |
+
|
62 |
+
<div style="margin-bottom: 6px; font-size: 14px;">
|
63 |
+
<b>🔍 Match Score:</b> <span style="color: #007acc; font-weight: bold;">{recipe['similarity_score']:.1%}</span>
|
64 |
+
<span style="font-size: 12px; color: #888;">(query match)</span><br>
|
65 |
+
<b>🏆 Overall Score:</b> <span style="color: green; font-weight: bold;">{recipe['combined_score']:.1%}</span>
|
66 |
+
<span style="font-size: 12px; color: #888;">(match + popularity)</span>
|
67 |
+
</div>
|
68 |
+
|
69 |
+
<div style="margin-bottom: 6px;">
|
70 |
+
<b>🏷️ Tags:</b><br>
|
71 |
+
{" ".join([f"<span style='background:#eee;padding:4px 8px;border-radius:6px;margin:2px;display:inline-block;font-size:12px'>{tag}</span>" for tag in recipe['tags']])}
|
72 |
+
</div>
|
73 |
+
|
74 |
+
<div style="margin-bottom: 6px;">
|
75 |
+
<b>🥘 Ingredients:</b><br>
|
76 |
+
<span style="font-size: 13px; color: #444;">{', '.join(recipe['ingredients'][:8])}
|
77 |
+
{'...' if len(recipe['ingredients']) > 8 else ''}</span>
|
78 |
+
</div>
|
79 |
+
|
80 |
+
{"<div style='margin-top: 10px; font-size: 13px; color: #333;'><b>📖 Description:</b><br>" + description + "</div>" if description else ""}
|
81 |
+
|
82 |
+
{"<div style='margin-top: 10px; font-size: 13px;'><b>🧑🍳 Steps:</b><ol style='margin: 6px 0 0 18px; padding: 0;'>" + steps_html + "</ol></div>" if steps_html else ""}
|
83 |
+
</div>
|
84 |
+
"""
|
85 |
+
components.html(html_code, height=360 + len(recipe.get("steps", [])) * 20)
|
86 |
+
|
87 |
+
else:
|
88 |
+
st.warning(f"😔 No recipes found for '{query}' with a minimum rating of {min_rating}/5.0.")
|
89 |
+
|
90 |
+
render_layout(recipe_search_page)
|
pages/4_Report.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
def render_report():
|
4 |
+
st.title("📊 Recipe Search System Report")
|
5 |
+
|
6 |
+
st.markdown("""
|
7 |
+
## Overview
|
8 |
+
This report summarizes the working of the **custom BERT-based Recipe Recommendation System**, dataset characteristics, scoring algorithm, and evaluation metrics.
|
9 |
+
""")
|
10 |
+
|
11 |
+
st.markdown("### 🔍 Query Embedding and Similarity Calculation")
|
12 |
+
st.latex(r"""
|
13 |
+
\text{Similarity}(q, r_i) = \cos(\hat{q}, \hat{r}_i) = \frac{\hat{q} \cdot \hat{r}_i}{\|\hat{q}\|\|\hat{r}_i\|}
|
14 |
+
""")
|
15 |
+
st.markdown("""
|
16 |
+
Here, $\\hat{q}$ is the BERT embedding of the query, and $\\hat{r}_i$ is the embedding of the i-th recipe.
|
17 |
+
""")
|
18 |
+
|
19 |
+
st.markdown("### 🏆 Final Score Calculation")
|
20 |
+
st.latex(r"""
|
21 |
+
\text{Score}_i = 0.6 \times \text{Similarity}_i + 0.4 \times \text{Popularity}_i
|
22 |
+
""")
|
23 |
+
|
24 |
+
st.markdown("### 📊 Dataset Summary")
|
25 |
+
st.markdown("""
|
26 |
+
- **Total Recipes:** 231,630
|
27 |
+
- **Average Tags per Recipe:** ~6
|
28 |
+
- **Ingredients per Recipe:** 3 to 20
|
29 |
+
- **Ratings Data:** Extracted from user interaction dataset
|
30 |
+
""")
|
31 |
+
|
32 |
+
st.markdown("### 🧪 Evaluation Strategy")
|
33 |
+
st.markdown("""
|
34 |
+
We use a combination of:
|
35 |
+
- Manual inspection
|
36 |
+
- Recipe diversity analysis
|
37 |
+
- Match vs rating correlation
|
38 |
+
- Qualitative feedback from test queries
|
39 |
+
""")
|
40 |
+
|
41 |
+
st.markdown("---")
|
42 |
+
st.markdown("© 2025 Your Name. All rights reserved.")
|
43 |
+
|
44 |
+
# If using a layout wrapper:
|
45 |
+
render_report()
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
# LaTeX content as string
|
50 |
+
latex_report = r"""
|
51 |
+
\documentclass{article}
|
52 |
+
\usepackage{amsmath}
|
53 |
+
\usepackage{geometry}
|
54 |
+
\geometry{margin=1in}
|
55 |
+
\title{Recipe Recommendation System Report}
|
56 |
+
\author{Saksham Lakhera}
|
57 |
+
\date{\today}
|
58 |
+
|
59 |
+
\begin{document}
|
60 |
+
\maketitle
|
61 |
+
|
62 |
+
\section*{Overview}
|
63 |
+
This report summarizes the working of the \textbf{custom BERT-based Recipe Recommendation System}, dataset characteristics, scoring algorithm, and evaluation metrics.
|
64 |
+
|
65 |
+
\section*{Query Embedding and Similarity Calculation}
|
66 |
+
\[
|
67 |
+
\text{Similarity}(q, r_i) = \cos(\hat{q}, \hat{r}_i) = \frac{\hat{q} \cdot \hat{r}_i}{\|\hat{q}\|\|\hat{r}_i\|}
|
68 |
+
\]
|
69 |
+
Here, $\hat{q}$ is the BERT embedding of the query, and $\hat{r}_i$ is the embedding of the i-th recipe.
|
70 |
+
|
71 |
+
\section*{Final Score Calculation}
|
72 |
+
\[
|
73 |
+
\text{Score}_i = 0.6 \times \text{Similarity}_i + 0.4 \times \text{Popularity}_i
|
74 |
+
\]
|
75 |
+
|
76 |
+
\section*{Dataset Summary}
|
77 |
+
\begin{itemize}
|
78 |
+
\item \textbf{Total Recipes:} 231,630
|
79 |
+
\item \textbf{Average Tags per Recipe:} $\sim$6
|
80 |
+
\item \textbf{Ingredients per Recipe:} 3 to 20
|
81 |
+
\item \textbf{Ratings Source:} User interaction dataset
|
82 |
+
\end{itemize}
|
83 |
+
|
84 |
+
\section*{Evaluation Strategy}
|
85 |
+
We use a combination of:
|
86 |
+
\begin{itemize}
|
87 |
+
\item Manual inspection
|
88 |
+
\item Recipe diversity analysis
|
89 |
+
\item Match vs rating correlation
|
90 |
+
\item Qualitative user feedback
|
91 |
+
\end{itemize}
|
92 |
+
|
93 |
+
\end{document}
|
94 |
+
"""
|
95 |
+
|
96 |
+
# ⬇️ Download button to get the .tex file
|
97 |
+
st.markdown("### 📥 Download LaTeX Report")
|
98 |
+
st.download_button(
|
99 |
+
label="Download LaTeX (.tex)",
|
100 |
+
data=latex_report,
|
101 |
+
file_name="recipe_report.tex",
|
102 |
+
mime="text/plain"
|
103 |
+
)
|
104 |
+
|
105 |
+
# 📤 Optional: Show the .tex content in the app
|
106 |
+
with st.expander("📄 View LaTeX (.tex) File Content"):
|
107 |
+
st.code(latex_report, language="latex")
|
scripts/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
scripts/CV/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
scripts/CV/script.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
scripts/NLP/nlp_colab.py
ADDED
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from ast import literal_eval
|
3 |
+
from transformers import BertTokenizer, BertModel
|
4 |
+
from torch import nn
|
5 |
+
from torch.utils.data import Dataset, DataLoader
|
6 |
+
import torch
|
7 |
+
import os
|
8 |
+
from sklearn.model_selection import train_test_split
|
9 |
+
import random
|
10 |
+
import re
|
11 |
+
|
12 |
+
def clean_text(text):
|
13 |
+
#helper function to clean the text from whitespace, double spaces
|
14 |
+
# converts to lowercase and checks if the text is a string first to avoid errors
|
15 |
+
if not isinstance(text, str):
|
16 |
+
return ''
|
17 |
+
text = text.lower()
|
18 |
+
text = ' '.join(text.split())
|
19 |
+
return text.strip()
|
20 |
+
|
21 |
+
def setup_tag_categories():
|
22 |
+
tag_categories = {
|
23 |
+
'cuisine': [
|
24 |
+
'italian', 'chinese', 'mexican', 'indian', 'french', 'greek', 'thai',
|
25 |
+
'japanese', 'american', 'european', 'asian', 'mediterranean', 'spanish',
|
26 |
+
'german', 'korean', 'vietnamese', 'turkish', 'moroccan', 'lebanese'
|
27 |
+
],
|
28 |
+
'course': [
|
29 |
+
'main-dish', 'side-dishes', 'appetizers', 'desserts', 'breakfast',
|
30 |
+
'lunch', 'dinner', 'snacks', 'beverages', 'salads', 'soups'
|
31 |
+
],
|
32 |
+
'main_ingredient': [
|
33 |
+
'chicken', 'beef', 'pork', 'fish', 'seafood', 'vegetables', 'fruit',
|
34 |
+
'pasta', 'rice', 'cheese', 'chocolate', 'potato', 'lamb', 'turkey',
|
35 |
+
'beans', 'nuts', 'eggs', 'tofu'
|
36 |
+
],
|
37 |
+
'dietary': [
|
38 |
+
'vegetarian', 'vegan', 'gluten-free', 'low-carb', 'healthy', 'low-fat',
|
39 |
+
'diabetic', 'dairy-free', 'keto', 'paleo', 'whole30'
|
40 |
+
],
|
41 |
+
'cooking_method': [
|
42 |
+
'oven', 'stove-top', 'no-cook', 'microwave', 'slow-cooker', 'grilling',
|
43 |
+
'baking', 'roasting', 'frying', 'steaming', 'braising'
|
44 |
+
],
|
45 |
+
'difficulty': ['easy', 'beginner-cook', 'advanced', 'intermediate', 'quick'],
|
46 |
+
'time': [
|
47 |
+
'15-minutes-or-less', '30-minutes-or-less', '60-minutes-or-less',
|
48 |
+
'4-hours-or-less', 'weeknight'
|
49 |
+
],
|
50 |
+
'occasion': [
|
51 |
+
'holiday-event', 'christmas', 'thanksgiving', 'valentines-day',
|
52 |
+
'summer', 'winter', 'spring', 'fall', 'party', 'picnic'
|
53 |
+
]
|
54 |
+
}
|
55 |
+
return tag_categories
|
56 |
+
|
57 |
+
def setup_ingredient_groups():
|
58 |
+
ingredient_groups = {
|
59 |
+
'proteins': [
|
60 |
+
'chicken', 'beef', 'pork', 'fish', 'salmon', 'tuna', 'shrimp', 'turkey',
|
61 |
+
'lamb', 'bacon', 'ham', 'sausage', 'eggs', 'tofu', 'beans', 'lentils'
|
62 |
+
],
|
63 |
+
'vegetables': [
|
64 |
+
'onion', 'garlic', 'tomato', 'carrot', 'celery', 'pepper', 'mushroom',
|
65 |
+
'spinach', 'broccoli', 'zucchini', 'potato', 'sweet potato'
|
66 |
+
],
|
67 |
+
'grains_starches': [
|
68 |
+
'rice', 'pasta', 'bread', 'flour', 'oats', 'quinoa', 'barley', 'noodles'
|
69 |
+
],
|
70 |
+
'dairy': [
|
71 |
+
'milk', 'butter', 'cheese', 'cream', 'yogurt', 'sour cream', 'cream cheese'
|
72 |
+
]
|
73 |
+
}
|
74 |
+
return ingredient_groups
|
75 |
+
|
76 |
+
def categorize_recipe_tags(recipe_tags, tag_categories):
|
77 |
+
categorized_tags = {}
|
78 |
+
|
79 |
+
# Initialize empty lists for each category
|
80 |
+
for category_name in tag_categories.keys():
|
81 |
+
categorized_tags[category_name] = []
|
82 |
+
|
83 |
+
# Check each tag
|
84 |
+
for tag in recipe_tags:
|
85 |
+
tag_lower = tag.lower()
|
86 |
+
|
87 |
+
# Check each category
|
88 |
+
for category_name in tag_categories.keys():
|
89 |
+
category_keywords = tag_categories[category_name]
|
90 |
+
|
91 |
+
# Check if any keyword matches this tag
|
92 |
+
for keyword in category_keywords:
|
93 |
+
if keyword in tag_lower:
|
94 |
+
categorized_tags[category_name].append(tag)
|
95 |
+
break
|
96 |
+
|
97 |
+
return categorized_tags
|
98 |
+
|
99 |
+
def extract_main_ingredients(ingredients_list, ingredient_groups):
|
100 |
+
if not ingredients_list or not isinstance(ingredients_list, list):
|
101 |
+
return []
|
102 |
+
|
103 |
+
# Clean each ingredient
|
104 |
+
cleaned_ingredients = []
|
105 |
+
|
106 |
+
for ingredient in ingredients_list:
|
107 |
+
# Convert to string
|
108 |
+
ingredient_string = str(ingredient) if ingredient is not None else ''
|
109 |
+
if not ingredient_string or ingredient_string == 'nan':
|
110 |
+
continue
|
111 |
+
|
112 |
+
# Make lowercase
|
113 |
+
cleaned_ingredient = ingredient_string.lower()
|
114 |
+
|
115 |
+
# Remove common descriptor words
|
116 |
+
words_to_remove = ['fresh', 'dried', 'chopped', 'minced', 'sliced', 'diced', 'ground', 'large', 'small', 'medium']
|
117 |
+
for word in words_to_remove:
|
118 |
+
cleaned_ingredient = cleaned_ingredient.replace(word, '')
|
119 |
+
|
120 |
+
# Remove numbers
|
121 |
+
cleaned_ingredient = re.sub(r'\d+', '', cleaned_ingredient)
|
122 |
+
|
123 |
+
# Remove measurement words
|
124 |
+
measurement_words = ['cup', 'cups', 'tablespoon', 'tablespoons', 'teaspoon', 'teaspoons', 'pound', 'pounds', 'ounce', 'ounces']
|
125 |
+
for measurement in measurement_words:
|
126 |
+
cleaned_ingredient = cleaned_ingredient.replace(measurement, '')
|
127 |
+
|
128 |
+
# Clean up extra spaces
|
129 |
+
cleaned_ingredient = re.sub(r'\s+', ' ', cleaned_ingredient).strip()
|
130 |
+
|
131 |
+
# Only keep if it's long enough
|
132 |
+
if cleaned_ingredient and len(cleaned_ingredient) > 2:
|
133 |
+
cleaned_ingredients.append(cleaned_ingredient)
|
134 |
+
|
135 |
+
|
136 |
+
# Put ingredients in order of importance
|
137 |
+
ordered_ingredients = []
|
138 |
+
|
139 |
+
# First, add proteins (most important)
|
140 |
+
for ingredient in cleaned_ingredients:
|
141 |
+
for protein in ingredient_groups['proteins']:
|
142 |
+
if protein in ingredient:
|
143 |
+
ordered_ingredients.append(ingredient)
|
144 |
+
break
|
145 |
+
|
146 |
+
|
147 |
+
# Then add vegetables, grains, and dairy
|
148 |
+
other_groups = ['vegetables', 'grains_starches', 'dairy']
|
149 |
+
for group_name in other_groups:
|
150 |
+
for ingredient in cleaned_ingredients:
|
151 |
+
if ingredient not in ordered_ingredients:
|
152 |
+
for group_item in ingredient_groups[group_name]:
|
153 |
+
if group_item in ingredient:
|
154 |
+
ordered_ingredients.append(ingredient)
|
155 |
+
break
|
156 |
+
|
157 |
+
# Finally, add any remaining ingredients
|
158 |
+
for ingredient in cleaned_ingredients:
|
159 |
+
if ingredient not in ordered_ingredients:
|
160 |
+
ordered_ingredients.append(ingredient)
|
161 |
+
|
162 |
+
return ordered_ingredients
|
163 |
+
|
164 |
+
def create_structured_recipe_text(recipe, tag_categories, ingredient_groups):
|
165 |
+
# Get recipe tags and categorize them
|
166 |
+
recipe_tags = recipe['tags'] if isinstance(recipe['tags'], list) else []
|
167 |
+
categorized_tags = categorize_recipe_tags(recipe_tags, tag_categories)
|
168 |
+
|
169 |
+
# Choose tags in priority order
|
170 |
+
priority_categories = ['main_ingredient', 'cuisine', 'course', 'dietary', 'cooking_method']
|
171 |
+
selected_tags = []
|
172 |
+
|
173 |
+
for category in priority_categories:
|
174 |
+
if category in categorized_tags:
|
175 |
+
# Take up to 2 tags from each category
|
176 |
+
category_tags = categorized_tags[category][:2]
|
177 |
+
for tag in category_tags:
|
178 |
+
selected_tags.append(tag)
|
179 |
+
|
180 |
+
# Add some additional important tags
|
181 |
+
important_keywords = ['easy', 'quick', 'healthy', 'spicy', 'sweet']
|
182 |
+
remaining_tags = []
|
183 |
+
|
184 |
+
for tag in recipe_tags:
|
185 |
+
if tag not in selected_tags:
|
186 |
+
for keyword in important_keywords:
|
187 |
+
if keyword in tag.lower():
|
188 |
+
remaining_tags.append(tag)
|
189 |
+
break
|
190 |
+
|
191 |
+
|
192 |
+
# Add up to 3 remaining tags
|
193 |
+
for i in range(min(3, len(remaining_tags))):
|
194 |
+
selected_tags.append(remaining_tags[i])
|
195 |
+
|
196 |
+
# Process ingredients
|
197 |
+
recipe_ingredients = recipe['ingredients'] if isinstance(recipe['ingredients'], list) else []
|
198 |
+
main_ingredients = extract_main_ingredients(recipe_ingredients, ingredient_groups)
|
199 |
+
|
200 |
+
# Step 5: Create the final structured text
|
201 |
+
# Join first 8 ingredients
|
202 |
+
ingredients_text = ', '.join(main_ingredients[:8])
|
203 |
+
|
204 |
+
# Join first 10 tags
|
205 |
+
tags_text = ', '.join(selected_tags[:10])
|
206 |
+
|
207 |
+
# Get recipe name
|
208 |
+
recipe_name = str(recipe['name']).replace(' ', ' ').strip()
|
209 |
+
|
210 |
+
# Create final structured text
|
211 |
+
structured_text = f"Recipe: {recipe_name}. Ingredients: {ingredients_text}. Style: {tags_text}"
|
212 |
+
|
213 |
+
return structured_text
|
214 |
+
|
215 |
+
def create_pair_data(recipes_df: pd.DataFrame, interactions_df: pd.DataFrame ,num_pairs: int = 15000):
|
216 |
+
# This function creates the training pairs for the model.
|
217 |
+
# we first analyzed the data to create catogeries for the tags and ingredients. Under each of these, we have a list for cuisine, dietery, poultry, etc.
|
218 |
+
# As we trained the model, we found that the model was not able to learn the tags and ingredients so we created a structured text represenation so it can easily learn.
|
219 |
+
# the prompt used is: Analyze the two csv files attached and created a structured text representation to be used for training a bert model to understand
|
220 |
+
# tags and ingredients such that if a user later searches for a quick recipe, it can be used to find a recipe that is quick to make.
|
221 |
+
|
222 |
+
# Set up the structured text categories and groups
|
223 |
+
tag_categories = setup_tag_categories()
|
224 |
+
ingredient_groups = setup_ingredient_groups()
|
225 |
+
|
226 |
+
# Make a list to store all our pairs
|
227 |
+
pair_data_list = []
|
228 |
+
|
229 |
+
# create the pairs
|
230 |
+
for pair_number in range(num_pairs):
|
231 |
+
|
232 |
+
#Pick a random recipe from our dataframe
|
233 |
+
random_recipe_data = recipes_df.iloc[random.randint(0, len(recipes_df) - 1)]
|
234 |
+
|
235 |
+
# Get the tags from this recipe
|
236 |
+
recipe_tags_list = random_recipe_data['tags']
|
237 |
+
|
238 |
+
# Select some random tags (maximum 5, but maybe less if recipe has fewer tags)
|
239 |
+
num_tags_to_select = min(5, len(recipe_tags_list))
|
240 |
+
selected_tags_list = []
|
241 |
+
|
242 |
+
# Pick random sample of tags and join them to a query string
|
243 |
+
selected_tags_list = random.sample(recipe_tags_list, num_tags_to_select)
|
244 |
+
|
245 |
+
# Create the positive recipe text using structured format
|
246 |
+
positive_recipe_text = create_structured_recipe_text(random_recipe_data, tag_categories, ingredient_groups)
|
247 |
+
|
248 |
+
# Find a negative recipe that has less than 2 tags in common with the query
|
249 |
+
anchor = ' '.join(selected_tags_list)
|
250 |
+
anchor_tags_set = set(anchor.split())
|
251 |
+
|
252 |
+
negative_recipe_text = None
|
253 |
+
attempts_counter = 0
|
254 |
+
max_attempts_allowed = 100
|
255 |
+
|
256 |
+
# Keep trying until we find a good negative recipe (Added a max attempts to avoid infinite loop)
|
257 |
+
while negative_recipe_text is None and attempts_counter < max_attempts_allowed:
|
258 |
+
random_negative_recipe = recipes_df.iloc[random.randint(0, len(recipes_df) - 1)]
|
259 |
+
|
260 |
+
# Get tags from this negative recipe
|
261 |
+
negative_recipe_tags = random_negative_recipe['tags']
|
262 |
+
negative_recipe_tags_set = set(negative_recipe_tags)
|
263 |
+
|
264 |
+
# Count how many tags overlap
|
265 |
+
overlap_count = 0
|
266 |
+
for anchor_tag in anchor_tags_set:
|
267 |
+
if anchor_tag in negative_recipe_tags_set:
|
268 |
+
overlap_count = overlap_count + 1
|
269 |
+
|
270 |
+
attempts_counter = attempts_counter + 1
|
271 |
+
|
272 |
+
# If overlap is small enough (2 or less), we can use this as negative
|
273 |
+
if overlap_count <= 2:
|
274 |
+
# Create the negative recipe text using structured format
|
275 |
+
negative_recipe_text = create_structured_recipe_text(random_negative_recipe, tag_categories, ingredient_groups)
|
276 |
+
|
277 |
+
print(f"Found all negative recipes. Overlap: {overlap_count}")
|
278 |
+
break
|
279 |
+
|
280 |
+
# If we found a negative recipe, add this pair to our list
|
281 |
+
if negative_recipe_text is not None:
|
282 |
+
# Create a tuple with the three parts
|
283 |
+
pair_data_list.append((anchor, positive_recipe_text, negative_recipe_text))
|
284 |
+
print(f"Created pair {pair_number + 1}: Anchor='{anchor}', Overlap={overlap_count}")
|
285 |
+
else:
|
286 |
+
print(f"Could not find negative recipe for anchor '{anchor}' after {max_attempts_allowed} attempts")
|
287 |
+
|
288 |
+
# Show progress every 1000 pairs
|
289 |
+
if (pair_number + 1) % 1000 == 0:
|
290 |
+
print(f"Progress: Created {pair_number + 1}/{num_pairs} pairs")
|
291 |
+
|
292 |
+
# Convert our list to a pandas DataFrame and return it
|
293 |
+
result_dataframe = pd.DataFrame(pair_data_list, columns=['anchor', 'positive', 'negative'])
|
294 |
+
|
295 |
+
print(f"Final result: Created {len(result_dataframe)} pairs total")
|
296 |
+
return result_dataframe
|
297 |
+
|
298 |
+
class pos_neg_pair_dataset(Dataset):
|
299 |
+
#typical dataset class to tokenize for bert model and return the ids and masks
|
300 |
+
def __init__(self, pair_data, tokenizer, max_length=128):
|
301 |
+
self.pair_data = pair_data
|
302 |
+
self.tokenizer = tokenizer
|
303 |
+
self.max_length = max_length
|
304 |
+
|
305 |
+
def __len__(self):
|
306 |
+
return len(self.pair_data)
|
307 |
+
|
308 |
+
def __getitem__(self, idx):
|
309 |
+
|
310 |
+
anchor = self.tokenizer(
|
311 |
+
self.pair_data.iloc[idx]['anchor'],
|
312 |
+
return_tensors='pt',
|
313 |
+
truncation=True,
|
314 |
+
max_length=self.max_length,
|
315 |
+
padding='max_length')
|
316 |
+
positive = self.tokenizer(
|
317 |
+
self.pair_data.iloc[idx]['positive'],
|
318 |
+
return_tensors='pt',
|
319 |
+
truncation=True,
|
320 |
+
max_length=self.max_length,
|
321 |
+
padding='max_length')
|
322 |
+
negative = self.tokenizer(
|
323 |
+
self.pair_data.iloc[idx]['negative'],
|
324 |
+
return_tensors='pt',
|
325 |
+
truncation=True,
|
326 |
+
max_length=self.max_length,
|
327 |
+
padding='max_length')
|
328 |
+
|
329 |
+
return {
|
330 |
+
'anchor_input_ids': anchor['input_ids'].squeeze(),
|
331 |
+
'anchor_attention_mask': anchor['attention_mask'].squeeze(),
|
332 |
+
'positive_input_ids': positive['input_ids'].squeeze(),
|
333 |
+
'positive_attention_mask': positive['attention_mask'].squeeze(),
|
334 |
+
'negative_input_ids': negative['input_ids'].squeeze(),
|
335 |
+
'negative_attention_mask': negative['attention_mask'].squeeze()
|
336 |
+
}
|
337 |
+
|
338 |
+
def evaluate_model(model, val_loader):
|
339 |
+
#evaluation method, same as training but with no gradient updates
|
340 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
341 |
+
model.to(device)
|
342 |
+
model.eval()
|
343 |
+
total_loss = 0
|
344 |
+
criterion = nn.TripletMarginLoss(margin=1.0)
|
345 |
+
with torch.no_grad():
|
346 |
+
for batch in val_loader:
|
347 |
+
anchor_input_ids = batch['anchor_input_ids'].to(device)
|
348 |
+
anchor_attention_mask = batch['anchor_attention_mask'].to(device)
|
349 |
+
positive_input_ids = batch['positive_input_ids'].to(device)
|
350 |
+
positive_attention_mask = batch['positive_attention_mask'].to(device)
|
351 |
+
negative_input_ids = batch['negative_input_ids'].to(device)
|
352 |
+
negative_attention_mask = batch['negative_attention_mask'].to(device)
|
353 |
+
|
354 |
+
# Forward pass - get raw BERT embeddings
|
355 |
+
anchor_outputs = model(anchor_input_ids, anchor_attention_mask)
|
356 |
+
positive_outputs = model(positive_input_ids, positive_attention_mask)
|
357 |
+
negative_outputs = model(negative_input_ids, negative_attention_mask)
|
358 |
+
|
359 |
+
# Extract [CLS] token embeddings
|
360 |
+
anchor_emb = anchor_outputs.last_hidden_state[:, 0, :]
|
361 |
+
positive_emb = positive_outputs.last_hidden_state[:, 0, :]
|
362 |
+
negative_emb = negative_outputs.last_hidden_state[:, 0, :]
|
363 |
+
|
364 |
+
# Calculate loss
|
365 |
+
loss = criterion(anchor_emb, positive_emb, negative_emb)
|
366 |
+
|
367 |
+
total_loss += loss.item()
|
368 |
+
|
369 |
+
print(f"Average loss on validation set: {total_loss/len(val_loader):.4f}")
|
370 |
+
|
371 |
+
def train_model(train_loader, num_epochs=3):
|
372 |
+
# initialize the model, criterion, and optimizer
|
373 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
374 |
+
model = BertModel.from_pretrained('bert-base-uncased')
|
375 |
+
model.to(device)
|
376 |
+
criterion = nn.TripletMarginLoss(margin=1.0)
|
377 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
|
378 |
+
|
379 |
+
for epoch in range(num_epochs):
|
380 |
+
model.train()
|
381 |
+
total_loss = 0
|
382 |
+
for batch in train_loader:
|
383 |
+
#load the ids and masks to device
|
384 |
+
anchor_input_ids = batch['anchor_input_ids'].to(device)
|
385 |
+
anchor_attention_mask = batch['anchor_attention_mask'].to(device)
|
386 |
+
positive_input_ids = batch['positive_input_ids'].to(device)
|
387 |
+
positive_attention_mask = batch['positive_attention_mask'].to(device)
|
388 |
+
negative_input_ids = batch['negative_input_ids'].to(device)
|
389 |
+
negative_attention_mask = batch['negative_attention_mask'].to(device)
|
390 |
+
|
391 |
+
# get the embeddings to extract the [CLS] token embeddings
|
392 |
+
model(anchor_input_ids,anchor_attention_mask)
|
393 |
+
anchor_outputs = model(anchor_input_ids, anchor_attention_mask)
|
394 |
+
positive_outputs = model(positive_input_ids, positive_attention_mask)
|
395 |
+
negative_outputs = model(negative_input_ids, negative_attention_mask)
|
396 |
+
|
397 |
+
# Extract the[CLS] token embeddings
|
398 |
+
anchor_emb = anchor_outputs.last_hidden_state[:, 0, :]
|
399 |
+
positive_emb = positive_outputs.last_hidden_state[:, 0, :]
|
400 |
+
negative_emb = negative_outputs.last_hidden_state[:, 0, :]
|
401 |
+
|
402 |
+
# Calculate loss
|
403 |
+
loss = criterion(anchor_emb, positive_emb, negative_emb)
|
404 |
+
|
405 |
+
# Backward pass
|
406 |
+
optimizer.zero_grad()
|
407 |
+
loss.backward()
|
408 |
+
optimizer.step()
|
409 |
+
|
410 |
+
total_loss += loss.item()
|
411 |
+
|
412 |
+
# per batch average loss total loss / number of batches
|
413 |
+
print(f'Epoch {epoch+1}, Average Loss: {total_loss/len(train_loader):.4f}')
|
414 |
+
|
415 |
+
return model
|
416 |
+
|
417 |
+
if __name__ == '__main__':
|
418 |
+
|
419 |
+
if not os.path.exists('pair_data.parquet'):
|
420 |
+
# Load and prepare the data
|
421 |
+
print("Loading recipe data")
|
422 |
+
recipes_df = pd.read_csv('RAW_recipes.csv')
|
423 |
+
|
424 |
+
# Clean the data
|
425 |
+
recipes_df['name'] = recipes_df['name'].apply(clean_text)
|
426 |
+
recipes_df['tags'] = recipes_df['tags'].apply(literal_eval)
|
427 |
+
recipes_df['ingredients'] = recipes_df['ingredients'].apply(literal_eval)
|
428 |
+
|
429 |
+
# Filter recipes with meaningful data (no empty tags)
|
430 |
+
recipes_df = recipes_df[recipes_df['tags'].str.len() > 0]
|
431 |
+
|
432 |
+
# Load interactions
|
433 |
+
print("Loading interaction data")
|
434 |
+
interactions_df = pd.read_csv('RAW_interactions.csv')
|
435 |
+
interactions_df = interactions_df.dropna(subset=['rating'])
|
436 |
+
interactions_df['rating'] = pd.to_numeric(interactions_df['rating'], errors='coerce')
|
437 |
+
interactions_df = interactions_df.dropna(subset=['rating'])
|
438 |
+
|
439 |
+
# Create training pairs
|
440 |
+
pair_data = create_pair_data(recipes_df, interactions_df, num_pairs=15000)
|
441 |
+
|
442 |
+
# Save the pair data
|
443 |
+
pair_data.to_parquet('pair_data.parquet', index=False)
|
444 |
+
print('Data saved to pair_data.parquet')
|
445 |
+
|
446 |
+
else:
|
447 |
+
pair_data = pd.read_parquet('pair_data.parquet')
|
448 |
+
print('Data loaded from pair_data.parquet')
|
449 |
+
|
450 |
+
# Split data to training and validation (80% training, 20% validation)
|
451 |
+
train_data, val_data = train_test_split(pair_data, test_size=0.2, random_state=42)
|
452 |
+
|
453 |
+
# initialize tokenizer and model
|
454 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
455 |
+
|
456 |
+
# Create the datasets with reduced max_length for better performance
|
457 |
+
train_dataset = pos_neg_pair_dataset(train_data, tokenizer, max_length=128)
|
458 |
+
val_dataset = pos_neg_pair_dataset(val_data, tokenizer, max_length=128)
|
459 |
+
|
460 |
+
# Create dataloaders with smaller batch size for stability
|
461 |
+
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
|
462 |
+
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
|
463 |
+
|
464 |
+
# Train model
|
465 |
+
print("Starting training...")
|
466 |
+
model = train_model(train_loader, num_epochs=3)
|
467 |
+
|
468 |
+
#evaluate the model
|
469 |
+
print("Evaluating model...")
|
470 |
+
evaluate_model(model, val_loader)
|
471 |
+
|
472 |
+
# Save model
|
473 |
+
torch.save(model.state_dict(), 'tag_based_bert_model.pth')
|
474 |
+
print("Model saved to tag_based_bert_model.pth")
|
475 |
+
print("Training Complete")
|
scripts/NLP/processing_files_for_app.py
ADDED
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from transformers import BertTokenizer, BertModel
|
5 |
+
from ast import literal_eval
|
6 |
+
import re
|
7 |
+
import pickle
|
8 |
+
from datetime import datetime
|
9 |
+
|
10 |
+
def clean_text(text):
|
11 |
+
#helper function to clean the text from whitespace, double spaces
|
12 |
+
# converts to lowercase and checks if the text is a string first to avoid errors
|
13 |
+
if not isinstance(text, str):
|
14 |
+
return ''
|
15 |
+
text = text.lower()
|
16 |
+
text = ' '.join(text.split())
|
17 |
+
return text.strip()
|
18 |
+
|
19 |
+
def setup_tag_categories():
|
20 |
+
tag_categories = {
|
21 |
+
'cuisine': [
|
22 |
+
'italian', 'chinese', 'mexican', 'indian', 'french', 'greek', 'thai',
|
23 |
+
'japanese', 'american', 'european', 'asian', 'mediterranean', 'spanish',
|
24 |
+
'german', 'korean', 'vietnamese', 'turkish', 'moroccan', 'lebanese'
|
25 |
+
],
|
26 |
+
'course': [
|
27 |
+
'main-dish', 'side-dishes', 'appetizers', 'desserts', 'breakfast',
|
28 |
+
'lunch', 'dinner', 'snacks', 'beverages', 'salads', 'soups'
|
29 |
+
],
|
30 |
+
'main_ingredient': [
|
31 |
+
'chicken', 'beef', 'pork', 'fish', 'seafood', 'vegetables', 'fruit',
|
32 |
+
'pasta', 'rice', 'cheese', 'chocolate', 'potato', 'lamb', 'turkey',
|
33 |
+
'beans', 'nuts', 'eggs', 'tofu'
|
34 |
+
],
|
35 |
+
'dietary': [
|
36 |
+
'vegetarian', 'vegan', 'gluten-free', 'low-carb', 'healthy', 'low-fat',
|
37 |
+
'diabetic', 'dairy-free', 'keto', 'paleo', 'whole30'
|
38 |
+
],
|
39 |
+
'cooking_method': [
|
40 |
+
'oven', 'stove-top', 'no-cook', 'microwave', 'slow-cooker', 'grilling',
|
41 |
+
'baking', 'roasting', 'frying', 'steaming', 'braising'
|
42 |
+
],
|
43 |
+
'difficulty': ['easy', 'beginner-cook', 'advanced', 'intermediate', 'quick'],
|
44 |
+
'time': [
|
45 |
+
'15-minutes-or-less', '30-minutes-or-less', '60-minutes-or-less',
|
46 |
+
'4-hours-or-less', 'weeknight'
|
47 |
+
],
|
48 |
+
'occasion': [
|
49 |
+
'holiday-event', 'christmas', 'thanksgiving', 'valentines-day',
|
50 |
+
'summer', 'winter', 'spring', 'fall', 'party', 'picnic'
|
51 |
+
]
|
52 |
+
}
|
53 |
+
return tag_categories
|
54 |
+
|
55 |
+
def setup_ingredient_groups():
|
56 |
+
|
57 |
+
ingredient_groups = {
|
58 |
+
'proteins': [
|
59 |
+
'chicken', 'beef', 'pork', 'fish', 'salmon', 'tuna', 'shrimp', 'turkey',
|
60 |
+
'lamb', 'bacon', 'ham', 'sausage', 'eggs', 'tofu', 'beans', 'lentils'
|
61 |
+
],
|
62 |
+
'vegetables': [
|
63 |
+
'onion', 'garlic', 'tomato', 'carrot', 'celery', 'pepper', 'mushroom',
|
64 |
+
'spinach', 'broccoli', 'zucchini', 'potato', 'sweet potato'
|
65 |
+
],
|
66 |
+
'grains_starches': [
|
67 |
+
'rice', 'pasta', 'bread', 'flour', 'oats', 'quinoa', 'barley', 'noodles'
|
68 |
+
],
|
69 |
+
'dairy': [
|
70 |
+
'milk', 'butter', 'cheese', 'cream', 'yogurt', 'sour cream', 'cream cheese'
|
71 |
+
]
|
72 |
+
}
|
73 |
+
return ingredient_groups
|
74 |
+
|
75 |
+
def load_and_clean_recipes(recipes_path):
|
76 |
+
print(f"Loading recipes from {recipes_path}")
|
77 |
+
|
78 |
+
# Load the CSV file
|
79 |
+
recipes_df = pd.read_csv(recipes_path)
|
80 |
+
|
81 |
+
# Clean the recipe names
|
82 |
+
recipes_df['name'] = recipes_df['name'].fillna('unknown recipe').astype(str).apply(clean_text)
|
83 |
+
|
84 |
+
# Update the dataframe
|
85 |
+
recipes_df['description'] = recipes_df['description'].fillna('').astype(str).apply(clean_text)
|
86 |
+
|
87 |
+
# cleaning tags and ingredients from string format
|
88 |
+
recipes_df['tags'] = recipes_df['tags'].apply(literal_eval)
|
89 |
+
recipes_df['ingredients'] = recipes_df['ingredients'].apply(literal_eval)
|
90 |
+
|
91 |
+
# Filter out recipes with no tags or ingredients
|
92 |
+
recipes_df = recipes_df[
|
93 |
+
(recipes_df['tags'].str.len() > 0) &
|
94 |
+
(recipes_df['ingredients'].str.len() > 0) &
|
95 |
+
(recipes_df['name'].str.len() > 0) &
|
96 |
+
(recipes_df['name'] != 'unknown recipe')
|
97 |
+
].reset_index(drop=True)
|
98 |
+
|
99 |
+
|
100 |
+
print(f"Final number of valid recipes: {len(recipes_df)}")
|
101 |
+
return recipes_df
|
102 |
+
|
103 |
+
def categorize_recipe_tags(recipe_tags, tag_categories):
|
104 |
+
categorized_tags = {}
|
105 |
+
|
106 |
+
# Initialize empty lists for each category
|
107 |
+
for category_name in tag_categories.keys():
|
108 |
+
categorized_tags[category_name] = []
|
109 |
+
|
110 |
+
# Check each tag
|
111 |
+
for tag in recipe_tags:
|
112 |
+
tag_lower = tag.lower()
|
113 |
+
|
114 |
+
# Check each category
|
115 |
+
for category_name in tag_categories.keys():
|
116 |
+
category_keywords = tag_categories[category_name]
|
117 |
+
|
118 |
+
# Check if any keyword matches this tag
|
119 |
+
for keyword in category_keywords:
|
120 |
+
if keyword in tag_lower:
|
121 |
+
categorized_tags[category_name].append(tag)
|
122 |
+
break
|
123 |
+
|
124 |
+
return categorized_tags
|
125 |
+
|
126 |
+
def extract_main_ingredients(ingredients_list, ingredient_groups):
|
127 |
+
if not ingredients_list or not isinstance(ingredients_list, list):
|
128 |
+
return []
|
129 |
+
|
130 |
+
# Clean each ingredient
|
131 |
+
cleaned_ingredients = []
|
132 |
+
|
133 |
+
for ingredient in ingredients_list:
|
134 |
+
# Convert to string
|
135 |
+
ingredient_string = str(ingredient) if ingredient is not None else ''
|
136 |
+
if not ingredient_string or ingredient_string == 'nan':
|
137 |
+
continue
|
138 |
+
|
139 |
+
# Make lowercase
|
140 |
+
cleaned_ingredient = ingredient_string.lower()
|
141 |
+
|
142 |
+
# Remove common descriptor words
|
143 |
+
words_to_remove = ['fresh', 'dried', 'chopped', 'minced', 'sliced', 'diced', 'ground', 'large', 'small', 'medium']
|
144 |
+
for word in words_to_remove:
|
145 |
+
cleaned_ingredient = cleaned_ingredient.replace(word, '')
|
146 |
+
|
147 |
+
# Remove numbers
|
148 |
+
cleaned_ingredient = re.sub(r'\d+', '', cleaned_ingredient)
|
149 |
+
|
150 |
+
# Remove measurement words
|
151 |
+
measurement_words = ['cup', 'cups', 'tablespoon', 'tablespoons', 'teaspoon', 'teaspoons', 'pound', 'pounds', 'ounce', 'ounces']
|
152 |
+
for measurement in measurement_words:
|
153 |
+
cleaned_ingredient = cleaned_ingredient.replace(measurement, '')
|
154 |
+
|
155 |
+
# Clean up extra spaces
|
156 |
+
cleaned_ingredient = re.sub(r'\s+', ' ', cleaned_ingredient).strip()
|
157 |
+
|
158 |
+
# Only keep if it's long enough
|
159 |
+
if cleaned_ingredient and len(cleaned_ingredient) > 2:
|
160 |
+
cleaned_ingredients.append(cleaned_ingredient)
|
161 |
+
|
162 |
+
|
163 |
+
# Put ingredients in order of importance
|
164 |
+
ordered_ingredients = []
|
165 |
+
|
166 |
+
# First, add proteins (most important)
|
167 |
+
for ingredient in cleaned_ingredients:
|
168 |
+
for protein in ingredient_groups['proteins']:
|
169 |
+
if protein in ingredient:
|
170 |
+
ordered_ingredients.append(ingredient)
|
171 |
+
break
|
172 |
+
|
173 |
+
|
174 |
+
# Then add vegetables, grains, and dairy
|
175 |
+
other_groups = ['vegetables', 'grains_starches', 'dairy']
|
176 |
+
for group_name in other_groups:
|
177 |
+
for ingredient in cleaned_ingredients:
|
178 |
+
if ingredient not in ordered_ingredients:
|
179 |
+
for group_item in ingredient_groups[group_name]:
|
180 |
+
if group_item in ingredient:
|
181 |
+
ordered_ingredients.append(ingredient)
|
182 |
+
break
|
183 |
+
|
184 |
+
# Finally, add any remaining ingredients
|
185 |
+
for ingredient in cleaned_ingredients:
|
186 |
+
if ingredient not in ordered_ingredients:
|
187 |
+
ordered_ingredients.append(ingredient)
|
188 |
+
|
189 |
+
return ordered_ingredients
|
190 |
+
|
191 |
+
def create_structured_recipe_text(recipe, tag_categories, ingredient_groups):
|
192 |
+
# Get recipe tags and categorize them
|
193 |
+
recipe_tags = recipe['tags'] if isinstance(recipe['tags'], list) else []
|
194 |
+
categorized_tags = categorize_recipe_tags(recipe_tags, tag_categories)
|
195 |
+
|
196 |
+
# Choose tags in priority order
|
197 |
+
priority_categories = ['main_ingredient', 'cuisine', 'course', 'dietary', 'cooking_method']
|
198 |
+
selected_tags = []
|
199 |
+
|
200 |
+
for category in priority_categories:
|
201 |
+
if category in categorized_tags:
|
202 |
+
# Take up to 2 tags from each category
|
203 |
+
category_tags = categorized_tags[category][:2]
|
204 |
+
for tag in category_tags:
|
205 |
+
selected_tags.append(tag)
|
206 |
+
|
207 |
+
# Add some additional important tags
|
208 |
+
important_keywords = ['easy', 'quick', 'healthy', 'spicy', 'sweet']
|
209 |
+
remaining_tags = []
|
210 |
+
|
211 |
+
for tag in recipe_tags:
|
212 |
+
if tag not in selected_tags:
|
213 |
+
for keyword in important_keywords:
|
214 |
+
if keyword in tag.lower():
|
215 |
+
remaining_tags.append(tag)
|
216 |
+
break
|
217 |
+
|
218 |
+
|
219 |
+
# Add up to 3 remaining tags
|
220 |
+
for i in range(min(3, len(remaining_tags))):
|
221 |
+
selected_tags.append(remaining_tags[i])
|
222 |
+
|
223 |
+
# Process ingredients
|
224 |
+
recipe_ingredients = recipe['ingredients'] if isinstance(recipe['ingredients'], list) else []
|
225 |
+
main_ingredients = extract_main_ingredients(recipe_ingredients, ingredient_groups)
|
226 |
+
|
227 |
+
# Step 5: Create the final structured text
|
228 |
+
# Join first 8 ingredients
|
229 |
+
ingredients_text = ', '.join(main_ingredients[:8])
|
230 |
+
|
231 |
+
# Join first 10 tags
|
232 |
+
tags_text = ', '.join(selected_tags[:10])
|
233 |
+
|
234 |
+
# Get recipe name
|
235 |
+
recipe_name = str(recipe['name']).replace(' ', ' ').strip()
|
236 |
+
|
237 |
+
# Create final structured text
|
238 |
+
structured_text = f"Recipe: {recipe_name}. Ingredients: {ingredients_text}. Style: {tags_text}"
|
239 |
+
|
240 |
+
return structured_text
|
241 |
+
|
242 |
+
|
243 |
+
def create_recipe_statistics(interactions_path='RAW_interactions.csv'):
|
244 |
+
print("Creating recipe statistics")
|
245 |
+
|
246 |
+
# Load interactions data
|
247 |
+
interactions_df = pd.read_csv(interactions_path)
|
248 |
+
# Clean interactions data
|
249 |
+
interactions_df = interactions_df.dropna(subset=['rating'])
|
250 |
+
# Convert ratings to numbers
|
251 |
+
interactions_df['rating'] = pd.to_numeric(interactions_df['rating'], errors='coerce')
|
252 |
+
|
253 |
+
# Remove rows where rating conversion failed
|
254 |
+
interactions_df = interactions_df.dropna(subset=['rating'])
|
255 |
+
|
256 |
+
print(f"Valid interactions after cleaning: {len(interactions_df)}")
|
257 |
+
|
258 |
+
# Calculate statistics for each recipe
|
259 |
+
recipe_stats = {}
|
260 |
+
unique_recipe_ids = interactions_df['recipe_id'].unique()
|
261 |
+
|
262 |
+
for recipe_id in unique_recipe_ids:
|
263 |
+
# Get all interactions for this recipe
|
264 |
+
recipe_interactions = interactions_df[interactions_df['recipe_id'] == recipe_id]
|
265 |
+
# Calculate average rating
|
266 |
+
ratings_list = recipe_interactions['rating'].tolist()
|
267 |
+
average_rating = sum(ratings_list) / len(ratings_list)
|
268 |
+
# Count number of ratings
|
269 |
+
number_of_ratings = len(recipe_interactions)
|
270 |
+
# Count unique users
|
271 |
+
unique_users = recipe_interactions['user_id'].nunique()
|
272 |
+
|
273 |
+
recipe_stats[recipe_id] = (average_rating, number_of_ratings, unique_users)
|
274 |
+
|
275 |
+
print(f"Created statistics for {len(recipe_stats)} recipes")
|
276 |
+
return recipe_stats
|
277 |
+
|
278 |
+
def create_recipe_embeddings(recipes_df, model, tokenizer, device, tag_categories, ingredient_groups):
|
279 |
+
print("Creating recipe embeddings (this will take a long time)")
|
280 |
+
|
281 |
+
recipe_embeddings_list = []
|
282 |
+
valid_recipes_list = []
|
283 |
+
|
284 |
+
# Process each recipe one by one
|
285 |
+
for i in range(len(recipes_df)):
|
286 |
+
recipe = recipes_df.iloc[i]
|
287 |
+
|
288 |
+
try:
|
289 |
+
# Create structured text for this recipe
|
290 |
+
recipe_text = create_structured_recipe_text(recipe, tag_categories, ingredient_groups)
|
291 |
+
|
292 |
+
# Tokenize the recipe text
|
293 |
+
tokenized_input = tokenizer(
|
294 |
+
recipe_text,
|
295 |
+
return_tensors='pt',
|
296 |
+
truncation=True,
|
297 |
+
max_length=128,
|
298 |
+
padding='max_length'
|
299 |
+
)
|
300 |
+
|
301 |
+
|
302 |
+
# Get embedding from model
|
303 |
+
with torch.no_grad():
|
304 |
+
tokenized_input = tokenized_input['input_ids'].to(device)
|
305 |
+
tokenized_mask = tokenized_input['attention_mask'].to(device)
|
306 |
+
model_outputs = model(tokenized_input, tokenized_mask)
|
307 |
+
# Get CLS token embedding (first token)
|
308 |
+
cls_embedding = model_outputs.last_hidden_state[:, 0, :]
|
309 |
+
# Move to CPU and convert to numpy
|
310 |
+
embedding_numpy = cls_embedding.cpu().numpy().flatten()
|
311 |
+
|
312 |
+
# Store the embedding and recipe
|
313 |
+
recipe_embeddings_list.append(embedding_numpy)
|
314 |
+
valid_recipes_list.append(recipe.copy())
|
315 |
+
|
316 |
+
# Show progress every 1000 recipes
|
317 |
+
if len(recipe_embeddings_list) % 1000 == 0:
|
318 |
+
print(f"Processed {len(recipe_embeddings_list)} recipes")
|
319 |
+
|
320 |
+
except Exception as e:
|
321 |
+
print(f"Error processing recipe {recipe.get('id', i)}: {e}")
|
322 |
+
continue
|
323 |
+
|
324 |
+
# Convert list to numpy array
|
325 |
+
embeddings_array = np.array(recipe_embeddings_list)
|
326 |
+
|
327 |
+
# Create new dataframe with only valid recipes
|
328 |
+
valid_recipes_df = pd.DataFrame(valid_recipes_list)
|
329 |
+
valid_recipes_df = valid_recipes_df.reset_index(drop=True)
|
330 |
+
|
331 |
+
print(f"Created {len(embeddings_array)} recipe embeddings")
|
332 |
+
return embeddings_array, valid_recipes_df
|
333 |
+
|
334 |
+
def save_all_files(recipes_df, recipe_embeddings, recipe_stats):
|
335 |
+
print("Saving all files...")
|
336 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
337 |
+
np.save(f'recipe_embeddings_{timestamp}.npy', recipe_embeddings)
|
338 |
+
print(f"Saved embeddings")
|
339 |
+
|
340 |
+
# Save filtered recipes dataframe
|
341 |
+
with open(f'filtered_recipes_{timestamp}.pkl', 'wb') as f:
|
342 |
+
pickle.dump(recipes_df, f)
|
343 |
+
print(f"Saved recipes.")
|
344 |
+
|
345 |
+
# Save recipe statistics
|
346 |
+
with open(f'recipe_statistics_{timestamp}.pkl', 'wb') as f:
|
347 |
+
pickle.dump(recipe_stats, f)
|
348 |
+
print(f"Saved statistics")
|
349 |
+
|
350 |
+
print("All files saved successfully!")
|
351 |
+
|
352 |
+
def create_all_necessary_files(recipes_path, interactions_path, model_path):
|
353 |
+
print("Starting full preprocessing pipeline")
|
354 |
+
|
355 |
+
# Set up device
|
356 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
357 |
+
print(f"Using device: {device}")
|
358 |
+
|
359 |
+
# Load tokenizer
|
360 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
361 |
+
|
362 |
+
# Load the trained model
|
363 |
+
model = BertModel.from_pretrained('bert-base-uncased')
|
364 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
365 |
+
model.to(device)
|
366 |
+
model.eval()
|
367 |
+
|
368 |
+
# Set up tag categories and ingredient groups
|
369 |
+
tag_categories = setup_tag_categories()
|
370 |
+
ingredient_groups = setup_ingredient_groups()
|
371 |
+
|
372 |
+
# Load and clean recipes
|
373 |
+
recipes_df = load_and_clean_recipes(recipes_path)
|
374 |
+
|
375 |
+
# Create recipe statistics
|
376 |
+
recipe_stats = create_recipe_statistics(interactions_path)
|
377 |
+
|
378 |
+
# Create recipe embeddings
|
379 |
+
recipe_embeddings, filtered_recipes_df = create_recipe_embeddings(
|
380 |
+
recipes_df, model, tokenizer, device, tag_categories, ingredient_groups
|
381 |
+
)
|
382 |
+
|
383 |
+
# Save all files
|
384 |
+
save_all_files(filtered_recipes_df, recipe_embeddings, recipe_stats)
|
385 |
+
|
386 |
+
if __name__ == "__main__":
|
387 |
+
create_all_necessary_files(
|
388 |
+
recipes_path='RAW_recipes.csv',
|
389 |
+
interactions_path='RAW_interactions.csv',
|
390 |
+
model_path='tag_based_bert_model.pth'
|
391 |
+
)
|
392 |
+
|
393 |
+
print("All preprocessing complete! You can now use the search system.")
|
scripts/NLP/search_script.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from transformers import BertTokenizer, BertModel
|
4 |
+
import pickle
|
5 |
+
import json
|
6 |
+
class RecipeSearchSystem:
|
7 |
+
|
8 |
+
def __init__(self, model_path='tag_based_bert_model.pth', max_recipes=231630):
|
9 |
+
# Set up device
|
10 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
11 |
+
|
12 |
+
# Load tokenizer
|
13 |
+
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
14 |
+
|
15 |
+
# Load the trained model
|
16 |
+
self.model = BertModel.from_pretrained('bert-base-uncased')
|
17 |
+
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
|
18 |
+
|
19 |
+
self.model.to(self.device)
|
20 |
+
self.model.eval()
|
21 |
+
|
22 |
+
# Load all the preprocessed files
|
23 |
+
self.max_recipes = max_recipes
|
24 |
+
#load recipe embeddings
|
25 |
+
self.recipe_embeddings = np.load(f'advanced_recipe_embeddings_{self.max_recipes}.npy')
|
26 |
+
#load recipes dataframe
|
27 |
+
with open(f'advanced_filtered_recipes_{self.max_recipes}.pkl', 'rb') as f:
|
28 |
+
self.recipes_df = pickle.load(f)
|
29 |
+
#load recipe statistics
|
30 |
+
with open(f'recipe_statistics_{self.max_recipes}.pkl', 'rb') as f:
|
31 |
+
self.recipe_stats = pickle.load(f)
|
32 |
+
|
33 |
+
|
34 |
+
def create_query_embedding(self, user_query):
|
35 |
+
|
36 |
+
structured_query = f"anchor: {user_query.lower()}"
|
37 |
+
|
38 |
+
# Tokenize the query
|
39 |
+
tokenized_query = self.tokenizer(
|
40 |
+
structured_query,
|
41 |
+
return_tensors='pt',
|
42 |
+
truncation=True,
|
43 |
+
max_length=128,
|
44 |
+
padding='max_length'
|
45 |
+
)
|
46 |
+
|
47 |
+
# Move to device
|
48 |
+
tokenized_query = tokenized_query.to(self.device)
|
49 |
+
|
50 |
+
# Get embedding from model
|
51 |
+
with torch.no_grad():
|
52 |
+
anchor_input_ids = tokenized_query['input_ids'].to(self.device)
|
53 |
+
anchor_attention_mask = tokenized_query['attention_mask'].to(self.device)
|
54 |
+
anchor_outputs = self.model(anchor_input_ids, anchor_attention_mask)
|
55 |
+
# Get CLS token embedding
|
56 |
+
anchor_embedding = anchor_outputs.last_hidden_state[:, 0, :]
|
57 |
+
# Move to CPU and convert to numpy
|
58 |
+
query_embedding_numpy = anchor_embedding.cpu().numpy().flatten()
|
59 |
+
|
60 |
+
return query_embedding_numpy
|
61 |
+
|
62 |
+
def calculate_similarities(self, query_embedding):
|
63 |
+
similarities = []
|
64 |
+
|
65 |
+
# Calculate cosine similarity for each recipe
|
66 |
+
for i in range(len(self.recipe_embeddings)):
|
67 |
+
recipe_embedding = self.recipe_embeddings[i]
|
68 |
+
|
69 |
+
# Calculate cosine similarity
|
70 |
+
#Cosine Similarity = (a · b) / (||a|| * ||b||)
|
71 |
+
dot_product = np.dot(recipe_embedding, query_embedding)
|
72 |
+
recipe_norm = np.linalg.norm(recipe_embedding)
|
73 |
+
query_norm = np.linalg.norm(query_embedding)
|
74 |
+
|
75 |
+
# Avoid division by zero
|
76 |
+
if recipe_norm > 0 and query_norm > 0:
|
77 |
+
similarity = dot_product / (recipe_norm * query_norm)
|
78 |
+
else:
|
79 |
+
similarity = 0.0
|
80 |
+
|
81 |
+
similarities.append(similarity)
|
82 |
+
|
83 |
+
return similarities
|
84 |
+
|
85 |
+
def filter_recipes_by_quality(self, min_rating=3.0, min_num_ratings=5):
|
86 |
+
#Get all indexes for recipes that meet the quality criteria the user chose
|
87 |
+
filtered_recipe_indices = []
|
88 |
+
|
89 |
+
for i in range(len(self.recipes_df)):
|
90 |
+
recipe = self.recipes_df.iloc[i]
|
91 |
+
recipe_id = recipe['id']
|
92 |
+
|
93 |
+
if recipe_id in self.recipe_stats:
|
94 |
+
avg_rating, num_ratings, _ = self.recipe_stats[recipe_id]
|
95 |
+
|
96 |
+
if avg_rating >= min_rating and num_ratings >= min_num_ratings:
|
97 |
+
filtered_recipe_indices.append(i)
|
98 |
+
|
99 |
+
return filtered_recipe_indices
|
100 |
+
|
101 |
+
def rank_recipes_by_similarity_and_rating(self, similarities, recipe_indices):
|
102 |
+
recipe_scores = []
|
103 |
+
|
104 |
+
for recipe_index in recipe_indices:
|
105 |
+
recipe = self.recipes_df.iloc[recipe_index]
|
106 |
+
recipe_id = recipe['id']
|
107 |
+
|
108 |
+
semantic_score = similarities[recipe_index]
|
109 |
+
|
110 |
+
#if the recipe has no ratings we will assume it is a bad recipe to choose and set the ratio to 1.0
|
111 |
+
if recipe_id in self.recipe_stats:
|
112 |
+
avg_rating, _, _ = self.recipe_stats[recipe_id]
|
113 |
+
else:
|
114 |
+
avg_rating = 1.0
|
115 |
+
|
116 |
+
recipe_scores.append({
|
117 |
+
'recipe_index': recipe_index,
|
118 |
+
'recipe_id': recipe_id,
|
119 |
+
'semantic_score': semantic_score,
|
120 |
+
'avg_rating': avg_rating
|
121 |
+
})
|
122 |
+
|
123 |
+
return recipe_scores
|
124 |
+
|
125 |
+
def create_recipe_result(self, recipe_index, scores_info):
|
126 |
+
recipe = self.recipes_df.iloc[recipe_index]
|
127 |
+
recipe_id = recipe['id']
|
128 |
+
|
129 |
+
|
130 |
+
avg_rating, num_ratings, unique_users = self.recipe_stats[recipe_id]
|
131 |
+
|
132 |
+
|
133 |
+
# Create result structure mapping
|
134 |
+
result = {
|
135 |
+
'recipe_id': int(recipe_id),
|
136 |
+
'name': recipe['name'],
|
137 |
+
'ingredients': recipe['ingredients'],
|
138 |
+
'tags': recipe['tags'],
|
139 |
+
'minutes': int(recipe['minutes']),
|
140 |
+
'n_steps': int(recipe['n_steps']),
|
141 |
+
'description': recipe.get('description', ''),
|
142 |
+
'semantic_score': float(scores_info['semantic_score']),
|
143 |
+
'avg_rating': float(avg_rating),
|
144 |
+
'num_ratings': int(num_ratings),
|
145 |
+
'unique_users': int(unique_users)
|
146 |
+
}
|
147 |
+
|
148 |
+
result = json.dumps(result)
|
149 |
+
return result
|
150 |
+
|
151 |
+
def search_recipes(self, user_query, top_k=5, min_rating=3.0, min_num_ratings=5):
|
152 |
+
|
153 |
+
# Create embedding for user query
|
154 |
+
query_embedding = self.create_query_embedding(user_query)
|
155 |
+
|
156 |
+
# Calculate similarities between query and all recipes
|
157 |
+
similarities = self.calculate_similarities(query_embedding)
|
158 |
+
|
159 |
+
# Filter recipes by quality
|
160 |
+
filtered_recipe_indices = self.filter_recipes_by_quality(min_rating, min_num_ratings)
|
161 |
+
|
162 |
+
# Rank by semantic similarity and rating
|
163 |
+
recipe_scores = self.rank_recipes_by_similarity_and_rating(similarities, filtered_recipe_indices)
|
164 |
+
|
165 |
+
# Sort by semantic similarity, then by average rating
|
166 |
+
recipe_scores.sort(key=lambda x: (x['semantic_score'], x['avg_rating']), reverse=True)
|
167 |
+
|
168 |
+
# Get top results
|
169 |
+
top_results = recipe_scores[:top_k]
|
170 |
+
|
171 |
+
# Create result dictionaries
|
172 |
+
final_results = []
|
173 |
+
for score_info in top_results:
|
174 |
+
recipe_result = self.create_recipe_result(score_info['recipe_index'], score_info)
|
175 |
+
final_results.append(recipe_result)
|
176 |
+
|
177 |
+
return final_results
|
178 |
+
|
179 |
+
|
180 |
+
def search_for_recipes(user_query, top_k=5, min_rating=3.0, min_num_ratings=5):
|
181 |
+
search_system = RecipeSearchSystem()
|
182 |
+
results = search_system.search_recipes(
|
183 |
+
user_query=user_query,
|
184 |
+
top_k=top_k,
|
185 |
+
min_rating=min_rating,
|
186 |
+
min_num_ratings=min_num_ratings
|
187 |
+
)
|
188 |
+
|
189 |
+
return results
|
190 |
+
|
191 |
+
|
192 |
+
if __name__ == "__main__":
|
193 |
+
|
194 |
+
search_system = RecipeSearchSystem()
|
195 |
+
test_queries = [
|
196 |
+
# "chicken pasta italian quick dinner",
|
197 |
+
# "chocolate cake dessert brownie baked healthy",
|
198 |
+
# "healthy vegetarian salad tomato basil",
|
199 |
+
# "quick easy dinner",
|
200 |
+
# "beef steak",
|
201 |
+
"beef pasta",
|
202 |
+
"beef"
|
203 |
+
]
|
204 |
+
|
205 |
+
for query in test_queries:
|
206 |
+
print(f"Testing query: '{query}'")
|
207 |
+
|
208 |
+
results = search_system.search_recipes(
|
209 |
+
user_query=query,
|
210 |
+
top_k=3,
|
211 |
+
min_rating=3.5,
|
212 |
+
min_num_ratings=10
|
213 |
+
)
|
214 |
+
|
215 |
+
print (results)
|
216 |
+
print("Recipe search system testing complete!")
|
utils/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
utils/__init__.py
ADDED
File without changes
|
utils/layout.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# layout.py
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
def set_custom_page_config():
|
5 |
+
st.set_page_config(
|
6 |
+
page_title="Smart Kitchen Assistant",
|
7 |
+
layout="wide",
|
8 |
+
initial_sidebar_state="expanded"
|
9 |
+
)
|
10 |
+
|
11 |
+
def render_header():
|
12 |
+
st.markdown("""
|
13 |
+
<div class="project-header">
|
14 |
+
<h1>Smart Kitchen Assistant</h1>
|
15 |
+
<p>CSE555 Final Project — Group 5: Saksham & Ahmed</p>
|
16 |
+
</div>
|
17 |
+
""", unsafe_allow_html=True)
|
18 |
+
|
19 |
+
def render_footer():
|
20 |
+
st.markdown("""
|
21 |
+
<div class="footer">
|
22 |
+
<p>Made with ❤️ by Saksham & Ahmed | CSE555 @ UB</p>
|
23 |
+
</div>
|
24 |
+
""", unsafe_allow_html=True)
|
25 |
+
|
26 |
+
def render_layout(content_function):
|
27 |
+
set_custom_page_config()
|
28 |
+
with open("assets/css/styles.css") as f:
|
29 |
+
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
|
30 |
+
|
31 |
+
render_header()
|
32 |
+
content_function()
|
33 |
+
render_footer()
|