Spaces:
Sleeping
Sleeping
File size: 9,381 Bytes
3df2cca f0ebec2 3df2cca f0ebec2 3df2cca f0ebec2 3df2cca f0ebec2 3df2cca f0ebec2 3df2cca f0ebec2 3df2cca f0ebec2 3df2cca f0ebec2 3df2cca f0ebec2 3df2cca f0ebec2 3df2cca f0ebec2 3df2cca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer # AutoModel entfernt
import torch # Beibehalten
import numpy as np # Beibehalten
import random
import json
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from pydantic import BaseModel
# Lade RecipeBERT Modell (KOMPLETT ENTFERNT für diesen Schritt)
# bert_model_name = "alexdseo/RecipeBERT"
# bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
# bert_model = AutoModel.from_pretrained(bert_model_name)
# bert_model.eval()
# Lade T5 Rezeptgenerierungsmodell
MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
# Token Mapping für die T5 Modell-Ausgabe
special_tokens = t5_tokenizer.all_special_token
tokens_map = {
"<sep>": "--",
"<section>": "\n"
}
# --- RecipeBERT-spezifische Funktionen sind entfernt oder vereinfacht ---
# get_embedding, average_embedding, get_cosine_similarity, get_combined_scores sind entfernt.
# find_best_ingredients (modifiziert, um KEINE Embeddings zu nutzen)
def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6, avg_weight=0.6):
"""
Findet die besten Zutaten. Für diesen einfachen Test wird nur
die Liste der benötigten Zutaten um zufällig ausgewählte
verfügbare Zutaten ergänzt, OHNE Embeddings zu nutzen.
"""
required_ingredients = list(set(required_ingredients))
available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients]))
# Sonderfall: Wenn keine benötigten Zutaten vorhanden sind, wähle zufällig eine aus den verfügbaren Zutaten
if not required_ingredients and available_ingredients:
random_ingredient = random.choice(available_ingredients)
required_ingredients = [random_ingredient]
available_ingredients = [i for i in available_ingredients if i != random_ingredient]
# Wenn bereits maximale Kapazität erreicht ist
if len(required_ingredients) >= max_ingredients:
return required_ingredients[:max_ingredients]
# Wenn keine zusätzlichen Zutaten verfügbar sind
if not available_ingredients:
return required_ingredients
# Füge zufällig weitere Zutaten hinzu, bis max_ingredients erreicht ist
current_ingredients = required_ingredients.copy()
num_to_add = min(max_ingredients - len(current_ingredients), len(available_ingredients))
# Wähle zufällig aus den verfügbaren Zutaten
selected_from_available = random.sample(available_ingredients, num_to_add)
current_ingredients.extend(selected_from_available)
return current_ingredients
def skip_special_tokens(text, special_tokens):
"""Entfernt spezielle Tokens aus dem Text"""
for token in special_tokens:
text = text.replace(token, "")
return text
def target_postprocessing(texts, special_tokens):
"""Post-processed generierten Text"""
if not isinstance(texts, list):
texts = [texts]
new_texts = []
for text in texts:
text = skip_special_tokens(text, special_tokens)
for k, v in tokens_map.items():
text = text.replace(k, v)
new_texts.append(text)
return new_texts
def validate_recipe_ingredients(recipe_ingredients, expected_ingredients, tolerance=0):
"""
Validiert, ob das Rezept ungefähr die erwarteten Zutaten enthält.
"""
recipe_count = len([ing for ing in recipe_ingredients if ing and ing.strip()])
expected_count = len(expected_ingredients)
return abs(recipe_count - expected_count) == tolerance
def generate_recipe_with_t5(ingredients_list, max_retries=5):
"""Generiert ein Rezept mit dem T5 Rezeptgenerierungsmodell mit Validierung."""
original_ingredients = ingredients_list.copy()
for attempt in range(max_retries):
try:
if attempt > 0:
current_ingredients = original_ingredients.copy()
random.shuffle(current_ingredients)
else:
current_ingredients = ingredients_list
ingredients_string = ", ".join(current_ingredients)
prefix = "items: "
generation_kwargs = {
"max_length": 512,
"min_length": 64,
"do_sample": True,
"top_k": 60,
"top_p": 0.95
}
inputs = t5_tokenizer(
prefix + ingredients_string,
max_length=256,
padding="max_length",
truncation=True,
return_tensors="jax"
)
output_ids = t5_model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
**generation_kwargs
)
generated = output_ids.sequences
generated_text = target_postprocessing(
t5_tokenizer.batch_decode(generated, skip_special_tokens=False),
special_tokens
)[0]
recipe = {}
sections = generated_text.split("\n")
for section in sections:
section = section.strip()
if section.startswith("title:"):
recipe["title"] = section.replace("title:", "").strip().capitalize()
elif section.startswith("ingredients:"):
ingredients_text = section.replace("ingredients:", "").strip()
recipe["ingredients"] = [item.strip().capitalize() for item in ingredients_text.split("--") if item.strip()]
elif section.startswith("directions:"):
directions_text = section.replace("directions:", "").strip()
recipe["directions"] = [step.strip().capitalize() for step in directions_text.split("--") if step.strip()]
if "title" not in recipe:
recipe["title"] = f"Rezept mit {', '.join(current_ingredients[:3])}"
if "ingredients" not in recipe:
recipe["ingredients"] = current_ingredients
if "directions" not in recipe:
recipe["directions"] = ["Keine Anweisungen generiert"]
if validate_recipe_ingredients(recipe["ingredients"], original_ingredients):
return recipe
else:
if attempt == max_retries - 1:
return recipe
except Exception as e:
if attempt == max_retries - 1:
return {
"title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}",
"ingredients": original_ingredients,
"directions": ["Fehler beim Generieren der Rezeptanweisungen"]
}
return {
"title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}",
"ingredients": original_ingredients,
"directions": ["Fehler beim Generieren der Rezeptanweisungen"]
}
def process_recipe_request_logic(required_ingredients, available_ingredients, max_ingredients, max_retries):
"""
Kernlogik zur Verarbeitung einer Rezeptgenerierungsanfrage.
"""
if not required_ingredients and not available_ingredients:
return {"error": "Keine Zutaten angegeben"}
try:
# Hier wird die vereinfachte find_best_ingredients verwendet, die KEINE Embeddings nutzt.
optimized_ingredients = find_best_ingredients(
required_ingredients, available_ingredients, max_ingredients
)
recipe = generate_recipe_with_t5(optimized_ingredients, max_retries)
result = {
'title': recipe['title'],
'ingredients': recipe['ingredients'],
'directions': recipe['directions'],
'used_ingredients': optimized_ingredients
}
return result
except Exception as e:
return {"error": f"Fehler bei der Rezeptgenerierung: {str(e)}"}
# --- FastAPI-Implementierung ---
app = FastAPI(title="AI Recipe Generator API") # Deine FastAPI-Instanz
class RecipeRequest(BaseModel):
required_ingredients: list[str] = []
available_ingredients: list[str] = []
max_ingredients: int = 7
max_retries: int = 5
ingredients: list[str] = [] # Für Abwärtskompatibilität
@app.post("/generate_recipe") # Der API-Endpunkt für Flutter
async def generate_recipe_api(request_data: RecipeRequest):
"""
Standard-REST-API-Endpunkt für die Flutter-App.
Nimmt direkt JSON-Daten an und gibt direkt JSON zurück.
"""
final_required_ingredients = request_data.required_ingredients
if not final_required_ingredients and request_data.ingredients:
final_required_ingredients = request_data.ingredients
result_dict = process_recipe_request_logic(
final_required_ingredients,
request_data.available_ingredients,
request_data.max_ingredients,
request_data.max_retries
)
return JSONResponse(content=result_dict)
# Optionaler Root-Endpunkt für Health-Checks
@app.get("/")
async def read_root():
return {"message": "AI Recipe Generator API is running (T5 only)!"} # Angepasste Nachricht
print("INFO: FastAPI application script finished execution and defined 'app' variable.")
|