Bils commited on
Commit
16060e9
·
verified ·
1 Parent(s): 1c1b50f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -67
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  import os
3
  import torch
4
- import time
5
  from transformers import (
6
  AutoTokenizer,
7
  AutoModelForCausalLM,
@@ -12,7 +11,7 @@ from transformers import (
12
  from scipy.io.wavfile import write
13
  import tempfile
14
  from dotenv import load_dotenv
15
- import spaces # Hugging Face Spaces library for ZeroGPU support
16
 
17
  # Load environment variables (e.g., Hugging Face token)
18
  load_dotenv()
@@ -24,43 +23,23 @@ musicgen_model = None
24
  musicgen_processor = None
25
 
26
  # ---------------------------------------------------------------------
27
- # Helper: Safe Model Loader with Retry Logic
28
  # ---------------------------------------------------------------------
29
- def safe_load_model(model_id, token, retries=3, delay=5):
30
- for attempt in range(retries):
 
 
31
  try:
 
32
  model = AutoModelForCausalLM.from_pretrained(
33
  model_id,
34
  use_auth_token=token,
35
  torch_dtype=torch.float16,
36
- device_map="auto",
37
  trust_remote_code=True,
38
- offload_folder="/tmp", # Stream shards
39
- cache_dir="/tmp" # Cache directory for shard downloads
40
  )
41
- return model
42
- except Exception as e:
43
- print(f"Attempt {attempt + 1} failed: {e}")
44
- time.sleep(delay)
45
- raise RuntimeError(f"Failed to load model {model_id} after {retries} attempts")
46
-
47
- # ---------------------------------------------------------------------
48
- # Load Llama 3 Model with Zero GPU (Lazy Loading)
49
- # ---------------------------------------------------------------------
50
- @spaces.GPU(duration=600) # Increased duration to handle large models
51
- def load_llama_pipeline_zero_gpu(model_id: str, token: str):
52
- global llama_pipeline
53
- if llama_pipeline is None:
54
- try:
55
- print("Starting model loading...")
56
- tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
57
- print("Tokenizer loaded.")
58
- model = safe_load_model(model_id, token)
59
- print("Model loaded. Initializing pipeline...")
60
  llama_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
61
- print("Pipeline initialized successfully.")
62
  except Exception as e:
63
- print(f"Error loading Llama pipeline: {e}")
64
  return str(e)
65
  return llama_pipeline
66
 
@@ -75,31 +54,28 @@ def generate_script(user_input: str, pipeline_llama):
75
  )
76
  combined_prompt = f"{system_prompt}\nUser concept: {user_input}\nRefined script:"
77
  result = pipeline_llama(combined_prompt, max_new_tokens=200, do_sample=True, temperature=0.9)
78
- return result[0]['generated_text'].split("Refined script:")[-1].strip()
79
  except Exception as e:
80
  return f"Error generating script: {e}"
81
 
82
  # ---------------------------------------------------------------------
83
  # Load MusicGen Model (Lazy Loading)
84
  # ---------------------------------------------------------------------
85
- @spaces.GPU(duration=600)
86
  def load_musicgen_model():
87
  global musicgen_model, musicgen_processor
88
  if musicgen_model is None or musicgen_processor is None:
89
  try:
90
- print("Loading MusicGen model...")
91
  musicgen_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
92
  musicgen_processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
93
- print("MusicGen model loaded successfully.")
94
  except Exception as e:
95
- print(f"Error loading MusicGen model: {e}")
96
  return None, str(e)
97
  return musicgen_model, musicgen_processor
98
 
99
  # ---------------------------------------------------------------------
100
  # Generate Audio
101
  # ---------------------------------------------------------------------
102
- @spaces.GPU(duration=600)
103
  def generate_audio(prompt: str, audio_length: int):
104
  global musicgen_model, musicgen_processor
105
  if musicgen_model is None or musicgen_processor is None:
@@ -125,51 +101,75 @@ def generate_audio(prompt: str, audio_length: int):
125
  # ---------------------------------------------------------------------
126
  # Gradio Interface
127
  # ---------------------------------------------------------------------
128
- def generate_script_interface(user_prompt, llama_model_id):
129
  # Load Llama 3 Pipeline with Zero GPU
130
  pipeline_llama = load_llama_pipeline_zero_gpu(llama_model_id, hf_token)
131
  if isinstance(pipeline_llama, str):
132
- return pipeline_llama
133
 
134
  # Generate Script
135
  script = generate_script(user_prompt, pipeline_llama)
136
- return script
137
 
138
- def generate_audio_interface(script, audio_length):
139
  # Generate Audio
140
  audio_data = generate_audio(script, audio_length)
141
- return audio_data
 
142
 
143
  # ---------------------------------------------------------------------
144
  # Interface
145
  # ---------------------------------------------------------------------
146
- with gr.Blocks() as demo:
147
- gr.Markdown("# 🎧 AI Radio Imaging with Llama 3 + MusicGen (Zero GPU)")
148
-
149
- with gr.Row():
150
- user_prompt = gr.Textbox(label="Enter your promo idea", placeholder="E.g., A 15-second hype jingle for a morning talk show.")
151
- llama_model_id = gr.Textbox(label="Llama 3 Model ID", value="meta-llama/Meta-Llama-3-8B") # Using a smaller model for better compatibility
152
- audio_length = gr.Slider(label="Audio Length (tokens)", minimum=128, maximum=1024, step=64, value=512)
153
-
154
- with gr.Row():
155
- generate_script_button = gr.Button("Generate Promo Script")
156
- script_output = gr.Textbox(label="Generated Script", interactive=False)
157
-
158
- with gr.Row():
159
- generate_audio_button = gr.Button("Generate Audio")
160
- audio_output = gr.Audio(label="Generated Audio", type="filepath")
161
-
162
- generate_script_button.click(
163
- generate_script_interface,
164
- inputs=[user_prompt, llama_model_id],
165
- outputs=script_output
166
- )
167
-
168
- generate_audio_button.click(
169
- generate_audio_interface,
170
- inputs=[script_output, audio_length],
171
- outputs=audio_output
172
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
  # ---------------------------------------------------------------------
175
  # Launch App
 
1
  import gradio as gr
2
  import os
3
  import torch
 
4
  from transformers import (
5
  AutoTokenizer,
6
  AutoModelForCausalLM,
 
11
  from scipy.io.wavfile import write
12
  import tempfile
13
  from dotenv import load_dotenv
14
+ import spaces # Assumes Hugging Face Spaces library supports `@spaces.GPU`
15
 
16
  # Load environment variables (e.g., Hugging Face token)
17
  load_dotenv()
 
23
  musicgen_processor = None
24
 
25
  # ---------------------------------------------------------------------
26
+ # Load Llama 3 Model with Zero GPU (Lazy Loading)
27
  # ---------------------------------------------------------------------
28
+ @spaces.GPU(duration=300) # Increased duration to 300 seconds
29
+ def load_llama_pipeline_zero_gpu(model_id: str, token: str):
30
+ global llama_pipeline
31
+ if llama_pipeline is None:
32
  try:
33
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
34
  model = AutoModelForCausalLM.from_pretrained(
35
  model_id,
36
  use_auth_token=token,
37
  torch_dtype=torch.float16,
38
+ device_map="auto", # Automatically handles GPU allocation
39
  trust_remote_code=True,
 
 
40
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  llama_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
 
42
  except Exception as e:
 
43
  return str(e)
44
  return llama_pipeline
45
 
 
54
  )
55
  combined_prompt = f"{system_prompt}\nUser concept: {user_input}\nRefined script:"
56
  result = pipeline_llama(combined_prompt, max_new_tokens=200, do_sample=True, temperature=0.9)
57
+ return result[0]["generated_text"].split("Refined script:")[-1].strip()
58
  except Exception as e:
59
  return f"Error generating script: {e}"
60
 
61
  # ---------------------------------------------------------------------
62
  # Load MusicGen Model (Lazy Loading)
63
  # ---------------------------------------------------------------------
64
+ @spaces.GPU(duration=300)
65
  def load_musicgen_model():
66
  global musicgen_model, musicgen_processor
67
  if musicgen_model is None or musicgen_processor is None:
68
  try:
 
69
  musicgen_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
70
  musicgen_processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
 
71
  except Exception as e:
 
72
  return None, str(e)
73
  return musicgen_model, musicgen_processor
74
 
75
  # ---------------------------------------------------------------------
76
  # Generate Audio
77
  # ---------------------------------------------------------------------
78
+ @spaces.GPU(duration=300)
79
  def generate_audio(prompt: str, audio_length: int):
80
  global musicgen_model, musicgen_processor
81
  if musicgen_model is None or musicgen_processor is None:
 
101
  # ---------------------------------------------------------------------
102
  # Gradio Interface
103
  # ---------------------------------------------------------------------
104
+ def radio_imaging_app(user_prompt, llama_model_id, audio_length):
105
  # Load Llama 3 Pipeline with Zero GPU
106
  pipeline_llama = load_llama_pipeline_zero_gpu(llama_model_id, hf_token)
107
  if isinstance(pipeline_llama, str):
108
+ return pipeline_llama, None
109
 
110
  # Generate Script
111
  script = generate_script(user_prompt, pipeline_llama)
 
112
 
 
113
  # Generate Audio
114
  audio_data = generate_audio(script, audio_length)
115
+ return script, audio_data
116
+
117
 
118
  # ---------------------------------------------------------------------
119
  # Interface
120
  # ---------------------------------------------------------------------
121
+ with gr.Blocks(css="""
122
+ #app-title {
123
+ text-align: center;
124
+ font-size: 2rem;
125
+ font-weight: bold;
126
+ color: #4CAF50;
127
+ }
128
+ #subsection {
129
+ margin: 20px 0;
130
+ font-size: 1.2rem;
131
+ color: #333;
132
+ text-align: center;
133
+ }
134
+ """) as demo:
135
+ gr.Markdown('<div id="app-title">🎧 AI Radio Imaging with Llama 3 + MusicGen (Zero GPU)</div>')
136
+
137
+ with gr.Tab("Step 1: Generate Promo Script"):
138
+ with gr.Row():
139
+ user_prompt = gr.Textbox(
140
+ label="Enter Your Promo Idea",
141
+ placeholder="E.g., A 15-second hype jingle for a morning talk show.",
142
+ )
143
+ llama_model_id = gr.Textbox(
144
+ label="Llama 3 Model ID", value="meta-llama/Meta-Llama-3-70B"
145
+ )
146
+
147
+ generate_script_button = gr.Button("Generate Script")
148
+ script_output = gr.Textbox(label="Generated Promo Script", interactive=False)
149
+
150
+ generate_script_button.click(
151
+ fn=radio_imaging_app,
152
+ inputs=[user_prompt, llama_model_id, gr.State(0)],
153
+ outputs=[script_output, None],
154
+ )
155
+
156
+ with gr.Tab("Step 2: Generate Audio"):
157
+ with gr.Row():
158
+ audio_length = gr.Slider(
159
+ label="Audio Length (tokens)",
160
+ minimum=128,
161
+ maximum=1024,
162
+ step=64,
163
+ value=512,
164
+ )
165
+ generate_audio_button = gr.Button("Generate Audio")
166
+ audio_output = gr.Audio(label="Generated Audio", type="filepath")
167
+
168
+ generate_audio_button.click(
169
+ fn=generate_audio,
170
+ inputs=[script_output, audio_length],
171
+ outputs=audio_output,
172
+ )
173
 
174
  # ---------------------------------------------------------------------
175
  # Launch App