Spaces:
Sleeping
Sleeping
import streamlit as st | |
from langchain_core.messages import HumanMessage, AIMessage | |
def get_llm(provider, config): | |
"""Initialize the selected LLM with configuration""" | |
try: | |
if provider == "OpenAI": | |
from langchain_openai import ChatOpenAI | |
return ChatOpenAI( | |
api_key=config.get("api_key"), | |
model=config.get("model_name", "gpt-3.5-turbo") | |
) | |
elif provider == "Anthropic": | |
from langchain_anthropic import ChatAnthropic | |
return ChatAnthropic( | |
api_key=config.get("api_key"), | |
model=config.get("model_name", "claude-3-sonnet-20240229") | |
) | |
elif provider == "Gemini": | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
return ChatGoogleGenerativeAI( | |
google_api_key=config.get("api_key"), | |
model=config.get("model_name", "gemini-pro") | |
) | |
elif provider == "DeepSeek": | |
from langchain_openai import ChatOpenAI | |
return ChatOpenAI( | |
api_key=config.get("api_key"), | |
base_url=config.get("base_url", "https://api.deepseek.com/v1"), | |
model=config.get("model_name", "deepseek-chat") | |
) | |
elif provider == "Ollama": | |
from langchain_community.chat_models import ChatOllama | |
return ChatOllama( | |
base_url=config.get("base_url", "http://localhost:11434"), | |
model=config.get("model_name", "llama2") | |
) | |
else: | |
raise ValueError("Selected provider is not supported") | |
except ImportError as e: | |
st.error(f"Missing required package: {e}") | |
return None | |
# Initialize chat history | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# Sidebar configuration | |
with st.sidebar: | |
st.title("⚙️ LLM Configuration") | |
provider = st.selectbox( | |
"Select Provider", | |
["OpenAI", "Anthropic", "Gemini", "DeepSeek", "Ollama"] | |
) | |
config = {} | |
if provider in ["OpenAI", "Anthropic", "Gemini", "DeepSeek"]: | |
config["api_key"] = st.text_input( | |
f"{provider} API Key", | |
type="password", | |
help=f"Get your API key from {provider}'s platform" | |
) | |
if provider == "DeepSeek": | |
config["base_url"] = st.text_input( | |
"API Base URL", | |
"https://api.deepseek.com/v1" | |
) | |
# Model name input with provider-specific defaults | |
default_models = { | |
"OpenAI": "gpt-3.5-turbo", | |
"Anthropic": "claude-3-sonnet-20240229", | |
"Gemini": "gemini-pro", | |
"DeepSeek": "deepseek-chat" | |
} | |
config["model_name"] = st.text_input( | |
"Model Name", | |
value=default_models.get(provider, "") | |
) | |
elif provider == "Ollama": | |
config["model_name"] = st.text_input( | |
"Model Name", | |
value="llama2", | |
help="Make sure the model is available in your Ollama instance" | |
) | |
config["base_url"] = st.text_input( | |
"Ollama Base URL", | |
"http://localhost:11434", | |
help="URL where your Ollama server is running" | |
) | |
# Main chat interface | |
st.title("💬 LLM Chat Interface") | |
# Display chat messages | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# Handle user input | |
if prompt := st.chat_input("Type your message..."): | |
# Add user message to chat history | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
# Display user message | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
# Generate response | |
with st.spinner("Thinking..."): | |
try: | |
llm = get_llm(provider, config) | |
if llm is None: | |
st.error("Failed to initialize LLM. Check configuration.") | |
st.stop() | |
# Convert messages to LangChain format | |
lc_messages = [ | |
HumanMessage(content=msg["content"]) if msg["role"] == "user" | |
else AIMessage(content=msg["content"]) | |
for msg in st.session_state.messages | |
] | |
# Get LLM response | |
response = llm.invoke(lc_messages) | |
# Display and store assistant response | |
with st.chat_message("assistant"): | |
st.markdown(response.content) | |
st.session_state.messages.append( | |
{"role": "assistant", "content": response.content} | |
) | |
except Exception as e: | |
st.error(f"Error generating response: {str(e)}") |