|
import torch |
|
from torch import nn |
|
import torchvision.transforms as transforms |
|
from torchvision import models |
|
from PIL import Image |
|
import time |
|
from typing import Tuple, Dict |
|
import json |
|
import gradio as gr |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
class_names = [ |
|
'apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare', 'beet_salad', 'beignets', 'bibimbap', |
|
'bread_pudding', 'breakfast_burrito', 'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake', |
|
'ceviche', 'cheese_plate', 'cheesecake', 'chicken_curry', 'chicken_quesadilla', 'chicken_wings', 'chocolate_cake', |
|
'chocolate_mousse', 'churros', 'clam_chowder', 'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame', |
|
'cup_cakes', 'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict', 'escargots', 'falafel', |
|
'filet_mignon', 'fish_and_chips', 'foie_gras', 'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari', |
|
'fried_rice', 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', 'grilled_cheese_sandwich', 'grilled_salmon', |
|
'guacamole', 'gyoza', 'hamburger', 'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream', |
|
'lasagna', 'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese', 'macarons', 'miso_soup', 'mussels', |
|
'nachos', 'omelette', 'onion_rings', 'oysters', 'pad_thai', 'paella', 'pancakes', 'panna_cotta', 'peking_duck', |
|
'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib', 'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake', |
|
'risotto', 'samosa', 'sashimi', 'scallops', 'seaweed_salad', 'shrimp_and_grits', 'spaghetti_bolognese', |
|
'spaghetti_carbonara', 'spring_rolls', 'steak', 'strawberry_shortcake', 'sushi', 'tacos', 'takoyaki', 'tiramisu', |
|
'tuna_tartare', 'waffles' |
|
] |
|
|
|
|
|
def load_and_concatenate_json(json_paths: Tuple[str, str]) -> list: |
|
combined_data = [] |
|
for json_path in json_paths: |
|
with open(json_path, 'r') as f: |
|
data = json.load(f) |
|
combined_data.extend(data) |
|
return combined_data |
|
|
|
json_paths = ('./demos/recipenet/recipesData_v1.json', './demos/recipenet/recipesData_v2.json') |
|
combined_data = load_and_concatenate_json(json_paths) |
|
|
|
|
|
weights = models.EfficientNet_B0_Weights.DEFAULT |
|
model = models.efficientnet_b0(weights=weights) |
|
|
|
|
|
model.classifier = nn.Sequential( |
|
nn.Dropout(p=0.2, inplace=True), |
|
nn.Linear(in_features=1280, out_features=101) |
|
) |
|
model.heads = nn.Linear(in_features=768, out_features=101) |
|
|
|
|
|
state_dict = torch.load('./demos/recipenet/best_model.pth', map_location=device) |
|
model.load_state_dict(state_dict) |
|
model.to(device) |
|
auto_transforms = weights.transforms() |
|
|
|
def get_recipe_details(label: str, combined_data: list) -> dict: |
|
label = label.replace("_", " ") |
|
for item in combined_data: |
|
if item["FoodName"].lower() == label.lower(): |
|
return { |
|
"Ingredients": item["RawIngredients"], |
|
"RecipeURL": item["RecipeURL"], |
|
"RecipeName": item["RecipeName"] |
|
} |
|
return {} |
|
|
|
def predict(img: Image.Image, model, class_names, device, combined_data: list) -> Tuple[Dict[str, float], float, dict]: |
|
|
|
start_time = time.time() |
|
|
|
|
|
img = auto_transforms(img).unsqueeze(0).to(device) |
|
|
|
|
|
model.eval() |
|
|
|
with torch.inference_mode(): |
|
|
|
pred_probs = torch.softmax(model(img), dim=1) |
|
|
|
|
|
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))} |
|
|
|
|
|
end_time = time.time() |
|
pred_time = end_time - start_time |
|
|
|
|
|
max_label = max(pred_labels_and_probs, key=pred_labels_and_probs.get) |
|
max_prob = pred_labels_and_probs[max_label] |
|
|
|
|
|
recipe_details = get_recipe_details(max_label, combined_data) |
|
|
|
return pred_labels_and_probs, pred_time, recipe_details |
|
|
|
def predict_wrapper(img: Image.Image): |
|
|
|
pred_labels_and_probs, pred_time, recipe_details = predict(img, model, class_names, device, combined_data) |
|
|
|
|
|
recipe = recipe_details.get('RecipeName', '') |
|
website = recipe_details.get('RecipeURL', '') |
|
ingredients = recipe_details.get('Ingredients', []) |
|
time = len(ingredients) * 10 |
|
ingredients = ', '.join(ingredients) |
|
|
|
return pred_labels_and_probs, pred_time, recipe, time, website, ingredients |
|
|
|
|
|
css = """ |
|
.gradio-container,.svelte-vt1mxs.gap.panel { |
|
background: repeating-linear-gradient( |
|
to top, |
|
rgba(255, 255, 255, 0.03) 0px 2px, |
|
transparent 2px 4px |
|
), |
|
linear-gradient(to bottom, #200933 75%, #3d0b43); |
|
color: #d5d7de; |
|
font-family: sans-serif; |
|
} |
|
gradio-app[control_page_title="true"][embed="false"][eager="true"] { |
|
background-color: #271139 !important; |
|
background-image: linear-gradient(to bottom, transparent 95%, #000 5%); |
|
} |
|
|
|
.gradio-root { |
|
overflow: hidden; |
|
box-shadow: 0px 0px 10px 0px rgba(0, 0, 0, 0.1); |
|
} |
|
|
|
.gradio-control { |
|
border: none; |
|
} |
|
|
|
.gradio-input { |
|
background-color: rgba(255, 255, 255, 0.1); |
|
border: 1px solid rgba(255, 255, 255, 0.2); |
|
} |
|
""" |
|
|
|
demo = gr.Interface( |
|
fn=predict_wrapper, |
|
inputs=gr.Image(type='pil'), |
|
outputs=[ |
|
gr.Label(num_top_classes=3, label='Predictions'), |
|
gr.Number(label='Prediction Time (s)'), |
|
gr.Textbox(label='Dish Name', elem_id="textbox_id"), |
|
gr.Textbox(label='Preparation time(in Minutes)'), |
|
gr.Textbox(label='Recipe available at'), |
|
gr.Textbox(label='Ingredients') |
|
], |
|
description="recipenet", |
|
article="Made using EfficientNet-b0", |
|
css=css, |
|
examples=[["./demos/recipenet/examples/pic1.jpg"], ["./demos/recipenet/examples/pic2.jpg"], ["./demos/recipenet/examples/pic3.jpg"]] |
|
) |
|
|
|
demo.launch(debug=False, share=False) |
|
|