add tool for generation audio from sample; update agent initialization and Gradio interface structure; swapping back to Gemma for fast testing
Browse files
app.py
CHANGED
|
@@ -44,7 +44,7 @@ def load_file(path: str) -> list | dict:
|
|
| 44 |
if image is not None:
|
| 45 |
return [image]
|
| 46 |
elif ext.endswith(".mp3") or ext.endswith(".wav"):
|
| 47 |
-
return {"audio
|
| 48 |
else:
|
| 49 |
return {"raw document text": text, "file path": path}
|
| 50 |
|
|
@@ -157,7 +157,6 @@ def generate_audio(prompt: str, duration: int) -> gr.Component:
|
|
| 157 |
Args:
|
| 158 |
prompt: The text prompt to generate the audio from.
|
| 159 |
duration: Duration of the generated audio in seconds. Max 30 seconds.
|
| 160 |
-
|
| 161 |
Returns:
|
| 162 |
gr.Component: The generated audio as a Gradio Audio component.
|
| 163 |
"""
|
|
@@ -167,18 +166,21 @@ def generate_audio(prompt: str, duration: int) -> gr.Component:
|
|
| 167 |
name="Sound_Generator",
|
| 168 |
description="Generate music or sound effects from a text prompt using MusicGen."
|
| 169 |
)
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
return gr.Audio(value=sound)
|
| 173 |
|
| 174 |
@tool
|
| 175 |
-
def generate_audio_from_sample(prompt: str, duration: int,
|
| 176 |
"""
|
| 177 |
Generate audio from a text prompt + audio sample using MusicGen.
|
| 178 |
Args:
|
| 179 |
prompt: The text prompt to generate the audio from.
|
| 180 |
duration: Duration of the generated audio in seconds. Max 30 seconds.
|
| 181 |
-
|
| 182 |
|
| 183 |
Returns:
|
| 184 |
gr.Component: The generated audio as a Gradio Audio component.
|
|
@@ -189,21 +191,24 @@ def generate_audio_from_sample(prompt: str, duration: int, sample: list[int, np.
|
|
| 189 |
name="Sound_Generator",
|
| 190 |
description="Generate music or sound effects from a text prompt using MusicGen."
|
| 191 |
)
|
| 192 |
-
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
return gr.Audio(value=sound)
|
| 195 |
|
| 196 |
|
| 197 |
-
|
| 198 |
## agent definition
|
| 199 |
class Agent:
|
| 200 |
def __init__(self, ):
|
| 201 |
#client = HfApiModel("deepseek-ai/DeepSeek-R1-0528", provider="nebius", api_key=os.getenv("NEBIUS_API_KEY"))
|
| 202 |
-
client =
|
|
|
|
| 203 |
model_id="claude-opus-4-20250514",
|
| 204 |
api_base="https://api.anthropic.com/v1/",
|
| 205 |
api_key=os.environ["ANTHROPIC_API_KEY"],
|
| 206 |
-
)
|
| 207 |
self.agent = CodeAgent(
|
| 208 |
model=client,
|
| 209 |
tools=[DuckDuckGoSearchTool(max_results=5),
|
|
@@ -271,23 +276,26 @@ def initialize_agent():
|
|
| 271 |
return agent
|
| 272 |
|
| 273 |
## gradio interface
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
|
|
|
| 287 |
gr.Checkbox(value=False, label="Web Search",
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
|
|
|
|
|
|
| 291 |
|
| 292 |
|
| 293 |
if __name__ == "__main__":
|
|
|
|
| 44 |
if image is not None:
|
| 45 |
return [image]
|
| 46 |
elif ext.endswith(".mp3") or ext.endswith(".wav"):
|
| 47 |
+
return {"audio path": path}
|
| 48 |
else:
|
| 49 |
return {"raw document text": text, "file path": path}
|
| 50 |
|
|
|
|
| 157 |
Args:
|
| 158 |
prompt: The text prompt to generate the audio from.
|
| 159 |
duration: Duration of the generated audio in seconds. Max 30 seconds.
|
|
|
|
| 160 |
Returns:
|
| 161 |
gr.Component: The generated audio as a Gradio Audio component.
|
| 162 |
"""
|
|
|
|
| 166 |
name="Sound_Generator",
|
| 167 |
description="Generate music or sound effects from a text prompt using MusicGen."
|
| 168 |
)
|
| 169 |
+
if duration > 30:
|
| 170 |
+
sound = client(prompt, 30)
|
| 171 |
+
else:
|
| 172 |
+
sound = client(prompt, duration)
|
| 173 |
|
| 174 |
return gr.Audio(value=sound)
|
| 175 |
|
| 176 |
@tool
|
| 177 |
+
def generate_audio_from_sample(prompt: str, duration: int, sample_path: str = None) -> gr.Component:
|
| 178 |
"""
|
| 179 |
Generate audio from a text prompt + audio sample using MusicGen.
|
| 180 |
Args:
|
| 181 |
prompt: The text prompt to generate the audio from.
|
| 182 |
duration: Duration of the generated audio in seconds. Max 30 seconds.
|
| 183 |
+
sample_path: audio sample path to guide generation.
|
| 184 |
|
| 185 |
Returns:
|
| 186 |
gr.Component: The generated audio as a Gradio Audio component.
|
|
|
|
| 191 |
name="Sound_Generator",
|
| 192 |
description="Generate music or sound effects from a text prompt using MusicGen."
|
| 193 |
)
|
| 194 |
+
if duration > 30:
|
| 195 |
+
sound = client(prompt, 30, sample_path)
|
| 196 |
+
else:
|
| 197 |
+
sound = client(prompt, duration, sample_path)
|
| 198 |
|
| 199 |
return gr.Audio(value=sound)
|
| 200 |
|
| 201 |
|
|
|
|
| 202 |
## agent definition
|
| 203 |
class Agent:
|
| 204 |
def __init__(self, ):
|
| 205 |
#client = HfApiModel("deepseek-ai/DeepSeek-R1-0528", provider="nebius", api_key=os.getenv("NEBIUS_API_KEY"))
|
| 206 |
+
client = HfApiModel("google/gemma-3-27b-it", provider="nebius", api_key=os.getenv("NEBIUS_API_KEY"))
|
| 207 |
+
"""client = OpenAIServerModel(
|
| 208 |
model_id="claude-opus-4-20250514",
|
| 209 |
api_base="https://api.anthropic.com/v1/",
|
| 210 |
api_key=os.environ["ANTHROPIC_API_KEY"],
|
| 211 |
+
)"""
|
| 212 |
self.agent = CodeAgent(
|
| 213 |
model=client,
|
| 214 |
tools=[DuckDuckGoSearchTool(max_results=5),
|
|
|
|
| 276 |
return agent
|
| 277 |
|
| 278 |
## gradio interface
|
| 279 |
+
|
| 280 |
+
global agent
|
| 281 |
+
agent = initialize_agent()
|
| 282 |
+
demo = gr.ChatInterface(
|
| 283 |
+
fn=respond,
|
| 284 |
+
type='messages',
|
| 285 |
+
multimodal=True,
|
| 286 |
+
title='MultiAgent System for Screenplay Creation and Editing',
|
| 287 |
+
show_progress='full',
|
| 288 |
+
fill_height=True,
|
| 289 |
+
fill_width=True,
|
| 290 |
+
save_history=True,
|
| 291 |
+
autoscroll=True,
|
| 292 |
+
additional_inputs=[
|
| 293 |
gr.Checkbox(value=False, label="Web Search",
|
| 294 |
+
info="Enable web search to find information online. If disabled, the agent will only use the provided files and images.",
|
| 295 |
+
render=False),
|
| 296 |
+
],
|
| 297 |
+
additional_inputs_accordion=gr.Accordion(label="Tools available: ", open=True, render=False)
|
| 298 |
+
)
|
| 299 |
|
| 300 |
|
| 301 |
if __name__ == "__main__":
|