Spaces:
Running
Running
import gradio as gr | |
from tools import FreightAgent, EXAMPLE_QUERIES | |
from utils import initialize_database | |
from smolagents import CodeAgent, OpenAIServerModel | |
import os | |
from dotenv import load_dotenv | |
from sql_data import sql_query, get_schema, get_csv_as_dataframe | |
import pandas as pd | |
from sqlalchemy import create_engine, text | |
# Create database engine | |
engine = create_engine("sqlite:///freights.db") | |
# Load environment variables | |
load_dotenv() | |
# Initialize the database if it doesn't exist | |
if not os.path.exists("freights.db"): | |
csv_url = "https://huggingface.co/datasets/sasu-SpidR/fretmaritime/resolve/main/freights.csv" | |
initialize_database(csv_url) | |
# Create the main agent | |
model_id = "gpt-4.1-mini" | |
model = OpenAIServerModel(model_id=model_id, api_key=os.environ["OPENAI_API_KEY"]) | |
agent = CodeAgent(tools=[sql_query, get_schema, get_csv_as_dataframe], model=model) | |
def sql_query(query: str) -> str: | |
""" | |
Allows you to perform SQL queries on the freights table. Returns a string representation of the result. | |
The table is named 'freights'. Its description is as follows: | |
Columns: | |
- departure: DateTime (Date and time of departure) | |
- origin_port_locode: String (Origin port code) | |
- origin_port_name: String (Name of the origin port) | |
- destination_port: String (Destination port code) | |
- destination_port_name: String (Name of the destination port) | |
- dv20rate: Float (Rate for 20ft container in USD) | |
- dv40rate: Float (Rate for 40ft container in USD) | |
- currency: String (Currency of the rates) | |
- inserted_on: DateTime (Date when the rate was inserted) | |
Args: | |
query: The query to perform. This should be correct SQL. | |
Returns: | |
A string representation of the result of the query. | |
""" | |
try: | |
with engine.connect() as con: | |
result = con.execute(text(query)) | |
rows = [dict(row._mapping) for row in result] | |
if not rows: | |
return "Aucun résultat trouvé." | |
# Convert to markdown table | |
headers = list(rows[0].keys()) | |
table = "| " + " | ".join(headers) + " |\n" | |
table += "| " + " | ".join(["---" for _ in headers]) + " |\n" | |
for row in rows: | |
table += "| " + " | ".join(str(row[h]) for h in headers) + " |\n" | |
return table | |
except Exception as e: | |
return f"Error executing query: {str(e)}" | |
def get_schema() -> str: | |
""" | |
Returns the schema of the freights table. | |
""" | |
return """ | |
Table: freights | |
Columns: | |
- departure: DateTime (Date and time of departure) | |
- origin_port_locode: String (Origin port code) | |
- origin_port_name: String (Name of the origin port) | |
- destination_port: String (Destination port code) | |
- destination_port_name: String (Name of the destination port) | |
- dv20rate: Float (Rate for 20ft container in USD) | |
- dv40rate: Float (Rate for 40ft container in USD) | |
- currency: String (Currency of the rates) | |
- inserted_on: DateTime (Date when the rate was inserted) | |
""" | |
def get_csv_as_dataframe() -> str: | |
""" | |
Returns a string representation of the freights table as a CSV file. | |
""" | |
df = pd.read_sql_table("freights", engine) | |
return df.to_csv(index=False) | |
def run_agent(question: str) -> str: | |
""" | |
Run the agent with the given question. | |
This ReAct Agent can make request to give you information about the freight data. | |
Args: | |
question: The question to run the agent with. | |
Returns: | |
The response of the agent. | |
""" | |
return agent.run(question,max_steps=5) | |
if __name__ == "__main__": | |
gr.Interface( | |
fn=run_agent, | |
inputs=gr.Textbox(lines=7, label="Question"), | |
outputs=gr.Textbox(), | |
title="Freight Agent MCP", | |
description="Ask a question about the freight data in natural language", | |
examples=EXAMPLE_QUERIES if "EXAMPLE_QUERIES" in globals() else None | |
).launch(mcp_server=True) | |