Ggggggg / app.py
Hjgugugjhuhjggg's picture
Create app.py
4864d22 verified
raw
history blame
9.72 kB
import os
import gc
import torch
import asyncio
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
GenerationConfig
)
import uvicorn
from duckduckgo_search import ddg # pip install duckduckgo-search
# Nombre del modelo fijo
MODEL_NAME = "lilmeaty/my_xdd"
# Variables globales para almacenar el modelo y el tokenizador
global_model = None
global_tokenizer = None
# Funci贸n as铆ncrona para limpiar la memoria (RAM y cach茅 CUDA)
async def cleanup_memory(device: str):
gc.collect()
if device == "cuda":
torch.cuda.empty_cache()
await asyncio.sleep(0.01)
# Request para la generaci贸n de texto
class GenerateRequest(BaseModel):
input_text: str = ""
max_new_tokens: int = 200 # l铆mite total de tokens generados (puede ser muy alto)
temperature: float = 1.0
top_p: float = 1.0
top_k: int = 50
repetition_penalty: float = 1.0
do_sample: bool = True
stream: bool = False
# L铆mite de tokens por chunk en modo streaming (si se excede, se emite el chunk acumulado)
chunk_token_limit: int = 20
# Secuencias que, si se detectan, hacen que se detenga la generaci贸n
stop_sequences: list[str] = []
# Si se desea incluir Duckasgo en la respuesta final
include_duckasgo: bool = False
# Request para b煤squedas independientes con Duckasgo
class DuckasgoRequest(BaseModel):
query: str
# Inicializar la aplicaci贸n FastAPI
app = FastAPI()
# Evento de startup: cargar el modelo y tokenizador globalmente
@app.on_event("startup")
async def load_global_model():
global global_model, global_tokenizer
try:
config = AutoConfig.from_pretrained(MODEL_NAME)
global_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, config=config)
# Se usa torch_dtype=torch.float16 para reducir la huella de memoria (si es posible)
global_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, config=config, torch_dtype=torch.float16)
# Configurar token de padding si es necesario
if global_tokenizer.eos_token_id is not None and global_tokenizer.pad_token_id is None:
global_tokenizer.pad_token_id = config.pad_token_id or global_tokenizer.eos_token_id
device = "cuda" if torch.cuda.is_available() else "cpu"
global_model.to(device)
print(f"Modelo {MODEL_NAME} cargado correctamente en {device}.")
except Exception as e:
print("Error al cargar el modelo:", e)
# Funci贸n para realizar b煤squeda con Duckasgo de forma as铆ncrona
async def perform_duckasgo_search(query: str, max_results: int = 3) -> str:
results = await asyncio.to_thread(ddg, query, max_results=max_results)
if not results:
return "No se encontraron resultados en Duckasgo."
summary = "\nResultados de b煤squeda (Duckasgo):\n"
for idx, res in enumerate(results, start=1):
title = res.get("title", "Sin t铆tulo")
url = res.get("href", "Sin URL")
snippet = res.get("body", "")
summary += f"{idx}. {title}\n URL: {url}\n {snippet}\n"
return summary
# Funci贸n para generar texto en modo streaming, dividiendo en chunks ilimitados
# y deteniendo la generaci贸n si se detectan secuencias de parada.
async def stream_text(request: GenerateRequest, device: str):
global global_model, global_tokenizer
# Codificar la entrada y obtener la longitud inicial
encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
initial_input_len = encoded_input.input_ids.shape[-1]
input_ids = encoded_input.input_ids
# Variables para acumular texto
accumulated_text = "" # todo el texto generado
current_chunk = "" # chunk actual
chunk_token_count = 0
# Configurar la generaci贸n seg煤n los par谩metros recibidos
gen_config = GenerationConfig(
temperature=request.temperature,
max_new_tokens=request.max_new_tokens,
top_p=request.top_p,
top_k=request.top_k,
repetition_penalty=request.repetition_penalty,
do_sample=request.do_sample,
)
past_key_values = None
# Usamos un bucle while para permitir generaci贸n ilimitada en chunks
for _ in range(request.max_new_tokens):
with torch.no_grad():
outputs = global_model(
input_ids,
past_key_values=past_key_values,
use_cache=True,
return_dict=True
)
logits = outputs.logits[:, -1, :]
past_key_values = outputs.past_key_values
if gen_config.do_sample:
logits = logits / gen_config.temperature
if gen_config.top_k and gen_config.top_k > 0:
topk_values, _ = torch.topk(logits, k=gen_config.top_k)
logits[logits < topk_values[:, [-1]]] = -float('Inf')
probs = torch.nn.functional.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(logits, dim=-1, keepdim=True)
token_id = next_token.item()
token_text = global_tokenizer.decode([token_id], skip_special_tokens=True)
accumulated_text += token_text
current_chunk += token_text
chunk_token_count += 1
# Verificar si se alcanz贸 alguna de las secuencias de parada
if request.stop_sequences:
for stop_seq in request.stop_sequences:
if stop_seq in accumulated_text:
# Si se detecta el stop, emitir el chunk actual y terminar
yield current_chunk
await cleanup_memory(device)
return
# Si se supera el l铆mite de tokens por chunk, enviar el chunk acumulado
if chunk_token_count >= request.chunk_token_limit:
yield current_chunk
current_chunk = ""
chunk_token_count = 0
# Permitir que otras tareas se ejecuten
await asyncio.sleep(0)
# Actualizar el input para la siguiente iteraci贸n
input_ids = next_token
# Si se ha generado el token de finalizaci贸n, se emite el chunk y se termina
if token_id == global_tokenizer.eos_token_id:
break
# Emitir el 煤ltimo chunk (si no est谩 vac铆o)
if current_chunk:
yield current_chunk
# Si se solicit贸 incluir Duckasgo, realizar la b煤squeda y agregarla como chunk final
if request.include_duckasgo:
search_summary = await perform_duckasgo_search(request.input_text)
yield "\n" + search_summary
await cleanup_memory(device)
# Endpoint para la generaci贸n de texto (modo streaming o no-streaming)
@app.post("/generate")
async def generate_text(request: GenerateRequest):
global global_model, global_tokenizer
if global_model is None or global_tokenizer is None:
raise HTTPException(status_code=500, detail="El modelo no se ha cargado correctamente.")
device = "cuda" if torch.cuda.is_available() else "cpu"
gen_config = GenerationConfig(
temperature=request.temperature,
max_new_tokens=request.max_new_tokens,
top_p=request.top_p,
top_k=request.top_k,
repetition_penalty=request.repetition_penalty,
do_sample=request.do_sample,
)
try:
if request.stream:
# Modo streaming: se env铆an m煤ltiples chunks seg煤n se generan
generator = stream_text(request, device)
return StreamingResponse(generator, media_type="text/plain")
else:
# Modo no-streaming: generaci贸n completa en una sola respuesta
encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device)
with torch.no_grad():
output = global_model.generate(
**encoded_input,
generation_config=gen_config,
return_dict_in_generate=True,
output_scores=True
)
input_length = encoded_input["input_ids"].shape[-1]
generated_text = global_tokenizer.decode(
output.sequences[0][input_length:],
skip_special_tokens=True
)
# Si se han definido secuencias de parada, se corta la respuesta en el primer match
if request.stop_sequences:
for stop_seq in request.stop_sequences:
if stop_seq in generated_text:
generated_text = generated_text.split(stop_seq)[0]
break
# Si se solicit贸 incluir Duckasgo, se realiza la b煤squeda y se agrega al final
if request.include_duckasgo:
search_summary = await perform_duckasgo_search(request.input_text)
generated_text += "\n" + search_summary
await cleanup_memory(device)
return {"generated_text": generated_text}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error durante la generaci贸n: {e}")
# Endpoint independiente para b煤squedas con Duckasgo
@app.post("/duckasgo")
async def duckasgo_search(request: DuckasgoRequest):
try:
results = await asyncio.to_thread(ddg, request.query, max_results=10)
if results is None:
results = []
return JSONResponse(content={"query": request.query, "results": results})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error en la b煤squeda: {e}")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)