File size: 4,691 Bytes
b12e5fe
0729c66
 
 
b12e5fe
52ec688
 
 
 
 
 
 
 
 
 
 
 
b12e5fe
0729c66
 
b12e5fe
52ec688
0729c66
dd5aa4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52ec688
b12e5fe
0729c66
7b770ed
 
 
 
 
52ec688
7b770ed
 
 
 
 
b12e5fe
52ec688
dd5aa4f
0729c66
 
 
52ec688
0729c66
dd5aa4f
0729c66
dd5aa4f
b12e5fe
52ec688
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b12e5fe
52ec688
 
 
 
 
 
 
 
 
 
0729c66
 
 
52ec688
 
 
 
 
 
243aa91
 
 
 
 
 
52ec688
 
243aa91
52ec688
b12e5fe
 
dd5aa4f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import sqlite3
import requests
import openai
import gradio as gr
import asyncio
from langgraph import Graph, FunctionNode, RouterNode
from gtts import gTTS

def stt_agent(audio_path: str) -> str:
    """Convert speech to text using OpenAI Whisper API"""
    with open(audio_path, "rb") as afile:
        transcript = openai.audio.transcriptions.create(
            model="whisper-1",
            file=afile
        )
    return transcript.text.strip()

# Load API keys from environment
openai.api_key = os.getenv("OPENAI_API_KEY")

# --- Business Logic Functions ---
def db_agent(query: str) -> str:
    try:
        conn = sqlite3.connect("shop.db")
        cur = conn.cursor()
        if "max revenue" in query.lower():
            cur.execute(
                """
                SELECT product, SUM(amount) AS revenue
                FROM transactions
                WHERE date = date('now')
                GROUP BY product
                ORDER BY revenue DESC
                LIMIT 1
                """
            )
            row = cur.fetchone()
            if row:
                return f"Top product today: {row[0]} with ₹{row[1]:,.2f}"
            return "No transactions found for today."
        return None
    except sqlite3.OperationalError as e:
        return f"Database error: {e}. Please initialize 'transactions' table in shop.db."  

def web_search_agent(query: str) -> str:
    try:
        resp = requests.get(
            "https://serpapi.com/search",
            params={"q": query, "api_key": os.getenv("SERPAPI_KEY")}  
        )
        snippet = resp.json().get("organic_results", [{}])[0].get("snippet", "").strip()
        if snippet:
            return llm_agent(f"Summarize: {snippet}")
    except Exception:
        pass
    return llm_agent(query)

def llm_agent(query: str) -> str:
    response = openai.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": query},
        ],
        temperature=0.2,
    )
    return response.choices[0].message.content.strip()

# Text-to-Speech

def tts_agent(text: str, lang: str = 'en') -> str:
    """Convert text to speech mp3 and return filepath"""
    tts = gTTS(text=text, lang=lang)
    out_path = "response_audio.mp3"
    tts.save(out_path)
    return out_path

# --- LangGraph Multi-Agent Setup ---
router_node = RouterNode(
    name="router",
    routes=[
        (lambda q: any(k in q.lower() for k in ["max revenue", "revenue"]), "db"),
        (lambda q: any(k in q.lower() for k in ["who", "what", "when", "where"]), "web"),
        (lambda q: True, "llm"),
    ]
)

db_node = FunctionNode(func=db_agent, name="db")
web_node = FunctionNode(func=web_search_agent, name="web")
llm_node = FunctionNode(func=llm_agent, name="llm")

# Build Graph
graph = Graph("shop-assistant")
graph.add_nodes([router_node, db_node, web_node, llm_node])
graph.add_edge("router", "db", condition=lambda r: r == "db")
graph.add_edge("router", "web", condition=lambda r: r == "web")
graph.add_edge("router", "llm", condition=lambda r: r == "llm")

async def graph_handler(query: str) -> str:
    # If audio file path passed, convert to text first
    if query.startswith("audio://"):
        audio_path = query.replace("audio://", "")
        query = stt_agent(audio_path)
    text_resp = await graph.run(input=query, start_node="router")
    return text_resp

def handle_query(audio_or_text: str):
    # Determine output type
    is_audio = audio_or_text.endswith('.wav') or audio_or_text.endswith('.mp3')
    text_input = f"audio://{audio_or_text}" if is_audio else audio_or_text
    text_resp = asyncio.run(graph_handler(text_input))
    if is_audio:
        # Return both text and audio
        audio_path = tts_agent(text_resp)
        return text_resp, audio_path
    return text_resp

# --- Gradio UI ---
with gr.Blocks() as demo:
    gr.Markdown("## Shop Voice-Box Assistant (Speech In/Out)")
    inp = gr.Audio(source="microphone", type="filepath", label="Speak or type your question")
    out_text = gr.Textbox(label="Answer (text)")
    out_audio = gr.Audio(label="Answer (speech)")
    submit = gr.Button("Submit")
    # Examples
    gr.Examples(
        examples=[
            ["What is the max revenue product today?"],
            ["Who invented the light bulb?"],
            ["Tell me a joke about cats."],
        ],
        inputs=inp,
        outputs=[out_text, out_audio],
    )
    submit.click(fn=handle_query, inputs=inp, outputs=[out_text, out_audio])

if __name__ == "__main__":
    demo.launch(share=False, server_name="0.0.0.0", server_port=7860)