Ggggggg / app.py
jnjj's picture
Update app.py
26eadbd verified
raw
history blame
11.7 kB
import os
import gc
import io
import asyncio
import threading
import time
import torch
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.responses import StreamingResponse, JSONResponse, PlainTextResponse, HTMLResponse, FileResponse
from pydantic import BaseModel, Field
from transformers import (
AutoConfig,
AutoTokenizer,
GenerationConfig,
BitsAndBytesConfig,
AutoModelForCausalLM,
AutoProcessor,
MusicgenForConditionalGeneration
)
from diffusers import AnimateDiffPipeline, LCMScheduler, MotionAdapter
from diffusers.utils import export_to_gif
from concurrent.futures import ThreadPoolExecutor
import uvicorn
from duckduckgo_search import DDGS
# -----------------------
# Configuraci贸n y optimizaciones
# -----------------------
MODEL_NAME = "jnjj/gemma-3-4b-it-1layer-actual"
MAX_CONTEXT_LEN = 1024
# Cuantizaci贸n y offload para texto
bnb_config = BitsAndBytesConfig(
load_in_8bit=False,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False
)
# Contexto m谩ximo para MusicGen
MUSICGEN_MAX_TOKENS = 256
global_model = None
global_tokenizer = None
global_tokens = {}
# Diffusers animation pipeline globals
motion_adapter = None
anim_pipe = None
# MusicGen globals
music_processor = None
music_model = None
executor = ThreadPoolExecutor(max_workers=4)
# -----------------------
# FastAPI App
# -----------------------
app = FastAPI()
@app.on_event("startup")
def load_global_models():
"""Carga modelos de texto, animaci贸n y audio con optimizaciones."""
global global_model, global_tokenizer, global_tokens
global motion_adapter, anim_pipe
global music_processor, music_model
# --- Text model ---
text_config = AutoConfig.from_pretrained(MODEL_NAME)
text_config.max_position_embeddings = MAX_CONTEXT_LEN
global_tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
config=text_config,
use_fast=True
)
global_tokenizer.model_max_length = MAX_CONTEXT_LEN
global_model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
config=text_config,
quantization_config=bnb_config,
device_map="auto",
offload_folder="./offload",
offload_state_dict=True,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
trust_remote_code=True,
use_cache=True
)
global_model = torch.compile(global_model, backend="inductor")
# Tokens especiales
if global_tokenizer.eos_token_id is not None and global_tokenizer.pad_token_id is None:
global_tokenizer.pad_token_id = global_tokenizer.eos_token_id
global_tokens.update({
"eos_token_id": global_tokenizer.eos_token_id,
"pad_token_id": global_tokenizer.pad_token_id
})
# --- Animation model ---
motion_adapter = MotionAdapter.from_pretrained(
"wangfuyun/AnimateLCM", torch_dtype=torch.float16
)
anim_pipe = AnimateDiffPipeline.from_pretrained(
"emilianJR/epiCRealism",
motion_adapter=motion_adapter,
torch_dtype=torch.float16
)
anim_pipe.scheduler = LCMScheduler.from_config(
anim_pipe.scheduler.config, beta_schedule="linear"
)
anim_pipe.load_lora_weights(
"wangfuyun/AnimateLCM",
weight_name="AnimateLCM_sd15_t2v_lora.safetensors",
adapter_name="lcm-lora"
)
anim_pipe.set_adapters(["lcm-lora"], [0.8])
anim_pipe.enable_vae_slicing()
anim_pipe.enable_model_cpu_offload()
# --- MusicGen model ---
music_processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
music_model = MusicgenForConditionalGeneration.from_pretrained(
"facebook/musicgen-small"
)
music_model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
print("Modelos de texto, animaci贸n y audio cargados con optimizaciones.")
@app.get("/", response_class=HTMLResponse)
def index():
return HTMLResponse(
content="""
<html><head><title>Generador Ultra-R谩pido</title></head>
<body>
<h1>Servicio de Generaci贸n Multimedia</h1>
<ul>
<li>Texto: cuantizaci贸n 4-bit, offload, torch.compile.</li>
<li>Animaci贸n: AnimateDiffPipeline con LoRA y CPU offload.</li>
<li>Audio: MusicGen small, max tokens 256.</li>
</ul>
</body></html>
""",
status_code=200
)
@app.get("/health")
def health():
return {"status": "ok"}
# -----------------------
# Pydantic Schemas
# -----------------------
class GenerateRequest(BaseModel):
input_text: str = ""
max_new_tokens: int = 2
temperature: float = Field(default_factory=lambda: round(torch.rand(1).item() * 0.3 + 0.5, 2))
top_p: float = Field(default_factory=lambda: round(torch.rand(1).item() * 0.2 + 0.75, 2))
top_k: int = Field(default_factory=lambda: int(torch.randint(20, 61, (1,)).item()))
repetition_penalty: float = Field(default_factory=lambda: round(torch.rand(1).item() * 0.7 + 1.1, 2))
frequency_penalty: float = Field(default_factory=lambda: round(torch.rand(1).item() * 0.5 + 0.2, 2))
presence_penalty: float = Field(default_factory=lambda: round(torch.rand(1).item() * 0.5 + 0.2, 2))
seed: int = Field(default_factory=lambda: int(torch.randint(0, 1001, (1,)).item()))
do_sample: bool = True
stream: bool = True
chunk_token_limit: int = 2
stop_sequences: list[str] = []
include_duckasgo: bool = False
class DuckasgoRequest(BaseModel):
query: str
class AnimateRequest(BaseModel):
prompt: str
negative_prompt: str = ""
num_frames: int = 16
guidance_scale: float = 2.0
num_inference_steps: int = 6
seed: int = Field(default_factory=lambda: int(torch.randint(0, 1001, (1,)).item()))
class MusicRequest(BaseModel):
texts: list[str]
max_new_tokens: int = MUSICGEN_MAX_TOKENS
# -----------------------
# Utility Functions
# -----------------------
def cleanup_memory():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
async def perform_duckasgo_search(query: str, max_results: int = 3) -> str:
try:
with DDGS() as ddgs:
results = ddgs.text(query, max_results=max_results)
except Exception as e:
return f"Error DuckDuckGo: {e}"
if not results:
return "No se encontraron resultados."
text = "\nResultados DuckDuckGo:\n"
for i, r in enumerate(results, 1):
text += f"{i}. {r.get('title','')}\n URL: {r.get('href','')}\n {r.get('body','')}\n"
return text
# -----------------------
# Text Generation
# -----------------------
async def generate_next_token(input_ids, past_key_values, gen_config, device):
with torch.autocast(device_type=device, dtype=torch.float16):
outputs = global_model(
input_ids,
past_key_values=past_key_values,
use_cache=True,
return_dict=True
)
logits = outputs.logits[:, -1, :]
past_key_values = outputs.past_key_values
if gen_config.do_sample:
logits = logits / gen_config.temperature
if gen_config.top_k > 0:
topk_vals, _ = torch.topk(logits, k=gen_config.top_k)
logits[logits < topk_vals[..., -1]] = -float('Inf')
probs = torch.nn.functional.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(logits, dim=-1, keepdim=True)
return next_token, past_key_values
async def stream_text(request: GenerateRequest, device: str):
encoded = global_tokenizer(request.input_text, return_tensors="pt", truncation=False)
all_ids = encoded.input_ids.to(device)
segments = [all_ids[:, i:i+MAX_CONTEXT_LEN] for i in range(0, all_ids.size(1), MAX_CONTEXT_LEN)]
past_key_values = None
for seg in segments[:-1]:
with torch.no_grad():
out = global_model(seg, past_key_values=past_key_values, use_cache=True, return_dict=True)
past_key_values = out.past_key_values
last_seg = segments[-1]
input_ids = last_seg[:, -1].unsqueeze(-1)
gen_config = GenerationConfig(
temperature=request.temperature,
max_new_tokens=request.max_new_tokens,
top_p=request.top_p,
top_k=request.top_k,
repetition_penalty=request.repetition_penalty,
frequency_penalty=request.frequency_penalty,
presence_penalty=request.presence_penalty,
do_sample=request.do_sample
)
torch.manual_seed(request.seed)
buffer = ""
count = 0
while True:
next_token, past_key_values = await asyncio.to_thread(
generate_next_token, input_ids, past_key_values, gen_config, device
)
tid = next_token.item()
txt = global_tokenizer.decode([tid], skip_special_tokens=True)
buffer += txt
count += 1
input_ids = next_token.unsqueeze(0)
if count >= request.chunk_token_limit:
yield buffer
buffer = ""
count = 0
if tid == global_tokens["eos_token_id"]:
break
if buffer:
yield buffer
if request.include_duckasgo:
yield "\n" + await perform_duckasgo_search(request.input_text)
cleanup_memory()
# -----------------------
# Endpoints
# -----------------------
@app.post("/generate")
async def generate_text(request: GenerateRequest, background_tasks: BackgroundTasks):
if global_model is None:
raise HTTPException(status_code=500, detail="Modelo de texto no cargado.")
device = "cuda" if torch.cuda.is_available() else "cpu"
return StreamingResponse(stream_text(request, device), media_type="text/plain")
@app.post("/duckasgo")
def duckasgo(request: DuckasgoRequest):
try:
with DDGS() as ddgs:
results = ddgs.text(request.query, max_results=10)
return JSONResponse(content={"query": request.query, "results": results or []})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/animate")
async def animate(request: AnimateRequest):
if anim_pipe is None:
raise HTTPException(status_code=500, detail="Pipeline de animaci贸n no cargado.")
def run_pipeline():
return anim_pipe(
prompt=request.prompt,
negative_prompt=request.negative_prompt,
num_frames=request.num_frames,
guidance_scale=request.guidance_scale,
num_inference_steps=request.num_inference_steps,
generator=torch.Generator("cpu").manual_seed(request.seed)
)
output = await asyncio.get_event_loop().run_in_executor(executor, run_pipeline)
frames = output.frames[0]
buf = io.BytesIO()
export_to_gif(frames, buf)
buf.seek(0)
return StreamingResponse(buf, media_type="image/gif")
@app.post("/music")
async def generate_music(request: MusicRequest):
if music_model is None or music_processor is None:
raise HTTPException(status_code=500, detail="Modelo de audio no cargado.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = music_processor(
text=request.texts,
padding=True,
return_tensors="pt"
).to(device)
with torch.no_grad():
audio = music_model.generate(
**inputs,
max_new_tokens=request.max_new_tokens
)
# audio is a tensor of shape (batch, seq_len)
# convert to WAV bytes
wav_bytes = music_processor.decode(audio[0].cpu()).numpy().tobytes()
return Response(wav_bytes, media_type="audio/wav")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)