Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gc
|
3 |
+
import torch
|
4 |
+
import asyncio
|
5 |
+
from fastapi import FastAPI, HTTPException
|
6 |
+
from fastapi.responses import StreamingResponse, JSONResponse
|
7 |
+
from pydantic import BaseModel
|
8 |
+
from transformers import (
|
9 |
+
AutoConfig,
|
10 |
+
AutoModelForCausalLM,
|
11 |
+
AutoTokenizer,
|
12 |
+
GenerationConfig
|
13 |
+
)
|
14 |
+
import uvicorn
|
15 |
+
from duckduckgo_search import ddg # pip install duckduckgo-search
|
16 |
+
|
17 |
+
# Nombre del modelo fijo
|
18 |
+
MODEL_NAME = "lilmeaty/my_xdd"
|
19 |
+
|
20 |
+
# Variables globales para almacenar el modelo y el tokenizador
|
21 |
+
global_model = None
|
22 |
+
global_tokenizer = None
|
23 |
+
|
24 |
+
# Funci贸n as铆ncrona para limpiar la memoria (RAM y cach茅 CUDA)
|
25 |
+
async def cleanup_memory(device: str):
|
26 |
+
gc.collect()
|
27 |
+
if device == "cuda":
|
28 |
+
torch.cuda.empty_cache()
|
29 |
+
await asyncio.sleep(0.01)
|
30 |
+
|
31 |
+
# Request para la generaci贸n de texto
|
32 |
+
class GenerateRequest(BaseModel):
|
33 |
+
input_text: str = ""
|
34 |
+
max_new_tokens: int = 200 # l铆mite total de tokens generados (puede ser muy alto)
|
35 |
+
temperature: float = 1.0
|
36 |
+
top_p: float = 1.0
|
37 |
+
top_k: int = 50
|
38 |
+
repetition_penalty: float = 1.0
|
39 |
+
do_sample: bool = True
|
40 |
+
stream: bool = False
|
41 |
+
# L铆mite de tokens por chunk en modo streaming (si se excede, se emite el chunk acumulado)
|
42 |
+
chunk_token_limit: int = 20
|
43 |
+
# Secuencias que, si se detectan, hacen que se detenga la generaci贸n
|
44 |
+
stop_sequences: list[str] = []
|
45 |
+
# Si se desea incluir Duckasgo en la respuesta final
|
46 |
+
include_duckasgo: bool = False
|
47 |
+
|
48 |
+
# Request para b煤squedas independientes con Duckasgo
|
49 |
+
class DuckasgoRequest(BaseModel):
|
50 |
+
query: str
|
51 |
+
|
52 |
+
# Inicializar la aplicaci贸n FastAPI
|
53 |
+
app = FastAPI()
|
54 |
+
|
55 |
+
# Evento de startup: cargar el modelo y tokenizador globalmente
|
56 |
+
@app.on_event("startup")
|
57 |
+
async def load_global_model():
|
58 |
+
global global_model, global_tokenizer
|
59 |
+
try:
|
60 |
+
config = AutoConfig.from_pretrained(MODEL_NAME)
|
61 |
+
global_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, config=config)
|
62 |
+
# Se usa torch_dtype=torch.float16 para reducir la huella de memoria (si es posible)
|
63 |
+
global_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, config=config, torch_dtype=torch.float16)
|
64 |
+
# Configurar token de padding si es necesario
|
65 |
+
if global_tokenizer.eos_token_id is not None and global_tokenizer.pad_token_id is None:
|
66 |
+
global_tokenizer.pad_token_id = config.pad_token_id or global_tokenizer.eos_token_id
|
67 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
68 |
+
global_model.to(device)
|
69 |
+
print(f"Modelo {MODEL_NAME} cargado correctamente en {device}.")
|
70 |
+
except Exception as e:
|
71 |
+
print("Error al cargar el modelo:", e)
|
72 |
+
|
73 |
+
# Funci贸n para realizar b煤squeda con Duckasgo de forma as铆ncrona
|
74 |
+
async def perform_duckasgo_search(query: str, max_results: int = 3) -> str:
|
75 |
+
results = await asyncio.to_thread(ddg, query, max_results=max_results)
|
76 |
+
if not results:
|
77 |
+
return "No se encontraron resultados en Duckasgo."
|
78 |
+
summary = "\nResultados de b煤squeda (Duckasgo):\n"
|
79 |
+
for idx, res in enumerate(results, start=1):
|
80 |
+
title = res.get("title", "Sin t铆tulo")
|
81 |
+
url = res.get("href", "Sin URL")
|
82 |
+
snippet = res.get("body", "")
|
83 |
+
summary += f"{idx}. {title}\n URL: {url}\n {snippet}\n"
|
84 |
+
return summary
|
85 |
+
|
86 |
+
# Funci贸n para generar texto en modo streaming, dividiendo en chunks ilimitados
|
87 |
+
# y deteniendo la generaci贸n si se detectan secuencias de parada.
|
88 |
+
async def stream_text(request: GenerateRequest, device: str):
|
89 |
+
global global_model, global_tokenizer
|
90 |
+
|
91 |
+
# Codificar la entrada y obtener la longitud inicial
|
92 |
+
encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
|
93 |
+
initial_input_len = encoded_input.input_ids.shape[-1]
|
94 |
+
input_ids = encoded_input.input_ids
|
95 |
+
|
96 |
+
# Variables para acumular texto
|
97 |
+
accumulated_text = "" # todo el texto generado
|
98 |
+
current_chunk = "" # chunk actual
|
99 |
+
chunk_token_count = 0
|
100 |
+
|
101 |
+
# Configurar la generaci贸n seg煤n los par谩metros recibidos
|
102 |
+
gen_config = GenerationConfig(
|
103 |
+
temperature=request.temperature,
|
104 |
+
max_new_tokens=request.max_new_tokens,
|
105 |
+
top_p=request.top_p,
|
106 |
+
top_k=request.top_k,
|
107 |
+
repetition_penalty=request.repetition_penalty,
|
108 |
+
do_sample=request.do_sample,
|
109 |
+
)
|
110 |
+
|
111 |
+
past_key_values = None
|
112 |
+
|
113 |
+
# Usamos un bucle while para permitir generaci贸n ilimitada en chunks
|
114 |
+
for _ in range(request.max_new_tokens):
|
115 |
+
with torch.no_grad():
|
116 |
+
outputs = global_model(
|
117 |
+
input_ids,
|
118 |
+
past_key_values=past_key_values,
|
119 |
+
use_cache=True,
|
120 |
+
return_dict=True
|
121 |
+
)
|
122 |
+
logits = outputs.logits[:, -1, :]
|
123 |
+
past_key_values = outputs.past_key_values
|
124 |
+
|
125 |
+
if gen_config.do_sample:
|
126 |
+
logits = logits / gen_config.temperature
|
127 |
+
if gen_config.top_k and gen_config.top_k > 0:
|
128 |
+
topk_values, _ = torch.topk(logits, k=gen_config.top_k)
|
129 |
+
logits[logits < topk_values[:, [-1]]] = -float('Inf')
|
130 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
131 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
132 |
+
else:
|
133 |
+
next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
134 |
+
|
135 |
+
token_id = next_token.item()
|
136 |
+
token_text = global_tokenizer.decode([token_id], skip_special_tokens=True)
|
137 |
+
|
138 |
+
accumulated_text += token_text
|
139 |
+
current_chunk += token_text
|
140 |
+
chunk_token_count += 1
|
141 |
+
|
142 |
+
# Verificar si se alcanz贸 alguna de las secuencias de parada
|
143 |
+
if request.stop_sequences:
|
144 |
+
for stop_seq in request.stop_sequences:
|
145 |
+
if stop_seq in accumulated_text:
|
146 |
+
# Si se detecta el stop, emitir el chunk actual y terminar
|
147 |
+
yield current_chunk
|
148 |
+
await cleanup_memory(device)
|
149 |
+
return
|
150 |
+
|
151 |
+
# Si se supera el l铆mite de tokens por chunk, enviar el chunk acumulado
|
152 |
+
if chunk_token_count >= request.chunk_token_limit:
|
153 |
+
yield current_chunk
|
154 |
+
current_chunk = ""
|
155 |
+
chunk_token_count = 0
|
156 |
+
|
157 |
+
# Permitir que otras tareas se ejecuten
|
158 |
+
await asyncio.sleep(0)
|
159 |
+
|
160 |
+
# Actualizar el input para la siguiente iteraci贸n
|
161 |
+
input_ids = next_token
|
162 |
+
|
163 |
+
# Si se ha generado el token de finalizaci贸n, se emite el chunk y se termina
|
164 |
+
if token_id == global_tokenizer.eos_token_id:
|
165 |
+
break
|
166 |
+
|
167 |
+
# Emitir el 煤ltimo chunk (si no est谩 vac铆o)
|
168 |
+
if current_chunk:
|
169 |
+
yield current_chunk
|
170 |
+
|
171 |
+
# Si se solicit贸 incluir Duckasgo, realizar la b煤squeda y agregarla como chunk final
|
172 |
+
if request.include_duckasgo:
|
173 |
+
search_summary = await perform_duckasgo_search(request.input_text)
|
174 |
+
yield "\n" + search_summary
|
175 |
+
|
176 |
+
await cleanup_memory(device)
|
177 |
+
|
178 |
+
# Endpoint para la generaci贸n de texto (modo streaming o no-streaming)
|
179 |
+
@app.post("/generate")
|
180 |
+
async def generate_text(request: GenerateRequest):
|
181 |
+
global global_model, global_tokenizer
|
182 |
+
if global_model is None or global_tokenizer is None:
|
183 |
+
raise HTTPException(status_code=500, detail="El modelo no se ha cargado correctamente.")
|
184 |
+
|
185 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
186 |
+
gen_config = GenerationConfig(
|
187 |
+
temperature=request.temperature,
|
188 |
+
max_new_tokens=request.max_new_tokens,
|
189 |
+
top_p=request.top_p,
|
190 |
+
top_k=request.top_k,
|
191 |
+
repetition_penalty=request.repetition_penalty,
|
192 |
+
do_sample=request.do_sample,
|
193 |
+
)
|
194 |
+
|
195 |
+
try:
|
196 |
+
if request.stream:
|
197 |
+
# Modo streaming: se env铆an m煤ltiples chunks seg煤n se generan
|
198 |
+
generator = stream_text(request, device)
|
199 |
+
return StreamingResponse(generator, media_type="text/plain")
|
200 |
+
else:
|
201 |
+
# Modo no-streaming: generaci贸n completa en una sola respuesta
|
202 |
+
encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
|
203 |
+
with torch.no_grad():
|
204 |
+
output = global_model.generate(
|
205 |
+
**encoded_input,
|
206 |
+
generation_config=gen_config,
|
207 |
+
return_dict_in_generate=True,
|
208 |
+
output_scores=True
|
209 |
+
)
|
210 |
+
input_length = encoded_input["input_ids"].shape[-1]
|
211 |
+
generated_text = global_tokenizer.decode(
|
212 |
+
output.sequences[0][input_length:],
|
213 |
+
skip_special_tokens=True
|
214 |
+
)
|
215 |
+
# Si se han definido secuencias de parada, se corta la respuesta en el primer match
|
216 |
+
if request.stop_sequences:
|
217 |
+
for stop_seq in request.stop_sequences:
|
218 |
+
if stop_seq in generated_text:
|
219 |
+
generated_text = generated_text.split(stop_seq)[0]
|
220 |
+
break
|
221 |
+
# Si se solicit贸 incluir Duckasgo, se realiza la b煤squeda y se agrega al final
|
222 |
+
if request.include_duckasgo:
|
223 |
+
search_summary = await perform_duckasgo_search(request.input_text)
|
224 |
+
generated_text += "\n" + search_summary
|
225 |
+
await cleanup_memory(device)
|
226 |
+
return {"generated_text": generated_text}
|
227 |
+
except Exception as e:
|
228 |
+
raise HTTPException(status_code=500, detail=f"Error durante la generaci贸n: {e}")
|
229 |
+
|
230 |
+
# Endpoint independiente para b煤squedas con Duckasgo
|
231 |
+
@app.post("/duckasgo")
|
232 |
+
async def duckasgo_search(request: DuckasgoRequest):
|
233 |
+
try:
|
234 |
+
results = await asyncio.to_thread(ddg, request.query, max_results=10)
|
235 |
+
if results is None:
|
236 |
+
results = []
|
237 |
+
return JSONResponse(content={"query": request.query, "results": results})
|
238 |
+
except Exception as e:
|
239 |
+
raise HTTPException(status_code=500, detail=f"Error en la b煤squeda: {e}")
|
240 |
+
|
241 |
+
if __name__ == "__main__":
|
242 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|