TimInf commited on
Commit
c436cf7
·
verified ·
1 Parent(s): c59cc6c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -36
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import AutoTokenizer, AutoModel
2
  import torch
3
  import numpy as np
4
  import random
@@ -7,16 +7,26 @@ from fastapi import FastAPI
7
  from fastapi.responses import JSONResponse
8
  from pydantic import BaseModel
9
 
10
- # Lade NUR RecipeBERT Modell
11
  bert_model_name = "alexdseo/RecipeBERT"
12
  bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
13
  bert_model = AutoModel.from_pretrained(bert_model_name)
14
- bert_model.eval() # Setze das Modell in den Evaluationsmodus
15
 
16
- # T5-Modell und -Logik KOMPLETT ENTFERNT für diesen Schritt
17
- # special_tokens und tokens_map sind nicht mehr relevant, bleiben aber als Kommentar
 
 
18
 
19
- # --- RecipeBERT-spezifische Funktionen ---
 
 
 
 
 
 
 
 
20
  def get_embedding(text):
21
  """Berechnet das Embedding für einen Text mit Mean Pooling über alle Tokens."""
22
  inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
@@ -29,11 +39,13 @@ def get_embedding(text):
29
  sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
30
  return (sum_embeddings / sum_mask).squeeze(0)
31
 
 
32
  def average_embedding(embedding_list):
33
  """Berechnet den Durchschnitt einer Liste von Embeddings."""
34
- tensors = torch.stack(embedding_list) # embedding_list enthält hier direkt die Tensoren
35
  return tensors.mean(dim=0)
36
 
 
37
  def get_cosine_similarity(vec1, vec2):
38
  """Berechnet die Cosinus-Ähnlichkeit zwischen zwei Vektoren."""
39
  if torch.is_tensor(vec1): vec1 = vec1.detach().numpy()
@@ -47,7 +59,7 @@ def get_cosine_similarity(vec1, vec2):
47
  return dot_product / (norm_a * norm_b)
48
 
49
 
50
- # find_best_ingredients (modifiziert, um die ähnlichste Zutat mit RecipeBERT zu finden)
51
  def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6):
52
  """
53
  Findet die besten Zutaten: Alle benötigten + EINE ähnlichste aus den verfügbaren Zutaten.
@@ -60,87 +72,86 @@ def find_best_ingredients(required_ingredients, available_ingredients, max_ingre
60
  # Nur wenn wir noch Platz haben und zusätzliche Zutaten verfügbar sind
61
  if len(final_ingredients) < max_ingredients and len(available_ingredients) > 0:
62
  if final_ingredients:
63
- # Berechne den Durchschnitts-Embedding der benötigten Zutaten
64
  required_embeddings = [get_embedding(ing) for ing in required_ingredients]
65
  avg_required_embedding = average_embedding(required_embeddings)
66
-
67
  best_additional_ingredient = None
68
  highest_similarity = -1.0
69
 
70
- # Finde die ähnlichste Zutat aus den verfügbaren
71
  for avail_ing in available_ingredients:
72
  avail_embedding = get_embedding(avail_ing)
73
  similarity = get_cosine_similarity(avg_required_embedding, avail_embedding)
74
  if similarity > highest_similarity:
75
  highest_similarity = similarity
76
  best_additional_ingredient = avail_ing
77
-
78
  if best_additional_ingredient:
79
  final_ingredients.append(best_additional_ingredient)
80
- print(f"INFO: Added '{best_additional_ingredient}' (similarity: {highest_similarity:.2f}) as most similar.")
 
81
  else:
82
- # Wenn keine benötigten Zutaten, wähle zufällig eine aus den verfügbaren (wie zuvor)
83
  random_ingredient = random.choice(available_ingredients)
84
  final_ingredients.append(random_ingredient)
85
  print(f"INFO: No required ingredients. Added random available ingredient: '{random_ingredient}'.")
86
 
87
- # Begrenze auf max_ingredients, falls durch Zufall/ähnlichster Auswahl zu viele hinzugefügt wurden
88
  return final_ingredients[:max_ingredients]
89
 
90
 
91
- # mock_generate_recipe (bleibt gleich)
92
  def mock_generate_recipe(ingredients_list):
93
- """Generiert ein Mock-Rezept, da T5-Modell entfernt ist."""
94
- title = f"Einfaches Rezept mit {', '.join(ingredients_list[:3])}" if ingredients_list else "Einfaches Testrezept"
95
  return {
96
  "title": title,
97
- "ingredients": ingredients_list, # Die "generierten" Zutaten sind einfach die Eingabe
98
  "directions": [
99
- "Dies ist ein generierter Text von RecipeBERT (ohne T5).",
100
- "Das Laden des RecipeBERT-Modells war erfolgreich!",
101
- f"Basierend auf deinen Eingaben wurde '{ingredients_list[-1]}' als ähnlichste Zutat hinzugefügt." if len(ingredients_list) > 1 else "Keine zusätzliche Zutat hinzugefügt."
 
 
102
  ],
103
- "used_ingredients": ingredients_list # In diesem Mock-Fall sind alle "used"
104
  }
105
 
106
 
107
  def process_recipe_request_logic(required_ingredients, available_ingredients, max_ingredients, max_retries):
108
  """
109
  Kernlogik zur Verarbeitung einer Rezeptgenerierungsanfrage.
110
- Für diesen Test wird nur RecipeBERT zum Laden getestet und ein Mock-Rezept zurückgegeben.
111
  """
112
  if not required_ingredients and not available_ingredients:
113
  return {"error": "Keine Zutaten angegeben"}
114
  try:
115
- # Hier wird die neue find_best_ingredients verwendet
116
  optimized_ingredients = find_best_ingredients(
117
  required_ingredients, available_ingredients, max_ingredients
118
  )
119
-
120
- # Rufe die Mock-Generierungsfunktion auf
121
- recipe = mock_generate_recipe(optimized_ingredients)
122
-
123
  result = {
124
  'title': recipe['title'],
125
  'ingredients': recipe['ingredients'],
126
  'directions': recipe['directions'],
127
- 'used_ingredients': optimized_ingredients # Jetzt wirklich die vom find_best_ingredients
128
  }
129
  return result
130
  except Exception as e:
131
  return {"error": f"Fehler bei der Rezeptgenerierung: {str(e)}"}
132
 
 
133
  # --- FastAPI-Implementierung ---
134
- app = FastAPI(title="AI Recipe Generator API (RecipeBERT Only Test)")
 
135
 
136
  class RecipeRequest(BaseModel):
137
  required_ingredients: list[str] = []
138
  available_ingredients: list[str] = []
139
  max_ingredients: int = 7
140
- max_retries: int = 5 # Wird hier nicht direkt genutzt, aber im Payload beibehalten
141
- ingredients: list[str] = [] # Für Abwärtskompatibilität
142
 
143
- @app.post("/generate_recipe") # Der API-Endpunkt für Flutter
 
144
  async def generate_recipe_api(request_data: RecipeRequest):
145
  final_required_ingredients = request_data.required_ingredients
146
  if not final_required_ingredients and request_data.ingredients:
@@ -150,12 +161,14 @@ async def generate_recipe_api(request_data: RecipeRequest):
150
  final_required_ingredients,
151
  request_data.available_ingredients,
152
  request_data.max_ingredients,
153
- request_data.max_retries # max_retries wird nur an die Logik übergeben, aber nicht verwendet
154
  )
155
  return JSONResponse(content=result_dict)
156
 
 
157
  @app.get("/")
158
  async def read_root():
159
- return {"message": "AI Recipe Generator API is running (RecipeBERT only, 1 similar ingredient)!"} # Angepasste Nachricht
 
160
 
161
  print("INFO: FastAPI application script finished execution and defined 'app' variable.")
 
1
+ from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer, AutoModel
2
  import torch
3
  import numpy as np
4
  import random
 
7
  from fastapi.responses import JSONResponse
8
  from pydantic import BaseModel
9
 
10
+ # Lade RecipeBERT Modell
11
  bert_model_name = "alexdseo/RecipeBERT"
12
  bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
13
  bert_model = AutoModel.from_pretrained(bert_model_name)
14
+ bert_model.eval() # Setze das Modell in den Evaluationsmodus
15
 
16
+ # Lade T5 Rezeptgenerierungsmodell (NEU hinzugefügt)
17
+ MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
18
+ t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
19
+ t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH) # Modell wird jetzt auch geladen
20
 
21
+ # Token Mapping für die T5 Modell-Ausgabe (bleibt hier, obwohl T5 noch nicht aktiv generiert)
22
+ special_tokens = t5_tokenizer.all_special_tokens
23
+ tokens_map = {
24
+ "<sep>": "--",
25
+ "<section>": "\n"
26
+ }
27
+
28
+
29
+ # --- RecipeBERT-spezifische Funktionen (unverändert) ---
30
  def get_embedding(text):
31
  """Berechnet das Embedding für einen Text mit Mean Pooling über alle Tokens."""
32
  inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
 
39
  sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
40
  return (sum_embeddings / sum_mask).squeeze(0)
41
 
42
+
43
  def average_embedding(embedding_list):
44
  """Berechnet den Durchschnitt einer Liste von Embeddings."""
45
+ tensors = torch.stack(embedding_list)
46
  return tensors.mean(dim=0)
47
 
48
+
49
  def get_cosine_similarity(vec1, vec2):
50
  """Berechnet die Cosinus-Ähnlichkeit zwischen zwei Vektoren."""
51
  if torch.is_tensor(vec1): vec1 = vec1.detach().numpy()
 
59
  return dot_product / (norm_a * norm_b)
60
 
61
 
62
+ # find_best_ingredients (unverändert, nutzt RecipeBERT für eine ähnlichste Zutat)
63
  def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6):
64
  """
65
  Findet die besten Zutaten: Alle benötigten + EINE ähnlichste aus den verfügbaren Zutaten.
 
72
  # Nur wenn wir noch Platz haben und zusätzliche Zutaten verfügbar sind
73
  if len(final_ingredients) < max_ingredients and len(available_ingredients) > 0:
74
  if final_ingredients:
 
75
  required_embeddings = [get_embedding(ing) for ing in required_ingredients]
76
  avg_required_embedding = average_embedding(required_embeddings)
77
+
78
  best_additional_ingredient = None
79
  highest_similarity = -1.0
80
 
 
81
  for avail_ing in available_ingredients:
82
  avail_embedding = get_embedding(avail_ing)
83
  similarity = get_cosine_similarity(avg_required_embedding, avail_embedding)
84
  if similarity > highest_similarity:
85
  highest_similarity = similarity
86
  best_additional_ingredient = avail_ing
87
+
88
  if best_additional_ingredient:
89
  final_ingredients.append(best_additional_ingredient)
90
+ print(
91
+ f"INFO: Added '{best_additional_ingredient}' (similarity: {highest_similarity:.2f}) as most similar.")
92
  else:
 
93
  random_ingredient = random.choice(available_ingredients)
94
  final_ingredients.append(random_ingredient)
95
  print(f"INFO: No required ingredients. Added random available ingredient: '{random_ingredient}'.")
96
 
 
97
  return final_ingredients[:max_ingredients]
98
 
99
 
100
+ # mock_generate_recipe (ANGEPASST, um zu bestätigen, dass BEIDE Modelle geladen sind)
101
  def mock_generate_recipe(ingredients_list):
102
+ """Generiert ein Mock-Rezept und bestätigt das Laden beider Modelle."""
103
+ title = f"Rezepttest mit {', '.join(ingredients_list[:3])}" if ingredients_list else "Testrezept"
104
  return {
105
  "title": title,
106
+ "ingredients": ingredients_list,
107
  "directions": [
108
+ "Dies ist ein Testrezept.",
109
+ "RecipeBERT und T5-Modell wurden beide erfolgreich geladen!",
110
+ "Die Zutaten wurden mit RecipeBERT-Intelligenz ausgewählt.",
111
+ f"Basierend auf deinen Eingaben wurde '{ingredients_list[-1]}' als ähnlichste Zutat hinzugefügt." if len(
112
+ ingredients_list) > 1 else "Keine zusätzliche Zutat hinzugefügt."
113
  ],
114
+ "used_ingredients": ingredients_list
115
  }
116
 
117
 
118
  def process_recipe_request_logic(required_ingredients, available_ingredients, max_ingredients, max_retries):
119
  """
120
  Kernlogik zur Verarbeitung einer Rezeptgenerierungsanfrage.
 
121
  """
122
  if not required_ingredients and not available_ingredients:
123
  return {"error": "Keine Zutaten angegeben"}
124
  try:
 
125
  optimized_ingredients = find_best_ingredients(
126
  required_ingredients, available_ingredients, max_ingredients
127
  )
128
+
129
+ recipe = mock_generate_recipe(optimized_ingredients) # Rufe die Mock-Generierungsfunktion auf
130
+
 
131
  result = {
132
  'title': recipe['title'],
133
  'ingredients': recipe['ingredients'],
134
  'directions': recipe['directions'],
135
+ 'used_ingredients': optimized_ingredients
136
  }
137
  return result
138
  except Exception as e:
139
  return {"error": f"Fehler bei der Rezeptgenerierung: {str(e)}"}
140
 
141
+
142
  # --- FastAPI-Implementierung ---
143
+ app = FastAPI(title="AI Recipe Generator API (Both Models Loaded Test)")
144
+
145
 
146
  class RecipeRequest(BaseModel):
147
  required_ingredients: list[str] = []
148
  available_ingredients: list[str] = []
149
  max_ingredients: int = 7
150
+ max_retries: int = 5
151
+ ingredients: list[str] = [] # Für Abwärtskompatibilität
152
 
153
+
154
+ @app.post("/generate_recipe") # Der API-Endpunkt für Flutter
155
  async def generate_recipe_api(request_data: RecipeRequest):
156
  final_required_ingredients = request_data.required_ingredients
157
  if not final_required_ingredients and request_data.ingredients:
 
161
  final_required_ingredients,
162
  request_data.available_ingredients,
163
  request_data.max_ingredients,
164
+ request_data.max_retries
165
  )
166
  return JSONResponse(content=result_dict)
167
 
168
+
169
  @app.get("/")
170
  async def read_root():
171
+ return {"message": "AI Recipe Generator API is running (Both models loaded for test)!"} # Angepasste Nachricht
172
+
173
 
174
  print("INFO: FastAPI application script finished execution and defined 'app' variable.")