InferenceLab commited on
Commit
2a97206
Β·
verified Β·
1 Parent(s): 45e5ab4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -16
app.py CHANGED
@@ -2,13 +2,18 @@ import gradio as gr
2
  import asyncio
3
  import wave
4
  import os
 
 
5
  from google import genai
6
  from google.genai import types
7
 
 
 
 
 
8
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
9
  client = genai.Client(http_options={'api_version': 'v1alpha'}, api_key=GOOGLE_API_KEY)
10
 
11
-
12
  # Save PCM audio to WAV file
13
  def save_wave_file(filename, pcm, channels=1, rate=24000, sample_width=2):
14
  with wave.open(filename, "wb") as wf:
@@ -17,7 +22,7 @@ def save_wave_file(filename, pcm, channels=1, rate=24000, sample_width=2):
17
  wf.setframerate(rate)
18
  wf.writeframes(pcm)
19
 
20
- # Async music generation
21
  async def generate_music(prompt, bpm, temperature):
22
  audio_chunks = []
23
 
@@ -45,24 +50,29 @@ async def generate_music(prompt, bpm, temperature):
45
  await session.pause()
46
 
47
  all_pcm = b"".join(audio_chunks)
 
 
48
  output_path = "generated_music.wav"
49
  save_wave_file(output_path, all_pcm)
50
- return output_path, output_path, "Music generated successfully."
 
 
 
51
 
52
  except Exception as e:
53
- return None, None, f"Error: {str(e)}"
54
 
55
- # Wrapper for Gradio to run async function
56
  def generate_music_gradio(prompt, bpm, temperature):
57
- return asyncio.run(generate_music(prompt, bpm, temperature))
 
58
 
59
  # Gradio UI
60
  with gr.Blocks(title="Gemini Lyria Music Generator") as demo:
61
- gr.Markdown("## Lyria Music Generator")
62
 
63
- # Section 1: Input
64
  with gr.Group():
65
- gr.Markdown("### Input")
66
  with gr.Row():
67
  prompt_input = gr.Textbox(
68
  label="Music Style / Prompt",
@@ -71,19 +81,17 @@ with gr.Blocks(title="Gemini Lyria Music Generator") as demo:
71
  with gr.Row():
72
  bpm_input = gr.Slider(label="BPM", minimum=60, maximum=180, value=90)
73
  temp_input = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=1.0)
74
- generate_btn = gr.Button("Generate Music")
75
 
76
- # Section 2: Output
77
  with gr.Group():
78
- gr.Markdown("### Output")
79
  with gr.Row():
80
- output_audio = gr.Audio(label="Generated Audio")
81
  download_file = gr.File(label="Download WAV")
82
  status_output = gr.Textbox(label="Status", interactive=False)
83
 
84
- # Section 3: Examples
85
  with gr.Group():
86
- gr.Markdown("### Examples")
87
  examples = gr.Examples(
88
  examples=[
89
  ["minimal techno", 125, 1.0],
@@ -95,7 +103,6 @@ with gr.Blocks(title="Gemini Lyria Music Generator") as demo:
95
  inputs=[prompt_input, bpm_input, temp_input]
96
  )
97
 
98
- # Event binding
99
  generate_btn.click(
100
  fn=generate_music_gradio,
101
  inputs=[prompt_input, bpm_input, temp_input],
 
2
  import asyncio
3
  import wave
4
  import os
5
+ import numpy as np
6
+ import nest_asyncio
7
  from google import genai
8
  from google.genai import types
9
 
10
+ # Enable nested asyncio (required for Gradio + asyncio)
11
+ nest_asyncio.apply()
12
+
13
+ # Configure Gemini API
14
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
15
  client = genai.Client(http_options={'api_version': 'v1alpha'}, api_key=GOOGLE_API_KEY)
16
 
 
17
  # Save PCM audio to WAV file
18
  def save_wave_file(filename, pcm, channels=1, rate=24000, sample_width=2):
19
  with wave.open(filename, "wb") as wf:
 
22
  wf.setframerate(rate)
23
  wf.writeframes(pcm)
24
 
25
+ # Async music generation function
26
  async def generate_music(prompt, bpm, temperature):
27
  audio_chunks = []
28
 
 
50
  await session.pause()
51
 
52
  all_pcm = b"".join(audio_chunks)
53
+
54
+ # Save WAV file
55
  output_path = "generated_music.wav"
56
  save_wave_file(output_path, all_pcm)
57
+
58
+ # Convert PCM to numpy array for audio playback
59
+ audio_np = np.frombuffer(all_pcm, dtype=np.int16)
60
+ return (24000, audio_np), output_path, "Music generated successfully!"
61
 
62
  except Exception as e:
63
+ return None, None, f"❌ Error: {str(e)}"
64
 
65
+ # Wrapper for Gradio
66
  def generate_music_gradio(prompt, bpm, temperature):
67
+ loop = asyncio.get_event_loop()
68
+ return loop.run_until_complete(generate_music(prompt, bpm, temperature))
69
 
70
  # Gradio UI
71
  with gr.Blocks(title="Gemini Lyria Music Generator") as demo:
72
+ gr.Markdown("## 🎢 Gemini Lyria Music Generator")
73
 
 
74
  with gr.Group():
75
+ gr.Markdown("### πŸŽ› Input")
76
  with gr.Row():
77
  prompt_input = gr.Textbox(
78
  label="Music Style / Prompt",
 
81
  with gr.Row():
82
  bpm_input = gr.Slider(label="BPM", minimum=60, maximum=180, value=90)
83
  temp_input = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=1.0)
84
+ generate_btn = gr.Button("🎧 Generate Music")
85
 
 
86
  with gr.Group():
87
+ gr.Markdown("### 🎡 Output")
88
  with gr.Row():
89
+ output_audio = gr.Audio(label="Generated Audio", type="numpy")
90
  download_file = gr.File(label="Download WAV")
91
  status_output = gr.Textbox(label="Status", interactive=False)
92
 
 
93
  with gr.Group():
94
+ gr.Markdown("### πŸ” Examples")
95
  examples = gr.Examples(
96
  examples=[
97
  ["minimal techno", 125, 1.0],
 
103
  inputs=[prompt_input, bpm_input, temp_input]
104
  )
105
 
 
106
  generate_btn.click(
107
  fn=generate_music_gradio,
108
  inputs=[prompt_input, bpm_input, temp_input],