Spaces:
Running
on
Zero
Running
on
Zero
import os | |
os.environ["TORCHDYNAMO_DISABLE"] = "1" | |
os.environ["TORCH_COMPILE_DISABLE"] = "1" | |
os.environ["PYTORCH_DISABLE_CUDNN_BENCHMARK"] = "1" | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
def is_restricted_environment(): | |
return ( | |
os.getenv("ZERO_GPU") or | |
"zero" in str(os.getenv("SPACE_ID", "")).lower() or | |
os.getenv("SPACES_ZERO_GPU") or | |
"spaces" in str(os.getenv("HOSTNAME", "")).lower() | |
) | |
if is_restricted_environment(): | |
os.environ["UNSLOTH_DISABLE"] = "1" | |
os.environ["DISABLE_UNSLOTH"] = "1" | |
os.environ["UNSLOTH_IGNORE_ERRORS"] = "1" | |
os.environ["UNSLOTH_NO_COMPILE"] = "1" | |
print("π ZeroGPU detected - Unsloth optimizations disabled for compatibility") | |
else: | |
print("π§ Local environment detected - Unsloth optimizations enabled") | |
import torch | |
import gradio as gr | |
import numpy as np | |
import spaces | |
import logging | |
from huggingface_hub import login | |
import threading | |
import time | |
torch._dynamo.config.disable = True | |
torch._dynamo.config.suppress_errors = True | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
hf_token = os.getenv("HF_TOKEN") | |
if hf_token: | |
login(token=hf_token) | |
_tts_model = None | |
_speakers_dict = None | |
_model_initialized = False | |
_initialization_in_progress = False | |
def get_speakers_dict(): | |
"""Get speakers dictionary using the correct SDK structure""" | |
try: | |
from maliba_ai.config.settings import Speakers | |
speakers_dict = { | |
"Adama": Speakers.Adama, | |
"Moussa": Speakers.Moussa, | |
"Bourama": Speakers.Bourama, | |
"Modibo": Speakers.Modibo, | |
"Seydou": Speakers.Seydou, | |
"Amadou": Speakers.Amadou, | |
"Bakary": Speakers.Bakary, | |
"Ngolo": Speakers.Ngolo, | |
"Amara": Speakers.Amara, | |
"Ibrahima": Speakers.Ibrahima | |
} | |
logger.info(f"π€ Successfully loaded {len(speakers_dict)} speakers: {list(speakers_dict.keys())}") | |
return speakers_dict | |
except Exception as e: | |
logger.error(f"β Failed to import Speakers class: {e}") | |
return {} | |
def initialize_model_once(): | |
"""Initialize model with retry logic for Unsloth failures""" | |
global _tts_model, _speakers_dict, _model_initialized, _initialization_in_progress | |
if _model_initialized: | |
logger.info("Model already initialized, returning existing instance") | |
return _tts_model, _speakers_dict | |
if _initialization_in_progress: | |
logger.info("Initialization already in progress, waiting...") | |
for _ in range(50): | |
time.sleep(0.1) | |
if _model_initialized: | |
return _tts_model, _speakers_dict | |
_initialization_in_progress = True | |
max_retries = 2 | |
retry_delay = 5 # seconds | |
try: | |
logger.info("Initializing Bambara TTS model...") | |
start_time = time.time() | |
from maliba_ai.tts.inference import BambaraTTSInference | |
for attempt in range(max_retries): | |
try: | |
model = BambaraTTSInference() | |
speakers = get_speakers_dict() | |
if not speakers: | |
raise ValueError("Failed to load speakers dictionary") | |
_tts_model = model | |
_speakers_dict = speakers | |
_model_initialized = True | |
elapsed = time.time() - start_time | |
logger.info(f"Model initialized successfully in {elapsed:.2f} seconds!") | |
return _tts_model, _speakers_dict | |
except Exception as e: | |
if "unsloth_compiled_module_qwen2" in str(e) and attempt < max_retries - 1: | |
logger.warning(f"Unsloth compilation failed, retrying in {retry_delay} seconds... (attempt {attempt + 1}/{max_retries})") | |
time.sleep(retry_delay) | |
else: | |
raise e | |
except Exception as e: | |
logger.error(f"Failed to initialize model after {max_retries} attempts: {e}") | |
raise e | |
finally: | |
_initialization_in_progress = False | |
def validate_inputs(text, temperature, top_k, top_p, max_tokens): | |
"""Same validation as your old version""" | |
if not text or not text.strip(): | |
return False, "Please enter some Bambara text." | |
if not (0.001 <= temperature <= 2.0): | |
return False, "Temperature must be between 0.001 and 2.0" | |
if not (1 <= top_k <= 100): | |
return False, "Top-K must be between 1 and 100" | |
if not (0.1 <= top_p <= 1.0): | |
return False, "Top-P must be between 0.1 and 1.0" | |
return True, "" | |
def generate_speech(text, speaker_name, use_advanced, temperature, top_k, top_p, max_tokens): | |
"""Generate speech - exactly like your old working version""" | |
if not text.strip(): | |
return None, "Please enter some Bambara text." | |
try: | |
tts, speakers = initialize_model_once() | |
if not tts or not speakers: | |
return None, "β Model not properly initialized" | |
if speaker_name not in speakers: | |
available_speakers = list(speakers.keys()) | |
return None, f"β Speaker '{speaker_name}' not found. Available: {available_speakers}" | |
speaker = speakers[speaker_name] | |
logger.info(f"Using speaker: {speaker_name}") | |
if use_advanced: | |
is_valid, error_msg = validate_inputs(text, temperature, top_k, top_p, max_tokens) | |
if not is_valid: | |
return None, f"β {error_msg}" | |
waveform = tts.generate_speech( | |
text=text.strip(), | |
speaker_id=speaker, | |
temperature=temperature, | |
top_k=int(top_k), | |
top_p=top_p, | |
max_new_audio_tokens=int(max_tokens) | |
) | |
else: | |
waveform = tts.generate_speech( | |
text=text.strip(), | |
speaker_id=speaker | |
) | |
if waveform is None or waveform.size == 0: | |
return None, "Failed to generate audio. Please try again." | |
if isinstance(waveform, torch.Tensor): | |
waveform = waveform.cpu().numpy() | |
if waveform.dtype == np.float32: | |
# Normalize to [-1, 1] range if needed | |
if np.max(np.abs(waveform)) > 1.0: | |
waveform = waveform / np.max(np.abs(waveform)) | |
# Keep as float32 but ensure proper range for Gradio | |
waveform = np.clip(waveform, -1.0, 1.0) | |
sample_rate = 16000 | |
return (sample_rate, waveform), f"β Audio generated successfully for speaker {speaker_name}" | |
except Exception as e: | |
logger.error(f"Speech generation failed: {e}") | |
return None, f"β Error: {str(e)}" | |
def get_speaker_names(): | |
speakers = get_speakers_dict() | |
if speakers: | |
speaker_list = list(speakers.keys()) | |
preferred_order = ["Bourama", "Adama", "Moussa", "Modibo", "Seydou", | |
"Amadou", "Bakary", "Ngolo", "Ibrahima", "Amara"] | |
ordered_speakers = [] | |
for speaker in preferred_order: | |
if speaker in speaker_list: | |
ordered_speakers.append(speaker) | |
for speaker in speaker_list: | |
if speaker not in ordered_speakers: | |
ordered_speakers.append(speaker) | |
logger.info(f"Available speakers: {ordered_speakers}") | |
return ordered_speakers | |
else: | |
logger.warning("No speakers loaded, using fallback list") | |
return ["Bourama", "Adama", "Moussa", "Modibo", "Seydou"] | |
SPEAKER_NAMES = get_speaker_names() | |
examples = [ | |
["Aw ni ce", "Adama"], | |
["Mali bΙna diya kΙsΙbΙ, ka a da a kan baara bΙ ka kΙ.", "Moussa"], | |
["Ne bΙ se ka sΙbΙnni yΙlΙma ka kΙ kuma ye", "Bourama"], | |
["I ka kΙnΙ wa?", "Modibo"], | |
["LakΙli karamΙgΙw tun tΙ ka se ka sΙbΙnni kΙ ka Ι²Ι walanda kan wa denmisΙnw tun tΙ ka se ka o sΙbΙnni ninnu ye, kuma tΙ ka u kalan. DenmisΙnw kΙra kunfinw ye.", "Adama"], | |
["sigikafΙ kΙnΙ jamanaw ni Ι²ΙgΙn cΙ, olu ye a haminankow ye, wa o ko ninnu ka kan ka kΙ sariya ani tilennenya kΙnΙ.", "Seydou"], | |
["Aw ni ce. Ne tΙgΙ ye Adama. AwΙ, ne ye maliden de ye. Aw SanbΙ SanbΙ. San min tΙ Ι²inan ye, an bΙΙ ka jΙ ka o seli Ι²ΙgΙn fΙ, hΙΙrΙ ni lafiya la. Ala ka Mali suma. Ala ka Mali yiriwa. Ala ka Mali taa Ι²Ι. Ala ka an ka seliw caya. Ala ka yafa an bΙΙ ma.", "Moussa"], | |
["An dΙlakelen bΙ masike bilenman don ka tΙw gΙn.", "Bourama"], | |
["Aw ni ce. Seidu bΙ aw fo wa aw ka yafa a ma, ka da a kan tuma dΙw la kow ka can.", "Modibo"], | |
["To tΙ nantan ni lafiya, o ka fisa ni so fa dumuniba kΙlΙma ye.", "Amadou"], | |
["Mali ye jamana Ι²uman ye!", "Bakary"], | |
["An ka Ι²ΙgΙn dΙmΙ ka baara kΙ Ι²ΙgΙn fΙ", "Ngolo"], | |
["Hakili to yΙrΙ min na, sabali bΙ yen", "Ibrahima"], | |
["DΙnko Ι²uman ye, a bΙ dΙn mΙgΙ kΙnΙ", "Amara"], | |
] | |
def get_safe_examples(): | |
"""Get examples with speaker fallbacks for missing speakers""" | |
safe_examples = [] | |
fallback_speakers = { | |
"Amadou": "Adama", | |
"Bakary": "Modibo", | |
"Ngolo": "Adama", | |
"Ibrahima": "Seydou", | |
"Amara": "Moussa" | |
} | |
for text, speaker in examples: | |
if speaker in SPEAKER_NAMES: | |
safe_examples.append([text, speaker]) | |
elif speaker in fallback_speakers and fallback_speakers[speaker] in SPEAKER_NAMES: | |
safe_examples.append([text, fallback_speakers[speaker]]) | |
else: | |
safe_examples.append([text, SPEAKER_NAMES[0]]) | |
return safe_examples | |
def build_interface(): | |
"""Build the Gradio interface - simplified like your old working version""" | |
with gr.Blocks(title="Bambara TTS - EXPERIMENTAL") as demo: | |
gr.Markdown(""" | |
# π€ Bambara Text-to-Speech β οΈ EXPERIMENTAL | |
**Powered by MALIBA-AI** | |
Convert Bambara text to speech. This model is currently experimental. | |
**Bambara** is spoken by millions of people in Mali and West Africa. | |
""") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
text_input = gr.Textbox( | |
label="π Bambara Text", | |
placeholder="Type your Bambara text here...", | |
lines=3, | |
max_lines=10, | |
value="I ni ce" | |
) | |
speaker_dropdown = gr.Dropdown( | |
choices=SPEAKER_NAMES, | |
value="Bourama" if "Bourama" in SPEAKER_NAMES else SPEAKER_NAMES[0], | |
label="π£οΈ Speaker Voice", | |
info=f"Choose from {len(SPEAKER_NAMES)} authentic Bambara voices" | |
) | |
generate_btn = gr.Button("π΅ Generate Speech", variant="primary", size="lg") | |
with gr.Column(scale=1): | |
use_advanced = gr.Checkbox( | |
label="βοΈ Use Advanced Settings", | |
value=False, | |
info="Enable to customize generation parameters" | |
) | |
with gr.Group(visible=False) as advanced_group: | |
gr.Markdown("**Advanced Parameters:**") | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=0.8, | |
step=0.1, | |
label="Temperature", | |
info="Higher = more varied" | |
) | |
top_k = gr.Slider( | |
minimum=1, | |
maximum=100, | |
value=50, | |
step=5, | |
label="Top-K" | |
) | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.9, | |
step=0.05, | |
label="Top-P" | |
) | |
max_tokens = gr.Slider( | |
minimum=256, | |
maximum=4096, | |
value=2048, | |
step=256, | |
label="Max Length" | |
) | |
gr.Markdown("### π Generated Audio") | |
audio_output = gr.Audio( | |
label="Generated Speech", | |
type="numpy", | |
interactive=False, | |
format="wav" | |
) | |
status_output = gr.Textbox( | |
label="Status", | |
interactive=False, | |
show_label=False, | |
container=False | |
) | |
with gr.Accordion("Try These Examples", open=True): | |
def load_example(text, speaker): | |
return text, speaker, False, 0.8, 50, 0.9, 2048 | |
gr.Markdown("**Click any example below:**") | |
# Use safe examples with fallbacks for missing speakers | |
safe_examples = get_safe_examples() | |
for i, (text, speaker) in enumerate(safe_examples): | |
btn = gr.Button(f"{text[:30]}{'...' if len(text) > 30 else ''}", size="sm") | |
btn.click( | |
fn=lambda t=text, s=speaker: load_example(t, s), | |
outputs=[text_input, speaker_dropdown, use_advanced, temperature, top_k, top_p, max_tokens] | |
) | |
with gr.Accordion("About", open=False): | |
gr.Markdown(f""" | |
## About MALIBA-AI Bambara TTS | |
**β οΈ This is an experimental Bambara TTS model.** | |
- **Languages**: Bambara (bm) | |
- **Speakers**: 10 different voice options | |
- **Sample Rate**: 16kHz | |
### π Available Speakers: | |
{" ".join(SPEAKER_NAMES)} | |
**License**: Creative Commons Attribution-NonCommercial-ShareAlike 4.0 (CC BY-NC-SA 4.0) | |
--- | |
**MALIBA-AI Mission**: Ensuring no Malian is left behind by technological advances π²π± | |
""") | |
def toggle_advanced(use_adv): | |
return gr.Group(visible=use_adv) | |
use_advanced.change( | |
fn=toggle_advanced, | |
inputs=[use_advanced], | |
outputs=[advanced_group] | |
) | |
generate_btn.click( | |
fn=generate_speech, | |
inputs=[text_input, speaker_dropdown, use_advanced, temperature, top_k, top_p, max_tokens], | |
outputs=[audio_output, status_output], | |
show_progress=True | |
) | |
text_input.submit( | |
fn=generate_speech, | |
inputs=[text_input, speaker_dropdown, use_advanced, temperature, top_k, top_p, max_tokens], | |
outputs=[audio_output, status_output], | |
show_progress=True | |
) | |
return demo | |
def main(): | |
"""Main function to launch the Gradio interface""" | |
logger.info("Starting Bambara TTS Gradio interface.") | |
interface = build_interface() | |
interface.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False | |
) | |
logger.info("Gradio interface launched successfully.") | |
if __name__ == "__main__": | |
main() |