sudoping01 commited on
Commit
92275ac
·
verified ·
1 Parent(s): faebdf2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -38
app.py CHANGED
@@ -1,40 +1,45 @@
1
- import os
2
- import warnings
3
-
4
- # Set environment variables BEFORE any imports to prevent CUDA initialization
5
- os.environ["CUDA_VISIBLE_DEVICES"] = "" # Hide CUDA during startup
6
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
7
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
8
- os.environ["CUDA_LAUNCH_BLOCKING"] = "1" # For debugging
9
-
10
- # Suppress warnings
11
- warnings.filterwarnings("ignore")
12
-
13
  import gradio as gr
14
  import numpy as np
 
15
  import spaces
16
  from huggingface_hub import login
17
 
18
- # These imports should now work without CUDA errors
19
- from maliba_ai.tts.inference import BambaraTTSInference
20
- from maliba_ai.config.speakers import Adame, Moussa, Bourama, Modibo, Seydou
21
 
22
  hf_token = os.getenv("HF_TOKEN")
23
  if hf_token:
24
  login(token=hf_token)
25
 
26
- # Initialize TTS model (this will use CPU during startup)
27
- print("Loading Bambara TTS model...")
28
- tts = BambaraTTSInference()
29
- print("Model loaded successfully!")
30
 
31
- SPEAKERS = {
32
- "Adame": Adame,
33
- "Moussa": Moussa,
34
- "Bourama": Bourama,
35
- "Modibo": Modibo,
36
- "Seydou": Seydou
37
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  def validate_inputs(text, temperature, top_k, top_p, max_tokens):
40
  """Validate user inputs"""
@@ -42,7 +47,7 @@ def validate_inputs(text, temperature, top_k, top_p, max_tokens):
42
  return False, "Please enter some Bambara text."
43
 
44
  if not (0.001 <= temperature <= 1):
45
- return False, "Temperature must be between positive"
46
 
47
  if not (1 <= top_k <= 100):
48
  return False, "Top-K must be between 1 and 100"
@@ -59,14 +64,10 @@ def generate_speech(text, speaker_name, use_advanced, temperature, top_k, top_p,
59
  return None, "Please enter some Bambara text."
60
 
61
  try:
62
- # Re-enable CUDA for GPU context
63
- import torch
64
- if torch.cuda.is_available():
65
- # Remove CUDA visibility restriction for GPU execution
66
- if "CUDA_VISIBLE_DEVICES" in os.environ:
67
- os.environ.pop("CUDA_VISIBLE_DEVICES", None)
68
 
69
- speaker = SPEAKERS[speaker_name]
70
 
71
  if use_advanced:
72
  is_valid, error_msg = validate_inputs(text, temperature, top_k, top_p, max_tokens)
@@ -94,8 +95,14 @@ def generate_speech(text, speaker_name, use_advanced, temperature, top_k, top_p,
94
  return (sample_rate, waveform), f"✅ Audio generated successfully"
95
 
96
  except Exception as e:
 
 
 
97
  return None, f"❌ Error: {str(e)}"
98
 
 
 
 
99
  examples = [
100
  ["Aw ni ce", "Adame"],
101
  ["I ni ce", "Moussa"],
@@ -117,7 +124,7 @@ with gr.Blocks(title="Bambara TTS - EXPERIMENTAL", theme=gr.themes.Soft()) as de
117
 
118
  **Bambara** is spoken by millions of people in Mali and West Africa.
119
 
120
- ⚡ **Note**: Model loads on CPU during startup, then uses GPU for generation.
121
  """)
122
 
123
  with gr.Row():
@@ -132,7 +139,7 @@ with gr.Blocks(title="Bambara TTS - EXPERIMENTAL", theme=gr.themes.Soft()) as de
132
  )
133
 
134
  speaker_dropdown = gr.Dropdown(
135
- choices=list(SPEAKERS.keys()),
136
  value="Adame",
137
  label="🗣️ Speaker Voice"
138
  )
@@ -216,8 +223,9 @@ with gr.Blocks(title="Bambara TTS - EXPERIMENTAL", theme=gr.themes.Soft()) as de
216
  gr.Markdown("""
217
  **⚠️ This is an experimental Bambara TTS model.**
218
 
219
- The model loads on CPU during startup to avoid CUDA initialization errors,
220
- then switches to GPU during speech generation for optimal performance.
 
221
  """)
222
 
223
  def toggle_advanced(use_adv):
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
+ import os
4
  import spaces
5
  from huggingface_hub import login
6
 
7
+ # DO NOT import maliba_ai here - it will cause CUDA errors
8
+ # from maliba_ai.tts.inference import BambaraTTSInference
9
+ # from maliba_ai.config.speakers import Adame, Moussa, Bourama, Modibo, Seydou
10
 
11
  hf_token = os.getenv("HF_TOKEN")
12
  if hf_token:
13
  login(token=hf_token)
14
 
15
+ # Global variable to store the TTS instance
16
+ tts_instance = None
17
+ SPEAKERS = None
 
18
 
19
+ def initialize_tts():
20
+ """Initialize TTS model and speakers - only called inside GPU context"""
21
+ global tts_instance, SPEAKERS
22
+
23
+ if tts_instance is None:
24
+ print("Loading Bambara TTS model...")
25
+
26
+ # Import here to avoid CUDA initialization during app startup
27
+ from maliba_ai.tts.inference import BambaraTTSInference
28
+ from maliba_ai.config.speakers import Adame, Moussa, Bourama, Modibo, Seydou
29
+
30
+ tts_instance = BambaraTTSInference()
31
+
32
+ SPEAKERS = {
33
+ "Adame": Adame,
34
+ "Moussa": Moussa,
35
+ "Bourama": Bourama,
36
+ "Modibo": Modibo,
37
+ "Seydou": Seydou
38
+ }
39
+
40
+ print("Model loaded successfully!")
41
+
42
+ return tts_instance, SPEAKERS
43
 
44
  def validate_inputs(text, temperature, top_k, top_p, max_tokens):
45
  """Validate user inputs"""
 
47
  return False, "Please enter some Bambara text."
48
 
49
  if not (0.001 <= temperature <= 1):
50
+ return False, "Temperature must be between 0.001 and 1"
51
 
52
  if not (1 <= top_k <= 100):
53
  return False, "Top-K must be between 1 and 100"
 
64
  return None, "Please enter some Bambara text."
65
 
66
  try:
67
+ # Initialize TTS inside GPU context
68
+ tts, speakers = initialize_tts()
 
 
 
 
69
 
70
+ speaker = speakers[speaker_name]
71
 
72
  if use_advanced:
73
  is_valid, error_msg = validate_inputs(text, temperature, top_k, top_p, max_tokens)
 
95
  return (sample_rate, waveform), f"✅ Audio generated successfully"
96
 
97
  except Exception as e:
98
+ import traceback
99
+ error_msg = f"❌ Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
100
+ print(error_msg) # Log to console for debugging
101
  return None, f"❌ Error: {str(e)}"
102
 
103
+ # Define speaker names for UI (without importing the actual speaker objects)
104
+ SPEAKER_NAMES = ["Adame", "Moussa", "Bourama", "Modibo", "Seydou"]
105
+
106
  examples = [
107
  ["Aw ni ce", "Adame"],
108
  ["I ni ce", "Moussa"],
 
124
 
125
  **Bambara** is spoken by millions of people in Mali and West Africa.
126
 
127
+ ⚡ **Note**: The model will load when you first generate speech (may take a moment).
128
  """)
129
 
130
  with gr.Row():
 
139
  )
140
 
141
  speaker_dropdown = gr.Dropdown(
142
+ choices=SPEAKER_NAMES,
143
  value="Adame",
144
  label="🗣️ Speaker Voice"
145
  )
 
223
  gr.Markdown("""
224
  **⚠️ This is an experimental Bambara TTS model.**
225
 
226
+ - The model loads automatically when you first generate speech
227
+ - First generation may take longer due to model initialization
228
+ - GPU acceleration is used for optimal performance
229
  """)
230
 
231
  def toggle_advanced(use_adv):