Devakumar868 commited on
Commit
2c1a7ab
Β·
verified Β·
1 Parent(s): b0aca8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -7
app.py CHANGED
@@ -1,10 +1,189 @@
1
  import gradio as gr
 
 
 
 
2
 
3
- with gr.Blocks(fill_height=True) as demo:
4
- with gr.Sidebar():
5
- gr.Markdown("# Inference Provider")
6
- gr.Markdown("This Space showcases the nari-labs/Dia-1.6B model, served by the fal-ai API. Sign in with your Hugging Face account to use this API.")
7
- button = gr.LoginButton("Sign in")
8
- gr.load("models/nari-labs/Dia-1.6B", accept_token=button, provider="fal-ai")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from dia.model import Dia
5
+ import warnings
6
 
7
+ # Suppress warnings for cleaner output
8
+ warnings.filterwarnings("ignore", category=FutureWarning)
9
+ warnings.filterwarnings("ignore", category=UserWarning)
10
+
11
+ # Global model variable
12
+ model = None
13
+
14
+ def load_model_once():
15
+ """Load the Dia model once and cache it globally"""
16
+ global model
17
+ if model is None:
18
+ print("Loading Dia model... This may take a few minutes.")
19
+ try:
20
+ # Load model with correct parameters for Dia
21
+ model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float32")
22
+
23
+ # Move model to GPU if available
24
+ if torch.cuda.is_available():
25
+ model = model.cuda()
26
+ print(f"Model loaded on GPU: {torch.cuda.get_device_name()}")
27
+ else:
28
+ print("Model loaded on CPU")
29
+
30
+ print("Model loaded successfully!")
31
+
32
+ except Exception as e:
33
+ print(f"Error loading model: {e}")
34
+ raise e
35
+
36
+ return model
37
+
38
+ def generate_audio(text, seed=42):
39
+ """Generate audio from text input with error handling"""
40
+ try:
41
+ # Clear GPU cache before generation
42
+ if torch.cuda.is_available():
43
+ torch.cuda.empty_cache()
44
+
45
+ current_model = load_model_once()
46
+
47
+ # Validate input
48
+ if not text or not text.strip():
49
+ return None, "❌ Please enter some text"
50
+
51
+ # Clean and format text
52
+ text = text.strip()
53
+ if not text.startswith('[S1]') and not text.startswith('[S2]'):
54
+ text = '[S1] ' + text
55
+
56
+ # Set seed for reproducibility
57
+ if seed:
58
+ torch.manual_seed(int(seed))
59
+ if torch.cuda.is_available():
60
+ torch.cuda.manual_seed(int(seed))
61
+
62
+ print(f"Generating speech for: {text[:100]}...")
63
+
64
+ # Generate audio - disable torch compile for stability
65
+ with torch.no_grad():
66
+ audio_output = current_model.generate(
67
+ text,
68
+ use_torch_compile=False, # Disabled for T4 compatibility
69
+ verbose=False
70
+ )
71
+
72
+ # Ensure audio_output is numpy array
73
+ if isinstance(audio_output, torch.Tensor):
74
+ audio_output = audio_output.cpu().numpy()
75
+
76
+ # Normalize audio to prevent clipping
77
+ if len(audio_output) > 0:
78
+ max_val = np.max(np.abs(audio_output))
79
+ if max_val > 1.0:
80
+ audio_output = audio_output / max_val * 0.95
81
+
82
+ print("βœ… Audio generated successfully!")
83
+ return (44100, audio_output), "βœ… Audio generated successfully!"
84
+
85
+ except torch.cuda.OutOfMemoryError:
86
+ # Handle GPU memory issues
87
+ if torch.cuda.is_available():
88
+ torch.cuda.empty_cache()
89
+ error_msg = "❌ GPU memory error. Try shorter text or restart the space."
90
+ print(error_msg)
91
+ return None, error_msg
92
+
93
+ except Exception as e:
94
+ error_msg = f"❌ Error: {str(e)}"
95
+ print(error_msg)
96
+ return None, error_msg
97
+
98
+ # Create the Gradio interface - simplified to avoid OAuth triggers
99
+ demo = gr.Blocks(title="Dia TTS Demo")
100
+
101
+ with demo:
102
+ gr.HTML("""
103
+ <div style="text-align: center; padding: 20px;">
104
+ <h1>πŸŽ™οΈ Dia TTS - Ultra-Realistic Text-to-Speech</h1>
105
+ <p style="font-size: 18px; color: #666;">
106
+ Generate multi-speaker, emotion-aware dialogue using the Dia 1.6B model
107
+ </p>
108
+ </div>
109
+ """)
110
+
111
+ with gr.Row():
112
+ with gr.Column():
113
+ text_input = gr.Textbox(
114
+ label="πŸ“ Text to Speech",
115
+ placeholder="[S1] Hello there! How are you today? [S2] I'm doing great, thanks for asking! (laughs)",
116
+ lines=6,
117
+ value="[S1] Welcome to the Dia TTS demo! [S2] This is amazing technology!",
118
+ info="Use [S1] and [S2] for different speakers. Add emotions like (laughs), (sighs), etc."
119
+ )
120
+
121
+ seed_input = gr.Number(
122
+ label="🎲 Random Seed",
123
+ value=42,
124
+ precision=0,
125
+ info="Same seed = consistent voices"
126
+ )
127
+
128
+ generate_btn = gr.Button("🎡 Generate Speech", variant="primary")
129
+
130
+ with gr.Column():
131
+ audio_output = gr.Audio(
132
+ label="πŸ”Š Generated Audio",
133
+ type="numpy"
134
+ )
135
+
136
+ status_text = gr.Textbox(
137
+ label="πŸ“Š Status",
138
+ interactive=False,
139
+ lines=2
140
+ )
141
+
142
+ # Connect the button to the function
143
+ generate_btn.click(
144
+ fn=generate_audio,
145
+ inputs=[text_input, seed_input],
146
+ outputs=[audio_output, status_text]
147
+ )
148
+
149
+ # Add example buttons
150
+ with gr.Row():
151
+ example_btn1 = gr.Button("πŸ“» Podcast", size="sm")
152
+ example_btn2 = gr.Button("πŸ˜„ Chat", size="sm")
153
+ example_btn3 = gr.Button("🎭 Drama", size="sm")
154
+
155
+ # Example button functions
156
+ example_btn1.click(
157
+ lambda: "[S1] Welcome to our podcast! [S2] Thanks for having me on the show!",
158
+ outputs=text_input
159
+ )
160
 
161
+ example_btn2.click(
162
+ lambda: "[S1] Did you see the game? [S2] Yes! (laughs) It was incredible!",
163
+ outputs=text_input
164
+ )
165
+
166
+ example_btn3.click(
167
+ lambda: "[S1] I can't believe you're leaving. (sighs) [S2] I know, it's hard. (sad)",
168
+ outputs=text_input
169
+ )
170
+
171
+ # Usage instructions
172
+ gr.HTML("""
173
+ <div style="margin-top: 20px; padding: 15px; background: #f0f8ff; border-radius: 8px;">
174
+ <h3>πŸ’‘ Usage Tips:</h3>
175
+ <ul>
176
+ <li><strong>Speaker Tags:</strong> Use [S1] and [S2] to switch between speakers</li>
177
+ <li><strong>Emotions:</strong> Add (laughs), (sighs), (excited), (whispers), (sad), etc.</li>
178
+ <li><strong>Length:</strong> Keep text moderate length (5-20 seconds of speech works best)</li>
179
+ <li><strong>Seeds:</strong> Use the same seed number for consistent voice characteristics</li>
180
+ </ul>
181
+
182
+ <p><strong>Supported Emotions:</strong> (laughs), (sighs), (gasps), (excited), (sad), (angry),
183
+ (surprised), (whispers), (shouts), (coughs), (clears throat), (sniffs), (chuckles), (groans)</p>
184
+ </div>
185
+ """)
186
+
187
+ # Launch with basic configuration
188
+ if __name__ == "__main__":
189
+ demo.launch()