NLSQL / app.py
HusnaManakkot's picture
Update app.py
91b0752 verified
raw
history blame
1.76 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
from difflib import get_close_matches
# Load the Spider dataset
spider_dataset = load_dataset("spider", split='train[:100]') # Increase the number of examples for better matching
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
def find_closest_match(query, dataset):
questions = [item['question'] for item in dataset]
matches = get_close_matches(query, questions, n=1)
return matches[0] if matches else None
def generate_sql_from_user_input(query):
# Find the closest match in the dataset
matched_query = find_closest_match(query, spider_dataset)
if not matched_query:
return "No close match found in the dataset.", ""
# Generate SQL for the matched query
input_text = "translate English to SQL: " + matched_query
inputs = tokenizer(input_text, return_tensors="pt", padding=True)
outputs = model.generate(**inputs, max_length=512)
sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
return matched_query, sql_query
# Create a Gradio interface
interface = gr.Interface(
fn=generate_sql_from_user_input,
inputs=gr.Textbox(label="Enter your natural language query"),
outputs=[gr.Textbox(label="Matched Query from Dataset"), gr.Textbox(label="Generated SQL Query")],
title="NL to SQL with T5 using Spider Dataset",
description="This model finds the closest match in the Spider dataset for your query and generates the corresponding SQL."
)
# Launch the app
if __name__ == "__main__":
interface.launch()