Update app.py
Browse files
app.py
CHANGED
@@ -1,13 +1,14 @@
|
|
1 |
import os
|
2 |
import gc
|
3 |
import json
|
|
|
4 |
import torch
|
5 |
import asyncio
|
6 |
import threading
|
7 |
import time
|
8 |
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
9 |
from fastapi.responses import StreamingResponse, JSONResponse, HTMLResponse
|
10 |
-
from pydantic import BaseModel
|
11 |
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
12 |
import uvicorn
|
13 |
from duckduckgo_search import DDGS
|
@@ -30,14 +31,17 @@ async def cleanup_memory(device: str):
|
|
30 |
|
31 |
class GenerateRequest(BaseModel):
|
32 |
input_text: str = ""
|
33 |
-
max_new_tokens: int =
|
34 |
-
temperature: float =
|
35 |
-
top_p: float =
|
36 |
-
top_k: int =
|
37 |
-
repetition_penalty: float = 1.
|
|
|
|
|
|
|
38 |
do_sample: bool = True
|
39 |
stream: bool = True # Streaming por defecto
|
40 |
-
chunk_token_limit: int = 2 # M谩ximo 2 tokens por respuesta
|
41 |
token_timeout: float = 0.0 # Timeout en 0: sin timeout
|
42 |
stop_sequences: list[str] = []
|
43 |
include_duckasgo: bool = False
|
@@ -148,24 +152,27 @@ def generate_next_token(input_ids, past_key_values, gen_config, device):
|
|
148 |
async def stream_text(request: GenerateRequest, device: str):
|
149 |
"""
|
150 |
Genera texto de forma streaming, enviando cada bloque con hasta 2 tokens generados de forma independiente.
|
151 |
-
Cada respuesta se env铆a inmediatamente con el campo "generated_text" que contiene
|
152 |
-
Si token_timeout es mayor que 0 se aplica un timeout; de lo contrario se espera sin l铆mite.
|
153 |
"""
|
154 |
global global_model, global_tokenizer, global_tokens
|
|
|
155 |
encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
|
156 |
input_ids = encoded_input.input_ids
|
157 |
-
# Se
|
158 |
gen_config = GenerationConfig(
|
159 |
temperature=request.temperature,
|
160 |
max_new_tokens=request.max_new_tokens,
|
161 |
top_p=request.top_p,
|
162 |
top_k=request.top_k,
|
163 |
repetition_penalty=request.repetition_penalty,
|
|
|
|
|
164 |
do_sample=request.do_sample,
|
165 |
)
|
166 |
-
#
|
167 |
-
|
168 |
-
#
|
169 |
current_chunk = ""
|
170 |
chunk_token_count = 0
|
171 |
past_key_values = None
|
@@ -185,17 +192,10 @@ async def stream_text(request: GenerateRequest, device: str):
|
|
185 |
)
|
186 |
token_id = next_token.item()
|
187 |
token_text = global_tokenizer.decode([token_id], skip_special_tokens=True)
|
188 |
-
|
189 |
current_chunk += token_text
|
190 |
chunk_token_count += 1
|
191 |
-
#
|
192 |
-
if request.stop_sequences:
|
193 |
-
for stop_seq in request.stop_sequences:
|
194 |
-
if stop_seq in all_tokens:
|
195 |
-
yield "data: " + json.dumps({"generated_text": current_chunk}) + "\n\n"
|
196 |
-
await cleanup_memory(device)
|
197 |
-
return
|
198 |
-
# Si se han generado 2 tokens en el bloque, se env铆a el bloque actual y se reinicia
|
199 |
if chunk_token_count >= request.chunk_token_limit:
|
200 |
yield "data: " + json.dumps({"generated_text": current_chunk}) + "\n\n"
|
201 |
current_chunk = ""
|
@@ -204,7 +204,7 @@ async def stream_text(request: GenerateRequest, device: str):
|
|
204 |
input_ids = next_token
|
205 |
if token_id == global_tokens["eos_token_id"]:
|
206 |
break
|
207 |
-
#
|
208 |
if current_chunk:
|
209 |
yield "data: " + json.dumps({"generated_text": current_chunk}) + "\n\n"
|
210 |
# Si se solicita incluir b煤squeda, se env铆a al final en el mismo formato
|
@@ -239,14 +239,19 @@ async def generate_text(request: GenerateRequest, background_tasks: BackgroundTa
|
|
239 |
if global_model is None or global_tokenizer is None:
|
240 |
raise HTTPException(status_code=500, detail="El modelo no se ha cargado correctamente.")
|
241 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
242 |
gen_config = GenerationConfig(
|
243 |
temperature=request.temperature,
|
244 |
max_new_tokens=request.max_new_tokens,
|
245 |
top_p=request.top_p,
|
246 |
top_k=request.top_k,
|
247 |
repetition_penalty=request.repetition_penalty,
|
|
|
|
|
248 |
do_sample=request.do_sample,
|
249 |
)
|
|
|
|
|
250 |
try:
|
251 |
if request.stream:
|
252 |
generator = stream_text(request, device)
|
|
|
1 |
import os
|
2 |
import gc
|
3 |
import json
|
4 |
+
import random
|
5 |
import torch
|
6 |
import asyncio
|
7 |
import threading
|
8 |
import time
|
9 |
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
10 |
from fastapi.responses import StreamingResponse, JSONResponse, HTMLResponse
|
11 |
+
from pydantic import BaseModel, Field
|
12 |
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
13 |
import uvicorn
|
14 |
from duckduckgo_search import DDGS
|
|
|
31 |
|
32 |
class GenerateRequest(BaseModel):
|
33 |
input_text: str = ""
|
34 |
+
max_new_tokens: int = 200
|
35 |
+
temperature: float = Field(default_factory=lambda: round(random.uniform(0.5, 0.8), 2))
|
36 |
+
top_p: float = Field(default_factory=lambda: round(random.uniform(0.75, 0.95), 2))
|
37 |
+
top_k: int = Field(default_factory=lambda: random.randint(20, 60))
|
38 |
+
repetition_penalty: float = Field(default_factory=lambda: round(random.uniform(1.1, 1.8), 2))
|
39 |
+
frequency_penalty: float = Field(default_factory=lambda: round(random.uniform(0.2, 0.7), 2))
|
40 |
+
presence_penalty: float = Field(default_factory=lambda: round(random.uniform(0.2, 0.7), 2))
|
41 |
+
seed: int = Field(default_factory=lambda: random.randint(0, 1000))
|
42 |
do_sample: bool = True
|
43 |
stream: bool = True # Streaming por defecto
|
44 |
+
chunk_token_limit: int = 2 # M谩ximo 2 tokens por respuesta (bloque)
|
45 |
token_timeout: float = 0.0 # Timeout en 0: sin timeout
|
46 |
stop_sequences: list[str] = []
|
47 |
include_duckasgo: bool = False
|
|
|
152 |
async def stream_text(request: GenerateRequest, device: str):
|
153 |
"""
|
154 |
Genera texto de forma streaming, enviando cada bloque con hasta 2 tokens generados de forma independiente.
|
155 |
+
Cada respuesta se env铆a inmediatamente con el campo "generated_text" que contiene 煤nicamente esos tokens (sin acumulaci贸n).
|
156 |
+
Si token_timeout es mayor que 0 se aplica un timeout; de lo contrario, se espera sin l铆mite.
|
157 |
"""
|
158 |
global global_model, global_tokenizer, global_tokens
|
159 |
+
# Prepara la entrada y configura la generaci贸n
|
160 |
encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
|
161 |
input_ids = encoded_input.input_ids
|
162 |
+
# Se configura el GenerationConfig con los nuevos par谩metros aleatorios
|
163 |
gen_config = GenerationConfig(
|
164 |
temperature=request.temperature,
|
165 |
max_new_tokens=request.max_new_tokens,
|
166 |
top_p=request.top_p,
|
167 |
top_k=request.top_k,
|
168 |
repetition_penalty=request.repetition_penalty,
|
169 |
+
frequency_penalty=request.frequency_penalty,
|
170 |
+
presence_penalty=request.presence_penalty,
|
171 |
do_sample=request.do_sample,
|
172 |
)
|
173 |
+
# Fijar la semilla para la generaci贸n
|
174 |
+
torch.manual_seed(request.seed)
|
175 |
+
# Variables para manejo de bloques de tokens
|
176 |
current_chunk = ""
|
177 |
chunk_token_count = 0
|
178 |
past_key_values = None
|
|
|
192 |
)
|
193 |
token_id = next_token.item()
|
194 |
token_text = global_tokenizer.decode([token_id], skip_special_tokens=True)
|
195 |
+
# En lugar de acumular, se env铆a 煤nicamente el token generado
|
196 |
current_chunk += token_text
|
197 |
chunk_token_count += 1
|
198 |
+
# Si se han generado 2 tokens en este bloque, se env铆a el bloque actual y se reinicia
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
if chunk_token_count >= request.chunk_token_limit:
|
200 |
yield "data: " + json.dumps({"generated_text": current_chunk}) + "\n\n"
|
201 |
current_chunk = ""
|
|
|
204 |
input_ids = next_token
|
205 |
if token_id == global_tokens["eos_token_id"]:
|
206 |
break
|
207 |
+
# Enviar cualquier bloque parcial pendiente
|
208 |
if current_chunk:
|
209 |
yield "data: " + json.dumps({"generated_text": current_chunk}) + "\n\n"
|
210 |
# Si se solicita incluir b煤squeda, se env铆a al final en el mismo formato
|
|
|
239 |
if global_model is None or global_tokenizer is None:
|
240 |
raise HTTPException(status_code=500, detail="El modelo no se ha cargado correctamente.")
|
241 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
242 |
+
# Configuraci贸n del GenerationConfig (se usar谩n los par谩metros ya asignados en request)
|
243 |
gen_config = GenerationConfig(
|
244 |
temperature=request.temperature,
|
245 |
max_new_tokens=request.max_new_tokens,
|
246 |
top_p=request.top_p,
|
247 |
top_k=request.top_k,
|
248 |
repetition_penalty=request.repetition_penalty,
|
249 |
+
frequency_penalty=request.frequency_penalty,
|
250 |
+
presence_penalty=request.presence_penalty,
|
251 |
do_sample=request.do_sample,
|
252 |
)
|
253 |
+
# Fijar la semilla para la generaci贸n
|
254 |
+
torch.manual_seed(request.seed)
|
255 |
try:
|
256 |
if request.stream:
|
257 |
generator = stream_text(request, device)
|