TimInf commited on
Commit
908127c
·
verified ·
1 Parent(s): 0f17420

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -157
app.py CHANGED
@@ -1,178 +1,105 @@
1
- from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer # AutoModel entfernt
2
- import torch # Beibehalten
3
- import numpy as np # Beibehalten
4
  import random
5
  import json
6
  from fastapi import FastAPI
7
  from fastapi.responses import JSONResponse
8
  from pydantic import BaseModel
9
 
10
- # Lade RecipeBERT Modell (KOMPLETT ENTFERNT für diesen Schritt)
11
- # bert_model_name = "alexdseo/RecipeBERT"
12
- # bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
13
- # bert_model = AutoModel.from_pretrained(bert_model_name)
14
- # bert_model.eval()
15
-
16
- # Lade T5 Rezeptgenerierungsmodell
17
- MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
18
- t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
19
- t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
20
-
21
- # Token Mapping für die T5 Modell-Ausgabe
22
- special_tokens = t5_tokenizer.all_special_token
23
- tokens_map = {
24
- "<sep>": "--",
25
- "<section>": "\n"
26
- }
27
-
28
- # --- RecipeBERT-spezifische Funktionen sind entfernt oder vereinfacht ---
29
- # get_embedding, average_embedding, get_cosine_similarity, get_combined_scores sind entfernt.
30
-
31
- # find_best_ingredients (modifiziert, um KEINE Embeddings zu nutzen)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6, avg_weight=0.6):
33
  """
34
- Findet die besten Zutaten. Für diesen einfachen Test wird nur
35
- die Liste der benötigten Zutaten um zufällig ausgewählte
36
- verfügbare Zutaten ergänzt, OHNE Embeddings zu nutzen.
37
  """
38
  required_ingredients = list(set(required_ingredients))
39
  available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients]))
40
 
41
- # Sonderfall: Wenn keine benötigten Zutaten vorhanden sind, wähle zufällig eine aus den verfügbaren Zutaten
42
- if not required_ingredients and available_ingredients:
43
- random_ingredient = random.choice(available_ingredients)
44
- required_ingredients = [random_ingredient]
45
- available_ingredients = [i for i in available_ingredients if i != random_ingredient]
46
-
47
- # Wenn bereits maximale Kapazität erreicht ist
48
- if len(required_ingredients) >= max_ingredients:
49
- return required_ingredients[:max_ingredients]
50
-
51
- # Wenn keine zusätzlichen Zutaten verfügbar sind
52
- if not available_ingredients:
53
- return required_ingredients
54
-
55
- # Füge zufällig weitere Zutaten hinzu, bis max_ingredients erreicht ist
56
- current_ingredients = required_ingredients.copy()
57
- num_to_add = min(max_ingredients - len(current_ingredients), len(available_ingredients))
58
-
59
- # Wähle zufällig aus den verfügbaren Zutaten
60
- selected_from_available = random.sample(available_ingredients, num_to_add)
61
- current_ingredients.extend(selected_from_available)
62
-
63
- return current_ingredients
64
-
65
-
66
- def skip_special_tokens(text, special_tokens):
67
- """Entfernt spezielle Tokens aus dem Text"""
68
- for token in special_tokens:
69
- text = text.replace(token, "")
70
- return text
71
-
72
- def target_postprocessing(texts, special_tokens):
73
- """Post-processed generierten Text"""
74
- if not isinstance(texts, list):
75
- texts = [texts]
76
- new_texts = []
77
- for text in texts:
78
- text = skip_special_tokens(text, special_tokens)
79
- for k, v in tokens_map.items():
80
- text = text.replace(k, v)
81
- new_texts.append(text)
82
- return new_texts
83
-
84
- def validate_recipe_ingredients(recipe_ingredients, expected_ingredients, tolerance=0):
85
- """
86
- Validiert, ob das Rezept ungefähr die erwarteten Zutaten enthält.
87
- """
88
- recipe_count = len([ing for ing in recipe_ingredients if ing and ing.strip()])
89
- expected_count = len(expected_ingredients)
90
- return abs(recipe_count - expected_count) == tolerance
91
-
92
- def generate_recipe_with_t5(ingredients_list, max_retries=5):
93
- """Generiert ein Rezept mit dem T5 Rezeptgenerierungsmodell mit Validierung."""
94
- original_ingredients = ingredients_list.copy()
95
- for attempt in range(max_retries):
96
- try:
97
- if attempt > 0:
98
- current_ingredients = original_ingredients.copy()
99
- random.shuffle(current_ingredients)
100
- else:
101
- current_ingredients = ingredients_list
102
- ingredients_string = ", ".join(current_ingredients)
103
- prefix = "items: "
104
- generation_kwargs = {
105
- "max_length": 512,
106
- "min_length": 64,
107
- "do_sample": True,
108
- "top_k": 60,
109
- "top_p": 0.95
110
- }
111
- inputs = t5_tokenizer(
112
- prefix + ingredients_string,
113
- max_length=256,
114
- padding="max_length",
115
- truncation=True,
116
- return_tensors="jax"
117
- )
118
- output_ids = t5_model.generate(
119
- input_ids=inputs.input_ids,
120
- attention_mask=inputs.attention_mask,
121
- **generation_kwargs
122
- )
123
- generated = output_ids.sequences
124
- generated_text = target_postprocessing(
125
- t5_tokenizer.batch_decode(generated, skip_special_tokens=False),
126
- special_tokens
127
- )[0]
128
- recipe = {}
129
- sections = generated_text.split("\n")
130
- for section in sections:
131
- section = section.strip()
132
- if section.startswith("title:"):
133
- recipe["title"] = section.replace("title:", "").strip().capitalize()
134
- elif section.startswith("ingredients:"):
135
- ingredients_text = section.replace("ingredients:", "").strip()
136
- recipe["ingredients"] = [item.strip().capitalize() for item in ingredients_text.split("--") if item.strip()]
137
- elif section.startswith("directions:"):
138
- directions_text = section.replace("directions:", "").strip()
139
- recipe["directions"] = [step.strip().capitalize() for step in directions_text.split("--") if step.strip()]
140
- if "title" not in recipe:
141
- recipe["title"] = f"Rezept mit {', '.join(current_ingredients[:3])}"
142
- if "ingredients" not in recipe:
143
- recipe["ingredients"] = current_ingredients
144
- if "directions" not in recipe:
145
- recipe["directions"] = ["Keine Anweisungen generiert"]
146
- if validate_recipe_ingredients(recipe["ingredients"], original_ingredients):
147
- return recipe
148
- else:
149
- if attempt == max_retries - 1:
150
- return recipe
151
- except Exception as e:
152
- if attempt == max_retries - 1:
153
- return {
154
- "title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}",
155
- "ingredients": original_ingredients,
156
- "directions": ["Fehler beim Generieren der Rezeptanweisungen"]
157
- }
158
  return {
159
- "title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}",
160
- "ingredients": original_ingredients,
161
- "directions": ["Fehler beim Generieren der Rezeptanweisungen"]
 
 
 
 
162
  }
163
 
 
164
  def process_recipe_request_logic(required_ingredients, available_ingredients, max_ingredients, max_retries):
165
  """
166
  Kernlogik zur Verarbeitung einer Rezeptgenerierungsanfrage.
 
167
  """
168
  if not required_ingredients and not available_ingredients:
169
  return {"error": "Keine Zutaten angegeben"}
170
  try:
171
- # Hier wird die vereinfachte find_best_ingredients verwendet, die KEINE Embeddings nutzt.
172
  optimized_ingredients = find_best_ingredients(
173
  required_ingredients, available_ingredients, max_ingredients
174
  )
175
- recipe = generate_recipe_with_t5(optimized_ingredients, max_retries)
 
 
 
176
  result = {
177
  'title': recipe['title'],
178
  'ingredients': recipe['ingredients'],
@@ -184,7 +111,7 @@ def process_recipe_request_logic(required_ingredients, available_ingredients, ma
184
  return {"error": f"Fehler bei der Rezeptgenerierung: {str(e)}"}
185
 
186
  # --- FastAPI-Implementierung ---
187
- app = FastAPI(title="AI Recipe Generator API") # Deine FastAPI-Instanz
188
 
189
  class RecipeRequest(BaseModel):
190
  required_ingredients: list[str] = []
@@ -195,10 +122,6 @@ class RecipeRequest(BaseModel):
195
 
196
  @app.post("/generate_recipe") # Der API-Endpunkt für Flutter
197
  async def generate_recipe_api(request_data: RecipeRequest):
198
- """
199
- Standard-REST-API-Endpunkt für die Flutter-App.
200
- Nimmt direkt JSON-Daten an und gibt direkt JSON zurück.
201
- """
202
  final_required_ingredients = request_data.required_ingredients
203
  if not final_required_ingredients and request_data.ingredients:
204
  final_required_ingredients = request_data.ingredients
@@ -211,9 +134,8 @@ async def generate_recipe_api(request_data: RecipeRequest):
211
  )
212
  return JSONResponse(content=result_dict)
213
 
214
- # Optionaler Root-Endpunkt für Health-Checks
215
  @app.get("/")
216
  async def read_root():
217
- return {"message": "AI Recipe Generator API is running (T5 only)!"} # Angepasste Nachricht
218
 
219
  print("INFO: FastAPI application script finished execution and defined 'app' variable.")
 
1
+ from transformers import AutoTokenizer, AutoModel # Entfernt: FlaxAutoModelForSeq2SeqLM
2
+ import torch
3
+ import numpy as np
4
  import random
5
  import json
6
  from fastapi import FastAPI
7
  from fastapi.responses import JSONResponse
8
  from pydantic import BaseModel
9
 
10
+ # Lade NUR RecipeBERT Modell
11
+ bert_model_name = "alexdseo/RecipeBERT"
12
+ bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
13
+ bert_model = AutoModel.from_pretrained(bert_model_name)
14
+ bert_model.eval() # Setze das Modell in den Evaluationsmodus
15
+
16
+ # T5-Modell und -Logik KOMPLETT ENTFERNT für diesen Schritt
17
+ # special_tokens und tokens_map sind nicht mehr relevant, bleiben aber als Kommentar
18
+
19
+ # --- RecipeBERT-spezifische Funktionen (die jetzt die Kernlogik sind) ---
20
+ def get_embedding(text):
21
+ """Berechnet das Embedding für einen Text mit Mean Pooling über alle Tokens."""
22
+ inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
23
+ with torch.no_grad():
24
+ outputs = bert_model(**inputs)
25
+ attention_mask = inputs['attention_mask']
26
+ token_embeddings = outputs.last_hidden_state
27
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
28
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
29
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
30
+ return (sum_embeddings / sum_mask).squeeze(0)
31
+
32
+ def get_cosine_similarity(vec1, vec2):
33
+ """Berechnet die Cosinus-Ähnlichkeit zwischen zwei Vektoren."""
34
+ if torch.is_tensor(vec1): vec1 = vec1.detach().numpy()
35
+ if torch.is_tensor(vec2): vec2 = vec2.detach().numpy()
36
+ vec1 = vec1.flatten()
37
+ vec2 = vec2.flatten()
38
+ dot_product = np.dot(vec1, vec2)
39
+ norm_a = np.linalg.norm(vec1)
40
+ norm_b = np.linalg.norm(vec2)
41
+ if norm_a == 0 or norm_b == 0: return 0
42
+ return dot_product / (norm_a * norm_b)
43
+
44
+
45
+ # find_best_ingredients (modifiziert, um KEINE Embeddings für T5-ähnliche Auswahl zu nutzen,
46
+ # sondern nur grundlegende Zutatenbearbeitung und Optionalen Test für RecipeBERT-Laden)
47
  def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6, avg_weight=0.6):
48
  """
49
+ Für diesen Test: Gibt einfach die benötigten Zutaten plus ein paar zufällige verfügbare Zutaten zurück.
50
+ Die semantische Auswahl von RecipeBERT ist hier nicht aktiv (da T5-Generierung fehlt).
 
51
  """
52
  required_ingredients = list(set(required_ingredients))
53
  available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients]))
54
 
55
+ final_ingredients = required_ingredients.copy()
56
+ num_to_add = min(max_ingredients - len(final_ingredients), len(available_ingredients))
57
+ if num_to_add > 0:
58
+ final_ingredients.extend(random.sample(available_ingredients, num_to_add))
59
+
60
+ # Optional: Ein kleiner Test-Print, ob RecipeBERT erfolgreich geladen wurde
61
+ try:
62
+ if final_ingredients:
63
+ # Versuche ein Embedding für die erste Zutat zu generieren
64
+ test_embedding = get_embedding(final_ingredients[0])
65
+ print(f"INFO: Successfully generated embedding for '{final_ingredients[0]}'. RecipeBERT is loaded.")
66
+ else:
67
+ print("INFO: No ingredients to test embedding with.")
68
+ except Exception as e:
69
+ print(f"ERROR: RecipeBERT embedding test failed: {e}")
70
+
71
+ return final_ingredients
72
+
73
+ # mock_generate_recipe (ersetzt generate_recipe_with_t5)
74
+ def mock_generate_recipe(ingredients_list):
75
+ """Generiert ein Mock-Rezept, da T5-Modell entfernt ist."""
76
+ title = f"Einfaches Rezept mit {', '.join(ingredients_list[:3])}" if ingredients_list else "Einfaches Testrezept"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  return {
78
+ "title": title,
79
+ "ingredients": ingredients_list, # Die "generierten" Zutaten sind einfach die Eingabe
80
+ "directions": [
81
+ "Dies ist ein generierter Text von RecipeBERT (ohne T5).",
82
+ "Fügen Sie Ihre Zutaten zusammen und kochen Sie es nach Belieben.",
83
+ "Das Laden des RecipeBERT-Modells war erfolgreich!"
84
+ ]
85
  }
86
 
87
+
88
  def process_recipe_request_logic(required_ingredients, available_ingredients, max_ingredients, max_retries):
89
  """
90
  Kernlogik zur Verarbeitung einer Rezeptgenerierungsanfrage.
91
+ Für diesen Test wird nur RecipeBERT zum Laden getestet und ein Mock-Rezept zurückgegeben.
92
  """
93
  if not required_ingredients and not available_ingredients:
94
  return {"error": "Keine Zutaten angegeben"}
95
  try:
 
96
  optimized_ingredients = find_best_ingredients(
97
  required_ingredients, available_ingredients, max_ingredients
98
  )
99
+
100
+ # Rufe die Mock-Generierungsfunktion auf
101
+ recipe = mock_generate_recipe(optimized_ingredients)
102
+
103
  result = {
104
  'title': recipe['title'],
105
  'ingredients': recipe['ingredients'],
 
111
  return {"error": f"Fehler bei der Rezeptgenerierung: {str(e)}"}
112
 
113
  # --- FastAPI-Implementierung ---
114
+ app = FastAPI(title="AI Recipe Generator API (RecipeBERT Only Test)")
115
 
116
  class RecipeRequest(BaseModel):
117
  required_ingredients: list[str] = []
 
122
 
123
  @app.post("/generate_recipe") # Der API-Endpunkt für Flutter
124
  async def generate_recipe_api(request_data: RecipeRequest):
 
 
 
 
125
  final_required_ingredients = request_data.required_ingredients
126
  if not final_required_ingredients and request_data.ingredients:
127
  final_required_ingredients = request_data.ingredients
 
134
  )
135
  return JSONResponse(content=result_dict)
136
 
 
137
  @app.get("/")
138
  async def read_root():
139
+ return {"message": "AI Recipe Generator API is running (RecipeBERT only)!"} # Angepasste Nachricht
140
 
141
  print("INFO: FastAPI application script finished execution and defined 'app' variable.")