arjunanand13's picture
Update app.py
52ec688 verified
raw
history blame
4.69 kB
import os
import sqlite3
import requests
import openai
import gradio as gr
import asyncio
from langgraph import Graph, FunctionNode, RouterNode
from gtts import gTTS
def stt_agent(audio_path: str) -> str:
"""Convert speech to text using OpenAI Whisper API"""
with open(audio_path, "rb") as afile:
transcript = openai.audio.transcriptions.create(
model="whisper-1",
file=afile
)
return transcript.text.strip()
# Load API keys from environment
openai.api_key = os.getenv("OPENAI_API_KEY")
# --- Business Logic Functions ---
def db_agent(query: str) -> str:
try:
conn = sqlite3.connect("shop.db")
cur = conn.cursor()
if "max revenue" in query.lower():
cur.execute(
"""
SELECT product, SUM(amount) AS revenue
FROM transactions
WHERE date = date('now')
GROUP BY product
ORDER BY revenue DESC
LIMIT 1
"""
)
row = cur.fetchone()
if row:
return f"Top product today: {row[0]} with ₹{row[1]:,.2f}"
return "No transactions found for today."
return None
except sqlite3.OperationalError as e:
return f"Database error: {e}. Please initialize 'transactions' table in shop.db."
def web_search_agent(query: str) -> str:
try:
resp = requests.get(
"https://serpapi.com/search",
params={"q": query, "api_key": os.getenv("SERPAPI_KEY")}
)
snippet = resp.json().get("organic_results", [{}])[0].get("snippet", "").strip()
if snippet:
return llm_agent(f"Summarize: {snippet}")
except Exception:
pass
return llm_agent(query)
def llm_agent(query: str) -> str:
response = openai.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": query},
],
temperature=0.2,
)
return response.choices[0].message.content.strip()
# Text-to-Speech
def tts_agent(text: str, lang: str = 'en') -> str:
"""Convert text to speech mp3 and return filepath"""
tts = gTTS(text=text, lang=lang)
out_path = "response_audio.mp3"
tts.save(out_path)
return out_path
# --- LangGraph Multi-Agent Setup ---
router_node = RouterNode(
name="router",
routes=[
(lambda q: any(k in q.lower() for k in ["max revenue", "revenue"]), "db"),
(lambda q: any(k in q.lower() for k in ["who", "what", "when", "where"]), "web"),
(lambda q: True, "llm"),
]
)
db_node = FunctionNode(func=db_agent, name="db")
web_node = FunctionNode(func=web_search_agent, name="web")
llm_node = FunctionNode(func=llm_agent, name="llm")
# Build Graph
graph = Graph("shop-assistant")
graph.add_nodes([router_node, db_node, web_node, llm_node])
graph.add_edge("router", "db", condition=lambda r: r == "db")
graph.add_edge("router", "web", condition=lambda r: r == "web")
graph.add_edge("router", "llm", condition=lambda r: r == "llm")
async def graph_handler(query: str) -> str:
# If audio file path passed, convert to text first
if query.startswith("audio://"):
audio_path = query.replace("audio://", "")
query = stt_agent(audio_path)
text_resp = await graph.run(input=query, start_node="router")
return text_resp
def handle_query(audio_or_text: str):
# Determine output type
is_audio = audio_or_text.endswith('.wav') or audio_or_text.endswith('.mp3')
text_input = f"audio://{audio_or_text}" if is_audio else audio_or_text
text_resp = asyncio.run(graph_handler(text_input))
if is_audio:
# Return both text and audio
audio_path = tts_agent(text_resp)
return text_resp, audio_path
return text_resp
# --- Gradio UI ---
with gr.Blocks() as demo:
gr.Markdown("## Shop Voice-Box Assistant (Speech In/Out)")
inp = gr.Audio(source="microphone", type="filepath", label="Speak or type your question")
out_text = gr.Textbox(label="Answer (text)")
out_audio = gr.Audio(label="Answer (speech)")
submit = gr.Button("Submit")
# Examples
gr.Examples(
examples=[
["What is the max revenue product today?"],
["Who invented the light bulb?"],
["Tell me a joke about cats."],
],
inputs=inp,
outputs=[out_text, out_audio],
)
submit.click(fn=handle_query, inputs=inp, outputs=[out_text, out_audio])
if __name__ == "__main__":
demo.launch(share=False, server_name="0.0.0.0", server_port=7860)