Update app.py
Browse files
app.py
CHANGED
@@ -27,7 +27,6 @@ async def cleanup_memory(device: str):
|
|
27 |
gc.collect()
|
28 |
if device == "cuda":
|
29 |
torch.cuda.empty_cache()
|
30 |
-
await asyncio.sleep(0.01)
|
31 |
|
32 |
class GenerateRequest(BaseModel):
|
33 |
input_text: str = ""
|
@@ -42,7 +41,7 @@ class GenerateRequest(BaseModel):
|
|
42 |
do_sample: bool = True
|
43 |
stream: bool = True # Streaming por defecto
|
44 |
chunk_token_limit: int = 2 # Máximo 2 tokens por bloque
|
45 |
-
|
46 |
stop_sequences: list[str] = []
|
47 |
include_duckasgo: bool = False
|
48 |
|
@@ -73,8 +72,7 @@ async def load_global_model():
|
|
73 |
@app.get("/", response_class=HTMLResponse)
|
74 |
async def index():
|
75 |
"""
|
76 |
-
Endpoint raíz que devuelve una página HTML simple
|
77 |
-
mientras se generan respuestas en paralelo.
|
78 |
"""
|
79 |
html_content = """
|
80 |
<html>
|
@@ -83,7 +81,6 @@ async def index():
|
|
83 |
</head>
|
84 |
<body>
|
85 |
<h1>Bienvenido al Generador de Texto</h1>
|
86 |
-
<p>El sistema utiliza streaming por defecto para generar respuestas rápidamente.</p>
|
87 |
<p>Prueba los endpoints <code>/generate</code> o <code>/duckasgo</code>.</p>
|
88 |
</body>
|
89 |
</html>
|
@@ -100,7 +97,6 @@ async def health():
|
|
100 |
async def perform_duckasgo_search(query: str, max_results: int = 3) -> str:
|
101 |
"""
|
102 |
Realiza una búsqueda en DuckDuckGo y retorna un resumen de los resultados.
|
103 |
-
Se ejecuta en cada llamada sin almacenar resultados en caché.
|
104 |
"""
|
105 |
try:
|
106 |
with DDGS() as ddgs:
|
@@ -121,8 +117,6 @@ async def perform_duckasgo_search(query: str, max_results: int = 3) -> str:
|
|
121 |
def generate_next_token(input_ids, past_key_values, gen_config, device):
|
122 |
"""
|
123 |
Función síncrona que genera el siguiente token utilizando el modelo.
|
124 |
-
Retorna además el log-probability del token seleccionado.
|
125 |
-
Se invoca en paralelo mediante asyncio.to_thread.
|
126 |
"""
|
127 |
with torch.no_grad():
|
128 |
outputs = global_model(
|
@@ -152,15 +146,14 @@ def generate_next_token(input_ids, past_key_values, gen_config, device):
|
|
152 |
async def stream_text(request: GenerateRequest, device: str):
|
153 |
"""
|
154 |
Genera texto de forma streaming, enviando cada bloque con hasta 'chunk_token_limit' tokens.
|
155 |
-
Se continúa
|
156 |
"""
|
157 |
global global_model, global_tokenizer, global_tokens
|
158 |
-
# Prepara la entrada y configura la generación
|
159 |
encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
|
160 |
input_ids = encoded_input.input_ids
|
161 |
gen_config = GenerationConfig(
|
162 |
temperature=request.temperature,
|
163 |
-
max_new_tokens=request.max_new_tokens,
|
164 |
top_p=request.top_p,
|
165 |
top_k=request.top_k,
|
166 |
repetition_penalty=request.repetition_penalty,
|
@@ -176,44 +169,28 @@ async def stream_text(request: GenerateRequest, device: str):
|
|
176 |
max_total_tokens = request.max_new_tokens if request.max_new_tokens > 0 else 1000
|
177 |
|
178 |
while True:
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
asyncio.to_thread(generate_next_token, input_ids, past_key_values, gen_config, device),
|
183 |
-
timeout=request.token_timeout
|
184 |
-
)
|
185 |
-
except asyncio.TimeoutError:
|
186 |
-
yield "data: " + json.dumps({"generated_text": "[Token generation timeout, continuing...]"} ) + "\n\n"
|
187 |
-
continue
|
188 |
-
else:
|
189 |
-
next_token, past_key_values, token_logprob = await asyncio.to_thread(
|
190 |
-
generate_next_token, input_ids, past_key_values, gen_config, device
|
191 |
-
)
|
192 |
-
|
193 |
token_id = next_token.item()
|
194 |
token_text = global_tokenizer.decode([token_id], skip_special_tokens=True)
|
195 |
current_chunk += token_text
|
196 |
chunk_token_count += 1
|
197 |
token_count += 1
|
198 |
|
199 |
-
# Envía el bloque actual cuando se alcanza el límite de tokens por bloque
|
200 |
if chunk_token_count >= request.chunk_token_limit:
|
201 |
yield "data: " + json.dumps({"generated_text": current_chunk}) + "\n\n"
|
202 |
current_chunk = ""
|
203 |
chunk_token_count = 0
|
204 |
|
205 |
-
await asyncio.sleep(0)
|
206 |
input_ids = next_token
|
207 |
|
208 |
-
# Condición de terminación: se detecta el token eos o se alcanza el límite total
|
209 |
if token_id == global_tokens["eos_token_id"] or token_count >= max_total_tokens:
|
210 |
break
|
211 |
|
212 |
-
# Envía cualquier bloque parcial pendiente
|
213 |
if current_chunk:
|
214 |
yield "data: " + json.dumps({"generated_text": current_chunk}) + "\n\n"
|
215 |
|
216 |
-
# Incluye resultados de búsqueda si se solicita
|
217 |
if request.include_duckasgo:
|
218 |
search_summary = await perform_duckasgo_search(request.input_text)
|
219 |
yield "data: " + json.dumps({"generated_text": search_summary}) + "\n\n"
|
@@ -222,7 +199,6 @@ async def stream_text(request: GenerateRequest, device: str):
|
|
222 |
def synchronous_generation(encoded_input, gen_config, device):
|
223 |
"""
|
224 |
Función síncrona para generación completa en modo no streaming.
|
225 |
-
Se ejecuta en paralelo mediante asyncio.to_thread.
|
226 |
"""
|
227 |
with torch.no_grad():
|
228 |
output = global_model.generate(
|
@@ -238,7 +214,7 @@ def synchronous_generation(encoded_input, gen_config, device):
|
|
238 |
async def generate_text(request: GenerateRequest, background_tasks: BackgroundTasks):
|
239 |
"""
|
240 |
Endpoint para la generación de texto.
|
241 |
-
En modo streaming se envían bloques de hasta 'chunk_token_limit' tokens
|
242 |
En modo no streaming se devuelve la respuesta completa dividida en bloques.
|
243 |
"""
|
244 |
global global_model, global_tokenizer, global_tokens
|
@@ -273,7 +249,6 @@ async def generate_text(request: GenerateRequest, background_tasks: BackgroundTa
|
|
273 |
full_generated_text = full_generated_text.split(stop_seq)[0]
|
274 |
break
|
275 |
|
276 |
-
# Dividir la respuesta en bloques de tokens según chunk_token_limit
|
277 |
final_token_ids = global_tokenizer.encode(full_generated_text, add_special_tokens=False)
|
278 |
chunks = []
|
279 |
for i in range(0, len(final_token_ids), request.chunk_token_limit):
|
@@ -295,7 +270,6 @@ async def generate_text(request: GenerateRequest, background_tasks: BackgroundTa
|
|
295 |
async def duckasgo_search(request: DuckasgoRequest):
|
296 |
"""
|
297 |
Endpoint para búsquedas en DuckDuckGo.
|
298 |
-
Se ejecuta en cada petición sin almacenar resultados en caché.
|
299 |
"""
|
300 |
try:
|
301 |
with DDGS() as ddgs:
|
@@ -310,7 +284,6 @@ def run_server():
|
|
310 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
311 |
|
312 |
if __name__ == "__main__":
|
313 |
-
# Inicia el servidor en un hilo separado para permitir tareas concurrentes.
|
314 |
server_thread = threading.Thread(target=run_server, daemon=True)
|
315 |
server_thread.start()
|
316 |
while True:
|
|
|
27 |
gc.collect()
|
28 |
if device == "cuda":
|
29 |
torch.cuda.empty_cache()
|
|
|
30 |
|
31 |
class GenerateRequest(BaseModel):
|
32 |
input_text: str = ""
|
|
|
41 |
do_sample: bool = True
|
42 |
stream: bool = True # Streaming por defecto
|
43 |
chunk_token_limit: int = 2 # Máximo 2 tokens por bloque
|
44 |
+
# Se eliminan token_timeout y demás esperas
|
45 |
stop_sequences: list[str] = []
|
46 |
include_duckasgo: bool = False
|
47 |
|
|
|
72 |
@app.get("/", response_class=HTMLResponse)
|
73 |
async def index():
|
74 |
"""
|
75 |
+
Endpoint raíz que devuelve una página HTML simple.
|
|
|
76 |
"""
|
77 |
html_content = """
|
78 |
<html>
|
|
|
81 |
</head>
|
82 |
<body>
|
83 |
<h1>Bienvenido al Generador de Texto</h1>
|
|
|
84 |
<p>Prueba los endpoints <code>/generate</code> o <code>/duckasgo</code>.</p>
|
85 |
</body>
|
86 |
</html>
|
|
|
97 |
async def perform_duckasgo_search(query: str, max_results: int = 3) -> str:
|
98 |
"""
|
99 |
Realiza una búsqueda en DuckDuckGo y retorna un resumen de los resultados.
|
|
|
100 |
"""
|
101 |
try:
|
102 |
with DDGS() as ddgs:
|
|
|
117 |
def generate_next_token(input_ids, past_key_values, gen_config, device):
|
118 |
"""
|
119 |
Función síncrona que genera el siguiente token utilizando el modelo.
|
|
|
|
|
120 |
"""
|
121 |
with torch.no_grad():
|
122 |
outputs = global_model(
|
|
|
146 |
async def stream_text(request: GenerateRequest, device: str):
|
147 |
"""
|
148 |
Genera texto de forma streaming, enviando cada bloque con hasta 'chunk_token_limit' tokens.
|
149 |
+
Se continúa generando hasta detectar el token de finalización (eos) o alcanzar un límite total.
|
150 |
"""
|
151 |
global global_model, global_tokenizer, global_tokens
|
|
|
152 |
encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
|
153 |
input_ids = encoded_input.input_ids
|
154 |
gen_config = GenerationConfig(
|
155 |
temperature=request.temperature,
|
156 |
+
max_new_tokens=request.max_new_tokens,
|
157 |
top_p=request.top_p,
|
158 |
top_k=request.top_k,
|
159 |
repetition_penalty=request.repetition_penalty,
|
|
|
169 |
max_total_tokens = request.max_new_tokens if request.max_new_tokens > 0 else 1000
|
170 |
|
171 |
while True:
|
172 |
+
next_token, past_key_values, token_logprob = await asyncio.to_thread(
|
173 |
+
generate_next_token, input_ids, past_key_values, gen_config, device
|
174 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
token_id = next_token.item()
|
176 |
token_text = global_tokenizer.decode([token_id], skip_special_tokens=True)
|
177 |
current_chunk += token_text
|
178 |
chunk_token_count += 1
|
179 |
token_count += 1
|
180 |
|
|
|
181 |
if chunk_token_count >= request.chunk_token_limit:
|
182 |
yield "data: " + json.dumps({"generated_text": current_chunk}) + "\n\n"
|
183 |
current_chunk = ""
|
184 |
chunk_token_count = 0
|
185 |
|
|
|
186 |
input_ids = next_token
|
187 |
|
|
|
188 |
if token_id == global_tokens["eos_token_id"] or token_count >= max_total_tokens:
|
189 |
break
|
190 |
|
|
|
191 |
if current_chunk:
|
192 |
yield "data: " + json.dumps({"generated_text": current_chunk}) + "\n\n"
|
193 |
|
|
|
194 |
if request.include_duckasgo:
|
195 |
search_summary = await perform_duckasgo_search(request.input_text)
|
196 |
yield "data: " + json.dumps({"generated_text": search_summary}) + "\n\n"
|
|
|
199 |
def synchronous_generation(encoded_input, gen_config, device):
|
200 |
"""
|
201 |
Función síncrona para generación completa en modo no streaming.
|
|
|
202 |
"""
|
203 |
with torch.no_grad():
|
204 |
output = global_model.generate(
|
|
|
214 |
async def generate_text(request: GenerateRequest, background_tasks: BackgroundTasks):
|
215 |
"""
|
216 |
Endpoint para la generación de texto.
|
217 |
+
En modo streaming se envían bloques de hasta 'chunk_token_limit' tokens.
|
218 |
En modo no streaming se devuelve la respuesta completa dividida en bloques.
|
219 |
"""
|
220 |
global global_model, global_tokenizer, global_tokens
|
|
|
249 |
full_generated_text = full_generated_text.split(stop_seq)[0]
|
250 |
break
|
251 |
|
|
|
252 |
final_token_ids = global_tokenizer.encode(full_generated_text, add_special_tokens=False)
|
253 |
chunks = []
|
254 |
for i in range(0, len(final_token_ids), request.chunk_token_limit):
|
|
|
270 |
async def duckasgo_search(request: DuckasgoRequest):
|
271 |
"""
|
272 |
Endpoint para búsquedas en DuckDuckGo.
|
|
|
273 |
"""
|
274 |
try:
|
275 |
with DDGS() as ddgs:
|
|
|
284 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
285 |
|
286 |
if __name__ == "__main__":
|
|
|
287 |
server_thread = threading.Thread(target=run_server, daemon=True)
|
288 |
server_thread.start()
|
289 |
while True:
|