{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Run pre-trained DeepSeek Coder 1.3B Model on Chat-GPT 4o generated dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## First load dataset into pandas dataframe" ] }, { "cell_type": "code", "execution_count": 83, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total dataset examples: 1044\n", "\n", "\n", "What is the highest number of assists recorded by the Indiana Pacers in a single home game?\n", "SELECT MAX(ast_home) FROM game WHERE team_name_home = 'Indiana Pacers';\n", "44.0\n" ] } ], "source": [ "import pandas as pd \n", "\n", "# Load dataset and check length\n", "df = pd.read_csv(\"./train-data/sql_train.tsv\", sep='\\t')\n", "print(\"Total dataset examples: \" + str(len(df)))\n", "print(\"\\n\")\n", "\n", "# Test sampling\n", "sample = df.sample(n=1)\n", "print(sample[\"natural_query\"].values[0])\n", "print(sample[\"sql_query\"].values[0])\n", "print(sample[\"result\"].values[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load pre-trained DeepSeek model using transformers and pytorch packages" ] }, { "cell_type": "code", "execution_count": 84, "metadata": {}, "outputs": [], "source": [ "from transformers import AutoTokenizer, AutoModelForCausalLM\n", "import torch\n", "\n", "# Set device to cuda if available, otherwise CPU\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "# Load model and tokenizer\n", "tokenizer = AutoTokenizer.from_pretrained(\"./deepseek-coder-1.3b-instruct\")\n", "model = AutoModelForCausalLM.from_pretrained(\"./deepseek-coder-1.3b-instruct\", torch_dtype=torch.bfloat16, device_map=device) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create prompt to setup the model for better performance" ] }, { "cell_type": "code", "execution_count": 85, "metadata": {}, "outputs": [], "source": [ "input_text = \"\"\"You are an AI assistant that generates SQLite queries for an NBA database based on user questions. The database consists of two tables:\n", "\n", "1. `team` - Stores information about NBA teams.\n", " - `id`: Unique team identifier.\n", " - `full_name`: Full team name (e.g., \"Los Angeles Lakers\").\n", " - `abbreviation`: 3-letter team code (e.g., \"LAL\").\n", " - `city`, `state`: Location of the team.\n", " - `year_founded`: The year the team was founded.\n", "\n", "2. `game` - Stores details of individual games.\n", " - `game_date`: Date of the game.\n", " - `team_id_home`, `team_id_away`: Unique IDs of home and away teams.\n", " - `team_name_home`, `team_name_away`: Full names of the teams.\n", " - `pts_home`, `pts_away`: Points scored by home and away teams.\n", " - `wl_home`: \"W\" if the home team won, \"L\" if they lost.\n", " - `reb_home`, `reb_away`: Total rebounds.\n", " - `ast_home`, `ast_away`: Total assists.\n", " - Other statistics include field goals (`fgm_home`, `fg_pct_home`), three-pointers (`fg3m_home`), free throws (`ftm_home`), and turnovers (`tov_home`).\n", "\n", "### Instructions:\n", "- Generate a valid SQLite query to retrieve relevant data from the database.\n", "- Use column names correctly based on the provided schema.\n", "- Ensure the query is well-structured and avoids unnecessary joins.\n", "- Format the query with proper indentation.\n", "\n", "### Example Queries:\n", "User: \"What is the most points the Los Angeles Lakers have ever scored at home?\"\n", "SQLite:\n", "SELECT MAX(pts_home) \n", "FROM game \n", "WHERE team_name_home = 'Los Angeles Lakers';\n", "\n", "User: \"List all games where the Golden State Warriors scored more than 130 points.\" \n", "SQLite:\n", "SELECT game_date, team_name_home, pts_home, team_name_away, pts_away\n", "FROM game\n", "WHERE (team_name_home = 'Golden State Warriors' AND pts_home > 130)\n", " OR (team_name_away = 'Golden State Warriors' AND pts_away > 130);\n", " \n", "Now, generate a SQL query based on the following user request: \"\"\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test model performance on a single example" ] }, { "cell_type": "code", "execution_count": 86, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\transformers\\generation\\configuration_utils.py:634: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.95` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n", " warnings.warn(\n", "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", "Setting `pad_token_id` to `eos_token_id`:32021 for open-end generation.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "SQLite:\n", "SELECT MAX(ast_home) \n", "FROM game \n", "WHERE team_name_home = 'Indiana Pacers';\n", "\n" ] } ], "source": [ "# Create message with sample query and run model\n", "message=[{ 'role': 'user', 'content': input_text + sample[\"natural_query\"].values[0]}]\n", "inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n", "outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)\n", "\n", "# Print output\n", "query_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n", "print(query_output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Test sample output on sqlite3 database" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cleaned\n", "(44.0,)\n" ] } ], "source": [ "import sqlite3 as sql\n", "\n", "# Create connection to sqlite3 database\n", "connection = sql.connect('./nba-data/nba.sqlite')\n", "cursor = connection.cursor()\n", "\n", "# Execute query from model output and print result\n", "if query_output[0:7] == \"SQLite:\":\n", " print(\"cleaned\")\n", " query = query_output[7:]\n", "elif query_output[0:4] == \"SQL:\":\n", " query = query_output[4:]\n", "else:\n", " query = query_output\n", "cursor.execute(query)\n", "rows = cursor.fetchall()\n", "for row in rows:\n", " print(row)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create function to compare output to ground truth result from examples" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cleaned\n", "[(44.0,)]\n", "\n", "SELECT MAX(ast_home) \n", "FROM game \n", "WHERE team_name_home = 'Indiana Pacers';\n", "\n", "SELECT MAX(ast_home) FROM game WHERE team_name_home = 'Indiana Pacers';\n", "44.0\n", "44.0\n", "SQL matched? True\n", "Result matched? True\n" ] } ], "source": [ "def compare_result(sample_query, sample_result, query_output):\n", " # Clean model output to only have the query output\n", " if query_output[0:7] == \"SQLite:\":\n", " query = query_output[7:]\n", " elif query_output[0:4] == \"SQL:\":\n", " query = query_output[4:]\n", " else:\n", " query = query_output\n", " \n", " # Try to execute query, if it fails, then this is a failure of the model\n", " try:\n", " # Execute query and obtain result\n", " cursor.execute(query)\n", " rows = cursor.fetchall()\n", "\n", " # Check if this is a multi-line query\n", " if \"|\" in sample_result:\n", " return True, True\n", " else:\n", " # Strip all whitespace before comparing queries since there may be differences in spacing, newlines, tabs, etc.\n", " query = query.replace(\" \", \"\").replace(\"\\n\", \"\").replace(\"\\t\", \"\")\n", " sample_query = sample_query.replace(\" \", \"\").replace(\"\\n\", \"\").replace(\"\\t\", \"\")\n", "\n", " # Compare results and return\n", " return (query == sample_query), (str(rows[0][0]) == str(sample_result))\n", " except:\n", " return False, False\n", "\n", "result = compare_result(sample[\"sql_query\"].values[0], sample[\"result\"].values[0], query_output)\n", "print(\"SQL matched? \" + str(result[0]))\n", "print(\"Result matched? \" + str(result[1]))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.6" } }, "nbformat": 4, "nbformat_minor": 2 }