Update app.py
Browse files
app.py
CHANGED
@@ -4,28 +4,34 @@ import torch
|
|
4 |
import asyncio
|
5 |
import threading
|
6 |
import time
|
7 |
-
from fastapi import FastAPI, HTTPException
|
8 |
-
from fastapi.responses import StreamingResponse, JSONResponse
|
9 |
from pydantic import BaseModel
|
10 |
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
11 |
import uvicorn
|
12 |
from duckduckgo_search import DDGS
|
|
|
13 |
|
|
|
14 |
MODEL_NAME = "lilmeaty/my_xdd"
|
15 |
global_model = None
|
16 |
global_tokenizer = None
|
17 |
global_tokens = {}
|
18 |
|
19 |
-
#
|
20 |
global_response_cache = {}
|
21 |
duckasgo_response_cache = {}
|
22 |
|
|
|
|
|
|
|
23 |
async def cleanup_memory(device: str):
|
24 |
gc.collect()
|
25 |
if device == "cuda":
|
26 |
torch.cuda.empty_cache()
|
27 |
await asyncio.sleep(0.01)
|
28 |
|
|
|
29 |
class GenerateRequest(BaseModel):
|
30 |
input_text: str = ""
|
31 |
max_new_tokens: int = 200
|
@@ -39,6 +45,7 @@ class GenerateRequest(BaseModel):
|
|
39 |
stop_sequences: list[str] = []
|
40 |
include_duckasgo: bool = False
|
41 |
|
|
|
42 |
class DuckasgoRequest(BaseModel):
|
43 |
query: str
|
44 |
|
@@ -46,6 +53,9 @@ app = FastAPI()
|
|
46 |
|
47 |
@app.on_event("startup")
|
48 |
async def load_global_model():
|
|
|
|
|
|
|
49 |
global global_model, global_tokenizer, global_tokens
|
50 |
config = AutoConfig.from_pretrained(MODEL_NAME)
|
51 |
global_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, config=config)
|
@@ -60,8 +70,28 @@ async def load_global_model():
|
|
60 |
global_tokens["pad_token_id"] = global_tokenizer.pad_token_id
|
61 |
print(f"Modelo {MODEL_NAME} cargado correctamente en {device}.")
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
async def perform_duckasgo_search(query: str, max_results: int = 3) -> str:
|
64 |
-
|
|
|
|
|
65 |
if query in duckasgo_response_cache:
|
66 |
return duckasgo_response_cache[query]
|
67 |
try:
|
@@ -78,11 +108,38 @@ async def perform_duckasgo_search(query: str, max_results: int = 3) -> str:
|
|
78 |
url = res.get("href", "Sin URL")
|
79 |
snippet = res.get("body", "")
|
80 |
result_text += f"{idx}. {title}\n URL: {url}\n {snippet}\n"
|
81 |
-
# Guardar en cach茅
|
82 |
duckasgo_response_cache[query] = result_text
|
83 |
return result_text
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
async def stream_text(request: GenerateRequest, device: str):
|
|
|
|
|
|
|
86 |
global global_model, global_tokenizer, global_tokens
|
87 |
encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
|
88 |
input_ids = encoded_input.input_ids
|
@@ -99,24 +156,10 @@ async def stream_text(request: GenerateRequest, device: str):
|
|
99 |
)
|
100 |
past_key_values = None
|
101 |
for _ in range(request.max_new_tokens):
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
use_cache=True,
|
107 |
-
return_dict=True
|
108 |
-
)
|
109 |
-
logits = outputs.logits[:, -1, :]
|
110 |
-
past_key_values = outputs.past_key_values
|
111 |
-
if gen_config.do_sample:
|
112 |
-
logits = logits / gen_config.temperature
|
113 |
-
if gen_config.top_k and gen_config.top_k > 0:
|
114 |
-
topk_values, _ = torch.topk(logits, k=gen_config.top_k)
|
115 |
-
logits[logits < topk_values[:, [-1]]] = -float('Inf')
|
116 |
-
probs = torch.nn.functional.softmax(logits, dim=-1)
|
117 |
-
next_token = torch.multinomial(probs, num_samples=1)
|
118 |
-
else:
|
119 |
-
next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
120 |
token_id = next_token.item()
|
121 |
token_text = global_tokenizer.decode([token_id], skip_special_tokens=True)
|
122 |
accumulated_text += token_text
|
@@ -143,8 +186,27 @@ async def stream_text(request: GenerateRequest, device: str):
|
|
143 |
yield "\n" + search_summary
|
144 |
await cleanup_memory(device)
|
145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
@app.post("/generate")
|
147 |
-
async def generate_text(request: GenerateRequest):
|
|
|
|
|
|
|
|
|
148 |
global global_model, global_tokenizer, global_tokens, global_response_cache
|
149 |
if global_model is None or global_tokenizer is None:
|
150 |
raise HTTPException(status_code=500, detail="El modelo no se ha cargado correctamente.")
|
@@ -157,9 +219,7 @@ async def generate_text(request: GenerateRequest):
|
|
157 |
repetition_penalty=request.repetition_penalty,
|
158 |
do_sample=request.do_sample,
|
159 |
)
|
160 |
-
# Solo se cachea si no es un request de stream
|
161 |
if not request.stream:
|
162 |
-
# Se construye una clave con los par谩metros relevantes
|
163 |
cache_key = (
|
164 |
request.input_text,
|
165 |
request.max_new_tokens,
|
@@ -175,18 +235,15 @@ async def generate_text(request: GenerateRequest):
|
|
175 |
return {"generated_text": global_response_cache[cache_key]}
|
176 |
try:
|
177 |
if request.stream:
|
|
|
178 |
generator = stream_text(request, device)
|
179 |
return StreamingResponse(generator, media_type="text/plain")
|
180 |
else:
|
181 |
encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
return_dict_in_generate=True,
|
187 |
-
output_scores=True,
|
188 |
-
return_legacy_cache=True # Se agrega el par谩metro solicitado
|
189 |
-
)
|
190 |
input_length = encoded_input["input_ids"].shape[-1]
|
191 |
generated_text = global_tokenizer.decode(
|
192 |
output.sequences[0][input_length:],
|
@@ -201,14 +258,17 @@ async def generate_text(request: GenerateRequest):
|
|
201 |
search_summary = await perform_duckasgo_search(request.input_text)
|
202 |
generated_text += "\n" + search_summary
|
203 |
await cleanup_memory(device)
|
204 |
-
# Se almacena en cach茅 la respuesta generada
|
205 |
global_response_cache[cache_key] = generated_text
|
|
|
206 |
return {"generated_text": generated_text}
|
207 |
except Exception as e:
|
208 |
raise HTTPException(status_code=500, detail=f"Error durante la generaci贸n: {e}")
|
209 |
|
210 |
@app.post("/duckasgo")
|
211 |
async def duckasgo_search(request: DuckasgoRequest):
|
|
|
|
|
|
|
212 |
global duckasgo_response_cache
|
213 |
if request.query in duckasgo_response_cache:
|
214 |
return JSONResponse(content={"query": request.query, "results": duckasgo_response_cache[request.query]})
|
@@ -217,13 +277,15 @@ async def duckasgo_search(request: DuckasgoRequest):
|
|
217 |
results = ddgs.text(request.query, max_results=10)
|
218 |
if not results:
|
219 |
results = []
|
220 |
-
# Se almacena la respuesta en cach茅
|
221 |
duckasgo_response_cache[request.query] = results
|
222 |
return JSONResponse(content={"query": request.query, "results": results})
|
223 |
except Exception as e:
|
224 |
raise HTTPException(status_code=500, detail=f"Error en la b煤squeda: {e}")
|
225 |
|
226 |
def run_server():
|
|
|
|
|
|
|
227 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
228 |
|
229 |
if __name__ == "__main__":
|
|
|
4 |
import asyncio
|
5 |
import threading
|
6 |
import time
|
7 |
+
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
8 |
+
from fastapi.responses import StreamingResponse, JSONResponse, HTMLResponse
|
9 |
from pydantic import BaseModel
|
10 |
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
11 |
import uvicorn
|
12 |
from duckduckgo_search import DDGS
|
13 |
+
from concurrent.futures import ThreadPoolExecutor
|
14 |
|
15 |
+
# Nombre del modelo a cargar y variables globales
|
16 |
MODEL_NAME = "lilmeaty/my_xdd"
|
17 |
global_model = None
|
18 |
global_tokenizer = None
|
19 |
global_tokens = {}
|
20 |
|
21 |
+
# Cach茅 para respuestas y b煤squedas
|
22 |
global_response_cache = {}
|
23 |
duckasgo_response_cache = {}
|
24 |
|
25 |
+
# Executor para ejecutar tareas en paralelo
|
26 |
+
executor = ThreadPoolExecutor(max_workers=4)
|
27 |
+
|
28 |
async def cleanup_memory(device: str):
|
29 |
gc.collect()
|
30 |
if device == "cuda":
|
31 |
torch.cuda.empty_cache()
|
32 |
await asyncio.sleep(0.01)
|
33 |
|
34 |
+
# Modelo de datos para generaci贸n
|
35 |
class GenerateRequest(BaseModel):
|
36 |
input_text: str = ""
|
37 |
max_new_tokens: int = 200
|
|
|
45 |
stop_sequences: list[str] = []
|
46 |
include_duckasgo: bool = False
|
47 |
|
48 |
+
# Modelo de datos para b煤squeda DuckDuckGo
|
49 |
class DuckasgoRequest(BaseModel):
|
50 |
query: str
|
51 |
|
|
|
53 |
|
54 |
@app.on_event("startup")
|
55 |
async def load_global_model():
|
56 |
+
"""
|
57 |
+
Carga el modelo y tokenizador en el inicio de la aplicaci贸n.
|
58 |
+
"""
|
59 |
global global_model, global_tokenizer, global_tokens
|
60 |
config = AutoConfig.from_pretrained(MODEL_NAME)
|
61 |
global_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, config=config)
|
|
|
70 |
global_tokens["pad_token_id"] = global_tokenizer.pad_token_id
|
71 |
print(f"Modelo {MODEL_NAME} cargado correctamente en {device}.")
|
72 |
|
73 |
+
@app.get("/", response_class=HTMLResponse)
|
74 |
+
async def index():
|
75 |
+
"""
|
76 |
+
Endpoint ra铆z que permite visitar el sitio mientras se ejecutan otras operaciones.
|
77 |
+
"""
|
78 |
+
html_content = """
|
79 |
+
<html>
|
80 |
+
<head>
|
81 |
+
<title>Mi Sitio de Generaci贸n de Texto</title>
|
82 |
+
</head>
|
83 |
+
<body>
|
84 |
+
<h1>Bienvenido a Mi Sitio</h1>
|
85 |
+
<p>Puedes enviar peticiones a <code>/generate</code> o <code>/duckasgo</code> sin afectar la navegaci贸n.</p>
|
86 |
+
</body>
|
87 |
+
</html>
|
88 |
+
"""
|
89 |
+
return HTMLResponse(content=html_content, status_code=200)
|
90 |
+
|
91 |
async def perform_duckasgo_search(query: str, max_results: int = 3) -> str:
|
92 |
+
"""
|
93 |
+
Realiza b煤squeda en DuckDuckGo y utiliza cach茅 para acelerar consultas repetidas.
|
94 |
+
"""
|
95 |
if query in duckasgo_response_cache:
|
96 |
return duckasgo_response_cache[query]
|
97 |
try:
|
|
|
108 |
url = res.get("href", "Sin URL")
|
109 |
snippet = res.get("body", "")
|
110 |
result_text += f"{idx}. {title}\n URL: {url}\n {snippet}\n"
|
|
|
111 |
duckasgo_response_cache[query] = result_text
|
112 |
return result_text
|
113 |
|
114 |
+
def generate_next_token(input_ids, past_key_values, gen_config, device):
|
115 |
+
"""
|
116 |
+
Funci贸n s铆ncrona que genera el siguiente token usando el modelo.
|
117 |
+
Se invoca en paralelo mediante asyncio.to_thread.
|
118 |
+
"""
|
119 |
+
with torch.no_grad():
|
120 |
+
outputs = global_model(
|
121 |
+
input_ids,
|
122 |
+
past_key_values=past_key_values,
|
123 |
+
use_cache=True,
|
124 |
+
return_dict=True
|
125 |
+
)
|
126 |
+
logits = outputs.logits[:, -1, :]
|
127 |
+
past_key_values = outputs.past_key_values
|
128 |
+
if gen_config.do_sample:
|
129 |
+
logits = logits / gen_config.temperature
|
130 |
+
if gen_config.top_k and gen_config.top_k > 0:
|
131 |
+
topk_values, _ = torch.topk(logits, k=gen_config.top_k)
|
132 |
+
logits[logits < topk_values[:, [-1]]] = -float('Inf')
|
133 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
134 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
135 |
+
else:
|
136 |
+
next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
137 |
+
return next_token, past_key_values
|
138 |
+
|
139 |
async def stream_text(request: GenerateRequest, device: str):
|
140 |
+
"""
|
141 |
+
Genera texto de forma streaming, utilizando generaci贸n paralela para cada token.
|
142 |
+
"""
|
143 |
global global_model, global_tokenizer, global_tokens
|
144 |
encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
|
145 |
input_ids = encoded_input.input_ids
|
|
|
156 |
)
|
157 |
past_key_values = None
|
158 |
for _ in range(request.max_new_tokens):
|
159 |
+
# Ejecuta la generaci贸n del siguiente token en paralelo
|
160 |
+
next_token, past_key_values = await asyncio.to_thread(
|
161 |
+
generate_next_token, input_ids, past_key_values, gen_config, device
|
162 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
token_id = next_token.item()
|
164 |
token_text = global_tokenizer.decode([token_id], skip_special_tokens=True)
|
165 |
accumulated_text += token_text
|
|
|
186 |
yield "\n" + search_summary
|
187 |
await cleanup_memory(device)
|
188 |
|
189 |
+
def synchronous_generation(encoded_input, gen_config, device, request: GenerateRequest):
|
190 |
+
"""
|
191 |
+
Funci贸n s铆ncrona para la generaci贸n completa en modo no streaming.
|
192 |
+
Se ejecuta en paralelo mediante asyncio.to_thread.
|
193 |
+
"""
|
194 |
+
with torch.no_grad():
|
195 |
+
output = global_model.generate(
|
196 |
+
**encoded_input,
|
197 |
+
generation_config=gen_config,
|
198 |
+
return_dict_in_generate=True,
|
199 |
+
output_scores=True,
|
200 |
+
return_legacy_cache=True # Par谩metro agregado para optimizar el uso de cach茅 interno
|
201 |
+
)
|
202 |
+
return output
|
203 |
+
|
204 |
@app.post("/generate")
|
205 |
+
async def generate_text(request: GenerateRequest, background_tasks: BackgroundTasks):
|
206 |
+
"""
|
207 |
+
Endpoint para la generaci贸n de texto. Se utiliza paralelismo para la generaci贸n (streaming o completa)
|
208 |
+
y se cachean respuestas para acelerar peticiones repetidas.
|
209 |
+
"""
|
210 |
global global_model, global_tokenizer, global_tokens, global_response_cache
|
211 |
if global_model is None or global_tokenizer is None:
|
212 |
raise HTTPException(status_code=500, detail="El modelo no se ha cargado correctamente.")
|
|
|
219 |
repetition_penalty=request.repetition_penalty,
|
220 |
do_sample=request.do_sample,
|
221 |
)
|
|
|
222 |
if not request.stream:
|
|
|
223 |
cache_key = (
|
224 |
request.input_text,
|
225 |
request.max_new_tokens,
|
|
|
235 |
return {"generated_text": global_response_cache[cache_key]}
|
236 |
try:
|
237 |
if request.stream:
|
238 |
+
# Generaci贸n en modo streaming en paralelo
|
239 |
generator = stream_text(request, device)
|
240 |
return StreamingResponse(generator, media_type="text/plain")
|
241 |
else:
|
242 |
encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
|
243 |
+
# Se ejecuta la generaci贸n completa en un hilo paralelo
|
244 |
+
output = await asyncio.to_thread(
|
245 |
+
synchronous_generation, encoded_input, gen_config, device, request
|
246 |
+
)
|
|
|
|
|
|
|
|
|
247 |
input_length = encoded_input["input_ids"].shape[-1]
|
248 |
generated_text = global_tokenizer.decode(
|
249 |
output.sequences[0][input_length:],
|
|
|
258 |
search_summary = await perform_duckasgo_search(request.input_text)
|
259 |
generated_text += "\n" + search_summary
|
260 |
await cleanup_memory(device)
|
|
|
261 |
global_response_cache[cache_key] = generated_text
|
262 |
+
background_tasks.add_task(lambda: print("Generaci贸n completada y cacheada."))
|
263 |
return {"generated_text": generated_text}
|
264 |
except Exception as e:
|
265 |
raise HTTPException(status_code=500, detail=f"Error durante la generaci贸n: {e}")
|
266 |
|
267 |
@app.post("/duckasgo")
|
268 |
async def duckasgo_search(request: DuckasgoRequest):
|
269 |
+
"""
|
270 |
+
Endpoint para b煤squedas en DuckDuckGo con cacheo para acelerar respuestas.
|
271 |
+
"""
|
272 |
global duckasgo_response_cache
|
273 |
if request.query in duckasgo_response_cache:
|
274 |
return JSONResponse(content={"query": request.query, "results": duckasgo_response_cache[request.query]})
|
|
|
277 |
results = ddgs.text(request.query, max_results=10)
|
278 |
if not results:
|
279 |
results = []
|
|
|
280 |
duckasgo_response_cache[request.query] = results
|
281 |
return JSONResponse(content={"query": request.query, "results": results})
|
282 |
except Exception as e:
|
283 |
raise HTTPException(status_code=500, detail=f"Error en la b煤squeda: {e}")
|
284 |
|
285 |
def run_server():
|
286 |
+
"""
|
287 |
+
Inicia el servidor Uvicorn.
|
288 |
+
"""
|
289 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
290 |
|
291 |
if __name__ == "__main__":
|