|
import os |
|
import gc |
|
import torch |
|
import asyncio |
|
from fastapi import FastAPI, HTTPException |
|
from fastapi.responses import StreamingResponse, JSONResponse |
|
from pydantic import BaseModel |
|
from transformers import ( |
|
AutoConfig, |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
GenerationConfig |
|
) |
|
import uvicorn |
|
from duckduckgo_search import ddg |
|
|
|
|
|
MODEL_NAME = "lilmeaty/my_xdd" |
|
|
|
|
|
global_model = None |
|
global_tokenizer = None |
|
|
|
|
|
async def cleanup_memory(device: str): |
|
gc.collect() |
|
if device == "cuda": |
|
torch.cuda.empty_cache() |
|
await asyncio.sleep(0.01) |
|
|
|
|
|
class GenerateRequest(BaseModel): |
|
input_text: str = "" |
|
max_new_tokens: int = 200 |
|
temperature: float = 1.0 |
|
top_p: float = 1.0 |
|
top_k: int = 50 |
|
repetition_penalty: float = 1.0 |
|
do_sample: bool = True |
|
stream: bool = False |
|
|
|
chunk_token_limit: int = 20 |
|
|
|
stop_sequences: list[str] = [] |
|
|
|
include_duckasgo: bool = False |
|
|
|
|
|
class DuckasgoRequest(BaseModel): |
|
query: str |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
@app.on_event("startup") |
|
async def load_global_model(): |
|
global global_model, global_tokenizer |
|
try: |
|
config = AutoConfig.from_pretrained(MODEL_NAME) |
|
global_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, config=config) |
|
|
|
global_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, config=config, torch_dtype=torch.float16) |
|
|
|
if global_tokenizer.eos_token_id is not None and global_tokenizer.pad_token_id is None: |
|
global_tokenizer.pad_token_id = config.pad_token_id or global_tokenizer.eos_token_id |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
global_model.to(device) |
|
print(f"Modelo {MODEL_NAME} cargado correctamente en {device}.") |
|
except Exception as e: |
|
print("Error al cargar el modelo:", e) |
|
|
|
|
|
async def perform_duckasgo_search(query: str, max_results: int = 3) -> str: |
|
results = await asyncio.to_thread(ddg, query, max_results=max_results) |
|
if not results: |
|
return "No se encontraron resultados en Duckasgo." |
|
summary = "\nResultados de b煤squeda (Duckasgo):\n" |
|
for idx, res in enumerate(results, start=1): |
|
title = res.get("title", "Sin t铆tulo") |
|
url = res.get("href", "Sin URL") |
|
snippet = res.get("body", "") |
|
summary += f"{idx}. {title}\n URL: {url}\n {snippet}\n" |
|
return summary |
|
|
|
|
|
|
|
async def stream_text(request: GenerateRequest, device: str): |
|
global global_model, global_tokenizer |
|
|
|
|
|
encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device) |
|
initial_input_len = encoded_input.input_ids.shape[-1] |
|
input_ids = encoded_input.input_ids |
|
|
|
|
|
accumulated_text = "" |
|
current_chunk = "" |
|
chunk_token_count = 0 |
|
|
|
|
|
gen_config = GenerationConfig( |
|
temperature=request.temperature, |
|
max_new_tokens=request.max_new_tokens, |
|
top_p=request.top_p, |
|
top_k=request.top_k, |
|
repetition_penalty=request.repetition_penalty, |
|
do_sample=request.do_sample, |
|
) |
|
|
|
past_key_values = None |
|
|
|
|
|
for _ in range(request.max_new_tokens): |
|
with torch.no_grad(): |
|
outputs = global_model( |
|
input_ids, |
|
past_key_values=past_key_values, |
|
use_cache=True, |
|
return_dict=True |
|
) |
|
logits = outputs.logits[:, -1, :] |
|
past_key_values = outputs.past_key_values |
|
|
|
if gen_config.do_sample: |
|
logits = logits / gen_config.temperature |
|
if gen_config.top_k and gen_config.top_k > 0: |
|
topk_values, _ = torch.topk(logits, k=gen_config.top_k) |
|
logits[logits < topk_values[:, [-1]]] = -float('Inf') |
|
probs = torch.nn.functional.softmax(logits, dim=-1) |
|
next_token = torch.multinomial(probs, num_samples=1) |
|
else: |
|
next_token = torch.argmax(logits, dim=-1, keepdim=True) |
|
|
|
token_id = next_token.item() |
|
token_text = global_tokenizer.decode([token_id], skip_special_tokens=True) |
|
|
|
accumulated_text += token_text |
|
current_chunk += token_text |
|
chunk_token_count += 1 |
|
|
|
|
|
if request.stop_sequences: |
|
for stop_seq in request.stop_sequences: |
|
if stop_seq in accumulated_text: |
|
|
|
yield current_chunk |
|
await cleanup_memory(device) |
|
return |
|
|
|
|
|
if chunk_token_count >= request.chunk_token_limit: |
|
yield current_chunk |
|
current_chunk = "" |
|
chunk_token_count = 0 |
|
|
|
|
|
await asyncio.sleep(0) |
|
|
|
|
|
input_ids = next_token |
|
|
|
|
|
if token_id == global_tokenizer.eos_token_id: |
|
break |
|
|
|
|
|
if current_chunk: |
|
yield current_chunk |
|
|
|
|
|
if request.include_duckasgo: |
|
search_summary = await perform_duckasgo_search(request.input_text) |
|
yield "\n" + search_summary |
|
|
|
await cleanup_memory(device) |
|
|
|
|
|
@app.post("/generate") |
|
async def generate_text(request: GenerateRequest): |
|
global global_model, global_tokenizer |
|
if global_model is None or global_tokenizer is None: |
|
raise HTTPException(status_code=500, detail="El modelo no se ha cargado correctamente.") |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
gen_config = GenerationConfig( |
|
temperature=request.temperature, |
|
max_new_tokens=request.max_new_tokens, |
|
top_p=request.top_p, |
|
top_k=request.top_k, |
|
repetition_penalty=request.repetition_penalty, |
|
do_sample=request.do_sample, |
|
) |
|
|
|
try: |
|
if request.stream: |
|
|
|
generator = stream_text(request, device) |
|
return StreamingResponse(generator, media_type="text/plain") |
|
else: |
|
|
|
encoded_input = global_tokenizer(request.input_text, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
output = global_model.generate( |
|
**encoded_input, |
|
generation_config=gen_config, |
|
return_dict_in_generate=True, |
|
output_scores=True |
|
) |
|
input_length = encoded_input["input_ids"].shape[-1] |
|
generated_text = global_tokenizer.decode( |
|
output.sequences[0][input_length:], |
|
skip_special_tokens=True |
|
) |
|
|
|
if request.stop_sequences: |
|
for stop_seq in request.stop_sequences: |
|
if stop_seq in generated_text: |
|
generated_text = generated_text.split(stop_seq)[0] |
|
break |
|
|
|
if request.include_duckasgo: |
|
search_summary = await perform_duckasgo_search(request.input_text) |
|
generated_text += "\n" + search_summary |
|
await cleanup_memory(device) |
|
return {"generated_text": generated_text} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error durante la generaci贸n: {e}") |
|
|
|
|
|
@app.post("/duckasgo") |
|
async def duckasgo_search(request: DuckasgoRequest): |
|
try: |
|
results = await asyncio.to_thread(ddg, request.query, max_results=10) |
|
if results is None: |
|
results = [] |
|
return JSONResponse(content={"query": request.query, "results": results}) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error en la b煤squeda: {e}") |
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|