File size: 10,702 Bytes
75a63b2
bc2875c
75a63b2
6e59463
bc2875c
76b6f27
e71d648
bc2875c
 
 
 
 
76b6f27
bc2875c
 
 
 
 
f9b91f5
89ac1a2
 
 
 
 
 
 
7a83934
89ac1a2
 
 
 
 
 
 
 
 
 
 
 
 
7a83934
89ac1a2
 
 
 
 
 
 
 
7a83934
89ac1a2
 
7a83934
 
 
 
89ac1a2
 
 
 
 
 
 
7a83934
89ac1a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a83934
89ac1a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75e738f
89ac1a2
 
 
 
 
 
7a83934
89ac1a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a83934
 
 
 
 
 
 
89ac1a2
7a83934
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89ac1a2
7a83934
 
 
 
 
89ac1a2
7a83934
 
 
 
 
 
 
75e738f
 
7a83934
89ac1a2
 
 
 
 
 
75e738f
7a83934
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75e738f
 
 
 
 
 
 
7a83934
75e738f
7a83934
75e738f
 
 
 
 
e71d648
75e738f
 
7a83934
89ac1a2
75e738f
 
 
 
 
 
76b6f27
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
import os
import logging
import httpx
import json
from dotenv import load_dotenv
import gradio as gr
from typing import AsyncGenerator, List, Dict, Tuple

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Load environment variables
load_dotenv()
logger.info("Environment variables loaded from .env file")
logger.info(f"OPENAI_API_KEY present: {'OPENAI_API_KEY' in os.environ}")
logger.info(f"ANTHROPIC_API_KEY present: {'ANTHROPIC_API_KEY' in os.environ}")
logger.info(f"GEMINI_API_KEY present: {'GEMINI_API_KEY' in os.environ}")

async def ask_openai(query: str, history: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
    openai_api_key = os.getenv("OPENAI_API_KEY")
    if not openai_api_key:
        logger.error("OpenAI API key not provided")
        yield "Error: OpenAI API key not provided."
        return

    # Build message history with user and assistant roles
    messages = []
    for msg in history:
        messages.append({"role": "user", "content": msg["user"]})
        if msg["bot"]:
            messages.append({"role": "assistant", "content": msg["bot"]})
    messages.append({"role": "user", "content": query})

    headers = {
        "Authorization": f"Bearer {openai_api_key}",
        "Content-Type": "application/json"
    }

    payload = {
        "model": "gpt-3.5-turbo",
        "messages": messages,
        "stream": True
    }

    try:
        async with httpx.AsyncClient() as client:
            async with client.stream("POST", "https://api.openai.com/v1/chat/completions", headers=headers, json=payload) as response:
                response.raise_for_status()
                buffer = ""
                async for chunk in response.aiter_text():
                    if chunk:
                        buffer += chunk
                        # Process complete JSON lines
                        while "\n" in buffer:
                            line, buffer = buffer.split("\n", 1)
                            if line.startswith("data: "):
                                data = line[6:]  # Remove "data: " prefix
                                if data == "[DONE]":
                                    break
                                if not data.strip():
                                    continue
                                try:
                                    json_data = json.loads(data)
                                    if "choices" in json_data and json_data["choices"]:
                                        delta = json_data["choices"][0].get("delta", {})
                                        if "content" in delta and delta["content"] is not None:
                                            yield delta["content"]
                                except json.JSONDecodeError as e:
                                    logger.error(f"Error parsing OpenAI stream chunk: {str(e)} - Data: {data}")
                                    yield f"Error parsing stream: {str(e)}"
                                except Exception as e:
                                    logger.error(f"Unexpected error in OpenAI stream: {str(e)} - Data: {data}")
                                    yield f"Error in stream: {str(e)}"

    except httpx.HTTPStatusError as e:
        response_text = await e.response.aread()
        logger.error(f"OpenAI HTTP Status Error: {e.response.status_code}, {response_text}")
        yield f"Error: OpenAI HTTP Status Error: {e.response.status_code}, {response_text.decode('utf-8')}"
    except Exception as e:
        logger.error(f"OpenAI Error: {str(e)}")
        yield f"Error: OpenAI Error: {str(e)}"

async def ask_anthropic(query: str, history: List[Dict[str, str]]) -> str:
    anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
    if not anthropic_api_key:
        logger.error("Anthropic API key not provided")
        return "Error: Anthropic API key not provided."

    # Build message history with user and assistant roles
    messages = []
    for msg in history:
        messages.append({"role": "user", "content": msg["user"]})
        if msg["bot"]:
            messages.append({"role": "assistant", "content": msg["bot"]})
    messages.append({"role": "user", "content": query})

    headers = {
        "x-api-key": anthropic_api_key,
        "anthropic-version": "2023-06-01",
        "Content-Type": "application/json"
    }

    payload = {
        "model": "claude-3-5-sonnet-20241022",
        "max_tokens": 1024,
        "messages": messages
    }

    try:
        async with httpx.AsyncClient() as client:
            logger.info(f"Sending Anthropic request: {payload}")
            response = await client.post("https://api.anthropic.com/v1/messages", headers=headers, json=payload)
        
        response.raise_for_status()
        logger.info(f"Anthropic response: {response.json()}")
        return response.json()['content'][0]['text']
    except httpx.HTTPStatusError as e:
        logger.error(f"Anthropic HTTP Status Error: {e.response.status_code}, {e.response.text}")
        return f"Error: Anthropic HTTP Status Error: {e.response.status_code}, {e.response.text}"
    except Exception as e:
        logger.error(f"Anthropic Error: {str(e)}")
        return f"Error: Anthropic Error: {str(e)}"

async def ask_gemini(query: str, history: List[Dict[str, str]]) -> str:
    gemini_api_key = os.getenv("GEMINI_API_KEY")
    if not gemini_api_key:
        logger.error("Gemini API key not provided")
        return "Error: Gemini API key not provided."

    # Concatenate history as text for Gemini
    history_text = ""
    for msg in history:
        history_text += f"User: {msg['user']}\nAssistant: {msg['bot']}\n" if msg["bot"] else f"User: {msg['user']}\n"
    full_query = history_text + f"User: {query}\n"

    headers = {
        "Content-Type": "application/json"
    }

    payload = {
        "contents": [{"parts": [{"text": full_query}]}]
    }

    try:
        async with httpx.AsyncClient() as client:
            response = await client.post(
                f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={gemini_api_key}",
                headers=headers,
                json=payload
            )
        
        response.raise_for_status()
        return response.json()['candidates'][0]['content']['parts'][0]['text']
    except httpx.HTTPStatusError as e:
        logger.error(f"Gemini HTTP Status Error: {e.response.status_code}, {e.response.text}")
        return f"Error: Gemini HTTP Status Error: {e.response.status_code}, {e.response.text}"
    except Exception as e:
        logger.error(f"Gemini Error: {str(e)}")
        return f"Error: Gemini Error: {str(e)}"

async def query_model(query: str, providers: List[str], history: List[Dict[str, str]]) -> AsyncGenerator[Tuple[str, List[Dict[str, str]]], None]:
    logger.info(f"Processing query with providers: {providers}")
    responses = []  # To collect responses from each provider
    streaming_response = ""

    # Handle OpenAI (streaming)
    if "OpenAI" in providers:
        async for chunk in ask_openai(query, history):
            streaming_response += chunk
            # Yield streaming updates for OpenAI
            chatbot_messages = []
            for msg in history:
                chatbot_messages.append({"role": "user", "content": msg["user"]})
                if msg["bot"]:
                    chatbot_messages.append({"role": "assistant", "content": msg["bot"]})
            chatbot_messages.append({"role": "user", "content": query})
            chatbot_messages.append({"role": "assistant", "content": streaming_response})
            yield "", chatbot_messages, history  # Yield partial updates
        if streaming_response.strip():
            responses.append(f"[OpenAI]: {streaming_response}")

    # Handle Anthropic (non-streaming)
    if "Anthropic" in providers:
        response = await ask_anthropic(query, history)
        if response.strip():
            responses.append(f"[Anthropic]: {response}")

    # Handle Gemini (non-streaming)
    if "Gemini" in providers:
        response = await ask_gemini(query, history)
        if response.strip():
            responses.append(f"[Gemini]: {response}")

    # Combine responses
    combined_response = "\n\n".join(responses) if responses else "No valid responses received."
    # Update history with the combined response
    updated_history = history + [{"user": query, "bot": combined_response}]
    logger.info(f"Updated history: {updated_history}")

    # Yield final response
    chatbot_messages = []
    for msg in updated_history:
        chatbot_messages.append({"role": "user", "content": msg["user"]})
        if msg["bot"]:
            chatbot_messages.append({"role": "assistant", "content": msg["bot"]})
    yield "", chatbot_messages, updated_history

async def submit_query(query: str, providers: List[str], history: List[Dict[str, str]]) -> AsyncGenerator[Tuple[str, List[Dict[str, str]], List[Dict[str, str]]], None]:
    if not query.strip():
        chatbot_messages = [{"role": "assistant", "content": "Please enter a query."}]
        yield "", chatbot_messages, history
        return
    
    if not providers:
        chatbot_messages = [{"role": "assistant", "content": "Please select at least one provider."}]
        yield "", chatbot_messages, history
        return

    async for response_chunk, chatbot_messages, updated_history in query_model(query, providers, history):
        yield "", chatbot_messages, updated_history  # Keep query textbox unchanged during streaming
    # Final yield to clear the query textbox
    yield "", chatbot_messages, updated_history

# Gradio interface
def clear_history():
    return [], []

# Define Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# Multi-Model Chat")
    gr.Markdown("Chat with OpenAI, Anthropic, or Gemini. Select providers and start typing!")
    
    providers = gr.CheckboxGroup(choices=["OpenAI", "Anthropic", "Gemini"], label="Select Providers", value=["OpenAI"])
    history_state = gr.State(value=[])
    chatbot = gr.Chatbot(label="Conversation", type="messages")
    query = gr.Textbox(label="Enter your query", placeholder="e.g., What is the capital of the United States?")
    submit_button = gr.Button("Submit")
    clear_button = gr.Button("Clear History")

    submit_button.click(
        fn=submit_query,
        inputs=[query, providers, history_state],
        outputs=[query, chatbot, history_state]
    )
    clear_button.click(
        fn=clear_history,
        inputs=[],
        outputs=[chatbot, history_state]
    )

# Launch the Gradio app
demo.launch()