DockerRecipe / app.py
TimInf's picture
Update app.py
b7333e0 verified
raw
history blame
7.13 kB
from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer, AutoModel
import torch
import numpy as np
import random
import json
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from pydantic import BaseModel
# Lade RecipeBERT Modell
bert_model_name = "alexdseo/RecipeBERT"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
bert_model = AutoModel.from_pretrained(bert_model_name)
bert_model.eval() # Setze das Modell in den Evaluationsmodus
# Lade T5 Rezeptgenerierungsmodell (NEU hinzugefügt)
MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH) # Modell wird jetzt auch geladen
# Token Mapping für die T5 Modell-Ausgabe (bleibt hier, obwohl T5 noch nicht aktiv generiert)
special_tokens = t5_tokenizer.all_special_tokens
tokens_map = {
"<sep>": "--",
"<section>": "\n"
}
# --- RecipeBERT-spezifische Funktionen (unverändert) ---
def get_embedding(text):
"""Berechnet das Embedding für einen Text mit Mean Pooling über alle Tokens."""
inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = bert_model(**inputs)
attention_mask = inputs['attention_mask']
token_embeddings = outputs.last_hidden_state
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return (sum_embeddings / sum_mask).squeeze(0)
def average_embedding(embedding_list):
"""Berechnet den Durchschnitt einer Liste von Embeddings."""
tensors = torch.stack(embedding_list)
return tensors.mean(dim=0)
def get_cosine_similarity(vec1, vec2):
"""Berechnet die Cosinus-Ähnlichkeit zwischen zwei Vektoren."""
if torch.is_tensor(vec1): vec1 = vec1.detach().numpy()
if torch.is_tensor(vec2): vec2 = vec2.detach().numpy()
vec1 = vec1.flatten()
vec2 = vec2.flatten()
dot_product = np.dot(vec1, vec2)
norm_a = np.linalg.norm(vec1)
norm_b = np.linalg.norm(vec2)
if norm_a == 0 or norm_b == 0: return 0
return dot_product / (norm_a * norm_b)
# find_best_ingredients (unverändert, nutzt RecipeBERT für eine ähnlichste Zutat)
def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6):
"""
Findet die besten Zutaten: Alle benötigten + EINE ähnlichste aus den verfügbaren Zutaten.
"""
required_ingredients = list(set(required_ingredients))
available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients]))
final_ingredients = required_ingredients.copy()
# Nur wenn wir noch Platz haben und zusätzliche Zutaten verfügbar sind
if len(final_ingredients) < max_ingredients and len(available_ingredients) > 0:
if final_ingredients:
required_embeddings = [get_embedding(ing) for ing in required_ingredients]
avg_required_embedding = average_embedding(required_embeddings)
best_additional_ingredient = None
highest_similarity = -1.0
for avail_ing in available_ingredients:
avail_embedding = get_embedding(avail_ing)
similarity = get_cosine_similarity(avg_required_embedding, avail_embedding)
if similarity > highest_similarity:
highest_similarity = similarity
best_additional_ingredient = avail_ing
if best_additional_ingredient:
final_ingredients.append(best_additional_ingredient)
print(f"INFO: Added '{best_additional_ingredient}' (similarity: {highest_similarity:.2f}) as most similar.")
else:
random_ingredient = random.choice(available_ingredients)
final_ingredients.append(random_ingredient)
print(f"INFO: No required ingredients. Added random available ingredient: '{random_ingredient}'.")
return final_ingredients[:max_ingredients]
# mock_generate_recipe (ANGEPASST, um zu bestätigen, dass BEIDE Modelle geladen sind)
def mock_generate_recipe(ingredients_list):
"""Generiert ein Mock-Rezept und bestätigt das Laden beider Modelle."""
title = f"Rezepttest mit {', '.join(ingredients_list[:3])}" if ingredients_list else "Testrezept"
return {
"title": title,
"ingredients": ingredients_list,
"directions": [
"Dies ist ein Testrezept.",
"RecipeBERT und T5-Modell wurden beide erfolgreich geladen!",
"Die Zutaten wurden mit RecipeBERT-Intelligenz ausgewählt.",
f"Basierend auf deinen Eingaben wurde '{ingredients_list[-1]}' als ähnlichste Zutat hinzugefügt." if len(ingredients_list) > 1 else "Keine zusätzliche Zutat hinzugefügt."
],
"used_ingredients": ingredients_list
}
def process_recipe_request_logic(required_ingredients, available_ingredients, max_ingredients, max_retries):
"""
Kernlogik zur Verarbeitung einer Rezeptgenerierungsanfrage.
"""
if not required_ingredients and not available_ingredients:
return {"error": "Keine Zutaten angegeben"}
try:
optimized_ingredients = find_best_ingredients(
required_ingredients, available_ingredients, max_ingredients
)
recipe = mock_generate_recipe(optimized_ingredients) # Rufe die Mock-Generierungsfunktion auf
result = {
'title': recipe['title'],
'ingredients': recipe['ingredients'],
'directions': recipe['directions'],
'used_ingredients': optimized_ingredients
}
return result
except Exception as e:
return {"error": f"Fehler bei der Rezeptgenerierung: {str(e)}"}
# --- FastAPI-Implementierung ---
app = FastAPI(title="AI Recipe Generator API (Both Models Loaded Test)")
class RecipeRequest(BaseModel):
required_ingredients: list[str] = []
available_ingredients: list[str] = []
max_ingredients: int = 7
max_retries: int = 5
ingredients: list[str] = [] # Für Abwärtskompatibilität
@app.post("/generate_recipe") # Der API-Endpunkt für Flutter
async def generate_recipe_api(request_data: RecipeRequest):
final_required_ingredients = request_data.required_ingredients
if not final_required_ingredients and request_data.ingredients:
final_required_ingredients = request_data.ingredients
result_dict = process_recipe_request_logic(
final_required_ingredients,
request_data.available_ingredients,
request_data.max_ingredients,
request_data.max_retries
)
return JSONResponse(content=result_dict)
@app.get("/")
async def read_root():
return {"message": "AI Recipe Generator API is running (Both models loaded for test)!"} # Angepasste Nachricht
print("INFO: FastAPI application script finished execution and defined 'app' variable.")