Update app.py
Browse files
app.py
CHANGED
@@ -5,14 +5,13 @@ import asyncio
|
|
5 |
import threading
|
6 |
import time
|
7 |
import torch
|
8 |
-
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
9 |
-
from fastapi.responses import StreamingResponse, JSONResponse, PlainTextResponse, HTMLResponse
|
10 |
from pydantic import BaseModel, Field
|
11 |
from transformers import (
|
12 |
AutoConfig,
|
13 |
AutoTokenizer,
|
14 |
GenerationConfig,
|
15 |
-
BitsAndBytesConfig,
|
16 |
AutoModelForCausalLM,
|
17 |
AutoProcessor,
|
18 |
MusicgenForConditionalGeneration
|
@@ -28,25 +27,15 @@ from duckduckgo_search import DDGS
|
|
28 |
# -----------------------
|
29 |
MODEL_NAME = "jnjj/gemma-3-4b-it-1layer-actual"
|
30 |
MAX_CONTEXT_LEN = 1024
|
31 |
-
# Cuantizaci贸n y offload para texto
|
32 |
-
bnb_config = BitsAndBytesConfig(
|
33 |
-
load_in_8bit=False,
|
34 |
-
llm_int8_threshold=6.0,
|
35 |
-
llm_int8_has_fp16_weight=False
|
36 |
-
)
|
37 |
-
# Contexto m谩ximo para MusicGen
|
38 |
MUSICGEN_MAX_TOKENS = 256
|
39 |
|
40 |
global_model = None
|
41 |
global_tokenizer = None
|
42 |
global_tokens = {}
|
43 |
-
# Diffusers animation pipeline globals
|
44 |
motion_adapter = None
|
45 |
anim_pipe = None
|
46 |
-
# MusicGen globals
|
47 |
music_processor = None
|
48 |
music_model = None
|
49 |
-
|
50 |
executor = ThreadPoolExecutor(max_workers=4)
|
51 |
|
52 |
# -----------------------
|
@@ -74,7 +63,6 @@ def load_global_models():
|
|
74 |
global_model = AutoModelForCausalLM.from_pretrained(
|
75 |
MODEL_NAME,
|
76 |
config=text_config,
|
77 |
-
quantization_config=bnb_config,
|
78 |
device_map="auto",
|
79 |
offload_folder="./offload",
|
80 |
offload_state_dict=True,
|
@@ -85,7 +73,6 @@ def load_global_models():
|
|
85 |
)
|
86 |
global_model = torch.compile(global_model, backend="inductor")
|
87 |
|
88 |
-
# Tokens especiales
|
89 |
if global_tokenizer.eos_token_id is not None and global_tokenizer.pad_token_id is None:
|
90 |
global_tokenizer.pad_token_id = global_tokenizer.eos_token_id
|
91 |
global_tokens.update({
|
@@ -121,7 +108,7 @@ def load_global_models():
|
|
121 |
)
|
122 |
music_model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
123 |
|
124 |
-
print("Modelos de texto, animaci贸n y audio cargados
|
125 |
|
126 |
@app.get("/", response_class=HTMLResponse)
|
127 |
def index():
|
@@ -131,7 +118,7 @@ def index():
|
|
131 |
<body>
|
132 |
<h1>Servicio de Generaci贸n Multimedia</h1>
|
133 |
<ul>
|
134 |
-
<li>Texto:
|
135 |
<li>Animaci贸n: AnimateDiffPipeline con LoRA y CPU offload.</li>
|
136 |
<li>Audio: MusicGen small, max tokens 256.</li>
|
137 |
</ul>
|
@@ -144,9 +131,6 @@ def index():
|
|
144 |
def health():
|
145 |
return {"status": "ok"}
|
146 |
|
147 |
-
# -----------------------
|
148 |
-
# Pydantic Schemas
|
149 |
-
# -----------------------
|
150 |
class GenerateRequest(BaseModel):
|
151 |
input_text: str = ""
|
152 |
max_new_tokens: int = 2
|
@@ -178,10 +162,7 @@ class MusicRequest(BaseModel):
|
|
178 |
texts: list[str]
|
179 |
max_new_tokens: int = MUSICGEN_MAX_TOKENS
|
180 |
|
181 |
-
|
182 |
-
# Utility Functions
|
183 |
-
# -----------------------
|
184 |
-
def cleanup_memory():
|
185 |
gc.collect()
|
186 |
if torch.cuda.is_available():
|
187 |
torch.cuda.empty_cache()
|
@@ -199,9 +180,6 @@ async def perform_duckasgo_search(query: str, max_results: int = 3) -> str:
|
|
199 |
text += f"{i}. {r.get('title','')}\n URL: {r.get('href','')}\n {r.get('body','')}\n"
|
200 |
return text
|
201 |
|
202 |
-
# -----------------------
|
203 |
-
# Text Generation
|
204 |
-
# -----------------------
|
205 |
async def generate_next_token(input_ids, past_key_values, gen_config, device):
|
206 |
with torch.autocast(device_type=device, dtype=torch.float16):
|
207 |
outputs = global_model(
|
@@ -240,7 +218,7 @@ async def stream_text(request: GenerateRequest, device: str):
|
|
240 |
top_p=request.top_p,
|
241 |
top_k=request.top_k,
|
242 |
repetition_penalty=request.repetition_penalty,
|
243 |
-
frequency_penalty=request.
|
244 |
presence_penalty=request.presence_penalty,
|
245 |
do_sample=request.do_sample
|
246 |
)
|
@@ -266,11 +244,8 @@ async def stream_text(request: GenerateRequest, device: str):
|
|
266 |
yield buffer
|
267 |
if request.include_duckasgo:
|
268 |
yield "\n" + await perform_duckasgo_search(request.input_text)
|
269 |
-
cleanup_memory()
|
270 |
|
271 |
-
# -----------------------
|
272 |
-
# Endpoints
|
273 |
-
# -----------------------
|
274 |
@app.post("/generate")
|
275 |
async def generate_text(request: GenerateRequest, background_tasks: BackgroundTasks):
|
276 |
if global_model is None:
|
@@ -318,14 +293,10 @@ async def generate_music(request: MusicRequest):
|
|
318 |
return_tensors="pt"
|
319 |
).to(device)
|
320 |
with torch.no_grad():
|
321 |
-
audio = music_model.generate(
|
322 |
-
**inputs,
|
323 |
-
max_new_tokens=request.max_new_tokens
|
324 |
-
)
|
325 |
-
# audio is a tensor of shape (batch, seq_len)
|
326 |
-
# convert to WAV bytes
|
327 |
wav_bytes = music_processor.decode(audio[0].cpu()).numpy().tobytes()
|
328 |
return Response(wav_bytes, media_type="audio/wav")
|
329 |
|
330 |
if __name__ == "__main__":
|
331 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
5 |
import threading
|
6 |
import time
|
7 |
import torch
|
8 |
+
from fastapi import FastAPI, HTTPException, BackgroundTasks, Response
|
9 |
+
from fastapi.responses import StreamingResponse, JSONResponse, PlainTextResponse, HTMLResponse
|
10 |
from pydantic import BaseModel, Field
|
11 |
from transformers import (
|
12 |
AutoConfig,
|
13 |
AutoTokenizer,
|
14 |
GenerationConfig,
|
|
|
15 |
AutoModelForCausalLM,
|
16 |
AutoProcessor,
|
17 |
MusicgenForConditionalGeneration
|
|
|
27 |
# -----------------------
|
28 |
MODEL_NAME = "jnjj/gemma-3-4b-it-1layer-actual"
|
29 |
MAX_CONTEXT_LEN = 1024
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
MUSICGEN_MAX_TOKENS = 256
|
31 |
|
32 |
global_model = None
|
33 |
global_tokenizer = None
|
34 |
global_tokens = {}
|
|
|
35 |
motion_adapter = None
|
36 |
anim_pipe = None
|
|
|
37 |
music_processor = None
|
38 |
music_model = None
|
|
|
39 |
executor = ThreadPoolExecutor(max_workers=4)
|
40 |
|
41 |
# -----------------------
|
|
|
63 |
global_model = AutoModelForCausalLM.from_pretrained(
|
64 |
MODEL_NAME,
|
65 |
config=text_config,
|
|
|
66 |
device_map="auto",
|
67 |
offload_folder="./offload",
|
68 |
offload_state_dict=True,
|
|
|
73 |
)
|
74 |
global_model = torch.compile(global_model, backend="inductor")
|
75 |
|
|
|
76 |
if global_tokenizer.eos_token_id is not None and global_tokenizer.pad_token_id is None:
|
77 |
global_tokenizer.pad_token_id = global_tokenizer.eos_token_id
|
78 |
global_tokens.update({
|
|
|
108 |
)
|
109 |
music_model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
110 |
|
111 |
+
print("Modelos de texto, animaci贸n y audio cargados sin bitsandbytes.")
|
112 |
|
113 |
@app.get("/", response_class=HTMLResponse)
|
114 |
def index():
|
|
|
118 |
<body>
|
119 |
<h1>Servicio de Generaci贸n Multimedia</h1>
|
120 |
<ul>
|
121 |
+
<li>Texto: FP16, offload, torch.compile.</li>
|
122 |
<li>Animaci贸n: AnimateDiffPipeline con LoRA y CPU offload.</li>
|
123 |
<li>Audio: MusicGen small, max tokens 256.</li>
|
124 |
</ul>
|
|
|
131 |
def health():
|
132 |
return {"status": "ok"}
|
133 |
|
|
|
|
|
|
|
134 |
class GenerateRequest(BaseModel):
|
135 |
input_text: str = ""
|
136 |
max_new_tokens: int = 2
|
|
|
162 |
texts: list[str]
|
163 |
max_new_tokens: int = MUSICGEN_MAX_TOKENS
|
164 |
|
165 |
+
async def cleanup_memory():
|
|
|
|
|
|
|
166 |
gc.collect()
|
167 |
if torch.cuda.is_available():
|
168 |
torch.cuda.empty_cache()
|
|
|
180 |
text += f"{i}. {r.get('title','')}\n URL: {r.get('href','')}\n {r.get('body','')}\n"
|
181 |
return text
|
182 |
|
|
|
|
|
|
|
183 |
async def generate_next_token(input_ids, past_key_values, gen_config, device):
|
184 |
with torch.autocast(device_type=device, dtype=torch.float16):
|
185 |
outputs = global_model(
|
|
|
218 |
top_p=request.top_p,
|
219 |
top_k=request.top_k,
|
220 |
repetition_penalty=request.repetition_penalty,
|
221 |
+
frequency_penalty=request.frequency.penalty,
|
222 |
presence_penalty=request.presence_penalty,
|
223 |
do_sample=request.do_sample
|
224 |
)
|
|
|
244 |
yield buffer
|
245 |
if request.include_duckasgo:
|
246 |
yield "\n" + await perform_duckasgo_search(request.input_text)
|
247 |
+
await cleanup_memory()
|
248 |
|
|
|
|
|
|
|
249 |
@app.post("/generate")
|
250 |
async def generate_text(request: GenerateRequest, background_tasks: BackgroundTasks):
|
251 |
if global_model is None:
|
|
|
293 |
return_tensors="pt"
|
294 |
).to(device)
|
295 |
with torch.no_grad():
|
296 |
+
audio = music_model.generate(**inputs, max_new_tokens=request.max_new_tokens)
|
|
|
|
|
|
|
|
|
|
|
297 |
wav_bytes = music_processor.decode(audio[0].cpu()).numpy().tobytes()
|
298 |
return Response(wav_bytes, media_type="audio/wav")
|
299 |
|
300 |
if __name__ == "__main__":
|
301 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
302 |
+
|