Update app.py
Browse files
app.py
CHANGED
@@ -17,18 +17,16 @@ def load_model_once():
|
|
17 |
if model is None:
|
18 |
print("Loading Dia model... This may take a few minutes.")
|
19 |
try:
|
20 |
-
# Load model
|
|
|
21 |
model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float32")
|
22 |
|
23 |
-
|
24 |
if torch.cuda.is_available():
|
25 |
-
|
26 |
-
print(f"Model loaded on GPU: {torch.cuda.get_device_name()}")
|
27 |
else:
|
28 |
-
print("
|
29 |
|
30 |
-
print("Model loaded successfully!")
|
31 |
-
|
32 |
except Exception as e:
|
33 |
print(f"Error loading model: {e}")
|
34 |
raise e
|
@@ -61,7 +59,7 @@ def generate_audio(text, seed=42):
|
|
61 |
|
62 |
print(f"Generating speech for: {text[:100]}...")
|
63 |
|
64 |
-
# Generate audio - disable torch compile for stability
|
65 |
with torch.no_grad():
|
66 |
audio_output = current_model.generate(
|
67 |
text,
|
@@ -95,7 +93,7 @@ def generate_audio(text, seed=42):
|
|
95 |
print(error_msg)
|
96 |
return None, error_msg
|
97 |
|
98 |
-
# Create the Gradio interface
|
99 |
demo = gr.Blocks(title="Dia TTS Demo")
|
100 |
|
101 |
with demo:
|
|
|
17 |
if model is None:
|
18 |
print("Loading Dia model... This may take a few minutes.")
|
19 |
try:
|
20 |
+
# Load model without trying to move it manually to GPU
|
21 |
+
# The Dia model handles GPU placement internally
|
22 |
model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float32")
|
23 |
|
24 |
+
print("Model loaded successfully!")
|
25 |
if torch.cuda.is_available():
|
26 |
+
print(f"CUDA is available: {torch.cuda.get_device_name()}")
|
|
|
27 |
else:
|
28 |
+
print("CUDA is not available, using CPU")
|
29 |
|
|
|
|
|
30 |
except Exception as e:
|
31 |
print(f"Error loading model: {e}")
|
32 |
raise e
|
|
|
59 |
|
60 |
print(f"Generating speech for: {text[:100]}...")
|
61 |
|
62 |
+
# Generate audio - disable torch compile for T4 stability
|
63 |
with torch.no_grad():
|
64 |
audio_output = current_model.generate(
|
65 |
text,
|
|
|
93 |
print(error_msg)
|
94 |
return None, error_msg
|
95 |
|
96 |
+
# Create the Gradio interface
|
97 |
demo = gr.Blocks(title="Dia TTS Demo")
|
98 |
|
99 |
with demo:
|