TimInf commited on
Commit
f0ebec2
·
verified ·
1 Parent(s): 1cc7a66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -221
app.py CHANGED
@@ -1,221 +1,37 @@
1
- from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer, AutoModel
2
- import torch
3
- import numpy as np
4
- import random
5
- import json
6
- from fastapi import FastAPI
7
- from fastapi.responses import JSONResponse
8
- from pydantic import BaseModel
9
-
10
- # Lade RecipeBERT Modell
11
- bert_model_name = "alexdseo/RecipeBERT"
12
- bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
13
- bert_model = AutoModel.from_pretrained(bert_model_name)
14
- bert_model.eval()
15
-
16
- # Lade T5 Rezeptgenerierungsmodell
17
- MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
18
- t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
19
- t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
20
-
21
- # Token Mapping
22
- special_tokens = t5_tokenizer.all_special_tokens
23
- tokens_map = {
24
- "<sep>": "--",
25
- "<section>": "\n"
26
- }
27
-
28
- # --- Deine Helper-Funktionen (get_embedding, average_embedding, get_cosine_similarity, etc.) ---
29
- # Kopiere alle diese Funktionen von deinem aktuellen app.py hierher.
30
- # Ich kürze sie hier aus Platzgründen, aber sie müssen vollständig eingefügt werden.
31
-
32
- def get_embedding(text):
33
- inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
34
- with torch.no_grad():
35
- outputs = bert_model(**inputs)
36
- attention_mask = inputs['attention_mask']
37
- token_embeddings = outputs.last_hidden_state
38
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
39
- sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
40
- sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
41
- return (sum_embeddings / sum_mask).squeeze(0)
42
-
43
- def average_embedding(embedding_list):
44
- tensors = torch.stack([emb for _, emb in embedding_list])
45
- return tensors.mean(dim=0)
46
-
47
- def get_cosine_similarity(vec1, vec2):
48
- if torch.is_tensor(vec1): vec1 = vec1.detach().numpy()
49
- if torch.is_tensor(vec2): vec2 = vec2.detach().numpy()
50
- vec1 = vec1.flatten()
51
- vec2 = vec2.flatten()
52
- dot_product = np.dot(vec1, vec2)
53
- norm_a = np.linalg.norm(vec1)
54
- norm_b = np.linalg.norm(vec2)
55
- if norm_a == 0 or norm_b == 0: return 0
56
- return dot_product / (norm_a * norm_b)
57
-
58
- def get_combined_scores(query_vector, embedding_list, all_good_embeddings, avg_weight=0.6):
59
- results = []
60
- for name, emb in embedding_list:
61
- avg_similarity = get_cosine_similarity(query_vector, emb)
62
- individual_similarities = [get_cosine_similarity(good_emb, emb) for _, good_emb in all_good_embeddings]
63
- avg_individual_similarity = sum(individual_similarities) / len(individual_similarities)
64
- combined_score = avg_weight * avg_similarity + (1 - avg_weight) * avg_individual_similarity
65
- results.append((name, emb, combined_score))
66
- results.sort(key=lambda x: x[2], reverse=True)
67
- return results
68
-
69
- def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6, avg_weight=0.6):
70
- required_ingredients = list(set(required_ingredients))
71
- available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients]))
72
- if not required_ingredients and available_ingredients:
73
- random_ingredient = random.choice(available_ingredients)
74
- required_ingredients = [random_ingredient]
75
- available_ingredients = [i for i in available_ingredients if i != random_ingredient]
76
- if not required_ingredients or len(required_ingredients) >= max_ingredients:
77
- return required_ingredients[:max_ingredients]
78
- if not available_ingredients:
79
- return required_ingredients
80
- embed_required = [(e, get_embedding(e)) for e in required_ingredients]
81
- embed_available = [(e, get_embedding(e)) for e in available_ingredients]
82
- num_to_add = min(max_ingredients - len(required_ingredients), len(available_ingredients))
83
- final_ingredients = embed_required.copy()
84
- for _ in range(num_to_add):
85
- avg = average_embedding(final_ingredients)
86
- candidates = get_combined_scores(avg, embed_available, final_ingredients, avg_weight)
87
- if not candidates: break
88
- best_name, best_embedding, _ = candidates[0]
89
- final_ingredients.append((best_name, best_embedding))
90
- embed_available = [item for item in embed_available if item[0] != best_name]
91
- return [name for name, _ in final_ingredients]
92
-
93
- def skip_special_tokens(text, special_tokens):
94
- for token in special_tokens: text = text.replace(token, "")
95
- return text
96
-
97
- def target_postprocessing(texts, special_tokens):
98
- if not isinstance(texts, list): texts = [texts]
99
- new_texts = []
100
- for text in texts:
101
- text = skip_special_tokens(text, special_tokens)
102
- for k, v in tokens_map.items(): text = text.replace(k, v)
103
- new_texts.append(text)
104
- return new_texts
105
-
106
- def validate_recipe_ingredients(recipe_ingredients, expected_ingredients, tolerance=0):
107
- recipe_count = len([ing for ing in recipe_ingredients if ing and ing.strip()])
108
- expected_count = len(expected_ingredients)
109
- return abs(recipe_count - expected_count) == tolerance
110
-
111
- def generate_recipe_with_t5(ingredients_list, max_retries=5):
112
- original_ingredients = ingredients_list.copy()
113
- for attempt in range(max_retries):
114
- try:
115
- if attempt > 0:
116
- current_ingredients = original_ingredients.copy()
117
- random.shuffle(current_ingredients)
118
- else:
119
- current_ingredients = ingredients_list
120
- ingredients_string = ", ".join(current_ingredients)
121
- prefix = "items: "
122
- generation_kwargs = {
123
- "max_length": 512, "min_length": 64, "do_sample": True,
124
- "top_k": 60, "top_p": 0.95
125
- }
126
- inputs = t5_tokenizer(
127
- prefix + ingredients_string, max_length=256, padding="max_length",
128
- truncation=True, return_tensors="jax"
129
- )
130
- output_ids = t5_model.generate(
131
- input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, **generation_kwargs
132
- )
133
- generated = output_ids.sequences
134
- generated_text = target_postprocessing(t5_tokenizer.batch_decode(generated, skip_special_tokens=False), special_tokens)[0]
135
- recipe = {}
136
- sections = generated_text.split("\n")
137
- for section in sections:
138
- section = section.strip()
139
- if section.startswith("title:"):
140
- recipe["title"] = section.replace("title:", "").strip().capitalize()
141
- elif section.startswith("ingredients:"):
142
- ingredients_text = section.replace("ingredients:", "").strip()
143
- recipe["ingredients"] = [item.strip().capitalize() for item in ingredients_text.split("--") if item.strip()]
144
- elif section.startswith("directions:"):
145
- directions_text = section.replace("directions:", "").strip()
146
- recipe["directions"] = [step.strip().capitalize() for step in directions_text.split("--") if step.strip()]
147
- if "title" not in recipe:
148
- recipe["title"] = f"Rezept mit {', '.join(current_ingredients[:3])}"
149
- if "ingredients" not in recipe:
150
- recipe["ingredients"] = current_ingredients
151
- if "directions" not in recipe:
152
- recipe["directions"] = ["Keine Anweisungen generiert"]
153
- if validate_recipe_ingredients(recipe["ingredients"], original_ingredients):
154
- return recipe
155
- else:
156
- if attempt == max_retries - 1: return recipe
157
- except Exception as e:
158
- if attempt == max_retries - 1:
159
- return {
160
- "title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}",
161
- "ingredients": original_ingredients,
162
- "directions": ["Fehler beim Generieren der Rezeptanweisungen"]
163
- }
164
- return {
165
- "title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}",
166
- "ingredients": original_ingredients,
167
- "directions": ["Fehler beim Generieren der Rezeptanweisungen"]
168
- }
169
-
170
- # Kernlogik, die von der FastAPI-Route aufgerufen wird
171
- def process_recipe_request_logic(required_ingredients, available_ingredients, max_ingredients, max_retries):
172
- """
173
- Kernlogik zur Verarbeitung einer Rezeptgenerierungsanfrage.
174
- """
175
- if not required_ingredients and not available_ingredients:
176
- return {"error": "Keine Zutaten angegeben"}
177
- try:
178
- optimized_ingredients = find_best_ingredients(
179
- required_ingredients, available_ingredients, max_ingredients
180
- )
181
- recipe = generate_recipe_with_t5(optimized_ingredients, max_retries)
182
- result = {
183
- 'title': recipe['title'],
184
- 'ingredients': recipe['ingredients'],
185
- 'directions': recipe['directions'],
186
- 'used_ingredients': optimized_ingredients
187
- }
188
- return result
189
- except Exception as e:
190
- return {"error": f"Fehler bei der Rezeptgenerierung: {str(e)}"}
191
-
192
- # --- FastAPI-Implementierung ---
193
- app = FastAPI() # Deine FastAPI-Instanz
194
-
195
- class RecipeRequest(BaseModel):
196
- required_ingredients: list[str] = []
197
- available_ingredients: list[str] = []
198
- max_ingredients: int = 7
199
- max_retries: int = 5
200
- # Abwärtskompatibilität: Falls 'ingredients' als Top-Level-Feld gesendet wird
201
- ingredients: list[str] = []
202
-
203
- @app.post("/generate_recipe") # Der API-Endpunkt für Flutter
204
- async def generate_recipe_api(request_data: RecipeRequest):
205
- """
206
- Standard-REST-API-Endpunkt für die Flutter-App.
207
- Nimmt direkt JSON-Daten an und gibt direkt JSON zurück.
208
- """
209
- final_required_ingredients = request_data.required_ingredients
210
- if not final_required_ingredients and request_data.ingredients:
211
- final_required_ingredients = request_data.ingredients
212
-
213
- result_dict = process_recipe_request_logic(
214
- final_required_ingredients,
215
- request_data.available_ingredients,
216
- request_data.max_ingredients,
217
- request_data.max_retries
218
- )
219
- return JSONResponse(content=result_dict)
220
-
221
- print("INFO: FastAPI application script finished execution and defined 'app' variable.")
 
1
+ from fastapi import FastAPI
2
+ from fastapi.responses import JSONResponse
3
+ from pydantic import BaseModel # Bleibt, da FastAPI es für Request Body Parsing nutzt
4
+
5
+ # Deine FastAPI-Instanz
6
+ app = FastAPI(title="Minimal FastAPI Test App")
7
+
8
+ # Eine einfache Request-Modell-Klasse (auch wenn wir sie hier nicht wirklich nutzen,
9
+ # zeigt es, dass Pydantic funktioniert)
10
+ class TestRequest(BaseModel):
11
+ message: str = "Hello"
12
+
13
+ # Ein einfacher "Hello World"-Endpunkt, der auf POST-Anfragen reagiert
14
+ @app.post("/test")
15
+ async def test_api_endpoint(request_data: TestRequest):
16
+ """
17
+ Ein einfacher Test-Endpunkt.
18
+ Gibt eine Begrüßungsnachricht zurück, die die empfangene Nachricht enthält.
19
+ """
20
+ print(f"INFO: Received test request with message: {request_data.message}") # Log für den Space
21
+ return JSONResponse(content={"status": "success", "message": f"Hello from FastAPI! You sent: {request_data.message}"})
22
+
23
+ # Ein optionaler Root-Endpunkt (oft gut für Health-Checks)
24
+ @app.get("/")
25
+ async def read_root():
26
+ """
27
+ Root-Endpunkt für grundlegenden Health-Check.
28
+ """
29
+ return {"message": "FastAPI is running!"}
30
+
31
+ print("INFO: FastAPI application script finished execution and defined 'app' variable.")
32
+
33
+ # Dieser Block wird in Hugging Face Spaces nicht direkt ausgeführt, da Uvicorn
34
+ # die App direkt lädt, aber er ist für lokale Tests nützlich.
35
+ if __name__ == "__main__":
36
+ import uvicorn
37
+ uvicorn.run(app, host="0.0.0.0", port=7860) # Lokaler Port 7860, wie in Space