sakshamlakhera commited on
Commit
b274faf
·
1 Parent(s): 2932a64

Home update

Browse files
Files changed (2) hide show
  1. Home.py +67 -20
  2. config.py +1 -1
Home.py CHANGED
@@ -1,52 +1,102 @@
1
  import streamlit as st
2
  from PIL import Image
3
- from model.classifier import get_model, predict
4
  from model.search_script import search_for_recipes
5
  import streamlit.components.v1 as components
6
- import time
7
  import base64
8
-
9
  from utils.layout import render_layout
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  @st.cache_resource
12
  def load_model():
13
  return get_model()
14
 
 
 
 
 
 
 
15
  def classification_and_recommendation_page():
16
- st.markdown("## 🖼️ Task A: Image Classification + 🍽️ Recipe Recommendation")
17
  st.markdown("""
18
- <div class="about-box">
 
 
 
 
 
 
 
 
 
19
  Upload one or more food images. This module classifies each image into
20
- <b>Onion, Pear, Strawberry, or Tomato</b> using EfficientNet-B0, and then recommends recipes
21
- based on the combined classification results.
22
- </div>
 
 
 
 
 
 
23
  """, unsafe_allow_html=True)
24
 
 
 
25
  model = load_model()
26
 
27
- # --- Upload and classify ---
28
  uploaded_files = st.file_uploader("📤 Upload images (JPG/PNG)", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
29
 
30
  if "uploaded_images" not in st.session_state:
31
  st.session_state.uploaded_images = []
32
  if "image_tags" not in st.session_state:
33
  st.session_state.image_tags = {}
 
 
34
 
35
  if uploaded_files:
36
  for img_file in uploaded_files:
37
  if img_file.name not in [img.name for img in st.session_state.uploaded_images]:
38
  img = Image.open(img_file).convert("RGB")
39
- label, _ = predict(img, model)
 
 
 
 
 
 
 
 
40
  st.session_state.uploaded_images.append(img_file)
41
  st.session_state.image_tags[img_file.name] = label
 
42
 
43
- # --- Show grid of classified images ---
 
 
 
 
44
  if st.session_state.uploaded_images:
45
  html = """
46
  <style>
47
  .image-grid { display: flex; flex-wrap: wrap; gap: 12px; margin-top: 10px; }
48
  .image-card {
49
- width: 140px; height: 180px;
50
  border: 1px solid #ccc; border-radius: 10px;
51
  overflow: hidden; text-align: center;
52
  font-size: 13px; position: relative;
@@ -56,31 +106,28 @@ def classification_and_recommendation_page():
56
  max-width: 100%; max-height: 110px;
57
  object-fit: contain; margin-top: 5px;
58
  }
59
- .remove-btn {
60
- position: absolute; top: 2px; right: 6px;
61
- color: #d33; background: #fff;
62
- border: none; cursor: pointer; font-size: 16px;
63
- }
64
  </style>
65
  <div class="image-grid">
66
  """
67
 
68
  for img in st.session_state.uploaded_images:
69
  label = st.session_state.image_tags.get(img.name, "unknown")
 
 
70
  img_b64 = base64.b64encode(img.getvalue()).decode()
 
71
  html += f"""
72
  <div class="image-card">
73
  <img src="data:image/png;base64,{img_b64}" />
74
- <div><b>{label.upper()}</b></div>
75
  <div style="color:gray; font-size:11px;">{img.name}</div>
76
  </div>
77
  """
78
 
79
  html += "</div>"
80
  grid_rows = ((len(st.session_state.uploaded_images) - 1) // 5 + 1)
81
- components.html(html, height=200 * grid_rows + 20, scrolling=True)
82
 
83
- # --- Recipe Search ---
84
  st.markdown("---")
85
  st.markdown("## 🔍 Recipe Recommendation")
86
 
 
1
  import streamlit as st
2
  from PIL import Image
3
+ from model.classifier import get_model, predict, get_model_by_name
4
  from model.search_script import search_for_recipes
5
  import streamlit.components.v1 as components
 
6
  import base64
7
+ import config as config
8
  from utils.layout import render_layout
9
 
10
+ MODEL_PATH_MAP = {
11
+ "Onion": config.MODEL_PATH_ONION,
12
+ "Pear": config.MODEL_PATH_PEAR,
13
+ "Strawberry": config.MODEL_PATH_STRAWBERRY,
14
+ "Tomato": config.MODEL_PATH_TOMATO
15
+ }
16
+
17
+ VARIATION_CLASS_MAP = {
18
+ "Onion": ['halved', 'sliced', 'whole'],
19
+ "Strawberry": ['Hulled', 'sliced', 'whole'],
20
+ "Tomato": ['diced', 'vines', 'whole'],
21
+ "Pear": ['halved', 'sliced', 'whole']
22
+ }
23
+
24
  @st.cache_resource
25
  def load_model():
26
  return get_model()
27
 
28
+ @st.cache_resource
29
+ def load_model_variation(product_name):
30
+ model_path = MODEL_PATH_MAP[product_name]
31
+ num_classes = len(VARIATION_CLASS_MAP[product_name])
32
+ return get_model_by_name(model_path, num_classes=num_classes)
33
+
34
  def classification_and_recommendation_page():
35
+ st.markdown("## 🍽️ Recipe Recommendation System")
36
  st.markdown("""
37
+ <div style='
38
+ background-color: #f9f9f9;
39
+ border-left: 6px solid #4CAF50;
40
+ padding: 16px;
41
+ border-radius: 10px;
42
+ font-size: 15px;
43
+ line-height: 1.6;
44
+ '>
45
+ <b>📚 Recipe Recommendation Guide</b><br><br>
46
+
47
  Upload one or more food images. This module classifies each image into
48
+ <b>Onion, Pear, Strawberry, or Tomato</b> using <b>EfficientNet-B0</b>, and recommends recipes
49
+ based on the combined classification results.<br><br>
50
+
51
+ <b>Steps:</b><br>
52
+ 1️⃣ Upload images (single or multiple) of produce, or directly add tags for recipe search.<br>
53
+ 2️⃣ Once uploaded, the corresponding produce tag will be automatically added to the search.<br>
54
+ 3️⃣ Use the sliders to choose the number of results and minimum recipe rating.<br>
55
+ 4️⃣ Click <b>"Search Recipe"</b> to view personalized recommendations.
56
+ </div></br>
57
  """, unsafe_allow_html=True)
58
 
59
+
60
+
61
  model = load_model()
62
 
 
63
  uploaded_files = st.file_uploader("📤 Upload images (JPG/PNG)", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
64
 
65
  if "uploaded_images" not in st.session_state:
66
  st.session_state.uploaded_images = []
67
  if "image_tags" not in st.session_state:
68
  st.session_state.image_tags = {}
69
+ if "image_variations" not in st.session_state:
70
+ st.session_state.image_variations = {}
71
 
72
  if uploaded_files:
73
  for img_file in uploaded_files:
74
  if img_file.name not in [img.name for img in st.session_state.uploaded_images]:
75
  img = Image.open(img_file).convert("RGB")
76
+ label, main_class_prob = predict(img, model)
77
+
78
+ variation = None
79
+ if label in VARIATION_CLASS_MAP:
80
+ variation_model = load_model_variation(label)
81
+ class_labels = VARIATION_CLASS_MAP[label]
82
+ variation_label, var_conf = predict(img, variation_model, class_labels=class_labels)
83
+ variation = f"{variation_label} ({var_conf*main_class_prob* 100:.1f}%)"
84
+
85
  st.session_state.uploaded_images.append(img_file)
86
  st.session_state.image_tags[img_file.name] = label
87
+ st.session_state.image_variations[img_file.name] = variation
88
 
89
+ current_file_names = [f.name for f in uploaded_files] if uploaded_files else []
90
+ st.session_state.uploaded_images = [f for f in st.session_state.uploaded_images if f.name in current_file_names]
91
+ st.session_state.image_tags = {k: v for k, v in st.session_state.image_tags.items() if k in current_file_names}
92
+ st.session_state.image_variations = {k: v for k, v in st.session_state.image_variations.items() if k in current_file_names}
93
+
94
  if st.session_state.uploaded_images:
95
  html = """
96
  <style>
97
  .image-grid { display: flex; flex-wrap: wrap; gap: 12px; margin-top: 10px; }
98
  .image-card {
99
+ width: 140px; height: 200px;
100
  border: 1px solid #ccc; border-radius: 10px;
101
  overflow: hidden; text-align: center;
102
  font-size: 13px; position: relative;
 
106
  max-width: 100%; max-height: 110px;
107
  object-fit: contain; margin-top: 5px;
108
  }
 
 
 
 
 
109
  </style>
110
  <div class="image-grid">
111
  """
112
 
113
  for img in st.session_state.uploaded_images:
114
  label = st.session_state.image_tags.get(img.name, "unknown")
115
+ variation = st.session_state.image_variations.get(img.name, "")
116
+ combined_label = f"{label.upper()} </br> {variation}" if variation else label.upper()
117
  img_b64 = base64.b64encode(img.getvalue()).decode()
118
+
119
  html += f"""
120
  <div class="image-card">
121
  <img src="data:image/png;base64,{img_b64}" />
122
+ <div style="margin-top: 5px; font-weight: bold; font-size: 13px;">{combined_label}</div>
123
  <div style="color:gray; font-size:11px;">{img.name}</div>
124
  </div>
125
  """
126
 
127
  html += "</div>"
128
  grid_rows = ((len(st.session_state.uploaded_images) - 1) // 5 + 1)
129
+ components.html(html, height=200 * grid_rows + 40, scrolling=True)
130
 
 
131
  st.markdown("---")
132
  st.markdown("## 🔍 Recipe Recommendation")
133
 
config.py CHANGED
@@ -1,4 +1,4 @@
1
- CLASS_LABELS = ['onion', 'pear', 'strawberry', 'tomato']
2
 
3
  MODEL_PATH = "assets/modelWeights/best_model_v1.pth"
4
  MODEL_PATH_ONION = "assets/modelWeights/best_model_onion_v1.pth"
 
1
+ CLASS_LABELS = ['Onion', 'Pear', 'Strawberry', 'Tomato']
2
 
3
  MODEL_PATH = "assets/modelWeights/best_model_v1.pth"
4
  MODEL_PATH_ONION = "assets/modelWeights/best_model_onion_v1.pth"