|
import os |
|
import sqlite3 |
|
import requests |
|
import openai |
|
import gradio as gr |
|
import asyncio |
|
from langgraph import Graph, FunctionNode, RouterNode |
|
from gtts import gTTS |
|
|
|
def stt_agent(audio_path: str) -> str: |
|
"""Convert speech to text using OpenAI Whisper API""" |
|
with open(audio_path, "rb") as afile: |
|
transcript = openai.audio.transcriptions.create( |
|
model="whisper-1", |
|
file=afile |
|
) |
|
return transcript.text.strip() |
|
|
|
|
|
openai.api_key = os.getenv("OPENAI_API_KEY") |
|
|
|
|
|
def db_agent(query: str) -> str: |
|
try: |
|
conn = sqlite3.connect("shop.db") |
|
cur = conn.cursor() |
|
if "max revenue" in query.lower(): |
|
cur.execute( |
|
""" |
|
SELECT product, SUM(amount) AS revenue |
|
FROM transactions |
|
WHERE date = date('now') |
|
GROUP BY product |
|
ORDER BY revenue DESC |
|
LIMIT 1 |
|
""" |
|
) |
|
row = cur.fetchone() |
|
if row: |
|
return f"Top product today: {row[0]} with ₹{row[1]:,.2f}" |
|
return "No transactions found for today." |
|
return None |
|
except sqlite3.OperationalError as e: |
|
return f"Database error: {e}. Please initialize 'transactions' table in shop.db." |
|
|
|
def web_search_agent(query: str) -> str: |
|
try: |
|
resp = requests.get( |
|
"https://serpapi.com/search", |
|
params={"q": query, "api_key": os.getenv("SERPAPI_KEY")} |
|
) |
|
snippet = resp.json().get("organic_results", [{}])[0].get("snippet", "").strip() |
|
if snippet: |
|
return llm_agent(f"Summarize: {snippet}") |
|
except Exception: |
|
pass |
|
return llm_agent(query) |
|
|
|
def llm_agent(query: str) -> str: |
|
response = openai.chat.completions.create( |
|
model="gpt-4o-mini", |
|
messages=[ |
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
{"role": "user", "content": query}, |
|
], |
|
temperature=0.2, |
|
) |
|
return response.choices[0].message.content.strip() |
|
|
|
|
|
|
|
def tts_agent(text: str, lang: str = 'en') -> str: |
|
"""Convert text to speech mp3 and return filepath""" |
|
tts = gTTS(text=text, lang=lang) |
|
out_path = "response_audio.mp3" |
|
tts.save(out_path) |
|
return out_path |
|
|
|
|
|
router_node = RouterNode( |
|
name="router", |
|
routes=[ |
|
(lambda q: any(k in q.lower() for k in ["max revenue", "revenue"]), "db"), |
|
(lambda q: any(k in q.lower() for k in ["who", "what", "when", "where"]), "web"), |
|
(lambda q: True, "llm"), |
|
] |
|
) |
|
|
|
db_node = FunctionNode(func=db_agent, name="db") |
|
web_node = FunctionNode(func=web_search_agent, name="web") |
|
llm_node = FunctionNode(func=llm_agent, name="llm") |
|
|
|
|
|
graph = Graph("shop-assistant") |
|
graph.add_nodes([router_node, db_node, web_node, llm_node]) |
|
graph.add_edge("router", "db", condition=lambda r: r == "db") |
|
graph.add_edge("router", "web", condition=lambda r: r == "web") |
|
graph.add_edge("router", "llm", condition=lambda r: r == "llm") |
|
|
|
async def graph_handler(query: str) -> str: |
|
|
|
if query.startswith("audio://"): |
|
audio_path = query.replace("audio://", "") |
|
query = stt_agent(audio_path) |
|
text_resp = await graph.run(input=query, start_node="router") |
|
return text_resp |
|
|
|
def handle_query(audio_or_text: str): |
|
|
|
is_audio = audio_or_text.endswith('.wav') or audio_or_text.endswith('.mp3') |
|
text_input = f"audio://{audio_or_text}" if is_audio else audio_or_text |
|
text_resp = asyncio.run(graph_handler(text_input)) |
|
if is_audio: |
|
|
|
audio_path = tts_agent(text_resp) |
|
return text_resp, audio_path |
|
return text_resp |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Shop Voice-Box Assistant (Speech In/Out)") |
|
inp = gr.Audio(source="microphone", type="filepath", label="Speak or type your question") |
|
out_text = gr.Textbox(label="Answer (text)") |
|
out_audio = gr.Audio(label="Answer (speech)") |
|
submit = gr.Button("Submit") |
|
|
|
gr.Examples( |
|
examples=[ |
|
["What is the max revenue product today?"], |
|
["Who invented the light bulb?"], |
|
["Tell me a joke about cats."], |
|
], |
|
inputs=inp, |
|
outputs=[out_text, out_audio], |
|
) |
|
submit.click(fn=handle_query, inputs=inp, outputs=[out_text, out_audio]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=False, server_name="0.0.0.0", server_port=7860) |