TimInf commited on
Commit
b968ba2
·
verified ·
1 Parent(s): c436cf7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -78
app.py CHANGED
@@ -7,28 +7,27 @@ from fastapi import FastAPI
7
  from fastapi.responses import JSONResponse
8
  from pydantic import BaseModel
9
 
10
- # Lade RecipeBERT Modell
11
  bert_model_name = "alexdseo/RecipeBERT"
12
  bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
13
  bert_model = AutoModel.from_pretrained(bert_model_name)
14
- bert_model.eval() # Setze das Modell in den Evaluationsmodus
15
 
16
- # Lade T5 Rezeptgenerierungsmodell (NEU hinzugefügt)
17
  MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
18
  t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
19
- t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH) # Modell wird jetzt auch geladen
20
 
21
- # Token Mapping für die T5 Modell-Ausgabe (bleibt hier, obwohl T5 noch nicht aktiv generiert)
22
  special_tokens = t5_tokenizer.all_special_tokens
23
  tokens_map = {
24
  "<sep>": "--",
25
  "<section>": "\n"
26
  }
27
 
28
-
29
- # --- RecipeBERT-spezifische Funktionen (unverändert) ---
30
  def get_embedding(text):
31
- """Berechnet das Embedding für einen Text mit Mean Pooling über alle Tokens."""
32
  inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
33
  with torch.no_grad():
34
  outputs = bert_model(**inputs)
@@ -39,82 +38,187 @@ def get_embedding(text):
39
  sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
40
  return (sum_embeddings / sum_mask).squeeze(0)
41
 
42
-
43
  def average_embedding(embedding_list):
44
- """Berechnet den Durchschnitt einer Liste von Embeddings."""
45
- tensors = torch.stack(embedding_list)
 
46
  return tensors.mean(dim=0)
47
 
48
-
49
  def get_cosine_similarity(vec1, vec2):
50
- """Berechnet die Cosinus-Ähnlichkeit zwischen zwei Vektoren."""
51
- if torch.is_tensor(vec1): vec1 = vec1.detach().numpy()
52
- if torch.is_tensor(vec2): vec2 = vec2.detach().numpy()
 
 
53
  vec1 = vec1.flatten()
54
  vec2 = vec2.flatten()
55
  dot_product = np.dot(vec1, vec2)
56
  norm_a = np.linalg.norm(vec1)
57
  norm_b = np.linalg.norm(vec2)
58
- if norm_a == 0 or norm_b == 0: return 0
 
59
  return dot_product / (norm_a * norm_b)
60
 
61
-
62
- # find_best_ingredients (unverändert, nutzt RecipeBERT für eine ähnlichste Zutat)
63
- def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6):
 
 
 
 
 
 
 
 
 
 
 
 
64
  """
65
- Findet die besten Zutaten: Alle benötigten + EINE ähnlichste aus den verfügbaren Zutaten.
66
  """
67
  required_ingredients = list(set(required_ingredients))
68
  available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients]))
69
-
70
- final_ingredients = required_ingredients.copy()
71
-
72
- # Nur wenn wir noch Platz haben und zusätzliche Zutaten verfügbar sind
73
- if len(final_ingredients) < max_ingredients and len(available_ingredients) > 0:
74
- if final_ingredients:
75
- required_embeddings = [get_embedding(ing) for ing in required_ingredients]
76
- avg_required_embedding = average_embedding(required_embeddings)
77
-
78
- best_additional_ingredient = None
79
- highest_similarity = -1.0
80
-
81
- for avail_ing in available_ingredients:
82
- avail_embedding = get_embedding(avail_ing)
83
- similarity = get_cosine_similarity(avg_required_embedding, avail_embedding)
84
- if similarity > highest_similarity:
85
- highest_similarity = similarity
86
- best_additional_ingredient = avail_ing
87
-
88
- if best_additional_ingredient:
89
- final_ingredients.append(best_additional_ingredient)
90
- print(
91
- f"INFO: Added '{best_additional_ingredient}' (similarity: {highest_similarity:.2f}) as most similar.")
92
- else:
93
- random_ingredient = random.choice(available_ingredients)
94
- final_ingredients.append(random_ingredient)
95
- print(f"INFO: No required ingredients. Added random available ingredient: '{random_ingredient}'.")
96
-
97
- return final_ingredients[:max_ingredients]
98
-
99
-
100
- # mock_generate_recipe (ANGEPASST, um zu bestätigen, dass BEIDE Modelle geladen sind)
101
- def mock_generate_recipe(ingredients_list):
102
- """Generiert ein Mock-Rezept und bestätigt das Laden beider Modelle."""
103
- title = f"Rezepttest mit {', '.join(ingredients_list[:3])}" if ingredients_list else "Testrezept"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  return {
105
- "title": title,
106
- "ingredients": ingredients_list,
107
- "directions": [
108
- "Dies ist ein Testrezept.",
109
- "RecipeBERT und T5-Modell wurden beide erfolgreich geladen!",
110
- "Die Zutaten wurden mit RecipeBERT-Intelligenz ausgewählt.",
111
- f"Basierend auf deinen Eingaben wurde '{ingredients_list[-1]}' als ähnlichste Zutat hinzugefügt." if len(
112
- ingredients_list) > 1 else "Keine zusätzliche Zutat hinzugefügt."
113
- ],
114
- "used_ingredients": ingredients_list
115
  }
116
 
117
-
118
  def process_recipe_request_logic(required_ingredients, available_ingredients, max_ingredients, max_retries):
119
  """
120
  Kernlogik zur Verarbeitung einer Rezeptgenerierungsanfrage.
@@ -125,9 +229,8 @@ def process_recipe_request_logic(required_ingredients, available_ingredients, ma
125
  optimized_ingredients = find_best_ingredients(
126
  required_ingredients, available_ingredients, max_ingredients
127
  )
128
-
129
- recipe = mock_generate_recipe(optimized_ingredients) # Rufe die Mock-Generierungsfunktion auf
130
-
131
  result = {
132
  'title': recipe['title'],
133
  'ingredients': recipe['ingredients'],
@@ -138,21 +241,25 @@ def process_recipe_request_logic(required_ingredients, available_ingredients, ma
138
  except Exception as e:
139
  return {"error": f"Fehler bei der Rezeptgenerierung: {str(e)}"}
140
 
141
-
142
  # --- FastAPI-Implementierung ---
143
- app = FastAPI(title="AI Recipe Generator API (Both Models Loaded Test)")
144
-
145
 
146
  class RecipeRequest(BaseModel):
147
  required_ingredients: list[str] = []
148
  available_ingredients: list[str] = []
149
  max_ingredients: int = 7
150
  max_retries: int = 5
151
- ingredients: list[str] = [] # Für Abwärtskompatibilität
152
-
153
 
154
- @app.post("/generate_recipe") # Der API-Endpunkt für Flutter
155
  async def generate_recipe_api(request_data: RecipeRequest):
 
 
 
 
 
 
156
  final_required_ingredients = request_data.required_ingredients
157
  if not final_required_ingredients and request_data.ingredients:
158
  final_required_ingredients = request_data.ingredients
@@ -165,10 +272,11 @@ async def generate_recipe_api(request_data: RecipeRequest):
165
  )
166
  return JSONResponse(content=result_dict)
167
 
168
-
169
  @app.get("/")
170
  async def read_root():
171
- return {"message": "AI Recipe Generator API is running (Both models loaded for test)!"} # Angepasste Nachricht
172
 
 
 
 
173
 
174
- print("INFO: FastAPI application script finished execution and defined 'app' variable.")
 
7
  from fastapi.responses import JSONResponse
8
  from pydantic import BaseModel
9
 
10
+ # Lade RecipeBERT Modell (für semantische Zutat-Kombination)
11
  bert_model_name = "alexdseo/RecipeBERT"
12
  bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
13
  bert_model = AutoModel.from_pretrained(bert_model_name)
14
+ bert_model.eval() # Setze das Modell in den Evaluationsmodus
15
 
16
+ # Lade T5 Rezeptgenerierungsmodell
17
  MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
18
  t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
19
+ t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
20
 
21
+ # Token Mapping für die T5 Modell-Ausgabe
22
  special_tokens = t5_tokenizer.all_special_tokens
23
  tokens_map = {
24
  "<sep>": "--",
25
  "<section>": "\n"
26
  }
27
 
28
+ # --- RecipeBERT-spezifische Funktionen ---
 
29
  def get_embedding(text):
30
+ """Berechnet das Embedding für einen Text mit Mean Pooling über alle Tokens"""
31
  inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
32
  with torch.no_grad():
33
  outputs = bert_model(**inputs)
 
38
  sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
39
  return (sum_embeddings / sum_mask).squeeze(0)
40
 
 
41
  def average_embedding(embedding_list):
42
+ """Berechnet den Durchschnitt einer Liste von Embeddings"""
43
+ # Sicherstellen, dass embedding_list Tupel von (Name, Embedding) enthält
44
+ tensors = torch.stack([emb for _, emb in embedding_list])
45
  return tensors.mean(dim=0)
46
 
 
47
  def get_cosine_similarity(vec1, vec2):
48
+ """Berechnet die Cosinus-Ähnlichkeit zwischen zwei Vektoren"""
49
+ if torch.is_tensor(vec1):
50
+ vec1 = vec1.detach().numpy()
51
+ if torch.is_tensor(vec2):
52
+ vec2 = vec2.detach().numpy()
53
  vec1 = vec1.flatten()
54
  vec2 = vec2.flatten()
55
  dot_product = np.dot(vec1, vec2)
56
  norm_a = np.linalg.norm(vec1)
57
  norm_b = np.linalg.norm(vec2)
58
+ if norm_a == 0 or norm_b == 0:
59
+ return 0
60
  return dot_product / (norm_a * norm_b)
61
 
62
+ def get_combined_scores(query_vector, embedding_list, all_good_embeddings, avg_weight=0.6):
63
+ """Berechnet einen kombinierten Score unter Berücksichtigung der Ähnlichkeit zum Durchschnitt und zu einzelnen Zutaten"""
64
+ results = []
65
+ for name, emb in embedding_list:
66
+ avg_similarity = get_cosine_similarity(query_vector, emb)
67
+ individual_similarities = [get_cosine_similarity(good_emb, emb)
68
+ for _, good_emb in all_good_embeddings]
69
+ avg_individual_similarity = sum(individual_similarities) / len(individual_similarities) if individual_similarities else 0
70
+ combined_score = avg_weight * avg_similarity + (1 - avg_weight) * avg_individual_similarity
71
+ results.append((name, emb, combined_score))
72
+ results.sort(key=lambda x: x[2], reverse=True)
73
+ return results
74
+
75
+ # Die vollständige find_best_ingredients Funktion, die du bereitgestellt hast
76
+ def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6, avg_weight=0.6):
77
  """
78
+ Findet die besten Zutaten basierend auf RecipeBERT Embeddings.
79
  """
80
  required_ingredients = list(set(required_ingredients))
81
  available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients]))
82
+
83
+ if not required_ingredients and available_ingredients:
84
+ random_ingredient = random.choice(available_ingredients)
85
+ required_ingredients = [random_ingredient]
86
+ available_ingredients = [i for i in available_ingredients if i != random_ingredient]
87
+ print(f"No required ingredients provided. Randomly selected: {random_ingredient}")
88
+
89
+ if not required_ingredients or len(required_ingredients) >= max_ingredients:
90
+ return required_ingredients[:max_ingredients]
91
+
92
+ if not available_ingredients:
93
+ return required_ingredients
94
+
95
+ embed_required = [(e, get_embedding(e)) for e in required_ingredients]
96
+ embed_available = [(e, get_embedding(e)) for e in available_ingredients]
97
+
98
+ num_to_add = min(max_ingredients - len(required_ingredients), len(available_ingredients))
99
+
100
+ final_ingredients = embed_required.copy()
101
+
102
+ for _ in range(num_to_add):
103
+ avg = average_embedding(final_ingredients)
104
+ candidates = get_combined_scores(avg, embed_available, final_ingredients, avg_weight)
105
+
106
+ if not candidates:
107
+ break
108
+
109
+ best_name, best_embedding, _ = candidates[0]
110
+
111
+ final_ingredients.append((best_name, best_embedding))
112
+
113
+ embed_available = [item for item in embed_available if item[0] != best_name]
114
+
115
+ return [name for name, _ in final_ingredients]
116
+
117
+ def skip_special_tokens(text, special_tokens):
118
+ """Entfernt spezielle Tokens aus dem Text"""
119
+ for token in special_tokens:
120
+ text = text.replace(token, "")
121
+ return text
122
+
123
+ def target_postprocessing(texts, special_tokens):
124
+ """Post-processed generierten Text"""
125
+ if not isinstance(texts, list):
126
+ texts = [texts]
127
+ new_texts = []
128
+ for text in texts:
129
+ text = skip_special_tokens(text, special_tokens)
130
+ for k, v in tokens_map.items():
131
+ text = text.replace(k, v)
132
+ new_texts.append(text)
133
+ return new_texts
134
+
135
+ def validate_recipe_ingredients(recipe_ingredients, expected_ingredients, tolerance=0):
136
+ """
137
+ Validiert, ob das Rezept ungefähr die erwarteten Zutaten enthält.
138
+ """
139
+ recipe_count = len([ing for ing in recipe_ingredients if ing and ing.strip()])
140
+ expected_count = len(expected_ingredients)
141
+ return abs(recipe_count - expected_count) == tolerance
142
+
143
+ def generate_recipe_with_t5(ingredients_list, max_retries=5):
144
+ """Generiert ein Rezept mit dem T5 Rezeptgenerierungsmodell mit Validierung."""
145
+ original_ingredients = ingredients_list.copy()
146
+ for attempt in range(max_retries):
147
+ try:
148
+ if attempt > 0:
149
+ current_ingredients = original_ingredients.copy()
150
+ random.shuffle(current_ingredients)
151
+ else:
152
+ current_ingredients = ingredients_list
153
+ ingredients_string = ", ".join(current_ingredients)
154
+ prefix = "items: "
155
+ generation_kwargs = {
156
+ "max_length": 512,
157
+ "min_length": 64,
158
+ "do_sample": True,
159
+ "top_k": 60,
160
+ "top_p": 0.95
161
+ }
162
+ print(f"Attempt {attempt + 1}: {prefix + ingredients_string}") # Debug-Print
163
+ inputs = t5_tokenizer(
164
+ prefix + ingredients_string,
165
+ max_length=256,
166
+ padding="max_length",
167
+ truncation=True,
168
+ return_tensors="jax"
169
+ )
170
+ output_ids = t5_model.generate(
171
+ input_ids=inputs.input_ids,
172
+ attention_mask=inputs.attention_mask,
173
+ **generation_kwargs
174
+ )
175
+ generated = output_ids.sequences
176
+ generated_text = target_postprocessing(
177
+ t5_tokenizer.batch_decode(generated, skip_special_tokens=False),
178
+ special_tokens
179
+ )[0]
180
+ recipe = {}
181
+ sections = generated_text.split("\n")
182
+ for section in sections:
183
+ section = section.strip()
184
+ if section.startswith("title:"):
185
+ recipe["title"] = section.replace("title:", "").strip().capitalize()
186
+ elif section.startswith("ingredients:"):
187
+ ingredients_text = section.replace("ingredients:", "").strip()
188
+ recipe["ingredients"] = [item.strip().capitalize() for item in ingredients_text.split("--") if item.strip()]
189
+ elif section.startswith("directions:"):
190
+ directions_text = section.replace("directions:", "").strip()
191
+ recipe["directions"] = [step.strip().capitalize() for step in directions_text.split("--") if step.strip()]
192
+
193
+ if "title" not in recipe:
194
+ recipe["title"] = f"Rezept mit {', '.join(current_ingredients[:3])}"
195
+ if "ingredients" not in recipe:
196
+ recipe["ingredients"] = current_ingredients
197
+ if "directions" not in recipe:
198
+ recipe["directions"] = ["Keine Anweisungen generiert"]
199
+
200
+ if validate_recipe_ingredients(recipe["ingredients"], original_ingredients):
201
+ print(f"Success on attempt {attempt + 1}: Recipe has correct number of ingredients") # Debug-Print
202
+ return recipe
203
+ else:
204
+ print(f"Attempt {attempt + 1} failed: Expected {len(original_ingredients)} ingredients, got {len(recipe['ingredients'])}") # Debug-Print
205
+ if attempt == max_retries - 1:
206
+ print("Max retries reached, returning last generated recipe") # Debug-Print
207
+ return recipe
208
+ except Exception as e:
209
+ print(f"Error in recipe generation attempt {attempt + 1}: {str(e)}") # Debug-Print
210
+ if attempt == max_retries - 1:
211
+ return {
212
+ "title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}",
213
+ "ingredients": original_ingredients,
214
+ "directions": ["Fehler beim Generieren der Rezeptanweisungen"]
215
+ }
216
  return {
217
+ "title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}",
218
+ "ingredients": original_ingredients,
219
+ "directions": ["Fehler beim Generieren der Rezeptanweisungen"]
 
 
 
 
 
 
 
220
  }
221
 
 
222
  def process_recipe_request_logic(required_ingredients, available_ingredients, max_ingredients, max_retries):
223
  """
224
  Kernlogik zur Verarbeitung einer Rezeptgenerierungsanfrage.
 
229
  optimized_ingredients = find_best_ingredients(
230
  required_ingredients, available_ingredients, max_ingredients
231
  )
232
+ # KORRIGIERT: Aufruf der echten T5-Generierungsfunktion
233
+ recipe = generate_recipe_with_t5(optimized_ingredients, max_retries)
 
234
  result = {
235
  'title': recipe['title'],
236
  'ingredients': recipe['ingredients'],
 
241
  except Exception as e:
242
  return {"error": f"Fehler bei der Rezeptgenerierung: {str(e)}"}
243
 
 
244
  # --- FastAPI-Implementierung ---
245
+ app = FastAPI(title="AI Recipe Generator API") # Ohne Gradio-spezifische Titelzusätze
 
246
 
247
  class RecipeRequest(BaseModel):
248
  required_ingredients: list[str] = []
249
  available_ingredients: list[str] = []
250
  max_ingredients: int = 7
251
  max_retries: int = 5
252
+ # Optional: Für Abwärtskompatibilität, falls 'ingredients' als Top-Level-Feld gesendet wird
253
+ ingredients: list[str] = []
254
 
255
+ @app.post("/generate_recipe") # Der API-Endpunkt für Flutter
256
  async def generate_recipe_api(request_data: RecipeRequest):
257
+ """
258
+ Standard-REST-API-Endpunkt für die Flutter-App.
259
+ Nimmt direkt JSON-Daten an und gibt direkt JSON zurück.
260
+ """
261
+ # Wenn required_ingredients leer ist, aber ingredients vorhanden sind,
262
+ # verwende ingredients für Abwärtskompatibilität.
263
  final_required_ingredients = request_data.required_ingredients
264
  if not final_required_ingredients and request_data.ingredients:
265
  final_required_ingredients = request_data.ingredients
 
272
  )
273
  return JSONResponse(content=result_dict)
274
 
 
275
  @app.get("/")
276
  async def read_root():
277
+ return {"message": "AI Recipe Generator API is running (FastAPI only)!"} # Angepasste Nachricht
278
 
279
+ # Hier gibt es KEINEN Gradio-Mount oder Gradio-Launch-Befehl
280
+ # Das `app` Objekt ist eine reine FastAPI-Instanz
281
+ print("INFO: Pure FastAPI application script finished execution and defined 'app' variable.")
282