ankz22's picture
Inital commit
d2454ee
raw
history blame
1.43 kB
import gradio as gr
from transformers import pipeline
from PIL import Image
import torch
from torchvision import transforms
INGREDIENT_MODEL_ID = "stchakman/Fridge_Items_Model"
ingredient_classifier = pipeline(
"image-classification",
model=INGREDIENT_MODEL_ID,
device=0 if torch.cuda.is_available() else -1,
top_k=5
)
RECIPE_MODEL_ID = "flax-community/t5-recipe-generation"
recipe_generator = pipeline("text2text-generation", model=RECIPE_MODEL_ID)
augment = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(10),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
])
def generate_recipe(image: Image.Image):
image_aug = augment(image)
results = ingredient_classifier(image_aug)
ingredients = [res["label"] for res in results]
ingredient_str = ", ".join(ingredients)
prompt = f"Ingredients: {ingredient_str}. Recipe:"
recipe = recipe_generator(prompt, max_length=300, do_sample=True)[0]["generated_text"]
return f"### Ingredients detectes :\n{ingredient_str}\n\n### Recette generee :\n{recipe}"
interface = gr.Interface(
fn=generate_recipe,
inputs=gr.Image(type="pil"),
outputs=gr.Markdown(),
title="🥕 Generateur de recettes 🧑‍🍳",
description="Envoyer une image d'ingredients pour recevoir une recette",
allow_flagging="never"
)
if __name__ == "__main__":
interface.launch()