testing_rag / agents /sql_agent.py
Bharath Gajula
sadas
42cabf2
raw
history blame contribute delete
990 Bytes
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.sql_database import SQLDatabase
import os
class SQLAgent:
def __init__(self, db_path: str):
self.db_path = db_path
# Create SQLDatabase instance
self.db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
self.llm = ChatGoogleGenerativeAI(
model="gemini-1.5-flash",
temperature=0,
google_api_key=os.getenv("GOOGLE_API_KEY")
)
# Create SQL toolkit and agent
toolkit = SQLDatabaseToolkit(db=self.db, llm=self.llm)
self.agent = create_sql_agent(
llm=self.llm,
toolkit=toolkit,
verbose=True
)
def query(self, question: str) -> str:
"""Run natural language query and return answer."""
return self.agent.run(question)