Hjgugugjhuhjggg commited on
Commit
28fbc8a
verified
1 Parent(s): 5624ca8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +172 -43
app.py CHANGED
@@ -7,17 +7,20 @@ import asyncio
7
  import threading
8
  import time
9
  from fastapi import FastAPI, HTTPException, BackgroundTasks
10
- from fastapi.responses import StreamingResponse, JSONResponse
11
  from pydantic import BaseModel, Field
12
  from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig
13
  import uvicorn
14
  from duckduckgo_search import DDGS
15
  from concurrent.futures import ThreadPoolExecutor
16
 
 
17
  MODEL_NAME = "lilmeaty/my_xdd"
18
  global_model = None
19
  global_tokenizer = None
20
  global_tokens = {}
 
 
21
  executor = ThreadPoolExecutor(max_workers=4)
22
 
23
  async def cleanup_memory(device: str):
@@ -37,20 +40,28 @@ class GenerateRequest(BaseModel):
37
  presence_penalty: float = Field(default_factory=lambda: round(random.uniform(0.2, 0.7), 2))
38
  seed: int = Field(default_factory=lambda: random.randint(0, 1000))
39
  do_sample: bool = True
40
- stream: bool = True
41
- chunk_token_limit: int = 2
42
- token_timeout: float = 0.0
43
  stop_sequences: list[str] = []
44
  include_duckasgo: bool = False
45
 
 
 
 
46
  app = FastAPI()
47
 
48
  @app.on_event("startup")
49
  async def load_global_model():
 
 
 
50
  global global_model, global_tokenizer, global_tokens
51
  config = AutoConfig.from_pretrained(MODEL_NAME)
52
  global_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, config=config)
53
- global_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, config=config, torch_dtype=torch.float16)
 
 
54
  if global_tokenizer.eos_token_id is not None and global_tokenizer.pad_token_id is None:
55
  global_tokenizer.pad_token_id = config.pad_token_id or global_tokenizer.eos_token_id
56
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -59,9 +70,67 @@ async def load_global_model():
59
  global_tokens["pad_token_id"] = global_tokenizer.pad_token_id
60
  print(f"Modelo {MODEL_NAME} cargado correctamente en {device}.")
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def generate_next_token(input_ids, past_key_values, gen_config, device):
 
 
 
 
 
63
  with torch.no_grad():
64
- outputs = global_model(input_ids, past_key_values=past_key_values, use_cache=True, return_dict=True)
 
 
 
 
 
65
  logits = outputs.logits[:, -1, :]
66
  past_key_values = outputs.past_key_values
67
  if gen_config.do_sample:
@@ -81,27 +150,52 @@ def generate_next_token(input_ids, past_key_values, gen_config, device):
81
  return next_token, past_key_values, token_logprob.item()
82
 
83
  async def stream_text(request: GenerateRequest, device: str):
 
 
 
 
 
84
  global global_model, global_tokenizer, global_tokens
 
85
  encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
86
  input_ids = encoded_input.input_ids
87
- gen_config = GenerationConfig(temperature=request.temperature, max_new_tokens=request.max_new_tokens, top_p=request.top_p, top_k=request.top_k, repetition_penalty=request.repetition_penalty, frequency_penalty=request.frequency_penalty, presence_penalty=request.presence_penalty, do_sample=request.do_sample)
 
 
 
 
 
 
 
 
 
 
 
88
  torch.manual_seed(request.seed)
 
89
  current_chunk = ""
90
  chunk_token_count = 0
91
  past_key_values = None
92
  for _ in range(request.max_new_tokens):
93
  if request.token_timeout > 0:
94
  try:
95
- next_token, past_key_values, token_logprob = await asyncio.wait_for(asyncio.to_thread(generate_next_token, input_ids, past_key_values, gen_config, device), timeout=request.token_timeout)
 
 
 
96
  except asyncio.TimeoutError:
97
- yield "data: " + json.dumps({"generated_text": "[Token generation timeout, continuing...]"}) + "\n\n"
98
  continue
99
  else:
100
- next_token, past_key_values, token_logprob = await asyncio.to_thread(generate_next_token, input_ids, past_key_values, gen_config, device)
 
 
101
  token_id = next_token.item()
102
  token_text = global_tokenizer.decode([token_id], skip_special_tokens=True)
 
103
  current_chunk += token_text
104
  chunk_token_count += 1
 
105
  if chunk_token_count >= request.chunk_token_limit:
106
  yield "data: " + json.dumps({"generated_text": current_chunk}) + "\n\n"
107
  current_chunk = ""
@@ -110,42 +204,54 @@ async def stream_text(request: GenerateRequest, device: str):
110
  input_ids = next_token
111
  if token_id == global_tokens["eos_token_id"]:
112
  break
 
113
  if current_chunk:
114
  yield "data: " + json.dumps({"generated_text": current_chunk}) + "\n\n"
 
115
  if request.include_duckasgo:
116
  search_summary = await perform_duckasgo_search(request.input_text)
117
  yield "data: " + json.dumps({"generated_text": search_summary}) + "\n\n"
118
  await cleanup_memory(device)
119
 
120
- async def perform_duckasgo_search(query: str, max_results: int = 3) -> str:
121
- try:
122
- with DDGS() as ddgs:
123
- results = ddgs.text(query, max_results=max_results)
124
- except Exception as e:
125
- return f"Error en la b煤squeda de DuckDuckGo: {e}"
126
- if not results:
127
- result_text = "No se encontraron resultados en DuckDuckGo."
128
- else:
129
- result_text = "\nResultados de b煤squeda (DuckDuckGo):\n"
130
- for idx, res in enumerate(results, start=1):
131
- title = res.get("title", "Sin t铆tulo")
132
- url = res.get("href", "Sin URL")
133
- snippet = res.get("body", "")
134
- result_text += f"{idx}. {title}\n URL: {url}\n {snippet}\n"
135
- return result_text
136
-
137
  def synchronous_generation(encoded_input, gen_config, device):
 
 
 
 
138
  with torch.no_grad():
139
- output = global_model.generate(**encoded_input, generation_config=gen_config, return_dict_in_generate=True, output_scores=True, return_legacy_cache=True)
 
 
 
 
 
 
140
  return output
141
 
142
  @app.post("/generate")
143
  async def generate_text(request: GenerateRequest, background_tasks: BackgroundTasks):
 
 
 
 
 
 
144
  global global_model, global_tokenizer, global_tokens
145
  if global_model is None or global_tokenizer is None:
146
  raise HTTPException(status_code=500, detail="El modelo no se ha cargado correctamente.")
147
  device = "cuda" if torch.cuda.is_available() else "cpu"
148
- gen_config = GenerationConfig(temperature=request.temperature, max_new_tokens=request.max_new_tokens, top_p=request.top_p, top_k=request.top_k, repetition_penalty=request.repetition_penalty, frequency_penalty=request.frequency_penalty, presence_penalty=request.presence_penalty, do_sample=request.do_sample)
 
 
 
 
 
 
 
 
 
 
 
149
  torch.manual_seed(request.seed)
150
  try:
151
  if request.stream:
@@ -155,33 +261,56 @@ async def generate_text(request: GenerateRequest, background_tasks: BackgroundTa
155
  encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
156
  output = await asyncio.to_thread(synchronous_generation, encoded_input, gen_config, device)
157
  input_length = encoded_input["input_ids"].shape[-1]
158
- full_text = global_tokenizer.decode(output.sequences[0][input_length:], skip_special_tokens=True)
159
- tokens = global_tokenizer.tokenize(full_text)
160
- chunks = []
161
- for i in range(0, len(tokens), request.chunk_token_limit):
162
- chunk_tokens = tokens[i:i+request.chunk_token_limit]
163
- chunk_text = global_tokenizer.convert_tokens_to_string(chunk_tokens)
164
- chunks.append(chunk_text)
165
  if request.stop_sequences:
166
  for stop_seq in request.stop_sequences:
167
- for idx, chunk in enumerate(chunks):
168
- if stop_seq in chunk:
169
- chunks[idx] = chunk.split(stop_seq)[0]
170
- chunks = chunks[:idx+1]
171
- break
 
 
 
 
 
 
 
 
172
  if request.include_duckasgo:
173
  search_summary = await perform_duckasgo_search(request.input_text)
174
- chunks.append(search_summary)
 
175
  await cleanup_memory(device)
176
  background_tasks.add_task(lambda: print("Generaci贸n completada."))
177
- return {"chunks": chunks}
178
  except Exception as e:
179
  raise HTTPException(status_code=500, detail=f"Error durante la generaci贸n: {e}")
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  def run_server():
182
  uvicorn.run(app, host="0.0.0.0", port=7860)
183
 
184
  if __name__ == "__main__":
 
185
  server_thread = threading.Thread(target=run_server, daemon=True)
186
  server_thread.start()
187
  while True:
 
7
  import threading
8
  import time
9
  from fastapi import FastAPI, HTTPException, BackgroundTasks
10
+ from fastapi.responses import StreamingResponse, JSONResponse, HTMLResponse
11
  from pydantic import BaseModel, Field
12
  from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig
13
  import uvicorn
14
  from duckduckgo_search import DDGS
15
  from concurrent.futures import ThreadPoolExecutor
16
 
17
+ # Nombre del modelo y variables globales
18
  MODEL_NAME = "lilmeaty/my_xdd"
19
  global_model = None
20
  global_tokenizer = None
21
  global_tokens = {}
22
+
23
+ # Executor para ejecutar tareas en paralelo
24
  executor = ThreadPoolExecutor(max_workers=4)
25
 
26
  async def cleanup_memory(device: str):
 
40
  presence_penalty: float = Field(default_factory=lambda: round(random.uniform(0.2, 0.7), 2))
41
  seed: int = Field(default_factory=lambda: random.randint(0, 1000))
42
  do_sample: bool = True
43
+ stream: bool = True # Streaming por defecto
44
+ chunk_token_limit: int = 2 # M谩ximo 2 tokens por respuesta (bloque)
45
+ token_timeout: float = 0.0 # Timeout en 0: sin timeout
46
  stop_sequences: list[str] = []
47
  include_duckasgo: bool = False
48
 
49
+ class DuckasgoRequest(BaseModel):
50
+ query: str
51
+
52
  app = FastAPI()
53
 
54
  @app.on_event("startup")
55
  async def load_global_model():
56
+ """
57
+ Carga el modelo y el tokenizador global al iniciar la aplicaci贸n.
58
+ """
59
  global global_model, global_tokenizer, global_tokens
60
  config = AutoConfig.from_pretrained(MODEL_NAME)
61
  global_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, config=config)
62
+ global_model = AutoModelForCausalLM.from_pretrained(
63
+ MODEL_NAME, config=config, torch_dtype=torch.float16
64
+ )
65
  if global_tokenizer.eos_token_id is not None and global_tokenizer.pad_token_id is None:
66
  global_tokenizer.pad_token_id = config.pad_token_id or global_tokenizer.eos_token_id
67
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
70
  global_tokens["pad_token_id"] = global_tokenizer.pad_token_id
71
  print(f"Modelo {MODEL_NAME} cargado correctamente en {device}.")
72
 
73
+ @app.get("/", response_class=HTMLResponse)
74
+ async def index():
75
+ """
76
+ Endpoint ra铆z que devuelve una p谩gina HTML simple para permitir la navegaci贸n
77
+ mientras se generan respuestas en paralelo.
78
+ """
79
+ html_content = """
80
+ <html>
81
+ <head>
82
+ <title>Generaci贸n de Texto - Streaming por defecto</title>
83
+ </head>
84
+ <body>
85
+ <h1>Bienvenido al Generador de Texto</h1>
86
+ <p>El sistema utiliza streaming por defecto para generar respuestas r谩pidamente.</p>
87
+ <p>Prueba los endpoints <code>/generate</code> o <code>/duckasgo</code>.</p>
88
+ </body>
89
+ </html>
90
+ """
91
+ return HTMLResponse(content=html_content, status_code=200)
92
+
93
+ @app.get("/health")
94
+ async def health():
95
+ """
96
+ Endpoint de salud para verificar el estado del servidor.
97
+ """
98
+ return {"status": "ok"}
99
+
100
+ async def perform_duckasgo_search(query: str, max_results: int = 3) -> str:
101
+ """
102
+ Realiza una b煤squeda en DuckDuckGo y retorna un resumen de los resultados.
103
+ Se ejecuta en cada llamada sin almacenar resultados en cach茅.
104
+ """
105
+ try:
106
+ with DDGS() as ddgs:
107
+ results = ddgs.text(query, max_results=max_results)
108
+ except Exception as e:
109
+ return f"Error en la b煤squeda de DuckDuckGo: {e}"
110
+ if not results:
111
+ result_text = "No se encontraron resultados en DuckDuckGo."
112
+ else:
113
+ result_text = "\nResultados de b煤squeda (DuckDuckGo):\n"
114
+ for idx, res in enumerate(results, start=1):
115
+ title = res.get("title", "Sin t铆tulo")
116
+ url = res.get("href", "Sin URL")
117
+ snippet = res.get("body", "")
118
+ result_text += f"{idx}. {title}\n URL: {url}\n {snippet}\n"
119
+ return result_text
120
+
121
  def generate_next_token(input_ids, past_key_values, gen_config, device):
122
+ """
123
+ Funci贸n s铆ncrona que genera el siguiente token utilizando el modelo.
124
+ Retorna adem谩s el log-probability del token seleccionado.
125
+ Esta funci贸n se invoca en paralelo mediante asyncio.to_thread.
126
+ """
127
  with torch.no_grad():
128
+ outputs = global_model(
129
+ input_ids,
130
+ past_key_values=past_key_values,
131
+ use_cache=True,
132
+ return_dict=True
133
+ )
134
  logits = outputs.logits[:, -1, :]
135
  past_key_values = outputs.past_key_values
136
  if gen_config.do_sample:
 
150
  return next_token, past_key_values, token_logprob.item()
151
 
152
  async def stream_text(request: GenerateRequest, device: str):
153
+ """
154
+ Genera texto de forma streaming, enviando cada bloque con hasta 2 tokens generados de forma independiente.
155
+ Cada respuesta se env铆a inmediatamente con el campo "generated_text" que contiene 煤nicamente esos tokens (sin acumulaci贸n).
156
+ Si token_timeout es mayor que 0 se aplica un timeout; de lo contrario, se espera sin l铆mite.
157
+ """
158
  global global_model, global_tokenizer, global_tokens
159
+ # Prepara la entrada y configura la generaci贸n
160
  encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
161
  input_ids = encoded_input.input_ids
162
+ # Configurar GenerationConfig con los par谩metros del request
163
+ gen_config = GenerationConfig(
164
+ temperature=request.temperature,
165
+ max_new_tokens=request.max_new_tokens,
166
+ top_p=request.top_p,
167
+ top_k=request.top_k,
168
+ repetition_penalty=request.repetition_penalty,
169
+ frequency_penalty=request.frequency_penalty,
170
+ presence_penalty=request.presence_penalty,
171
+ do_sample=request.do_sample,
172
+ )
173
+ # Fijar la semilla para la generaci贸n
174
  torch.manual_seed(request.seed)
175
+ # Variables para manejo de bloques de tokens
176
  current_chunk = ""
177
  chunk_token_count = 0
178
  past_key_values = None
179
  for _ in range(request.max_new_tokens):
180
  if request.token_timeout > 0:
181
  try:
182
+ next_token, past_key_values, token_logprob = await asyncio.wait_for(
183
+ asyncio.to_thread(generate_next_token, input_ids, past_key_values, gen_config, device),
184
+ timeout=request.token_timeout
185
+ )
186
  except asyncio.TimeoutError:
187
+ yield "data: " + json.dumps({"generated_text": "[Token generation timeout, continuing...]"} ) + "\n\n"
188
  continue
189
  else:
190
+ next_token, past_key_values, token_logprob = await asyncio.to_thread(
191
+ generate_next_token, input_ids, past_key_values, gen_config, device
192
+ )
193
  token_id = next_token.item()
194
  token_text = global_tokenizer.decode([token_id], skip_special_tokens=True)
195
+ # Se acumula el token generado
196
  current_chunk += token_text
197
  chunk_token_count += 1
198
+ # Si se han generado 'chunk_token_limit' tokens en este bloque, se env铆a el bloque actual y se reinicia
199
  if chunk_token_count >= request.chunk_token_limit:
200
  yield "data: " + json.dumps({"generated_text": current_chunk}) + "\n\n"
201
  current_chunk = ""
 
204
  input_ids = next_token
205
  if token_id == global_tokens["eos_token_id"]:
206
  break
207
+ # Enviar cualquier bloque parcial pendiente
208
  if current_chunk:
209
  yield "data: " + json.dumps({"generated_text": current_chunk}) + "\n\n"
210
+ # Si se solicita incluir b煤squeda, se env铆a al final en el mismo formato
211
  if request.include_duckasgo:
212
  search_summary = await perform_duckasgo_search(request.input_text)
213
  yield "data: " + json.dumps({"generated_text": search_summary}) + "\n\n"
214
  await cleanup_memory(device)
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  def synchronous_generation(encoded_input, gen_config, device):
217
+ """
218
+ Funci贸n s铆ncrona para la generaci贸n completa en modo no streaming.
219
+ Se ejecuta en paralelo mediante asyncio.to_thread.
220
+ """
221
  with torch.no_grad():
222
+ output = global_model.generate(
223
+ **encoded_input,
224
+ generation_config=gen_config,
225
+ return_dict_in_generate=True,
226
+ output_scores=True,
227
+ return_legacy_cache=True
228
+ )
229
  return output
230
 
231
  @app.post("/generate")
232
  async def generate_text(request: GenerateRequest, background_tasks: BackgroundTasks):
233
+ """
234
+ Endpoint para la generaci贸n de texto.
235
+ Por defecto se utiliza streaming para emitir cada bloque (m谩ximo 2 tokens) tan pronto est茅n listos.
236
+ Cada respuesta incluye 煤nicamente el campo "generated_text" con los tokens generados en ese bloque.
237
+ En el modo no streaming se divide la respuesta en bloques de tokens, si 茅sta excede el l铆mite.
238
+ """
239
  global global_model, global_tokenizer, global_tokens
240
  if global_model is None or global_tokenizer is None:
241
  raise HTTPException(status_code=500, detail="El modelo no se ha cargado correctamente.")
242
  device = "cuda" if torch.cuda.is_available() else "cpu"
243
+ # Configuraci贸n del GenerationConfig
244
+ gen_config = GenerationConfig(
245
+ temperature=request.temperature,
246
+ max_new_tokens=request.max_new_tokens,
247
+ top_p=request.top_p,
248
+ top_k=request.top_k,
249
+ repetition_penalty=request.repetition_penalty,
250
+ frequency_penalty=request.frequency_penalty,
251
+ presence_penalty=request.presence_penalty,
252
+ do_sample=request.do_sample,
253
+ )
254
+ # Fijar la semilla para la generaci贸n
255
  torch.manual_seed(request.seed)
256
  try:
257
  if request.stream:
 
261
  encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
262
  output = await asyncio.to_thread(synchronous_generation, encoded_input, gen_config, device)
263
  input_length = encoded_input["input_ids"].shape[-1]
264
+ # Decodificar la parte generada (excluyendo la entrada)
265
+ full_generated_text = global_tokenizer.decode(
266
+ output.sequences[0][input_length:], skip_special_tokens=True
267
+ )
268
+ # Aplicar secuencias de corte (stop_sequences) si se han definido
 
 
269
  if request.stop_sequences:
270
  for stop_seq in request.stop_sequences:
271
+ if stop_seq in full_generated_text:
272
+ full_generated_text = full_generated_text.split(stop_seq)[0]
273
+ break
274
+
275
+ # Dividir la respuesta en bloques de tokens seg煤n chunk_token_limit
276
+ final_token_ids = global_tokenizer.encode(full_generated_text, add_special_tokens=False)
277
+ chunks = []
278
+ for i in range(0, len(final_token_ids), request.chunk_token_limit):
279
+ chunk_ids = final_token_ids[i:i+request.chunk_token_limit]
280
+ chunk_text = global_tokenizer.decode(chunk_ids, skip_special_tokens=True)
281
+ chunks.append(chunk_text)
282
+
283
+ # Si se solicita incluir b煤squeda en DuckDuckGo, se agrega como bloque extra
284
  if request.include_duckasgo:
285
  search_summary = await perform_duckasgo_search(request.input_text)
286
+ chunks.append("\n" + search_summary)
287
+
288
  await cleanup_memory(device)
289
  background_tasks.add_task(lambda: print("Generaci贸n completada."))
290
+ return {"generated_chunks": chunks}
291
  except Exception as e:
292
  raise HTTPException(status_code=500, detail=f"Error durante la generaci贸n: {e}")
293
 
294
+ @app.post("/duckasgo")
295
+ async def duckasgo_search(request: DuckasgoRequest):
296
+ """
297
+ Endpoint para b煤squedas en DuckDuckGo.
298
+ Se ejecuta en cada petici贸n sin almacenar la respuesta final en cach茅.
299
+ """
300
+ try:
301
+ with DDGS() as ddgs:
302
+ results = ddgs.text(request.query, max_results=10)
303
+ if not results:
304
+ results = []
305
+ return JSONResponse(content={"query": request.query, "results": results})
306
+ except Exception as e:
307
+ raise HTTPException(status_code=500, detail=f"Error en la b煤squeda: {e}")
308
+
309
  def run_server():
310
  uvicorn.run(app, host="0.0.0.0", port=7860)
311
 
312
  if __name__ == "__main__":
313
+ # Inicia el servidor en un hilo separado para permitir otras tareas concurrentes.
314
  server_thread = threading.Thread(target=run_server, daemon=True)
315
  server_thread.start()
316
  while True: