DockerRecipe / app.py
TimInf's picture
Update app.py
b968ba2 verified
raw
history blame
12.5 kB
from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer, AutoModel
import torch
import numpy as np
import random
import json
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from pydantic import BaseModel
# Lade RecipeBERT Modell (für semantische Zutat-Kombination)
bert_model_name = "alexdseo/RecipeBERT"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
bert_model = AutoModel.from_pretrained(bert_model_name)
bert_model.eval() # Setze das Modell in den Evaluationsmodus
# Lade T5 Rezeptgenerierungsmodell
MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
# Token Mapping für die T5 Modell-Ausgabe
special_tokens = t5_tokenizer.all_special_tokens
tokens_map = {
"<sep>": "--",
"<section>": "\n"
}
# --- RecipeBERT-spezifische Funktionen ---
def get_embedding(text):
"""Berechnet das Embedding für einen Text mit Mean Pooling über alle Tokens"""
inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = bert_model(**inputs)
attention_mask = inputs['attention_mask']
token_embeddings = outputs.last_hidden_state
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return (sum_embeddings / sum_mask).squeeze(0)
def average_embedding(embedding_list):
"""Berechnet den Durchschnitt einer Liste von Embeddings"""
# Sicherstellen, dass embedding_list Tupel von (Name, Embedding) enthält
tensors = torch.stack([emb for _, emb in embedding_list])
return tensors.mean(dim=0)
def get_cosine_similarity(vec1, vec2):
"""Berechnet die Cosinus-Ähnlichkeit zwischen zwei Vektoren"""
if torch.is_tensor(vec1):
vec1 = vec1.detach().numpy()
if torch.is_tensor(vec2):
vec2 = vec2.detach().numpy()
vec1 = vec1.flatten()
vec2 = vec2.flatten()
dot_product = np.dot(vec1, vec2)
norm_a = np.linalg.norm(vec1)
norm_b = np.linalg.norm(vec2)
if norm_a == 0 or norm_b == 0:
return 0
return dot_product / (norm_a * norm_b)
def get_combined_scores(query_vector, embedding_list, all_good_embeddings, avg_weight=0.6):
"""Berechnet einen kombinierten Score unter Berücksichtigung der Ähnlichkeit zum Durchschnitt und zu einzelnen Zutaten"""
results = []
for name, emb in embedding_list:
avg_similarity = get_cosine_similarity(query_vector, emb)
individual_similarities = [get_cosine_similarity(good_emb, emb)
for _, good_emb in all_good_embeddings]
avg_individual_similarity = sum(individual_similarities) / len(individual_similarities) if individual_similarities else 0
combined_score = avg_weight * avg_similarity + (1 - avg_weight) * avg_individual_similarity
results.append((name, emb, combined_score))
results.sort(key=lambda x: x[2], reverse=True)
return results
# Die vollständige find_best_ingredients Funktion, die du bereitgestellt hast
def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6, avg_weight=0.6):
"""
Findet die besten Zutaten basierend auf RecipeBERT Embeddings.
"""
required_ingredients = list(set(required_ingredients))
available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients]))
if not required_ingredients and available_ingredients:
random_ingredient = random.choice(available_ingredients)
required_ingredients = [random_ingredient]
available_ingredients = [i for i in available_ingredients if i != random_ingredient]
print(f"No required ingredients provided. Randomly selected: {random_ingredient}")
if not required_ingredients or len(required_ingredients) >= max_ingredients:
return required_ingredients[:max_ingredients]
if not available_ingredients:
return required_ingredients
embed_required = [(e, get_embedding(e)) for e in required_ingredients]
embed_available = [(e, get_embedding(e)) for e in available_ingredients]
num_to_add = min(max_ingredients - len(required_ingredients), len(available_ingredients))
final_ingredients = embed_required.copy()
for _ in range(num_to_add):
avg = average_embedding(final_ingredients)
candidates = get_combined_scores(avg, embed_available, final_ingredients, avg_weight)
if not candidates:
break
best_name, best_embedding, _ = candidates[0]
final_ingredients.append((best_name, best_embedding))
embed_available = [item for item in embed_available if item[0] != best_name]
return [name for name, _ in final_ingredients]
def skip_special_tokens(text, special_tokens):
"""Entfernt spezielle Tokens aus dem Text"""
for token in special_tokens:
text = text.replace(token, "")
return text
def target_postprocessing(texts, special_tokens):
"""Post-processed generierten Text"""
if not isinstance(texts, list):
texts = [texts]
new_texts = []
for text in texts:
text = skip_special_tokens(text, special_tokens)
for k, v in tokens_map.items():
text = text.replace(k, v)
new_texts.append(text)
return new_texts
def validate_recipe_ingredients(recipe_ingredients, expected_ingredients, tolerance=0):
"""
Validiert, ob das Rezept ungefähr die erwarteten Zutaten enthält.
"""
recipe_count = len([ing for ing in recipe_ingredients if ing and ing.strip()])
expected_count = len(expected_ingredients)
return abs(recipe_count - expected_count) == tolerance
def generate_recipe_with_t5(ingredients_list, max_retries=5):
"""Generiert ein Rezept mit dem T5 Rezeptgenerierungsmodell mit Validierung."""
original_ingredients = ingredients_list.copy()
for attempt in range(max_retries):
try:
if attempt > 0:
current_ingredients = original_ingredients.copy()
random.shuffle(current_ingredients)
else:
current_ingredients = ingredients_list
ingredients_string = ", ".join(current_ingredients)
prefix = "items: "
generation_kwargs = {
"max_length": 512,
"min_length": 64,
"do_sample": True,
"top_k": 60,
"top_p": 0.95
}
print(f"Attempt {attempt + 1}: {prefix + ingredients_string}") # Debug-Print
inputs = t5_tokenizer(
prefix + ingredients_string,
max_length=256,
padding="max_length",
truncation=True,
return_tensors="jax"
)
output_ids = t5_model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
**generation_kwargs
)
generated = output_ids.sequences
generated_text = target_postprocessing(
t5_tokenizer.batch_decode(generated, skip_special_tokens=False),
special_tokens
)[0]
recipe = {}
sections = generated_text.split("\n")
for section in sections:
section = section.strip()
if section.startswith("title:"):
recipe["title"] = section.replace("title:", "").strip().capitalize()
elif section.startswith("ingredients:"):
ingredients_text = section.replace("ingredients:", "").strip()
recipe["ingredients"] = [item.strip().capitalize() for item in ingredients_text.split("--") if item.strip()]
elif section.startswith("directions:"):
directions_text = section.replace("directions:", "").strip()
recipe["directions"] = [step.strip().capitalize() for step in directions_text.split("--") if step.strip()]
if "title" not in recipe:
recipe["title"] = f"Rezept mit {', '.join(current_ingredients[:3])}"
if "ingredients" not in recipe:
recipe["ingredients"] = current_ingredients
if "directions" not in recipe:
recipe["directions"] = ["Keine Anweisungen generiert"]
if validate_recipe_ingredients(recipe["ingredients"], original_ingredients):
print(f"Success on attempt {attempt + 1}: Recipe has correct number of ingredients") # Debug-Print
return recipe
else:
print(f"Attempt {attempt + 1} failed: Expected {len(original_ingredients)} ingredients, got {len(recipe['ingredients'])}") # Debug-Print
if attempt == max_retries - 1:
print("Max retries reached, returning last generated recipe") # Debug-Print
return recipe
except Exception as e:
print(f"Error in recipe generation attempt {attempt + 1}: {str(e)}") # Debug-Print
if attempt == max_retries - 1:
return {
"title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}",
"ingredients": original_ingredients,
"directions": ["Fehler beim Generieren der Rezeptanweisungen"]
}
return {
"title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}",
"ingredients": original_ingredients,
"directions": ["Fehler beim Generieren der Rezeptanweisungen"]
}
def process_recipe_request_logic(required_ingredients, available_ingredients, max_ingredients, max_retries):
"""
Kernlogik zur Verarbeitung einer Rezeptgenerierungsanfrage.
"""
if not required_ingredients and not available_ingredients:
return {"error": "Keine Zutaten angegeben"}
try:
optimized_ingredients = find_best_ingredients(
required_ingredients, available_ingredients, max_ingredients
)
# KORRIGIERT: Aufruf der echten T5-Generierungsfunktion
recipe = generate_recipe_with_t5(optimized_ingredients, max_retries)
result = {
'title': recipe['title'],
'ingredients': recipe['ingredients'],
'directions': recipe['directions'],
'used_ingredients': optimized_ingredients
}
return result
except Exception as e:
return {"error": f"Fehler bei der Rezeptgenerierung: {str(e)}"}
# --- FastAPI-Implementierung ---
app = FastAPI(title="AI Recipe Generator API") # Ohne Gradio-spezifische Titelzusätze
class RecipeRequest(BaseModel):
required_ingredients: list[str] = []
available_ingredients: list[str] = []
max_ingredients: int = 7
max_retries: int = 5
# Optional: Für Abwärtskompatibilität, falls 'ingredients' als Top-Level-Feld gesendet wird
ingredients: list[str] = []
@app.post("/generate_recipe") # Der API-Endpunkt für Flutter
async def generate_recipe_api(request_data: RecipeRequest):
"""
Standard-REST-API-Endpunkt für die Flutter-App.
Nimmt direkt JSON-Daten an und gibt direkt JSON zurück.
"""
# Wenn required_ingredients leer ist, aber ingredients vorhanden sind,
# verwende ingredients für Abwärtskompatibilität.
final_required_ingredients = request_data.required_ingredients
if not final_required_ingredients and request_data.ingredients:
final_required_ingredients = request_data.ingredients
result_dict = process_recipe_request_logic(
final_required_ingredients,
request_data.available_ingredients,
request_data.max_ingredients,
request_data.max_retries
)
return JSONResponse(content=result_dict)
@app.get("/")
async def read_root():
return {"message": "AI Recipe Generator API is running (FastAPI only)!"} # Angepasste Nachricht
# Hier gibt es KEINEN Gradio-Mount oder Gradio-Launch-Befehl
# Das `app` Objekt ist eine reine FastAPI-Instanz
print("INFO: Pure FastAPI application script finished execution and defined 'app' variable.")