Hjgugugjhuhjggg commited on
Commit
fe94180
verified
1 Parent(s): 3016722

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -36
app.py CHANGED
@@ -4,28 +4,34 @@ import torch
4
  import asyncio
5
  import threading
6
  import time
7
- from fastapi import FastAPI, HTTPException
8
- from fastapi.responses import StreamingResponse, JSONResponse
9
  from pydantic import BaseModel
10
  from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig
11
  import uvicorn
12
  from duckduckgo_search import DDGS
 
13
 
 
14
  MODEL_NAME = "lilmeaty/my_xdd"
15
  global_model = None
16
  global_tokenizer = None
17
  global_tokens = {}
18
 
19
- # Diccionarios para cachear respuestas
20
  global_response_cache = {}
21
  duckasgo_response_cache = {}
22
 
 
 
 
23
  async def cleanup_memory(device: str):
24
  gc.collect()
25
  if device == "cuda":
26
  torch.cuda.empty_cache()
27
  await asyncio.sleep(0.01)
28
 
 
29
  class GenerateRequest(BaseModel):
30
  input_text: str = ""
31
  max_new_tokens: int = 200
@@ -39,6 +45,7 @@ class GenerateRequest(BaseModel):
39
  stop_sequences: list[str] = []
40
  include_duckasgo: bool = False
41
 
 
42
  class DuckasgoRequest(BaseModel):
43
  query: str
44
 
@@ -46,6 +53,9 @@ app = FastAPI()
46
 
47
  @app.on_event("startup")
48
  async def load_global_model():
 
 
 
49
  global global_model, global_tokenizer, global_tokens
50
  config = AutoConfig.from_pretrained(MODEL_NAME)
51
  global_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, config=config)
@@ -60,8 +70,28 @@ async def load_global_model():
60
  global_tokens["pad_token_id"] = global_tokenizer.pad_token_id
61
  print(f"Modelo {MODEL_NAME} cargado correctamente en {device}.")
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  async def perform_duckasgo_search(query: str, max_results: int = 3) -> str:
64
- # Primero se revisa la cach茅
 
 
65
  if query in duckasgo_response_cache:
66
  return duckasgo_response_cache[query]
67
  try:
@@ -78,11 +108,38 @@ async def perform_duckasgo_search(query: str, max_results: int = 3) -> str:
78
  url = res.get("href", "Sin URL")
79
  snippet = res.get("body", "")
80
  result_text += f"{idx}. {title}\n URL: {url}\n {snippet}\n"
81
- # Guardar en cach茅
82
  duckasgo_response_cache[query] = result_text
83
  return result_text
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  async def stream_text(request: GenerateRequest, device: str):
 
 
 
86
  global global_model, global_tokenizer, global_tokens
87
  encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
88
  input_ids = encoded_input.input_ids
@@ -99,24 +156,10 @@ async def stream_text(request: GenerateRequest, device: str):
99
  )
100
  past_key_values = None
101
  for _ in range(request.max_new_tokens):
102
- with torch.no_grad():
103
- outputs = global_model(
104
- input_ids,
105
- past_key_values=past_key_values,
106
- use_cache=True,
107
- return_dict=True
108
- )
109
- logits = outputs.logits[:, -1, :]
110
- past_key_values = outputs.past_key_values
111
- if gen_config.do_sample:
112
- logits = logits / gen_config.temperature
113
- if gen_config.top_k and gen_config.top_k > 0:
114
- topk_values, _ = torch.topk(logits, k=gen_config.top_k)
115
- logits[logits < topk_values[:, [-1]]] = -float('Inf')
116
- probs = torch.nn.functional.softmax(logits, dim=-1)
117
- next_token = torch.multinomial(probs, num_samples=1)
118
- else:
119
- next_token = torch.argmax(logits, dim=-1, keepdim=True)
120
  token_id = next_token.item()
121
  token_text = global_tokenizer.decode([token_id], skip_special_tokens=True)
122
  accumulated_text += token_text
@@ -143,8 +186,27 @@ async def stream_text(request: GenerateRequest, device: str):
143
  yield "\n" + search_summary
144
  await cleanup_memory(device)
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  @app.post("/generate")
147
- async def generate_text(request: GenerateRequest):
 
 
 
 
148
  global global_model, global_tokenizer, global_tokens, global_response_cache
149
  if global_model is None or global_tokenizer is None:
150
  raise HTTPException(status_code=500, detail="El modelo no se ha cargado correctamente.")
@@ -157,9 +219,7 @@ async def generate_text(request: GenerateRequest):
157
  repetition_penalty=request.repetition_penalty,
158
  do_sample=request.do_sample,
159
  )
160
- # Solo se cachea si no es un request de stream
161
  if not request.stream:
162
- # Se construye una clave con los par谩metros relevantes
163
  cache_key = (
164
  request.input_text,
165
  request.max_new_tokens,
@@ -175,18 +235,15 @@ async def generate_text(request: GenerateRequest):
175
  return {"generated_text": global_response_cache[cache_key]}
176
  try:
177
  if request.stream:
 
178
  generator = stream_text(request, device)
179
  return StreamingResponse(generator, media_type="text/plain")
180
  else:
181
  encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
182
- with torch.no_grad():
183
- output = global_model.generate(
184
- **encoded_input,
185
- generation_config=gen_config,
186
- return_dict_in_generate=True,
187
- output_scores=True,
188
- return_legacy_cache=True # Se agrega el par谩metro solicitado
189
- )
190
  input_length = encoded_input["input_ids"].shape[-1]
191
  generated_text = global_tokenizer.decode(
192
  output.sequences[0][input_length:],
@@ -201,14 +258,17 @@ async def generate_text(request: GenerateRequest):
201
  search_summary = await perform_duckasgo_search(request.input_text)
202
  generated_text += "\n" + search_summary
203
  await cleanup_memory(device)
204
- # Se almacena en cach茅 la respuesta generada
205
  global_response_cache[cache_key] = generated_text
 
206
  return {"generated_text": generated_text}
207
  except Exception as e:
208
  raise HTTPException(status_code=500, detail=f"Error durante la generaci贸n: {e}")
209
 
210
  @app.post("/duckasgo")
211
  async def duckasgo_search(request: DuckasgoRequest):
 
 
 
212
  global duckasgo_response_cache
213
  if request.query in duckasgo_response_cache:
214
  return JSONResponse(content={"query": request.query, "results": duckasgo_response_cache[request.query]})
@@ -217,13 +277,15 @@ async def duckasgo_search(request: DuckasgoRequest):
217
  results = ddgs.text(request.query, max_results=10)
218
  if not results:
219
  results = []
220
- # Se almacena la respuesta en cach茅
221
  duckasgo_response_cache[request.query] = results
222
  return JSONResponse(content={"query": request.query, "results": results})
223
  except Exception as e:
224
  raise HTTPException(status_code=500, detail=f"Error en la b煤squeda: {e}")
225
 
226
  def run_server():
 
 
 
227
  uvicorn.run(app, host="0.0.0.0", port=7860)
228
 
229
  if __name__ == "__main__":
 
4
  import asyncio
5
  import threading
6
  import time
7
+ from fastapi import FastAPI, HTTPException, BackgroundTasks
8
+ from fastapi.responses import StreamingResponse, JSONResponse, HTMLResponse
9
  from pydantic import BaseModel
10
  from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig
11
  import uvicorn
12
  from duckduckgo_search import DDGS
13
+ from concurrent.futures import ThreadPoolExecutor
14
 
15
+ # Nombre del modelo a cargar y variables globales
16
  MODEL_NAME = "lilmeaty/my_xdd"
17
  global_model = None
18
  global_tokenizer = None
19
  global_tokens = {}
20
 
21
+ # Cach茅 para respuestas y b煤squedas
22
  global_response_cache = {}
23
  duckasgo_response_cache = {}
24
 
25
+ # Executor para ejecutar tareas en paralelo
26
+ executor = ThreadPoolExecutor(max_workers=4)
27
+
28
  async def cleanup_memory(device: str):
29
  gc.collect()
30
  if device == "cuda":
31
  torch.cuda.empty_cache()
32
  await asyncio.sleep(0.01)
33
 
34
+ # Modelo de datos para generaci贸n
35
  class GenerateRequest(BaseModel):
36
  input_text: str = ""
37
  max_new_tokens: int = 200
 
45
  stop_sequences: list[str] = []
46
  include_duckasgo: bool = False
47
 
48
+ # Modelo de datos para b煤squeda DuckDuckGo
49
  class DuckasgoRequest(BaseModel):
50
  query: str
51
 
 
53
 
54
  @app.on_event("startup")
55
  async def load_global_model():
56
+ """
57
+ Carga el modelo y tokenizador en el inicio de la aplicaci贸n.
58
+ """
59
  global global_model, global_tokenizer, global_tokens
60
  config = AutoConfig.from_pretrained(MODEL_NAME)
61
  global_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, config=config)
 
70
  global_tokens["pad_token_id"] = global_tokenizer.pad_token_id
71
  print(f"Modelo {MODEL_NAME} cargado correctamente en {device}.")
72
 
73
+ @app.get("/", response_class=HTMLResponse)
74
+ async def index():
75
+ """
76
+ Endpoint ra铆z que permite visitar el sitio mientras se ejecutan otras operaciones.
77
+ """
78
+ html_content = """
79
+ <html>
80
+ <head>
81
+ <title>Mi Sitio de Generaci贸n de Texto</title>
82
+ </head>
83
+ <body>
84
+ <h1>Bienvenido a Mi Sitio</h1>
85
+ <p>Puedes enviar peticiones a <code>/generate</code> o <code>/duckasgo</code> sin afectar la navegaci贸n.</p>
86
+ </body>
87
+ </html>
88
+ """
89
+ return HTMLResponse(content=html_content, status_code=200)
90
+
91
  async def perform_duckasgo_search(query: str, max_results: int = 3) -> str:
92
+ """
93
+ Realiza b煤squeda en DuckDuckGo y utiliza cach茅 para acelerar consultas repetidas.
94
+ """
95
  if query in duckasgo_response_cache:
96
  return duckasgo_response_cache[query]
97
  try:
 
108
  url = res.get("href", "Sin URL")
109
  snippet = res.get("body", "")
110
  result_text += f"{idx}. {title}\n URL: {url}\n {snippet}\n"
 
111
  duckasgo_response_cache[query] = result_text
112
  return result_text
113
 
114
+ def generate_next_token(input_ids, past_key_values, gen_config, device):
115
+ """
116
+ Funci贸n s铆ncrona que genera el siguiente token usando el modelo.
117
+ Se invoca en paralelo mediante asyncio.to_thread.
118
+ """
119
+ with torch.no_grad():
120
+ outputs = global_model(
121
+ input_ids,
122
+ past_key_values=past_key_values,
123
+ use_cache=True,
124
+ return_dict=True
125
+ )
126
+ logits = outputs.logits[:, -1, :]
127
+ past_key_values = outputs.past_key_values
128
+ if gen_config.do_sample:
129
+ logits = logits / gen_config.temperature
130
+ if gen_config.top_k and gen_config.top_k > 0:
131
+ topk_values, _ = torch.topk(logits, k=gen_config.top_k)
132
+ logits[logits < topk_values[:, [-1]]] = -float('Inf')
133
+ probs = torch.nn.functional.softmax(logits, dim=-1)
134
+ next_token = torch.multinomial(probs, num_samples=1)
135
+ else:
136
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
137
+ return next_token, past_key_values
138
+
139
  async def stream_text(request: GenerateRequest, device: str):
140
+ """
141
+ Genera texto de forma streaming, utilizando generaci贸n paralela para cada token.
142
+ """
143
  global global_model, global_tokenizer, global_tokens
144
  encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
145
  input_ids = encoded_input.input_ids
 
156
  )
157
  past_key_values = None
158
  for _ in range(request.max_new_tokens):
159
+ # Ejecuta la generaci贸n del siguiente token en paralelo
160
+ next_token, past_key_values = await asyncio.to_thread(
161
+ generate_next_token, input_ids, past_key_values, gen_config, device
162
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  token_id = next_token.item()
164
  token_text = global_tokenizer.decode([token_id], skip_special_tokens=True)
165
  accumulated_text += token_text
 
186
  yield "\n" + search_summary
187
  await cleanup_memory(device)
188
 
189
+ def synchronous_generation(encoded_input, gen_config, device, request: GenerateRequest):
190
+ """
191
+ Funci贸n s铆ncrona para la generaci贸n completa en modo no streaming.
192
+ Se ejecuta en paralelo mediante asyncio.to_thread.
193
+ """
194
+ with torch.no_grad():
195
+ output = global_model.generate(
196
+ **encoded_input,
197
+ generation_config=gen_config,
198
+ return_dict_in_generate=True,
199
+ output_scores=True,
200
+ return_legacy_cache=True # Par谩metro agregado para optimizar el uso de cach茅 interno
201
+ )
202
+ return output
203
+
204
  @app.post("/generate")
205
+ async def generate_text(request: GenerateRequest, background_tasks: BackgroundTasks):
206
+ """
207
+ Endpoint para la generaci贸n de texto. Se utiliza paralelismo para la generaci贸n (streaming o completa)
208
+ y se cachean respuestas para acelerar peticiones repetidas.
209
+ """
210
  global global_model, global_tokenizer, global_tokens, global_response_cache
211
  if global_model is None or global_tokenizer is None:
212
  raise HTTPException(status_code=500, detail="El modelo no se ha cargado correctamente.")
 
219
  repetition_penalty=request.repetition_penalty,
220
  do_sample=request.do_sample,
221
  )
 
222
  if not request.stream:
 
223
  cache_key = (
224
  request.input_text,
225
  request.max_new_tokens,
 
235
  return {"generated_text": global_response_cache[cache_key]}
236
  try:
237
  if request.stream:
238
+ # Generaci贸n en modo streaming en paralelo
239
  generator = stream_text(request, device)
240
  return StreamingResponse(generator, media_type="text/plain")
241
  else:
242
  encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
243
+ # Se ejecuta la generaci贸n completa en un hilo paralelo
244
+ output = await asyncio.to_thread(
245
+ synchronous_generation, encoded_input, gen_config, device, request
246
+ )
 
 
 
 
247
  input_length = encoded_input["input_ids"].shape[-1]
248
  generated_text = global_tokenizer.decode(
249
  output.sequences[0][input_length:],
 
258
  search_summary = await perform_duckasgo_search(request.input_text)
259
  generated_text += "\n" + search_summary
260
  await cleanup_memory(device)
 
261
  global_response_cache[cache_key] = generated_text
262
+ background_tasks.add_task(lambda: print("Generaci贸n completada y cacheada."))
263
  return {"generated_text": generated_text}
264
  except Exception as e:
265
  raise HTTPException(status_code=500, detail=f"Error durante la generaci贸n: {e}")
266
 
267
  @app.post("/duckasgo")
268
  async def duckasgo_search(request: DuckasgoRequest):
269
+ """
270
+ Endpoint para b煤squedas en DuckDuckGo con cacheo para acelerar respuestas.
271
+ """
272
  global duckasgo_response_cache
273
  if request.query in duckasgo_response_cache:
274
  return JSONResponse(content={"query": request.query, "results": duckasgo_response_cache[request.query]})
 
277
  results = ddgs.text(request.query, max_results=10)
278
  if not results:
279
  results = []
 
280
  duckasgo_response_cache[request.query] = results
281
  return JSONResponse(content={"query": request.query, "results": results})
282
  except Exception as e:
283
  raise HTTPException(status_code=500, detail=f"Error en la b煤squeda: {e}")
284
 
285
  def run_server():
286
+ """
287
+ Inicia el servidor Uvicorn.
288
+ """
289
  uvicorn.run(app, host="0.0.0.0", port=7860)
290
 
291
  if __name__ == "__main__":