File size: 990 Bytes
42cabf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.sql_database import SQLDatabase
import os

class SQLAgent:
    def __init__(self, db_path: str):
        self.db_path = db_path
        
        # Create SQLDatabase instance
        self.db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
        
        self.llm = ChatGoogleGenerativeAI(
            model="gemini-1.5-flash", 
            temperature=0,
            google_api_key=os.getenv("GOOGLE_API_KEY")
        )
        
        # Create SQL toolkit and agent
        toolkit = SQLDatabaseToolkit(db=self.db, llm=self.llm)
        self.agent = create_sql_agent(
            llm=self.llm,
            toolkit=toolkit,
            verbose=True
        )
    
    def query(self, question: str) -> str:
        """Run natural language query and return answer."""
        return self.agent.run(question)