Hjgugugjhuhjggg commited on
Commit
3a605c4
verified
1 Parent(s): 6b3664f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -23
app.py CHANGED
@@ -1,13 +1,14 @@
1
  import os
2
  import gc
3
  import json
 
4
  import torch
5
  import asyncio
6
  import threading
7
  import time
8
  from fastapi import FastAPI, HTTPException, BackgroundTasks
9
  from fastapi.responses import StreamingResponse, JSONResponse, HTMLResponse
10
- from pydantic import BaseModel
11
  from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig
12
  import uvicorn
13
  from duckduckgo_search import DDGS
@@ -30,14 +31,17 @@ async def cleanup_memory(device: str):
30
 
31
  class GenerateRequest(BaseModel):
32
  input_text: str = ""
33
- max_new_tokens: int = 2
34
- temperature: float = 1.0
35
- top_p: float = 1.0
36
- top_k: int = 50
37
- repetition_penalty: float = 1.0
 
 
 
38
  do_sample: bool = True
39
  stream: bool = True # Streaming por defecto
40
- chunk_token_limit: int = 2 # M谩ximo 2 tokens por respuesta
41
  token_timeout: float = 0.0 # Timeout en 0: sin timeout
42
  stop_sequences: list[str] = []
43
  include_duckasgo: bool = False
@@ -148,24 +152,27 @@ def generate_next_token(input_ids, past_key_values, gen_config, device):
148
  async def stream_text(request: GenerateRequest, device: str):
149
  """
150
  Genera texto de forma streaming, enviando cada bloque con hasta 2 tokens generados de forma independiente.
151
- Cada respuesta se env铆a inmediatamente con el campo "generated_text" que contiene solo esos tokens (sin acumulaci贸n).
152
- Si token_timeout es mayor que 0 se aplica un timeout; de lo contrario se espera sin l铆mite.
153
  """
154
  global global_model, global_tokenizer, global_tokens
 
155
  encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
156
  input_ids = encoded_input.input_ids
157
- # Se crea un objeto de configuraci贸n para la generaci贸n
158
  gen_config = GenerationConfig(
159
  temperature=request.temperature,
160
  max_new_tokens=request.max_new_tokens,
161
  top_p=request.top_p,
162
  top_k=request.top_k,
163
  repetition_penalty=request.repetition_penalty,
 
 
164
  do_sample=request.do_sample,
165
  )
166
- # Variable para acumulaci贸n interna (para validaci贸n de stop sequences)
167
- all_tokens = ""
168
- # Variable para el bloque actual (m谩ximo 2 tokens por bloque)
169
  current_chunk = ""
170
  chunk_token_count = 0
171
  past_key_values = None
@@ -185,17 +192,10 @@ async def stream_text(request: GenerateRequest, device: str):
185
  )
186
  token_id = next_token.item()
187
  token_text = global_tokenizer.decode([token_id], skip_special_tokens=True)
188
- all_tokens += token_text
189
  current_chunk += token_text
190
  chunk_token_count += 1
191
- # Validar si se encuentra alguna secuencia de parada en el texto acumulado (para detener la generaci贸n)
192
- if request.stop_sequences:
193
- for stop_seq in request.stop_sequences:
194
- if stop_seq in all_tokens:
195
- yield "data: " + json.dumps({"generated_text": current_chunk}) + "\n\n"
196
- await cleanup_memory(device)
197
- return
198
- # Si se han generado 2 tokens en el bloque, se env铆a el bloque actual y se reinicia
199
  if chunk_token_count >= request.chunk_token_limit:
200
  yield "data: " + json.dumps({"generated_text": current_chunk}) + "\n\n"
201
  current_chunk = ""
@@ -204,7 +204,7 @@ async def stream_text(request: GenerateRequest, device: str):
204
  input_ids = next_token
205
  if token_id == global_tokens["eos_token_id"]:
206
  break
207
- # Si qued贸 un bloque parcial, se env铆a
208
  if current_chunk:
209
  yield "data: " + json.dumps({"generated_text": current_chunk}) + "\n\n"
210
  # Si se solicita incluir b煤squeda, se env铆a al final en el mismo formato
@@ -239,14 +239,19 @@ async def generate_text(request: GenerateRequest, background_tasks: BackgroundTa
239
  if global_model is None or global_tokenizer is None:
240
  raise HTTPException(status_code=500, detail="El modelo no se ha cargado correctamente.")
241
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
242
  gen_config = GenerationConfig(
243
  temperature=request.temperature,
244
  max_new_tokens=request.max_new_tokens,
245
  top_p=request.top_p,
246
  top_k=request.top_k,
247
  repetition_penalty=request.repetition_penalty,
 
 
248
  do_sample=request.do_sample,
249
  )
 
 
250
  try:
251
  if request.stream:
252
  generator = stream_text(request, device)
 
1
  import os
2
  import gc
3
  import json
4
+ import random
5
  import torch
6
  import asyncio
7
  import threading
8
  import time
9
  from fastapi import FastAPI, HTTPException, BackgroundTasks
10
  from fastapi.responses import StreamingResponse, JSONResponse, HTMLResponse
11
+ from pydantic import BaseModel, Field
12
  from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig
13
  import uvicorn
14
  from duckduckgo_search import DDGS
 
31
 
32
  class GenerateRequest(BaseModel):
33
  input_text: str = ""
34
+ max_new_tokens: int = 200
35
+ temperature: float = Field(default_factory=lambda: round(random.uniform(0.5, 0.8), 2))
36
+ top_p: float = Field(default_factory=lambda: round(random.uniform(0.75, 0.95), 2))
37
+ top_k: int = Field(default_factory=lambda: random.randint(20, 60))
38
+ repetition_penalty: float = Field(default_factory=lambda: round(random.uniform(1.1, 1.8), 2))
39
+ frequency_penalty: float = Field(default_factory=lambda: round(random.uniform(0.2, 0.7), 2))
40
+ presence_penalty: float = Field(default_factory=lambda: round(random.uniform(0.2, 0.7), 2))
41
+ seed: int = Field(default_factory=lambda: random.randint(0, 1000))
42
  do_sample: bool = True
43
  stream: bool = True # Streaming por defecto
44
+ chunk_token_limit: int = 2 # M谩ximo 2 tokens por respuesta (bloque)
45
  token_timeout: float = 0.0 # Timeout en 0: sin timeout
46
  stop_sequences: list[str] = []
47
  include_duckasgo: bool = False
 
152
  async def stream_text(request: GenerateRequest, device: str):
153
  """
154
  Genera texto de forma streaming, enviando cada bloque con hasta 2 tokens generados de forma independiente.
155
+ Cada respuesta se env铆a inmediatamente con el campo "generated_text" que contiene 煤nicamente esos tokens (sin acumulaci贸n).
156
+ Si token_timeout es mayor que 0 se aplica un timeout; de lo contrario, se espera sin l铆mite.
157
  """
158
  global global_model, global_tokenizer, global_tokens
159
+ # Prepara la entrada y configura la generaci贸n
160
  encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
161
  input_ids = encoded_input.input_ids
162
+ # Se configura el GenerationConfig con los nuevos par谩metros aleatorios
163
  gen_config = GenerationConfig(
164
  temperature=request.temperature,
165
  max_new_tokens=request.max_new_tokens,
166
  top_p=request.top_p,
167
  top_k=request.top_k,
168
  repetition_penalty=request.repetition_penalty,
169
+ frequency_penalty=request.frequency_penalty,
170
+ presence_penalty=request.presence_penalty,
171
  do_sample=request.do_sample,
172
  )
173
+ # Fijar la semilla para la generaci贸n
174
+ torch.manual_seed(request.seed)
175
+ # Variables para manejo de bloques de tokens
176
  current_chunk = ""
177
  chunk_token_count = 0
178
  past_key_values = None
 
192
  )
193
  token_id = next_token.item()
194
  token_text = global_tokenizer.decode([token_id], skip_special_tokens=True)
195
+ # En lugar de acumular, se env铆a 煤nicamente el token generado
196
  current_chunk += token_text
197
  chunk_token_count += 1
198
+ # Si se han generado 2 tokens en este bloque, se env铆a el bloque actual y se reinicia
 
 
 
 
 
 
 
199
  if chunk_token_count >= request.chunk_token_limit:
200
  yield "data: " + json.dumps({"generated_text": current_chunk}) + "\n\n"
201
  current_chunk = ""
 
204
  input_ids = next_token
205
  if token_id == global_tokens["eos_token_id"]:
206
  break
207
+ # Enviar cualquier bloque parcial pendiente
208
  if current_chunk:
209
  yield "data: " + json.dumps({"generated_text": current_chunk}) + "\n\n"
210
  # Si se solicita incluir b煤squeda, se env铆a al final en el mismo formato
 
239
  if global_model is None or global_tokenizer is None:
240
  raise HTTPException(status_code=500, detail="El modelo no se ha cargado correctamente.")
241
  device = "cuda" if torch.cuda.is_available() else "cpu"
242
+ # Configuraci贸n del GenerationConfig (se usar谩n los par谩metros ya asignados en request)
243
  gen_config = GenerationConfig(
244
  temperature=request.temperature,
245
  max_new_tokens=request.max_new_tokens,
246
  top_p=request.top_p,
247
  top_k=request.top_k,
248
  repetition_penalty=request.repetition_penalty,
249
+ frequency_penalty=request.frequency_penalty,
250
+ presence_penalty=request.presence_penalty,
251
  do_sample=request.do_sample,
252
  )
253
+ # Fijar la semilla para la generaci贸n
254
+ torch.manual_seed(request.seed)
255
  try:
256
  if request.stream:
257
  generator = stream_text(request, device)