Devakumar868 commited on
Commit
d7bff21
·
verified ·
1 Parent(s): b7abc4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -9
app.py CHANGED
@@ -17,18 +17,16 @@ def load_model_once():
17
  if model is None:
18
  print("Loading Dia model... This may take a few minutes.")
19
  try:
20
- # Load model with correct parameters for Dia
 
21
  model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float32")
22
 
23
- # Move model to GPU if available
24
  if torch.cuda.is_available():
25
- model = model.cuda()
26
- print(f"Model loaded on GPU: {torch.cuda.get_device_name()}")
27
  else:
28
- print("Model loaded on CPU")
29
 
30
- print("Model loaded successfully!")
31
-
32
  except Exception as e:
33
  print(f"Error loading model: {e}")
34
  raise e
@@ -61,7 +59,7 @@ def generate_audio(text, seed=42):
61
 
62
  print(f"Generating speech for: {text[:100]}...")
63
 
64
- # Generate audio - disable torch compile for stability
65
  with torch.no_grad():
66
  audio_output = current_model.generate(
67
  text,
@@ -95,7 +93,7 @@ def generate_audio(text, seed=42):
95
  print(error_msg)
96
  return None, error_msg
97
 
98
- # Create the Gradio interface - simplified to avoid OAuth triggers
99
  demo = gr.Blocks(title="Dia TTS Demo")
100
 
101
  with demo:
 
17
  if model is None:
18
  print("Loading Dia model... This may take a few minutes.")
19
  try:
20
+ # Load model without trying to move it manually to GPU
21
+ # The Dia model handles GPU placement internally
22
  model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float32")
23
 
24
+ print("Model loaded successfully!")
25
  if torch.cuda.is_available():
26
+ print(f"CUDA is available: {torch.cuda.get_device_name()}")
 
27
  else:
28
+ print("CUDA is not available, using CPU")
29
 
 
 
30
  except Exception as e:
31
  print(f"Error loading model: {e}")
32
  raise e
 
59
 
60
  print(f"Generating speech for: {text[:100]}...")
61
 
62
+ # Generate audio - disable torch compile for T4 stability
63
  with torch.no_grad():
64
  audio_output = current_model.generate(
65
  text,
 
93
  print(error_msg)
94
  return None, error_msg
95
 
96
+ # Create the Gradio interface
97
  demo = gr.Blocks(title="Dia TTS Demo")
98
 
99
  with demo: