Bils commited on
Commit
a3b5047
Β·
verified Β·
1 Parent(s): 7840c4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -28
app.py CHANGED
@@ -13,19 +13,9 @@ import tempfile
13
  from dotenv import load_dotenv
14
  import spaces
15
 
16
- # Load environment variables
17
  load_dotenv()
18
  hf_token = os.getenv("HF_TOKEN")
19
 
20
- # Check and enable Xformers for memory-efficient attention
21
- if torch.cuda.is_available():
22
- try:
23
- from xformers.ops import memory_efficient_attention
24
- os.environ["XFORMERS_ATTENTION"] = "1"
25
- print("Xformers is enabled for memory-efficient attention.")
26
- except ImportError:
27
- print("Xformers is not installed or could not be imported.")
28
-
29
  # ---------------------------------------------------------------------
30
  # Load Llama 3 Pipeline with Zero GPU (Encapsulated)
31
  # ---------------------------------------------------------------------
@@ -53,6 +43,7 @@ def generate_script(user_prompt: str, model_id: str, token: str):
53
  except Exception as e:
54
  return f"Error generating script: {e}"
55
 
 
56
  # ---------------------------------------------------------------------
57
  # Load MusicGen Model (Encapsulated)
58
  # ---------------------------------------------------------------------
@@ -62,42 +53,49 @@ def generate_audio(prompt: str, audio_length: int):
62
  musicgen_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
63
  musicgen_processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
64
 
65
- musicgen_model.to("cuda")
66
- inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt")
 
 
 
67
  outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
68
- musicgen_model.to("cpu") # Return the model to CPU
69
 
70
- sr = musicgen_model.config.audio_encoder.sampling_rate
71
  audio_data = outputs[0, 0].cpu().numpy()
72
- normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
73
 
 
 
74
  output_path = f"{tempfile.gettempdir()}/generated_audio.wav"
75
- write(output_path, sr, normalized_audio)
76
 
77
  return output_path
78
  except Exception as e:
79
  return f"Error generating audio: {e}"
80
 
 
81
  # ---------------------------------------------------------------------
82
  # Gradio Interface Functions
83
  # ---------------------------------------------------------------------
84
  def interface_generate_script(user_prompt, llama_model_id):
85
  return generate_script(user_prompt, llama_model_id, hf_token)
86
 
 
87
  def interface_generate_audio(script, audio_length):
88
  return generate_audio(script, audio_length)
89
 
 
90
  # ---------------------------------------------------------------------
91
  # Interface
92
  # ---------------------------------------------------------------------
93
  with gr.Blocks() as demo:
94
  # Header
95
- gr.Markdown("""
96
- # πŸŽ™οΈ AI-Powered Radio Imaging Studio πŸš€
 
97
  ### Create stunning **radio promos** with **Llama 3** and **MusicGen**
98
  πŸ”₯ **Zero GPU** integration for efficiency and ease!
99
- πŸ™Œ Thanks to the Hugging Face community for supporting this space.
100
- """)
101
 
102
  # Script Generation Section
103
  gr.Markdown("## ✍️ Step 1: Generate Your Promo Script")
@@ -109,43 +107,46 @@ with gr.Blocks() as demo:
109
  info="Describe your promo idea clearly to generate a creative script."
110
  )
111
  llama_model_id = gr.Textbox(
112
- label="πŸŽ›οΈ Llama 3 Model ID",
113
  value="meta-llama/Meta-Llama-3-8B-Instruct",
114
  info="Enter the Hugging Face model ID for Llama 3."
115
  )
116
  generate_script_button = gr.Button("Generate Script ✨")
117
  script_output = gr.Textbox(
118
- label="πŸ“œ Generated Promo Script",
119
  lines=4,
120
  interactive=False,
121
  info="Your generated promo script will appear here."
122
  )
123
 
124
  # Audio Generation Section
125
- gr.Markdown("## 🎧 Step 2: Generate Audio from Your Script")
126
  with gr.Row():
127
  audio_length = gr.Slider(
128
- label="🎡 Audio Length (tokens)",
129
  minimum=128,
130
  maximum=1024,
131
  step=64,
132
  value=512,
133
- info="Select the desired audio token length."
134
  )
135
  generate_audio_button = gr.Button("Generate Audio 🎢")
136
  audio_output = gr.Audio(
137
- label="🎢 Generated Audio File",
138
  type="filepath",
139
  interactive=False
140
  )
141
 
142
  # Footer
143
- gr.Markdown("""
 
144
  <br><hr>
145
  <p style="text-align: center; font-size: 0.9em;">
146
  Created with ❀️ by <a href="https://bilsimaging.com" target="_blank">bilsimaging.com</a>
147
  </p>
148
- """, elem_id="footer")
 
 
149
 
150
  # Button Actions
151
  generate_script_button.click(
 
13
  from dotenv import load_dotenv
14
  import spaces
15
 
 
16
  load_dotenv()
17
  hf_token = os.getenv("HF_TOKEN")
18
 
 
 
 
 
 
 
 
 
 
19
  # ---------------------------------------------------------------------
20
  # Load Llama 3 Pipeline with Zero GPU (Encapsulated)
21
  # ---------------------------------------------------------------------
 
43
  except Exception as e:
44
  return f"Error generating script: {e}"
45
 
46
+
47
  # ---------------------------------------------------------------------
48
  # Load MusicGen Model (Encapsulated)
49
  # ---------------------------------------------------------------------
 
53
  musicgen_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
54
  musicgen_processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
55
 
56
+ # Ensure everything is on the same device (GPU or CPU)
57
+ device = "cuda" if torch.cuda.is_available() else "cpu"
58
+ musicgen_model.to(device)
59
+
60
+ inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt").to(device)
61
  outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
 
62
 
63
+ # Move outputs to CPU for further processing
64
  audio_data = outputs[0, 0].cpu().numpy()
 
65
 
66
+ # Normalize and save the audio file
67
+ normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
68
  output_path = f"{tempfile.gettempdir()}/generated_audio.wav"
69
+ write(output_path, musicgen_model.config.audio_encoder.sampling_rate, normalized_audio)
70
 
71
  return output_path
72
  except Exception as e:
73
  return f"Error generating audio: {e}"
74
 
75
+
76
  # ---------------------------------------------------------------------
77
  # Gradio Interface Functions
78
  # ---------------------------------------------------------------------
79
  def interface_generate_script(user_prompt, llama_model_id):
80
  return generate_script(user_prompt, llama_model_id, hf_token)
81
 
82
+
83
  def interface_generate_audio(script, audio_length):
84
  return generate_audio(script, audio_length)
85
 
86
+
87
  # ---------------------------------------------------------------------
88
  # Interface
89
  # ---------------------------------------------------------------------
90
  with gr.Blocks() as demo:
91
  # Header
92
+ gr.Markdown(
93
+ """
94
+ # 🎧 AI-Powered Radio Imaging Studio πŸš€
95
  ### Create stunning **radio promos** with **Llama 3** and **MusicGen**
96
  πŸ”₯ **Zero GPU** integration for efficiency and ease!
97
+ """
98
+ )
99
 
100
  # Script Generation Section
101
  gr.Markdown("## ✍️ Step 1: Generate Your Promo Script")
 
107
  info="Describe your promo idea clearly to generate a creative script."
108
  )
109
  llama_model_id = gr.Textbox(
110
+ label="🎿 Llama 3 Model ID",
111
  value="meta-llama/Meta-Llama-3-8B-Instruct",
112
  info="Enter the Hugging Face model ID for Llama 3."
113
  )
114
  generate_script_button = gr.Button("Generate Script ✨")
115
  script_output = gr.Textbox(
116
+ label="πŸ–ŒοΈ Generated Promo Script",
117
  lines=4,
118
  interactive=False,
119
  info="Your generated promo script will appear here."
120
  )
121
 
122
  # Audio Generation Section
123
+ gr.Markdown("## 🎡 Step 2: Generate Audio from Your Script")
124
  with gr.Row():
125
  audio_length = gr.Slider(
126
+ label="🎢 Audio Length (tokens)",
127
  minimum=128,
128
  maximum=1024,
129
  step=64,
130
  value=512,
131
+ info="Select the desired audio token length."
132
  )
133
  generate_audio_button = gr.Button("Generate Audio 🎢")
134
  audio_output = gr.Audio(
135
+ label="🎡 Generated Audio File",
136
  type="filepath",
137
  interactive=False
138
  )
139
 
140
  # Footer
141
+ gr.Markdown(
142
+ """
143
  <br><hr>
144
  <p style="text-align: center; font-size: 0.9em;">
145
  Created with ❀️ by <a href="https://bilsimaging.com" target="_blank">bilsimaging.com</a>
146
  </p>
147
+ """,
148
+ elem_id="footer"
149
+ )
150
 
151
  # Button Actions
152
  generate_script_button.click(