# Run pre-trained DeepSeek Coder 1.3B Model on Chat-GPT 4o generated dataset

## First load dataset into pandas dataframe

In [83]:
import pandas as pd 

# Load dataset and check length
df = pd.read_csv("./train-data/sql_train.tsv", sep='\t')
print("Total dataset examples: " + str(len(df)))
print("\n")

# Test sampling
sample = df.sample(n=1)
print(sample["natural_query"].values[0])
print(sample["sql_query"].values[0])
print(sample["result"].values[0])

Total dataset examples: 1044


What is the highest number of assists recorded by the Indiana Pacers in a single home game?
SELECT MAX(ast_home) FROM game WHERE team_name_home = 'Indiana Pacers';
44.0


## Load pre-trained DeepSeek model using transformers and pytorch packages

In [84]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Set device to cuda if available, otherwise CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("./deepseek-coder-1.3b-instruct")
model = AutoModelForCausalLM.from_pretrained("./deepseek-coder-1.3b-instruct", torch_dtype=torch.bfloat16, device_map=device) 

## Create prompt to setup the model for better performance

In [85]:
input_text = """You are an AI assistant that generates SQLite queries for an NBA database based on user questions. The database consists of two tables:

1. `team` - Stores information about NBA teams.
 - `id`: Unique team identifier.
 - `full_name`: Full team name (e.g., "Los Angeles Lakers").
 - `abbreviation`: 3-letter team code (e.g., "LAL").
 - `city`, `state`: Location of the team.
 - `year_founded`: The year the team was founded.

2. `game` - Stores details of individual games.
 - `game_date`: Date of the game.
 - `team_id_home`, `team_id_away`: Unique IDs of home and away teams.
 - `team_name_home`, `team_name_away`: Full names of the teams.
 - `pts_home`, `pts_away`: Points scored by home and away teams.
 - `wl_home`: "W" if the home team won, "L" if they lost.
 - `reb_home`, `reb_away`: Total rebounds.
 - `ast_home`, `ast_away`: Total assists.
 - Other statistics include field goals (`fgm_home`, `fg_pct_home`), three-pointers (`fg3m_home`), free throws (`ftm_home`), and turnovers (`tov_home`).

### Instructions:
- Generate a valid SQLite query to retrieve relevant data from the database.
- Use column names correctly based on the provided schema.
- Ensure the query is well-structured and avoids unnecessary joins.
- Format the query with proper indentation.

### Example Queries:
User: "What is the most points the Los Angeles Lakers have ever scored at home?"
SQLite:
SELECT MAX(pts_home) 
FROM game 
WHERE team_name_home = 'Los Angeles Lakers';

User: "List all games where the Golden State Warriors scored more than 130 points." 
SQLite:
SELECT game_date, team_name_home, pts_home, team_name_away, pts_away
FROM game
WHERE (team_name_home = 'Golden State Warriors' AND pts_home > 130)
 OR (team_name_away = 'Golden State Warriors' AND pts_away > 130);
 
Now, generate a SQL query based on the following user request: """

## Test model performance on a single example

In [86]:
# Create message with sample query and run model
message=[{ 'role': 'user', 'content': input_text + sample["natural_query"].values[0]}]
inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors="pt").to(model.device)
outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)

# Print output
query_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
print(query_output)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:32021 for open-end generation.


SQLite:
SELECT MAX(ast_home) 
FROM game 
WHERE team_name_home = 'Indiana Pacers';



# Test sample output on sqlite3 database

In [None]:
import sqlite3 as sql

# Create connection to sqlite3 database
connection = sql.connect('./nba-data/nba.sqlite')
cursor = connection.cursor()

# Execute query from model output and print result
if query_output[0:7] == "SQLite:":
 print("cleaned")
 query = query_output[7:]
elif query_output[0:4] == "SQL:":
 query = query_output[4:]
else:
 query = query_output
cursor.execute(query)
rows = cursor.fetchall()
for row in rows:
 print(row)

cleaned
(44.0,)


## Create function to compare output to ground truth result from examples

In [None]:
def compare_result(sample_query, sample_result, query_output):
 # Clean model output to only have the query output
 if query_output[0:7] == "SQLite:":
 query = query_output[7:]
 elif query_output[0:4] == "SQL:":
 query = query_output[4:]
 else:
 query = query_output
 
 # Try to execute query, if it fails, then this is a failure of the model
 try:
 # Execute query and obtain result
 cursor.execute(query)
 rows = cursor.fetchall()

 # Check if this is a multi-line query
 if "|" in sample_result:
 return True, True
 else:
 # Strip all whitespace before comparing queries since there may be differences in spacing, newlines, tabs, etc.
 query = query.replace(" ", "").replace("\n", "").replace("\t", "")
 sample_query = sample_query.replace(" ", "").replace("\n", "").replace("\t", "")

 # Compare results and return
 return (query == sample_query), (str(rows[0][0]) == str(sample_result))
 except:
 return False, False

result = compare_result(sample["sql_query"].values[0], sample["result"].values[0], query_output)
print("SQL matched? " + str(result[0]))
print("Result matched? " + str(result[1]))

cleaned
[(44.0,)]

SELECT MAX(ast_home) 
FROM game 
WHERE team_name_home = 'Indiana Pacers';

SELECT MAX(ast_home) FROM game WHERE team_name_home = 'Indiana Pacers';
44.0
44.0
SQL matched? True
Result matched? True
