Spaces:
Sleeping
Sleeping
File size: 5,962 Bytes
10e9b7d 57f85de 2ba0de9 e25ef11 2ba0de9 57f85de bd5a767 57f85de e6bc26b 2ba0de9 e25ef11 1a8d658 7e4a06b 2ba0de9 e80aab9 2ba0de9 31243f4 2ba0de9 31243f4 2ba0de9 3c4371f 2ba0de9 e80aab9 2ba0de9 e25ef11 2ba0de9 e80aab9 2ba0de9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import os
import gradio as gr
from dotenv import load_dotenv
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from langchain_community.vectorstores import SupabaseVectorStore
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.tools import tool
from supabase import create_client, Client
# Load environment variables
load_dotenv()
# Tool definitions remain unchanged
@tool
def multiply(a: int, b: int) -> int:
return a * b
@tool
def add(a: int, b: int) -> int:
return a + b
@tool
def subtract(a: int, b: int) -> int:
return a - b
@tool
def divide(a: int, b: int) -> int:
if b == 0:
raise ValueError("Cannot divide by zero.")
return a / b
@tool
def modulus(a: int, b: int) -> int:
return a % b
@tool
def wiki_search(query: str) -> str:
search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
formatted_search_docs = "\n\n---\n\n".join(
[f'<Document source="{doc.metadata["source"]}"/>\n{doc.page_content}\n</Document>'
for doc in search_docs])
return {"wiki_results": formatted_search_docs}
@tool
def web_search(query: str) -> str:
search_docs = TavilySearchResults(max_results=3).invoke(query)
formatted_search_docs = "\n\n---\n\n".join(
[f'<Document source="{doc.metadata["source"]}"/>\n{doc.page_content}\n</Document>'
for doc in search_docs])
return {"web_results": formatted_search_docs}
@tool
def arvix_search(query: str) -> str:
search_docs = ArxivLoader(query=query, load_max_docs=3).load()
formatted_search_docs = "\n\n---\n\n".join(
[f'<Document source="{doc.metadata["source"]}"/>\n{doc.page_content[:1000]}\n</Document>'
for doc in search_docs])
return {"arvix_results": formatted_search_docs}
# System prompt definition
SYSTEM_PROMPT = """You are a helpful assistant. For every question, reply with only the answer—no explanation,
no units, and no extra words. If the answer is a number, just return the number.
If it is a word or phrase, return only that. If it is a list, return a comma-separated list with no extra words.
Do not include any prefix, suffix, or explanation."""
sys_msg = SystemMessage(content=SYSTEM_PROMPT)
# Initialize vector store
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
supabase: Client = create_client(
os.environ["SUPABASE_URL"],
os.environ["SUPABASE_SERVICE_KEY"]
)
vector_store = SupabaseVectorStore(
client=supabase,
embedding=embeddings,
table_name="documents",
)
tools = [multiply, add, subtract, divide, modulus,
wiki_search, web_search, arvix_search]
# Build graph function with multi-provider support
def build_graph(provider: str = "groq"):
# Provider selection
if provider == "google":
llm = ChatGoogleGenerativeAI(
model="gemini-2.0-flash",
temperature=0,
api_key=os.getenv("GOOGLE_API_KEY")
)
elif provider == "groq":
llm = ChatGroq(
model="llama3-70b-8192",
temperature=0,
api_key=os.getenv("GROQ_API_KEY")
)
elif provider == "huggingface":
llm = ChatHuggingFace(
llm=HuggingFaceEndpoint(
endpoint_url="https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2",
temperature=0,
api_key=os.getenv("HF_API_KEY")
)
)
else:
raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
llm_with_tools = llm.bind_tools(tools)
# Graph nodes
def retriever(state: MessagesState):
similar_question = vector_store.similarity_search(state["messages"][-1].content, k=1)
if similar_question:
example_msg = HumanMessage(content=f"Similar reference: {similar_question[0].page_content[:200]}...")
return {"messages": state["messages"] + [example_msg]}
return {"messages": state["messages"]}
def assistant(state: MessagesState):
return {"messages": [llm_with_tools.invoke(state["messages"])]}
# Build graph
builder = StateGraph(MessagesState)
builder.add_node("retriever", retriever)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "retriever")
builder.add_edge("retriever", "assistant")
builder.add_conditional_edges(
"assistant",
tools_condition,
)
builder.add_edge("tools", "assistant")
return builder.compile()
# Gradio interface
def run_agent(question, provider):
try:
graph = build_graph(provider)
messages = [HumanMessage(content=question)]
result = graph.invoke({"messages": messages})
final_answer = result["messages"][-1].content
return final_answer
except Exception as e:
return f"Error: {str(e)}"
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("## LangGraph Multi-Provider Agent")
provider = gr.Dropdown(
choices=["groq", "google", "huggingface"],
value="groq",
label="LLM Provider"
)
question = gr.Textbox(label="Your Question")
submit_btn = gr.Button("Run Agent")
output = gr.Textbox(label="Agent Response", interactive=False)
submit_btn.click(
fn=run_agent,
inputs=[question, provider],
outputs=output
)
if __name__ == "__main__":
demo.launch()
|