Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -57,6 +57,7 @@ print(traceback.format_exc())
|
|
| 57 |
# =============================================================================
|
| 58 |
# create_agent_app: Given a database path, initialize the agent workflow.
|
| 59 |
# =============================================================================
|
|
|
|
| 60 |
def create_agent_app(db_path: str):
|
| 61 |
# Use ChatGroq as our LLM here; you can swap to ChatMistralAI if preferred.
|
| 62 |
from langchain_groq import ChatGroq
|
|
@@ -141,6 +142,7 @@ def create_agent_app(db_path: str):
|
|
| 141 |
# -------------------------------------------------------------------------
|
| 142 |
# Update database URI and file path, create SQLDatabase connection.
|
| 143 |
# -------------------------------------------------------------------------
|
|
|
|
| 144 |
abs_db_path_local = os.path.abspath(db_path)
|
| 145 |
global DATABASE_URI
|
| 146 |
DATABASE_URI = abs_db_path_local
|
|
@@ -157,6 +159,7 @@ def create_agent_app(db_path: str):
|
|
| 157 |
# -------------------------------------------------------------------------
|
| 158 |
# Create SQL toolkit.
|
| 159 |
# -------------------------------------------------------------------------
|
|
|
|
| 160 |
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
| 161 |
toolkit_instance = SQLDatabaseToolkit(db=db_instance, llm=llm)
|
| 162 |
tools_instance = toolkit_instance.get_tools()
|
|
@@ -164,6 +167,7 @@ def create_agent_app(db_path: str):
|
|
| 164 |
# -------------------------------------------------------------------------
|
| 165 |
# Define workflow nodes and fallback functions.
|
| 166 |
# -------------------------------------------------------------------------
|
|
|
|
| 167 |
def first_tool_call(state: State) -> dict[str, list[AIMessage]]:
|
| 168 |
return {"messages": [AIMessage(content="", tool_calls=[{"name": "sql_db_list_tables", "args": {}, "id": "tool_abcd123"}])]}
|
| 169 |
|
|
@@ -205,6 +209,7 @@ def create_agent_app(db_path: str):
|
|
| 205 |
# -------------------------------------------------------------------------
|
| 206 |
# Get tools for listing tables and fetching schema.
|
| 207 |
# -------------------------------------------------------------------------
|
|
|
|
| 208 |
list_tables_tool = next((tool for tool in tools_instance if tool.name == "sql_db_list_tables"), None)
|
| 209 |
get_schema_tool = next((tool for tool in tools_instance if tool.name == "sql_db_schema"), None)
|
| 210 |
|
|
@@ -234,6 +239,7 @@ def create_agent_app(db_path: str):
|
|
| 234 |
# =============================================================================
|
| 235 |
# create_app: The application factory.
|
| 236 |
# =============================================================================
|
|
|
|
| 237 |
def create_app():
|
| 238 |
flask_app = Flask(__name__, static_url_path='/uploads', static_folder='uploads')
|
| 239 |
socketio = SocketIO(flask_app, cors_allowed_origins="*")
|
|
@@ -246,6 +252,7 @@ def create_app():
|
|
| 246 |
# -------------------------------------------------------------------------
|
| 247 |
# Serve uploaded files via a custom route.
|
| 248 |
# -------------------------------------------------------------------------
|
|
|
|
| 249 |
@flask_app.route("/files/<path:filename>")
|
| 250 |
def uploaded_file(filename):
|
| 251 |
return send_from_directory(flask_app.config['UPLOAD_FOLDER'], filename)
|
|
@@ -253,6 +260,7 @@ def create_app():
|
|
| 253 |
# -------------------------------------------------------------------------
|
| 254 |
# Helper: run_agent runs the agent with the given prompt.
|
| 255 |
# -------------------------------------------------------------------------
|
|
|
|
| 256 |
def run_agent(prompt, socketio):
|
| 257 |
global agent_app, abs_file_path, db_path
|
| 258 |
if not abs_file_path:
|
|
@@ -284,6 +292,7 @@ def create_app():
|
|
| 284 |
# -------------------------------------------------------------------------
|
| 285 |
# Route: index page.
|
| 286 |
# -------------------------------------------------------------------------
|
|
|
|
| 287 |
@flask_app.route("/")
|
| 288 |
def index():
|
| 289 |
return render_template("index.html")
|
|
@@ -291,6 +300,7 @@ def create_app():
|
|
| 291 |
# -------------------------------------------------------------------------
|
| 292 |
# Route: generate (POST) – receives a prompt and runs the agent.
|
| 293 |
# -------------------------------------------------------------------------
|
|
|
|
| 294 |
@flask_app.route("/generate", methods=["POST"])
|
| 295 |
def generate():
|
| 296 |
try:
|
|
@@ -310,6 +320,7 @@ def create_app():
|
|
| 310 |
# -------------------------------------------------------------------------
|
| 311 |
# Route: upload (GET/POST) – handles uploading the SQLite DB file.
|
| 312 |
# -------------------------------------------------------------------------
|
|
|
|
| 313 |
@flask_app.route("/upload", methods=["GET", "POST"])
|
| 314 |
def upload():
|
| 315 |
global abs_file_path, agent_app, db_path
|
|
@@ -339,6 +350,7 @@ def create_app():
|
|
| 339 |
# =============================================================================
|
| 340 |
# Create the app for Gunicorn compatibility.
|
| 341 |
# =============================================================================
|
|
|
|
| 342 |
app, socketio_instance = create_app()
|
| 343 |
|
| 344 |
if __name__ == "__main__":
|
|
|
|
| 57 |
# =============================================================================
|
| 58 |
# create_agent_app: Given a database path, initialize the agent workflow.
|
| 59 |
# =============================================================================
|
| 60 |
+
|
| 61 |
def create_agent_app(db_path: str):
|
| 62 |
# Use ChatGroq as our LLM here; you can swap to ChatMistralAI if preferred.
|
| 63 |
from langchain_groq import ChatGroq
|
|
|
|
| 142 |
# -------------------------------------------------------------------------
|
| 143 |
# Update database URI and file path, create SQLDatabase connection.
|
| 144 |
# -------------------------------------------------------------------------
|
| 145 |
+
|
| 146 |
abs_db_path_local = os.path.abspath(db_path)
|
| 147 |
global DATABASE_URI
|
| 148 |
DATABASE_URI = abs_db_path_local
|
|
|
|
| 159 |
# -------------------------------------------------------------------------
|
| 160 |
# Create SQL toolkit.
|
| 161 |
# -------------------------------------------------------------------------
|
| 162 |
+
|
| 163 |
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
| 164 |
toolkit_instance = SQLDatabaseToolkit(db=db_instance, llm=llm)
|
| 165 |
tools_instance = toolkit_instance.get_tools()
|
|
|
|
| 167 |
# -------------------------------------------------------------------------
|
| 168 |
# Define workflow nodes and fallback functions.
|
| 169 |
# -------------------------------------------------------------------------
|
| 170 |
+
|
| 171 |
def first_tool_call(state: State) -> dict[str, list[AIMessage]]:
|
| 172 |
return {"messages": [AIMessage(content="", tool_calls=[{"name": "sql_db_list_tables", "args": {}, "id": "tool_abcd123"}])]}
|
| 173 |
|
|
|
|
| 209 |
# -------------------------------------------------------------------------
|
| 210 |
# Get tools for listing tables and fetching schema.
|
| 211 |
# -------------------------------------------------------------------------
|
| 212 |
+
|
| 213 |
list_tables_tool = next((tool for tool in tools_instance if tool.name == "sql_db_list_tables"), None)
|
| 214 |
get_schema_tool = next((tool for tool in tools_instance if tool.name == "sql_db_schema"), None)
|
| 215 |
|
|
|
|
| 239 |
# =============================================================================
|
| 240 |
# create_app: The application factory.
|
| 241 |
# =============================================================================
|
| 242 |
+
|
| 243 |
def create_app():
|
| 244 |
flask_app = Flask(__name__, static_url_path='/uploads', static_folder='uploads')
|
| 245 |
socketio = SocketIO(flask_app, cors_allowed_origins="*")
|
|
|
|
| 252 |
# -------------------------------------------------------------------------
|
| 253 |
# Serve uploaded files via a custom route.
|
| 254 |
# -------------------------------------------------------------------------
|
| 255 |
+
|
| 256 |
@flask_app.route("/files/<path:filename>")
|
| 257 |
def uploaded_file(filename):
|
| 258 |
return send_from_directory(flask_app.config['UPLOAD_FOLDER'], filename)
|
|
|
|
| 260 |
# -------------------------------------------------------------------------
|
| 261 |
# Helper: run_agent runs the agent with the given prompt.
|
| 262 |
# -------------------------------------------------------------------------
|
| 263 |
+
|
| 264 |
def run_agent(prompt, socketio):
|
| 265 |
global agent_app, abs_file_path, db_path
|
| 266 |
if not abs_file_path:
|
|
|
|
| 292 |
# -------------------------------------------------------------------------
|
| 293 |
# Route: index page.
|
| 294 |
# -------------------------------------------------------------------------
|
| 295 |
+
|
| 296 |
@flask_app.route("/")
|
| 297 |
def index():
|
| 298 |
return render_template("index.html")
|
|
|
|
| 300 |
# -------------------------------------------------------------------------
|
| 301 |
# Route: generate (POST) – receives a prompt and runs the agent.
|
| 302 |
# -------------------------------------------------------------------------
|
| 303 |
+
|
| 304 |
@flask_app.route("/generate", methods=["POST"])
|
| 305 |
def generate():
|
| 306 |
try:
|
|
|
|
| 320 |
# -------------------------------------------------------------------------
|
| 321 |
# Route: upload (GET/POST) – handles uploading the SQLite DB file.
|
| 322 |
# -------------------------------------------------------------------------
|
| 323 |
+
|
| 324 |
@flask_app.route("/upload", methods=["GET", "POST"])
|
| 325 |
def upload():
|
| 326 |
global abs_file_path, agent_app, db_path
|
|
|
|
| 350 |
# =============================================================================
|
| 351 |
# Create the app for Gunicorn compatibility.
|
| 352 |
# =============================================================================
|
| 353 |
+
|
| 354 |
app, socketio_instance = create_app()
|
| 355 |
|
| 356 |
if __name__ == "__main__":
|