quazim commited on
Commit
43f4544
·
1 Parent(s): 8d76635

added elastic model

Browse files
Files changed (1) hide show
  1. app.py +12 -2
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoProcessor, MusicgenForConditionalGeneration
 
4
  import scipy.io.wavfile
5
  import numpy as np
6
  import subprocess
@@ -51,7 +52,13 @@ def setup_flash_attention():
51
  def load_model():
52
  """Load the musicgen model and processor"""
53
  processor = AutoProcessor.from_pretrained("facebook/musicgen-large")
54
- model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-large")
 
 
 
 
 
 
55
  return processor, model
56
 
57
  def generate_music(text_prompt, duration=10, temperature=1.0, top_k=250, top_p=0.0):
@@ -75,6 +82,7 @@ def generate_music(text_prompt, duration=10, temperature=1.0, top_k=250, top_p=0
75
  temperature=temperature,
76
  top_k=top_k,
77
  top_p=top_p,
 
78
  )
79
 
80
  # Convert to numpy array and prepare for output
@@ -159,6 +167,8 @@ with gr.Blocks(title="MusicGen Large - Music Generation") as demo:
159
  ["classical violin concerto"],
160
  ["reggae with steel drums and bass"],
161
  ["rock ballad with electric guitar solo"],
 
 
162
  ],
163
  inputs=text_input,
164
  label="Example Prompts"
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoProcessor
4
+ from elastic_models.transformers import MusicgenForConditionalGeneration
5
  import scipy.io.wavfile
6
  import numpy as np
7
  import subprocess
 
52
  def load_model():
53
  """Load the musicgen model and processor"""
54
  processor = AutoProcessor.from_pretrained("facebook/musicgen-large")
55
+ model = MusicgenForConditionalGeneration.from_pretrained(
56
+ "facebook/musicgen-large",
57
+ torch_dtype=torch.float16,
58
+ device=self.device,
59
+ mode="S",
60
+ __paged=True,
61
+ )
62
  return processor, model
63
 
64
  def generate_music(text_prompt, duration=10, temperature=1.0, top_k=250, top_p=0.0):
 
82
  temperature=temperature,
83
  top_k=top_k,
84
  top_p=top_p,
85
+ cache_implementation="paged"
86
  )
87
 
88
  # Convert to numpy array and prepare for output
 
167
  ["classical violin concerto"],
168
  ["reggae with steel drums and bass"],
169
  ["rock ballad with electric guitar solo"],
170
+ ["test example"],
171
+
172
  ],
173
  inputs=text_input,
174
  label="Example Prompts"