sudoping01 commited on
Commit
0456d9c
Β·
verified Β·
1 Parent(s): 7274cc0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -31
app.py CHANGED
@@ -6,9 +6,39 @@ os.environ["TORCH_COMPILE_DISABLE"] = "1"
6
  os.environ["PYTORCH_DISABLE_CUDNN_BENCHMARK"] = "1"
7
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
8
 
9
- # Disable Unsloth optimizations that cause issues in ZeroGPU
10
- os.environ["UNSLOTH_DISABLE"] = "1"
11
- os.environ["DISABLE_UNSLOTH"] = "1"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  import torch
14
  import gradio as gr
@@ -64,7 +94,7 @@ def get_speakers_dict():
64
 
65
  @spaces.GPU()
66
  def initialize_model_once():
67
- """Initialize model exactly like your old working version"""
68
  global _tts_model, _speakers_dict, _model_initialized, _initialization_in_progress
69
 
70
  if _model_initialized:
@@ -80,34 +110,44 @@ def initialize_model_once():
80
 
81
  _initialization_in_progress = True
82
 
83
- try:
84
- logger.info("Initializing Bambara TTS model...")
85
- start_time = time.time()
86
-
87
- # Use the correct import path
88
- from maliba_ai.tts.inference import BambaraTTSInference
89
-
90
- model = BambaraTTSInference()
91
- speakers = get_speakers_dict()
92
-
93
- if not speakers:
94
- raise ValueError("Failed to load speakers dictionary")
95
-
96
- _tts_model = model
97
- _speakers_dict = speakers
98
- _model_initialized = True
99
-
100
- elapsed = time.time() - start_time
101
- logger.info(f"Model initialized successfully in {elapsed:.2f} seconds!")
102
-
103
- return _tts_model, _speakers_dict
104
-
105
- except Exception as e:
106
- logger.error(f"Failed to initialize model: {e}")
107
- _initialization_in_progress = False
108
- raise e
 
 
 
 
 
 
 
 
 
109
  finally:
110
- _initialization_in_progress = False
 
111
 
112
  def validate_inputs(text, temperature, top_k, top_p, max_tokens):
113
  """Same validation as your old version"""
 
6
  os.environ["PYTORCH_DISABLE_CUDNN_BENCHMARK"] = "1"
7
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
8
 
9
+ # Check if we're in ZeroGPU or similar restricted environment
10
+ def is_restricted_environment():
11
+ return (
12
+ os.getenv("ZERO_GPU") or
13
+ "zero" in str(os.getenv("SPACE_ID", "")).lower() or
14
+ os.getenv("SPACES_ZERO_GPU") or
15
+ "spaces" in str(os.getenv("HOSTNAME", "")).lower()
16
+ )
17
+
18
+ # Disable Unsloth optimizations in restricted environments
19
+ if is_restricted_environment():
20
+ os.environ["UNSLOTH_DISABLE"] = "1"
21
+ os.environ["DISABLE_UNSLOTH"] = "1"
22
+ os.environ["UNSLOTH_IGNORE_ERRORS"] = "1"
23
+ os.environ["UNSLOTH_NO_COMPILE"] = "1"
24
+ print("πŸš€ ZeroGPU detected - Unsloth optimizations disabled for compatibility")
25
+ else:
26
+ print("πŸ”§ Local environment detected - Unsloth optimizations enabled")
27
+
28
+ # Check if we're in ZeroGPU or similar restricted environment
29
+ def is_restricted_environment():
30
+ return (
31
+ os.getenv("ZERO_GPU") or
32
+ "zero" in str(os.getenv("SPACE_ID", "")).lower() or
33
+ os.getenv("SPACES_ZERO_GPU")
34
+ )
35
+
36
+ # Disable Unsloth optimizations in restricted environments
37
+ if is_restricted_environment():
38
+ os.environ["UNSLOTH_DISABLE"] = "1"
39
+ os.environ["DISABLE_UNSLOTH"] = "1"
40
+ os.environ["UNSLOTH_IGNORE_ERRORS"] = "1"
41
+ print("πŸš€ ZeroGPU detected - Unsloth optimizations disabled for compatibility")
42
 
43
  import torch
44
  import gradio as gr
 
94
 
95
  @spaces.GPU()
96
  def initialize_model_once():
97
+ """Initialize model with retry logic for Unsloth compilation issues"""
98
  global _tts_model, _speakers_dict, _model_initialized, _initialization_in_progress
99
 
100
  if _model_initialized:
 
110
 
111
  _initialization_in_progress = True
112
 
113
+ # Retry logic for Unsloth compilation issues
114
+ max_retries = 2
115
+ for attempt in range(max_retries + 1):
116
+ try:
117
+ logger.info(f"Initializing Bambara TTS model... (attempt {attempt + 1})")
118
+ start_time = time.time()
119
+
120
+ # Use the correct import path
121
+ from maliba_ai.tts.inference import BambaraTTSInference
122
+
123
+ model = BambaraTTSInference()
124
+ speakers = get_speakers_dict()
125
+
126
+ if not speakers:
127
+ raise ValueError("Failed to load speakers dictionary")
128
+
129
+ _tts_model = model
130
+ _speakers_dict = speakers
131
+ _model_initialized = True
132
+
133
+ elapsed = time.time() - start_time
134
+ logger.info(f"Model initialized successfully in {elapsed:.2f} seconds!")
135
+
136
+ return _tts_model, _speakers_dict
137
+
138
+ except Exception as e:
139
+ error_msg = str(e)
140
+ if "unsloth_compiled_module" in error_msg and attempt < max_retries:
141
+ logger.warning(f"Unsloth compilation failed (attempt {attempt + 1}/{max_retries + 1}), retrying...")
142
+ time.sleep(2) # Brief delay before retry
143
+ continue
144
+ else:
145
+ logger.error(f"Failed to initialize model: {e}")
146
+ _initialization_in_progress = False
147
+ raise e
148
  finally:
149
+ if not _model_initialized:
150
+ _initialization_in_progress = False
151
 
152
  def validate_inputs(text, temperature, top_k, top_p, max_tokens):
153
  """Same validation as your old version"""