from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer, AutoModel import torch import numpy as np import random import json from fastapi import FastAPI from fastapi.responses import JSONResponse from pydantic import BaseModel # Lade RecipeBERT Modell bert_model_name = "alexdseo/RecipeBERT" bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name) bert_model = AutoModel.from_pretrained(bert_model_name) bert_model.eval() # Setze das Modell in den Evaluationsmodus # Lade T5 Rezeptgenerierungsmodell MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation" t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True) t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH) # Modell wird jetzt auch geladen # Token Mapping für die T5 Modell-Ausgabe special_tokens = t5_tokenizer.all_special_tokens tokens_map = { "": "--", "
": "\n" } # --- RecipeBERT-spezifische Funktionen (unverändert) --- def get_embedding(text): """Berechnet das Embedding für einen Text mit Mean Pooling über alle Tokens.""" inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): outputs = bert_model(**inputs) attention_mask = inputs['attention_mask'] token_embeddings = outputs.last_hidden_state input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) return (sum_embeddings / sum_mask).squeeze(0) def average_embedding(embedding_list): """Berechnet den Durchschnitt einer Liste von Embeddings.""" tensors = torch.stack(embedding_list) return tensors.mean(dim=0) def get_cosine_similarity(vec1, vec2): """Berechnet die Cosinus-Ähnlichkeit zwischen zwei Vektoren.""" if torch.is_tensor(vec1): vec1 = vec1.detach().numpy() if torch.is_tensor(vec2): vec2 = vec2.detach().numpy() vec1 = vec1.flatten() vec2 = vec2.flatten() dot_product = np.dot(vec1, vec2) norm_a = np.linalg.norm(vec1) norm_b = np.linalg.norm(vec2) if norm_a == 0 or norm_b == 0: return 0 return dot_product / (norm_a * norm_b) # find_best_ingredients (unverändert, nutzt RecipeBERT für eine ähnlichste Zutat) def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6): """ Findet die besten Zutaten: Alle benötigten + EINE ähnlichste aus den verfügbaren Zutaten. """ required_ingredients = list(set(required_ingredients)) available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients])) final_ingredients = required_ingredients.copy() # Nur wenn wir noch Platz haben und zusätzliche Zutaten verfügbar sind if len(final_ingredients) < max_ingredients and len(available_ingredients) > 0: if final_ingredients: required_embeddings = [get_embedding(ing) for ing in required_ingredients] avg_required_embedding = average_embedding(required_embeddings) best_additional_ingredient = None highest_similarity = -1.0 for avail_ing in available_ingredients: avail_embedding = get_embedding(avail_ing) similarity = get_cosine_similarity(avg_required_embedding, avail_embedding) if similarity > highest_similarity: highest_similarity = similarity best_additional_ingredient = avail_ing if best_additional_ingredient: final_ingredients.append(best_additional_ingredient) print(f"INFO: Added '{best_additional_ingredient}' (similarity: {highest_similarity:.2f}) as most similar.") else: random_ingredient = random.choice(available_ingredients) final_ingredients.append(random_ingredient) print(f"INFO: No required ingredients. Added random available ingredient: '{random_ingredient}'.") return final_ingredients[:max_ingredients] # skip_special_tokens (unverändert, wird von generate_recipe_with_t5 genutzt) def skip_special_tokens(text, special_tokens): """Entfernt spezielle Tokens aus dem Text""" for token in special_tokens: text = text.replace(token, "") return text # target_postprocessing (unverändert, wird von generate_recipe_with_t5 genutzt) def target_postprocessing(texts, special_tokens): """Post-processed generierten Text""" if not isinstance(texts, list): texts = [texts] new_texts = [] for text in texts: text = skip_special_tokens(text, special_tokens) for k, v in tokens_map.items(): text = text.replace(k, v) new_texts.append(text) return new_texts # validate_recipe_ingredients (unverändert, wird von generate_recipe_with_t5 genutzt) def validate_recipe_ingredients(recipe_ingredients, expected_ingredients, tolerance=0): """ Validiert, ob das Rezept ungefähr die erwarteten Zutaten enthält. """ recipe_count = len([ing for ing in recipe_ingredients if ing and ing.strip()]) expected_count = len(expected_ingredients) return abs(recipe_count - expected_count) == tolerance # generate_recipe_with_t5 (jetzt AKTIVIERT) def generate_recipe_with_t5(ingredients_list, max_retries=5): """Generiert ein Rezept mit dem T5 Rezeptgenerierungsmodell mit Validierung.""" original_ingredients = ingredients_list.copy() for attempt in range(max_retries): try: # Für Wiederholungsversuche nach dem ersten Versuch, mische die Zutaten if attempt > 0: current_ingredients = original_ingredients.copy() random.shuffle(current_ingredients) else: current_ingredients = ingredients_list # Formatiere Zutaten als kommaseparierten String ingredients_string = ", ".join(current_ingredients) prefix = "items: " # Generationseinstellungen generation_kwargs = { "max_length": 512, "min_length": 64, "do_sample": True, "top_k": 60, "top_p": 0.95 } # print(f"Versuch {attempt + 1}: {prefix + ingredients_string}") # Tokenisiere Eingabe inputs = t5_tokenizer( prefix + ingredients_string, max_length=256, padding="max_length", truncation=True, return_tensors="jax" ) # Generiere Text output_ids = t5_model.generate( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, **generation_kwargs ) # Dekodieren und Nachbearbeiten generated = output_ids.sequences generated_text = target_postprocessing( t5_tokenizer.batch_decode(generated, skip_special_tokens=False), special_tokens )[0] # Abschnitte parsen recipe = {} sections = generated_text.split("\n") for section in sections: section = section.strip() if section.startswith("title:"): recipe["title"] = section.replace("title:", "").strip().capitalize() elif section.startswith("ingredients:"): ingredients_text = section.replace("ingredients:", "").strip() recipe["ingredients"] = [item.strip().capitalize() for item in ingredients_text.split("--") if item.strip()] elif section.startswith("directions:"): directions_text = section.replace("directions:", "").strip() recipe["directions"] = [step.strip().capitalize() for step in directions_text.split("--") if step.strip()] # Wenn der Titel fehlt, erstelle einen if "title" not in recipe: recipe["title"] = f"Rezept mit {', '.join(current_ingredients[:3])}" # Stelle sicher, dass alle Abschnitte existieren if "ingredients" not in recipe: recipe["ingredients"] = current_ingredients if "directions" not in recipe: recipe["directions"] = ["Keine Anweisungen generiert"] # Validiere das Rezept if validate_recipe_ingredients(recipe["ingredients"], original_ingredients): # print(f"Erfolg bei Versuch {attempt + 1}: Rezept hat die richtige Anzahl von Zutaten") return recipe else: # print(f"Versuch {attempt + 1} fehlgeschlagen: Erwartet {len(original_ingredients)} Zutaten, erhalten {len(recipe['ingredients'])}") if attempt == max_retries - 1: # print("Maximale Wiederholungsversuche erreicht, letztes generiertes Rezept wird zurückgegeben") return recipe except Exception as e: # print(f"Fehler bei der Rezeptgenerierung Versuch {attempt + 1}: {str(e)}") if attempt == max_retries - 1: return { "title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}", "ingredients": original_ingredients, "directions": ["Fehler beim Generieren der Rezeptanweisungen"] } # Fallback (sollte nicht erreicht werden) return { "title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}", "ingredients": original_ingredients, "directions": ["Fehler beim Generieren der Rezeptanweisungen"] } # process_recipe_request_logic (JETZT RUFT generate_recipe_with_t5 auf) def process_recipe_request_logic(required_ingredients, available_ingredients, max_ingredients, max_retries): """ Kernlogik zur Verarbeitung einer Rezeptgenerierungsanfrage. Ausgelagert, um von verschiedenen Endpunkten aufgerufen zu werden. """ if not required_ingredients and not available_ingredients: return {"error": "Keine Zutaten angegeben"} try: # Optimale Zutaten finden (mit RecipeBERT) optimized_ingredients = find_best_ingredients( required_ingredients, available_ingredients, max_ingredients ) # Rezept mit optimierten Zutaten generieren (JETZT MIT T5!) recipe = generate_recipe_with_t5(optimized_ingredients, max_retries) # Ergebnis formatieren result = { 'title': recipe['title'], 'ingredients': recipe['ingredients'], 'directions': recipe['directions'], 'used_ingredients': optimized_ingredients } return result except Exception as e: return {"error": f"Fehler bei der Rezeptgenerierung: {str(e)}"} # --- FastAPI-Implementierung --- app = FastAPI(title="AI Recipe Generator API (Full Functionality)") class RecipeRequest(BaseModel): required_ingredients: list[str] = [] available_ingredients: list[str] = [] max_ingredients: int = 7 max_retries: int = 5 ingredients: list[str] = [] # Für Abwärtskompatibilität @app.post("/generate_recipe") # Der API-Endpunkt für Flutter async def generate_recipe_api(request_data: RecipeRequest): final_required_ingredients = request_data.required_ingredients if not final_required_ingredients and request_data.ingredients: final_required_ingredients = request_data.ingredients result_dict = process_recipe_request_logic( final_required_ingredients, request_data.available_ingredients, request_data.max_ingredients, request_data.max_retries ) return JSONResponse(content=result_dict) @app.get("/") async def read_root(): return {"message": "AI Recipe Generator API is running (Full functionality activated)!"} # Angepasste Nachricht print("INFO: FastAPI application script finished execution and defined 'app' variable.")