Hjgugugjhuhjggg commited on
Commit
c9f727d
·
verified ·
1 Parent(s): a542983

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -35
app.py CHANGED
@@ -27,7 +27,6 @@ async def cleanup_memory(device: str):
27
  gc.collect()
28
  if device == "cuda":
29
  torch.cuda.empty_cache()
30
- await asyncio.sleep(0.01)
31
 
32
  class GenerateRequest(BaseModel):
33
  input_text: str = ""
@@ -42,7 +41,7 @@ class GenerateRequest(BaseModel):
42
  do_sample: bool = True
43
  stream: bool = True # Streaming por defecto
44
  chunk_token_limit: int = 2 # Máximo 2 tokens por bloque
45
- token_timeout: float = 0.0 # Timeout en 0: sin timeout
46
  stop_sequences: list[str] = []
47
  include_duckasgo: bool = False
48
 
@@ -73,8 +72,7 @@ async def load_global_model():
73
  @app.get("/", response_class=HTMLResponse)
74
  async def index():
75
  """
76
- Endpoint raíz que devuelve una página HTML simple para permitir la navegación
77
- mientras se generan respuestas en paralelo.
78
  """
79
  html_content = """
80
  <html>
@@ -83,7 +81,6 @@ async def index():
83
  </head>
84
  <body>
85
  <h1>Bienvenido al Generador de Texto</h1>
86
- <p>El sistema utiliza streaming por defecto para generar respuestas rápidamente.</p>
87
  <p>Prueba los endpoints <code>/generate</code> o <code>/duckasgo</code>.</p>
88
  </body>
89
  </html>
@@ -100,7 +97,6 @@ async def health():
100
  async def perform_duckasgo_search(query: str, max_results: int = 3) -> str:
101
  """
102
  Realiza una búsqueda en DuckDuckGo y retorna un resumen de los resultados.
103
- Se ejecuta en cada llamada sin almacenar resultados en caché.
104
  """
105
  try:
106
  with DDGS() as ddgs:
@@ -121,8 +117,6 @@ async def perform_duckasgo_search(query: str, max_results: int = 3) -> str:
121
  def generate_next_token(input_ids, past_key_values, gen_config, device):
122
  """
123
  Función síncrona que genera el siguiente token utilizando el modelo.
124
- Retorna además el log-probability del token seleccionado.
125
- Se invoca en paralelo mediante asyncio.to_thread.
126
  """
127
  with torch.no_grad():
128
  outputs = global_model(
@@ -152,15 +146,14 @@ def generate_next_token(input_ids, past_key_values, gen_config, device):
152
  async def stream_text(request: GenerateRequest, device: str):
153
  """
154
  Genera texto de forma streaming, enviando cada bloque con hasta 'chunk_token_limit' tokens.
155
- Se continúa en un loop hasta detectar el token de finalización (eos) o alcanzar un límite total.
156
  """
157
  global global_model, global_tokenizer, global_tokens
158
- # Prepara la entrada y configura la generación
159
  encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
160
  input_ids = encoded_input.input_ids
161
  gen_config = GenerationConfig(
162
  temperature=request.temperature,
163
- max_new_tokens=request.max_new_tokens, # Límite total opcional
164
  top_p=request.top_p,
165
  top_k=request.top_k,
166
  repetition_penalty=request.repetition_penalty,
@@ -176,44 +169,28 @@ async def stream_text(request: GenerateRequest, device: str):
176
  max_total_tokens = request.max_new_tokens if request.max_new_tokens > 0 else 1000
177
 
178
  while True:
179
- if request.token_timeout > 0:
180
- try:
181
- next_token, past_key_values, token_logprob = await asyncio.wait_for(
182
- asyncio.to_thread(generate_next_token, input_ids, past_key_values, gen_config, device),
183
- timeout=request.token_timeout
184
- )
185
- except asyncio.TimeoutError:
186
- yield "data: " + json.dumps({"generated_text": "[Token generation timeout, continuing...]"} ) + "\n\n"
187
- continue
188
- else:
189
- next_token, past_key_values, token_logprob = await asyncio.to_thread(
190
- generate_next_token, input_ids, past_key_values, gen_config, device
191
- )
192
-
193
  token_id = next_token.item()
194
  token_text = global_tokenizer.decode([token_id], skip_special_tokens=True)
195
  current_chunk += token_text
196
  chunk_token_count += 1
197
  token_count += 1
198
 
199
- # Envía el bloque actual cuando se alcanza el límite de tokens por bloque
200
  if chunk_token_count >= request.chunk_token_limit:
201
  yield "data: " + json.dumps({"generated_text": current_chunk}) + "\n\n"
202
  current_chunk = ""
203
  chunk_token_count = 0
204
 
205
- await asyncio.sleep(0)
206
  input_ids = next_token
207
 
208
- # Condición de terminación: se detecta el token eos o se alcanza el límite total
209
  if token_id == global_tokens["eos_token_id"] or token_count >= max_total_tokens:
210
  break
211
 
212
- # Envía cualquier bloque parcial pendiente
213
  if current_chunk:
214
  yield "data: " + json.dumps({"generated_text": current_chunk}) + "\n\n"
215
 
216
- # Incluye resultados de búsqueda si se solicita
217
  if request.include_duckasgo:
218
  search_summary = await perform_duckasgo_search(request.input_text)
219
  yield "data: " + json.dumps({"generated_text": search_summary}) + "\n\n"
@@ -222,7 +199,6 @@ async def stream_text(request: GenerateRequest, device: str):
222
  def synchronous_generation(encoded_input, gen_config, device):
223
  """
224
  Función síncrona para generación completa en modo no streaming.
225
- Se ejecuta en paralelo mediante asyncio.to_thread.
226
  """
227
  with torch.no_grad():
228
  output = global_model.generate(
@@ -238,7 +214,7 @@ def synchronous_generation(encoded_input, gen_config, device):
238
  async def generate_text(request: GenerateRequest, background_tasks: BackgroundTasks):
239
  """
240
  Endpoint para la generación de texto.
241
- En modo streaming se envían bloques de hasta 'chunk_token_limit' tokens hasta finalizar.
242
  En modo no streaming se devuelve la respuesta completa dividida en bloques.
243
  """
244
  global global_model, global_tokenizer, global_tokens
@@ -273,7 +249,6 @@ async def generate_text(request: GenerateRequest, background_tasks: BackgroundTa
273
  full_generated_text = full_generated_text.split(stop_seq)[0]
274
  break
275
 
276
- # Dividir la respuesta en bloques de tokens según chunk_token_limit
277
  final_token_ids = global_tokenizer.encode(full_generated_text, add_special_tokens=False)
278
  chunks = []
279
  for i in range(0, len(final_token_ids), request.chunk_token_limit):
@@ -295,7 +270,6 @@ async def generate_text(request: GenerateRequest, background_tasks: BackgroundTa
295
  async def duckasgo_search(request: DuckasgoRequest):
296
  """
297
  Endpoint para búsquedas en DuckDuckGo.
298
- Se ejecuta en cada petición sin almacenar resultados en caché.
299
  """
300
  try:
301
  with DDGS() as ddgs:
@@ -310,7 +284,6 @@ def run_server():
310
  uvicorn.run(app, host="0.0.0.0", port=7860)
311
 
312
  if __name__ == "__main__":
313
- # Inicia el servidor en un hilo separado para permitir tareas concurrentes.
314
  server_thread = threading.Thread(target=run_server, daemon=True)
315
  server_thread.start()
316
  while True:
 
27
  gc.collect()
28
  if device == "cuda":
29
  torch.cuda.empty_cache()
 
30
 
31
  class GenerateRequest(BaseModel):
32
  input_text: str = ""
 
41
  do_sample: bool = True
42
  stream: bool = True # Streaming por defecto
43
  chunk_token_limit: int = 2 # Máximo 2 tokens por bloque
44
+ # Se eliminan token_timeout y demás esperas
45
  stop_sequences: list[str] = []
46
  include_duckasgo: bool = False
47
 
 
72
  @app.get("/", response_class=HTMLResponse)
73
  async def index():
74
  """
75
+ Endpoint raíz que devuelve una página HTML simple.
 
76
  """
77
  html_content = """
78
  <html>
 
81
  </head>
82
  <body>
83
  <h1>Bienvenido al Generador de Texto</h1>
 
84
  <p>Prueba los endpoints <code>/generate</code> o <code>/duckasgo</code>.</p>
85
  </body>
86
  </html>
 
97
  async def perform_duckasgo_search(query: str, max_results: int = 3) -> str:
98
  """
99
  Realiza una búsqueda en DuckDuckGo y retorna un resumen de los resultados.
 
100
  """
101
  try:
102
  with DDGS() as ddgs:
 
117
  def generate_next_token(input_ids, past_key_values, gen_config, device):
118
  """
119
  Función síncrona que genera el siguiente token utilizando el modelo.
 
 
120
  """
121
  with torch.no_grad():
122
  outputs = global_model(
 
146
  async def stream_text(request: GenerateRequest, device: str):
147
  """
148
  Genera texto de forma streaming, enviando cada bloque con hasta 'chunk_token_limit' tokens.
149
+ Se continúa generando hasta detectar el token de finalización (eos) o alcanzar un límite total.
150
  """
151
  global global_model, global_tokenizer, global_tokens
 
152
  encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
153
  input_ids = encoded_input.input_ids
154
  gen_config = GenerationConfig(
155
  temperature=request.temperature,
156
+ max_new_tokens=request.max_new_tokens,
157
  top_p=request.top_p,
158
  top_k=request.top_k,
159
  repetition_penalty=request.repetition_penalty,
 
169
  max_total_tokens = request.max_new_tokens if request.max_new_tokens > 0 else 1000
170
 
171
  while True:
172
+ next_token, past_key_values, token_logprob = await asyncio.to_thread(
173
+ generate_next_token, input_ids, past_key_values, gen_config, device
174
+ )
 
 
 
 
 
 
 
 
 
 
 
175
  token_id = next_token.item()
176
  token_text = global_tokenizer.decode([token_id], skip_special_tokens=True)
177
  current_chunk += token_text
178
  chunk_token_count += 1
179
  token_count += 1
180
 
 
181
  if chunk_token_count >= request.chunk_token_limit:
182
  yield "data: " + json.dumps({"generated_text": current_chunk}) + "\n\n"
183
  current_chunk = ""
184
  chunk_token_count = 0
185
 
 
186
  input_ids = next_token
187
 
 
188
  if token_id == global_tokens["eos_token_id"] or token_count >= max_total_tokens:
189
  break
190
 
 
191
  if current_chunk:
192
  yield "data: " + json.dumps({"generated_text": current_chunk}) + "\n\n"
193
 
 
194
  if request.include_duckasgo:
195
  search_summary = await perform_duckasgo_search(request.input_text)
196
  yield "data: " + json.dumps({"generated_text": search_summary}) + "\n\n"
 
199
  def synchronous_generation(encoded_input, gen_config, device):
200
  """
201
  Función síncrona para generación completa en modo no streaming.
 
202
  """
203
  with torch.no_grad():
204
  output = global_model.generate(
 
214
  async def generate_text(request: GenerateRequest, background_tasks: BackgroundTasks):
215
  """
216
  Endpoint para la generación de texto.
217
+ En modo streaming se envían bloques de hasta 'chunk_token_limit' tokens.
218
  En modo no streaming se devuelve la respuesta completa dividida en bloques.
219
  """
220
  global global_model, global_tokenizer, global_tokens
 
249
  full_generated_text = full_generated_text.split(stop_seq)[0]
250
  break
251
 
 
252
  final_token_ids = global_tokenizer.encode(full_generated_text, add_special_tokens=False)
253
  chunks = []
254
  for i in range(0, len(final_token_ids), request.chunk_token_limit):
 
270
  async def duckasgo_search(request: DuckasgoRequest):
271
  """
272
  Endpoint para búsquedas en DuckDuckGo.
 
273
  """
274
  try:
275
  with DDGS() as ddgs:
 
284
  uvicorn.run(app, host="0.0.0.0", port=7860)
285
 
286
  if __name__ == "__main__":
 
287
  server_thread = threading.Thread(target=run_server, daemon=True)
288
  server_thread.start()
289
  while True: