jnjj commited on
Commit
4fceb86
verified
1 Parent(s): 26eadbd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -38
app.py CHANGED
@@ -5,14 +5,13 @@ import asyncio
5
  import threading
6
  import time
7
  import torch
8
- from fastapi import FastAPI, HTTPException, BackgroundTasks
9
- from fastapi.responses import StreamingResponse, JSONResponse, PlainTextResponse, HTMLResponse, FileResponse
10
  from pydantic import BaseModel, Field
11
  from transformers import (
12
  AutoConfig,
13
  AutoTokenizer,
14
  GenerationConfig,
15
- BitsAndBytesConfig,
16
  AutoModelForCausalLM,
17
  AutoProcessor,
18
  MusicgenForConditionalGeneration
@@ -28,25 +27,15 @@ from duckduckgo_search import DDGS
28
  # -----------------------
29
  MODEL_NAME = "jnjj/gemma-3-4b-it-1layer-actual"
30
  MAX_CONTEXT_LEN = 1024
31
- # Cuantizaci贸n y offload para texto
32
- bnb_config = BitsAndBytesConfig(
33
- load_in_8bit=False,
34
- llm_int8_threshold=6.0,
35
- llm_int8_has_fp16_weight=False
36
- )
37
- # Contexto m谩ximo para MusicGen
38
  MUSICGEN_MAX_TOKENS = 256
39
 
40
  global_model = None
41
  global_tokenizer = None
42
  global_tokens = {}
43
- # Diffusers animation pipeline globals
44
  motion_adapter = None
45
  anim_pipe = None
46
- # MusicGen globals
47
  music_processor = None
48
  music_model = None
49
-
50
  executor = ThreadPoolExecutor(max_workers=4)
51
 
52
  # -----------------------
@@ -74,7 +63,6 @@ def load_global_models():
74
  global_model = AutoModelForCausalLM.from_pretrained(
75
  MODEL_NAME,
76
  config=text_config,
77
- quantization_config=bnb_config,
78
  device_map="auto",
79
  offload_folder="./offload",
80
  offload_state_dict=True,
@@ -85,7 +73,6 @@ def load_global_models():
85
  )
86
  global_model = torch.compile(global_model, backend="inductor")
87
 
88
- # Tokens especiales
89
  if global_tokenizer.eos_token_id is not None and global_tokenizer.pad_token_id is None:
90
  global_tokenizer.pad_token_id = global_tokenizer.eos_token_id
91
  global_tokens.update({
@@ -121,7 +108,7 @@ def load_global_models():
121
  )
122
  music_model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
123
 
124
- print("Modelos de texto, animaci贸n y audio cargados con optimizaciones.")
125
 
126
  @app.get("/", response_class=HTMLResponse)
127
  def index():
@@ -131,7 +118,7 @@ def index():
131
  <body>
132
  <h1>Servicio de Generaci贸n Multimedia</h1>
133
  <ul>
134
- <li>Texto: cuantizaci贸n 4-bit, offload, torch.compile.</li>
135
  <li>Animaci贸n: AnimateDiffPipeline con LoRA y CPU offload.</li>
136
  <li>Audio: MusicGen small, max tokens 256.</li>
137
  </ul>
@@ -144,9 +131,6 @@ def index():
144
  def health():
145
  return {"status": "ok"}
146
 
147
- # -----------------------
148
- # Pydantic Schemas
149
- # -----------------------
150
  class GenerateRequest(BaseModel):
151
  input_text: str = ""
152
  max_new_tokens: int = 2
@@ -178,10 +162,7 @@ class MusicRequest(BaseModel):
178
  texts: list[str]
179
  max_new_tokens: int = MUSICGEN_MAX_TOKENS
180
 
181
- # -----------------------
182
- # Utility Functions
183
- # -----------------------
184
- def cleanup_memory():
185
  gc.collect()
186
  if torch.cuda.is_available():
187
  torch.cuda.empty_cache()
@@ -199,9 +180,6 @@ async def perform_duckasgo_search(query: str, max_results: int = 3) -> str:
199
  text += f"{i}. {r.get('title','')}\n URL: {r.get('href','')}\n {r.get('body','')}\n"
200
  return text
201
 
202
- # -----------------------
203
- # Text Generation
204
- # -----------------------
205
  async def generate_next_token(input_ids, past_key_values, gen_config, device):
206
  with torch.autocast(device_type=device, dtype=torch.float16):
207
  outputs = global_model(
@@ -240,7 +218,7 @@ async def stream_text(request: GenerateRequest, device: str):
240
  top_p=request.top_p,
241
  top_k=request.top_k,
242
  repetition_penalty=request.repetition_penalty,
243
- frequency_penalty=request.frequency_penalty,
244
  presence_penalty=request.presence_penalty,
245
  do_sample=request.do_sample
246
  )
@@ -266,11 +244,8 @@ async def stream_text(request: GenerateRequest, device: str):
266
  yield buffer
267
  if request.include_duckasgo:
268
  yield "\n" + await perform_duckasgo_search(request.input_text)
269
- cleanup_memory()
270
 
271
- # -----------------------
272
- # Endpoints
273
- # -----------------------
274
  @app.post("/generate")
275
  async def generate_text(request: GenerateRequest, background_tasks: BackgroundTasks):
276
  if global_model is None:
@@ -318,14 +293,10 @@ async def generate_music(request: MusicRequest):
318
  return_tensors="pt"
319
  ).to(device)
320
  with torch.no_grad():
321
- audio = music_model.generate(
322
- **inputs,
323
- max_new_tokens=request.max_new_tokens
324
- )
325
- # audio is a tensor of shape (batch, seq_len)
326
- # convert to WAV bytes
327
  wav_bytes = music_processor.decode(audio[0].cpu()).numpy().tobytes()
328
  return Response(wav_bytes, media_type="audio/wav")
329
 
330
  if __name__ == "__main__":
331
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
5
  import threading
6
  import time
7
  import torch
8
+ from fastapi import FastAPI, HTTPException, BackgroundTasks, Response
9
+ from fastapi.responses import StreamingResponse, JSONResponse, PlainTextResponse, HTMLResponse
10
  from pydantic import BaseModel, Field
11
  from transformers import (
12
  AutoConfig,
13
  AutoTokenizer,
14
  GenerationConfig,
 
15
  AutoModelForCausalLM,
16
  AutoProcessor,
17
  MusicgenForConditionalGeneration
 
27
  # -----------------------
28
  MODEL_NAME = "jnjj/gemma-3-4b-it-1layer-actual"
29
  MAX_CONTEXT_LEN = 1024
 
 
 
 
 
 
 
30
  MUSICGEN_MAX_TOKENS = 256
31
 
32
  global_model = None
33
  global_tokenizer = None
34
  global_tokens = {}
 
35
  motion_adapter = None
36
  anim_pipe = None
 
37
  music_processor = None
38
  music_model = None
 
39
  executor = ThreadPoolExecutor(max_workers=4)
40
 
41
  # -----------------------
 
63
  global_model = AutoModelForCausalLM.from_pretrained(
64
  MODEL_NAME,
65
  config=text_config,
 
66
  device_map="auto",
67
  offload_folder="./offload",
68
  offload_state_dict=True,
 
73
  )
74
  global_model = torch.compile(global_model, backend="inductor")
75
 
 
76
  if global_tokenizer.eos_token_id is not None and global_tokenizer.pad_token_id is None:
77
  global_tokenizer.pad_token_id = global_tokenizer.eos_token_id
78
  global_tokens.update({
 
108
  )
109
  music_model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
110
 
111
+ print("Modelos de texto, animaci贸n y audio cargados sin bitsandbytes.")
112
 
113
  @app.get("/", response_class=HTMLResponse)
114
  def index():
 
118
  <body>
119
  <h1>Servicio de Generaci贸n Multimedia</h1>
120
  <ul>
121
+ <li>Texto: FP16, offload, torch.compile.</li>
122
  <li>Animaci贸n: AnimateDiffPipeline con LoRA y CPU offload.</li>
123
  <li>Audio: MusicGen small, max tokens 256.</li>
124
  </ul>
 
131
  def health():
132
  return {"status": "ok"}
133
 
 
 
 
134
  class GenerateRequest(BaseModel):
135
  input_text: str = ""
136
  max_new_tokens: int = 2
 
162
  texts: list[str]
163
  max_new_tokens: int = MUSICGEN_MAX_TOKENS
164
 
165
+ async def cleanup_memory():
 
 
 
166
  gc.collect()
167
  if torch.cuda.is_available():
168
  torch.cuda.empty_cache()
 
180
  text += f"{i}. {r.get('title','')}\n URL: {r.get('href','')}\n {r.get('body','')}\n"
181
  return text
182
 
 
 
 
183
  async def generate_next_token(input_ids, past_key_values, gen_config, device):
184
  with torch.autocast(device_type=device, dtype=torch.float16):
185
  outputs = global_model(
 
218
  top_p=request.top_p,
219
  top_k=request.top_k,
220
  repetition_penalty=request.repetition_penalty,
221
+ frequency_penalty=request.frequency.penalty,
222
  presence_penalty=request.presence_penalty,
223
  do_sample=request.do_sample
224
  )
 
244
  yield buffer
245
  if request.include_duckasgo:
246
  yield "\n" + await perform_duckasgo_search(request.input_text)
247
+ await cleanup_memory()
248
 
 
 
 
249
  @app.post("/generate")
250
  async def generate_text(request: GenerateRequest, background_tasks: BackgroundTasks):
251
  if global_model is None:
 
293
  return_tensors="pt"
294
  ).to(device)
295
  with torch.no_grad():
296
+ audio = music_model.generate(**inputs, max_new_tokens=request.max_new_tokens)
 
 
 
 
 
297
  wav_bytes = music_processor.decode(audio[0].cpu()).numpy().tobytes()
298
  return Response(wav_bytes, media_type="audio/wav")
299
 
300
  if __name__ == "__main__":
301
  uvicorn.run(app, host="0.0.0.0", port=7860)
302
+