TimInf commited on
Commit
4f9d434
·
verified ·
1 Parent(s): b360d1c

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +24 -0
  2. app.py +221 -0
  3. requirements.txt +9 -0
Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Verwende ein offizielles Python-Image als Basis
2
+ FROM python:3.10-slim-bullseye
3
+
4
+ # Setze das Arbeitsverzeichnis im Container
5
+ WORKDIR /app
6
+
7
+ # Kopiere die requirements.txt in das Arbeitsverzeichnis
8
+ COPY requirements.txt .
9
+
10
+ # Installiere die Python-Abhängigkeiten
11
+ # --no-cache-dir: keine Pip-Cache-Dateien schreiben
12
+ # --upgrade: installierte Pakete aktualisieren
13
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
14
+
15
+ # Kopiere den Rest deiner Anwendungsdateien (app.py, etc.) in das Arbeitsverzeichnis
16
+ COPY . .
17
+
18
+ # Exponiere den Port, auf dem Uvicorn lauschen wird
19
+ # Dies ist der Standard-Port für Hugging Face Spaces
20
+ EXPOSE 7860
21
+
22
+ # Starte die FastAPI-Anwendung mit Uvicorn
23
+ # 'app:app' bedeutet: Finde die Variable 'app' im Modul 'app.py'
24
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer, AutoModel
2
+ import torch
3
+ import numpy as np
4
+ import random
5
+ import json
6
+ from fastapi import FastAPI
7
+ from fastapi.responses import JSONResponse
8
+ from pydantic import BaseModel
9
+
10
+ # Lade RecipeBERT Modell
11
+ bert_model_name = "alexdseo/RecipeBERT"
12
+ bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
13
+ bert_model = AutoModel.from_pretrained(bert_model_name)
14
+ bert_model.eval()
15
+
16
+ # Lade T5 Rezeptgenerierungsmodell
17
+ MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
18
+ t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
19
+ t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
20
+
21
+ # Token Mapping
22
+ special_tokens = t5_tokenizer.all_special_tokens
23
+ tokens_map = {
24
+ "<sep>": "--",
25
+ "<section>": "\n"
26
+ }
27
+
28
+ # --- Deine Helper-Funktionen (get_embedding, average_embedding, get_cosine_similarity, etc.) ---
29
+ # Kopiere alle diese Funktionen von deinem aktuellen app.py hierher.
30
+ # Ich kürze sie hier aus Platzgründen, aber sie müssen vollständig eingefügt werden.
31
+
32
+ def get_embedding(text):
33
+ inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
34
+ with torch.no_grad():
35
+ outputs = bert_model(**inputs)
36
+ attention_mask = inputs['attention_mask']
37
+ token_embeddings = outputs.last_hidden_state
38
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
39
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
40
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
41
+ return (sum_embeddings / sum_mask).squeeze(0)
42
+
43
+ def average_embedding(embedding_list):
44
+ tensors = torch.stack([emb for _, emb in embedding_list])
45
+ return tensors.mean(dim=0)
46
+
47
+ def get_cosine_similarity(vec1, vec2):
48
+ if torch.is_tensor(vec1): vec1 = vec1.detach().numpy()
49
+ if torch.is_tensor(vec2): vec2 = vec2.detach().numpy()
50
+ vec1 = vec1.flatten()
51
+ vec2 = vec2.flatten()
52
+ dot_product = np.dot(vec1, vec2)
53
+ norm_a = np.linalg.norm(vec1)
54
+ norm_b = np.linalg.norm(vec2)
55
+ if norm_a == 0 or norm_b == 0: return 0
56
+ return dot_product / (norm_a * norm_b)
57
+
58
+ def get_combined_scores(query_vector, embedding_list, all_good_embeddings, avg_weight=0.6):
59
+ results = []
60
+ for name, emb in embedding_list:
61
+ avg_similarity = get_cosine_similarity(query_vector, emb)
62
+ individual_similarities = [get_cosine_similarity(good_emb, emb) for _, good_emb in all_good_embeddings]
63
+ avg_individual_similarity = sum(individual_similarities) / len(individual_similarities)
64
+ combined_score = avg_weight * avg_similarity + (1 - avg_weight) * avg_individual_similarity
65
+ results.append((name, emb, combined_score))
66
+ results.sort(key=lambda x: x[2], reverse=True)
67
+ return results
68
+
69
+ def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6, avg_weight=0.6):
70
+ required_ingredients = list(set(required_ingredients))
71
+ available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients]))
72
+ if not required_ingredients and available_ingredients:
73
+ random_ingredient = random.choice(available_ingredients)
74
+ required_ingredients = [random_ingredient]
75
+ available_ingredients = [i for i in available_ingredients if i != random_ingredient]
76
+ if not required_ingredients or len(required_ingredients) >= max_ingredients:
77
+ return required_ingredients[:max_ingredients]
78
+ if not available_ingredients:
79
+ return required_ingredients
80
+ embed_required = [(e, get_embedding(e)) for e in required_ingredients]
81
+ embed_available = [(e, get_embedding(e)) for e in available_ingredients]
82
+ num_to_add = min(max_ingredients - len(required_ingredients), len(available_ingredients))
83
+ final_ingredients = embed_required.copy()
84
+ for _ in range(num_to_add):
85
+ avg = average_embedding(final_ingredients)
86
+ candidates = get_combined_scores(avg, embed_available, final_ingredients, avg_weight)
87
+ if not candidates: break
88
+ best_name, best_embedding, _ = candidates[0]
89
+ final_ingredients.append((best_name, best_embedding))
90
+ embed_available = [item for item in embed_available if item[0] != best_name]
91
+ return [name for name, _ in final_ingredients]
92
+
93
+ def skip_special_tokens(text, special_tokens):
94
+ for token in special_tokens: text = text.replace(token, "")
95
+ return text
96
+
97
+ def target_postprocessing(texts, special_tokens):
98
+ if not isinstance(texts, list): texts = [texts]
99
+ new_texts = []
100
+ for text in texts:
101
+ text = skip_special_tokens(text, special_tokens)
102
+ for k, v in tokens_map.items(): text = text.replace(k, v)
103
+ new_texts.append(text)
104
+ return new_texts
105
+
106
+ def validate_recipe_ingredients(recipe_ingredients, expected_ingredients, tolerance=0):
107
+ recipe_count = len([ing for ing in recipe_ingredients if ing and ing.strip()])
108
+ expected_count = len(expected_ingredients)
109
+ return abs(recipe_count - expected_count) == tolerance
110
+
111
+ def generate_recipe_with_t5(ingredients_list, max_retries=5):
112
+ original_ingredients = ingredients_list.copy()
113
+ for attempt in range(max_retries):
114
+ try:
115
+ if attempt > 0:
116
+ current_ingredients = original_ingredients.copy()
117
+ random.shuffle(current_ingredients)
118
+ else:
119
+ current_ingredients = ingredients_list
120
+ ingredients_string = ", ".join(current_ingredients)
121
+ prefix = "items: "
122
+ generation_kwargs = {
123
+ "max_length": 512, "min_length": 64, "do_sample": True,
124
+ "top_k": 60, "top_p": 0.95
125
+ }
126
+ inputs = t5_tokenizer(
127
+ prefix + ingredients_string, max_length=256, padding="max_length",
128
+ truncation=True, return_tensors="jax"
129
+ )
130
+ output_ids = t5_model.generate(
131
+ input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, **generation_kwargs
132
+ )
133
+ generated = output_ids.sequences
134
+ generated_text = target_postprocessing(t5_tokenizer.batch_decode(generated, skip_special_tokens=False), special_tokens)[0]
135
+ recipe = {}
136
+ sections = generated_text.split("\n")
137
+ for section in sections:
138
+ section = section.strip()
139
+ if section.startswith("title:"):
140
+ recipe["title"] = section.replace("title:", "").strip().capitalize()
141
+ elif section.startswith("ingredients:"):
142
+ ingredients_text = section.replace("ingredients:", "").strip()
143
+ recipe["ingredients"] = [item.strip().capitalize() for item in ingredients_text.split("--") if item.strip()]
144
+ elif section.startswith("directions:"):
145
+ directions_text = section.replace("directions:", "").strip()
146
+ recipe["directions"] = [step.strip().capitalize() for step in directions_text.split("--") if step.strip()]
147
+ if "title" not in recipe:
148
+ recipe["title"] = f"Rezept mit {', '.join(current_ingredients[:3])}"
149
+ if "ingredients" not in recipe:
150
+ recipe["ingredients"] = current_ingredients
151
+ if "directions" not in recipe:
152
+ recipe["directions"] = ["Keine Anweisungen generiert"]
153
+ if validate_recipe_ingredients(recipe["ingredients"], original_ingredients):
154
+ return recipe
155
+ else:
156
+ if attempt == max_retries - 1: return recipe
157
+ except Exception as e:
158
+ if attempt == max_retries - 1:
159
+ return {
160
+ "title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}",
161
+ "ingredients": original_ingredients,
162
+ "directions": ["Fehler beim Generieren der Rezeptanweisungen"]
163
+ }
164
+ return {
165
+ "title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}",
166
+ "ingredients": original_ingredients,
167
+ "directions": ["Fehler beim Generieren der Rezeptanweisungen"]
168
+ }
169
+
170
+ # Kernlogik, die von der FastAPI-Route aufgerufen wird
171
+ def process_recipe_request_logic(required_ingredients, available_ingredients, max_ingredients, max_retries):
172
+ """
173
+ Kernlogik zur Verarbeitung einer Rezeptgenerierungsanfrage.
174
+ """
175
+ if not required_ingredients and not available_ingredients:
176
+ return {"error": "Keine Zutaten angegeben"}
177
+ try:
178
+ optimized_ingredients = find_best_ingredients(
179
+ required_ingredients, available_ingredients, max_ingredients
180
+ )
181
+ recipe = generate_recipe_with_t5(optimized_ingredients, max_retries)
182
+ result = {
183
+ 'title': recipe['title'],
184
+ 'ingredients': recipe['ingredients'],
185
+ 'directions': recipe['directions'],
186
+ 'used_ingredients': optimized_ingredients
187
+ }
188
+ return result
189
+ except Exception as e:
190
+ return {"error": f"Fehler bei der Rezeptgenerierung: {str(e)}"}
191
+
192
+ # --- FastAPI-Implementierung ---
193
+ app = FastAPI() # Deine FastAPI-Instanz
194
+
195
+ class RecipeRequest(BaseModel):
196
+ required_ingredients: list[str] = []
197
+ available_ingredients: list[str] = []
198
+ max_ingredients: int = 7
199
+ max_retries: int = 5
200
+ # Abwärtskompatibilität: Falls 'ingredients' als Top-Level-Feld gesendet wird
201
+ ingredients: list[str] = []
202
+
203
+ @app.post("/generate_recipe") # Der API-Endpunkt für Flutter
204
+ async def generate_recipe_api(request_data: RecipeRequest):
205
+ """
206
+ Standard-REST-API-Endpunkt für die Flutter-App.
207
+ Nimmt direkt JSON-Daten an und gibt direkt JSON zurück.
208
+ """
209
+ final_required_ingredients = request_data.required_ingredients
210
+ if not final_required_ingredients and request_data.ingredients:
211
+ final_required_ingredients = request_data.ingredients
212
+
213
+ result_dict = process_recipe_request_logic(
214
+ final_required_ingredients,
215
+ request_data.available_ingredients,
216
+ request_data.max_ingredients,
217
+ request_data.max_retries
218
+ )
219
+ return JSONResponse(content=result_dict)
220
+
221
+ print("INFO: FastAPI application script finished execution and defined 'app' variable.")
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ numpy
4
+ jax
5
+ jaxlib
6
+ flax
7
+ fastapi
8
+ uvicorn
9
+ pydantic