Update app.py
Browse files
app.py
CHANGED
@@ -80,108 +80,12 @@ def load_model():
|
|
80 |
logger.error(f"Error loading model: {str(e)}")
|
81 |
raise
|
82 |
|
83 |
-
def generate_podcast_script(api_key, content, uploaded_file, duration, num_hosts):
|
84 |
-
try:
|
85 |
-
genai.configure(api_key=api_key)
|
86 |
-
model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
|
87 |
-
|
88 |
-
combined_content = content or ""
|
89 |
-
if uploaded_file:
|
90 |
-
file_content = uploaded_file.read().decode('utf-8')
|
91 |
-
combined_content += "\n" + file_content if combined_content else file_content
|
92 |
-
|
93 |
-
prompt = f"""
|
94 |
-
Create a podcast script for {'one person' if num_hosts == 1 else 'two people'} discussing:
|
95 |
-
{combined_content}
|
96 |
-
|
97 |
-
Duration: {duration}. Include natural speech, humor, and occasional off-topic thoughts.
|
98 |
-
Use speech fillers like um, ah. Vary emotional tone.
|
99 |
-
|
100 |
-
Format: {'Monologue' if num_hosts == 1 else 'Alternating dialogue'} without speaker labels.
|
101 |
-
Separate {'paragraphs' if num_hosts == 1 else 'lines'} with blank lines.
|
102 |
-
|
103 |
-
Use emotion tags in angle brackets: <laugh>, <sigh>, <chuckle>, <cough>, <sniffle>, <groan>, <yawn>, <gasp>.
|
104 |
-
|
105 |
-
Example: "I can't believe I stayed up all night <yawn> only to find out the meeting was canceled <groan>."
|
106 |
-
|
107 |
-
Ensure content flows naturally and stays on topic. Match the script length to {duration}.
|
108 |
-
"""
|
109 |
-
|
110 |
-
response = model.generate_content(prompt)
|
111 |
-
return re.sub(r'[^a-zA-Z0-9\s.,?!<>]', '', response.text)
|
112 |
-
except Exception as e:
|
113 |
-
logger.error(f"Error generating podcast script: {str(e)}")
|
114 |
-
raise
|
115 |
-
|
116 |
-
def process_prompt(prompt, voice, tokenizer, device):
|
117 |
-
prompt = f"{voice}: {prompt}"
|
118 |
-
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
119 |
-
|
120 |
-
start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
|
121 |
-
end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
|
122 |
-
|
123 |
-
modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) # SOH SOT Text EOT EOH
|
124 |
-
|
125 |
-
# No padding needed for single input
|
126 |
-
attention_mask = torch.ones_like(modified_input_ids)
|
127 |
-
|
128 |
-
return modified_input_ids.to(device), attention_mask.to(device)
|
129 |
-
|
130 |
-
def parse_output(generated_ids):
|
131 |
-
token_to_find = 128257
|
132 |
-
token_to_remove = 128258
|
133 |
-
|
134 |
-
token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
|
135 |
-
|
136 |
-
if len(token_indices[1]) > 0:
|
137 |
-
last_occurrence_idx = token_indices[1][-1].item()
|
138 |
-
cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
|
139 |
-
else:
|
140 |
-
cropped_tensor = generated_ids
|
141 |
-
|
142 |
-
processed_rows = []
|
143 |
-
for row in cropped_tensor:
|
144 |
-
masked_row = row[row != token_to_remove]
|
145 |
-
processed_rows.append(masked_row)
|
146 |
-
|
147 |
-
code_lists = []
|
148 |
-
for row in processed_rows:
|
149 |
-
row_length = row.size(0)
|
150 |
-
new_length = (row_length // 7) * 7
|
151 |
-
trimmed_row = row[:new_length]
|
152 |
-
trimmed_row = [t - 128266 for t in trimmed_row]
|
153 |
-
code_lists.append(trimmed_row)
|
154 |
-
|
155 |
-
return code_lists[0] # Return just the first one for single sample
|
156 |
-
|
157 |
-
def redistribute_codes(code_list, snac_model):
|
158 |
-
device = next(snac_model.parameters()).device # Get the device of SNAC model
|
159 |
-
|
160 |
-
layer_1 = []
|
161 |
-
layer_2 = []
|
162 |
-
layer_3 = []
|
163 |
-
for i in range((len(code_list)+1)//7):
|
164 |
-
layer_1.append(code_list[7*i])
|
165 |
-
layer_2.append(code_list[7*i+1]-4096)
|
166 |
-
layer_3.append(code_list[7*i+2]-(2*4096))
|
167 |
-
layer_3.append(code_list[7*i+3]-(3*4096))
|
168 |
-
layer_2.append(code_list[7*i+4]-(4*4096))
|
169 |
-
layer_3.append(code_list[7*i+5]-(5*4096))
|
170 |
-
layer_3.append(code_list[7*i+6]-(6*4096))
|
171 |
-
|
172 |
-
# Move tensors to the same device as the SNAC model
|
173 |
-
codes = [
|
174 |
-
torch.tensor(layer_1, device=device).unsqueeze(0),
|
175 |
-
torch.tensor(layer_2, device=device).unsqueeze(0),
|
176 |
-
torch.tensor(layer_3, device=device).unsqueeze(0)
|
177 |
-
]
|
178 |
-
|
179 |
-
audio_hat = snac_model.decode(codes)
|
180 |
-
return audio_hat.detach().squeeze().cpu().numpy() # Always return CPU numpy array
|
181 |
-
|
182 |
@spaces.GPU()
|
183 |
def text_to_speech(text, voice, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=1200):
|
184 |
global model, tokenizer, snac_model
|
|
|
|
|
|
|
185 |
if not text.strip():
|
186 |
return None
|
187 |
|
@@ -238,44 +142,11 @@ def render_podcast(api_key, script, voice1, voice2, num_hosts):
|
|
238 |
logger.error(f"Error rendering podcast: {str(e)}")
|
239 |
raise
|
240 |
|
241 |
-
#
|
242 |
-
with gr.Blocks() as demo:
|
243 |
-
gr.Markdown("# AI Podcast Generator")
|
244 |
-
|
245 |
-
api_key_input = gr.Textbox(label="Enter your Gemini API Key", type="password")
|
246 |
-
|
247 |
-
with gr.Row():
|
248 |
-
content_input = gr.Textbox(label="Paste your content (optional)")
|
249 |
-
document_upload = gr.File(label="Upload Document (optional)")
|
250 |
-
|
251 |
-
duration = gr.Radio(["1-5 min", "5-10 min", "10-15 min"], label="Estimated podcast duration")
|
252 |
-
num_hosts = gr.Radio([1, 2], label="Number of podcast hosts", value=2)
|
253 |
-
|
254 |
-
voice_options = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"]
|
255 |
-
voice1_select = gr.Dropdown(label="Select Voice 1", choices=voice_options, value="tara")
|
256 |
-
voice2_select = gr.Dropdown(label="Select Voice 2", choices=voice_options, value="leo")
|
257 |
-
|
258 |
-
generate_btn = gr.Button("Generate Script")
|
259 |
-
script_output = gr.Textbox(label="Generated Script", lines=10)
|
260 |
-
|
261 |
-
render_btn = gr.Button("Render Podcast")
|
262 |
-
audio_output = gr.Audio(label="Generated Podcast")
|
263 |
-
|
264 |
-
generate_btn.click(generate_podcast_script,
|
265 |
-
inputs=[api_key_input, content_input, document_upload, duration, num_hosts],
|
266 |
-
outputs=script_output)
|
267 |
-
|
268 |
-
render_btn.click(render_podcast,
|
269 |
-
inputs=[api_key_input, script_output, voice1_select, voice2_select, num_hosts],
|
270 |
-
outputs=audio_output)
|
271 |
-
|
272 |
-
num_hosts.change(lambda x: gr.update(visible=x == 2),
|
273 |
-
inputs=[num_hosts],
|
274 |
-
outputs=[voice2_select])
|
275 |
|
276 |
if __name__ == "__main__":
|
277 |
try:
|
278 |
-
load_model()
|
279 |
demo.launch()
|
280 |
except Exception as e:
|
281 |
logger.error(f"Error launching the application: {str(e)}")
|
|
|
80 |
logger.error(f"Error loading model: {str(e)}")
|
81 |
raise
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
@spaces.GPU()
|
84 |
def text_to_speech(text, voice, temperature=0.6, top_p=0.95, repetition_penalty=1.1, max_new_tokens=1200):
|
85 |
global model, tokenizer, snac_model
|
86 |
+
if model is None or tokenizer is None or snac_model is None:
|
87 |
+
load_model()
|
88 |
+
|
89 |
if not text.strip():
|
90 |
return None
|
91 |
|
|
|
142 |
logger.error(f"Error rendering podcast: {str(e)}")
|
143 |
raise
|
144 |
|
145 |
+
# ... (rest of the code remains the same)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
if __name__ == "__main__":
|
148 |
try:
|
149 |
+
load_model() # Load models at startup
|
150 |
demo.launch()
|
151 |
except Exception as e:
|
152 |
logger.error(f"Error launching the application: {str(e)}")
|