Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import re | |
| import torch | |
| import sqlite3 | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel, PeftConfig | |
| # β Load fine-tuned models from Hugging Face Model Hub instead of Kaggle paths | |
| codellama_model_path = "srishtirai/codellama-sql-finetuned" # Upload to HF Model Hub | |
| mistral_model_path = "srishtirai/mistral-sql-finetuned" # Upload to HF Model Hub | |
| def load_model(model_path): | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "right" | |
| peft_config = PeftConfig.from_pretrained(model_path) | |
| base_model_name = peft_config.base_model_name_or_path | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| base_model_name, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| model = PeftModel.from_pretrained(base_model, model_path) | |
| model.eval() | |
| return model, tokenizer | |
| # β Load both models from Hugging Face | |
| codellama_model, codellama_tokenizer = load_model(codellama_model_path) | |
| mistral_model, mistral_tokenizer = load_model(mistral_model_path) | |
| # β Function to format input | |
| def format_input_prompt(schema, question): | |
| return f"""### Context: | |
| {schema} | |
| ### Question: | |
| {question} | |
| ### Response: | |
| Here's the SQL query: | |
| """ | |
| # β Function to generate SQL with explanation | |
| def generate_sql_with_explanation(model_choice, schema, question, max_new_tokens=512, temperature=0.7): | |
| """ | |
| Generate SQL query and explanation based on the selected model. | |
| """ | |
| # Select model based on user choice | |
| if model_choice == "CodeLlama": | |
| model, tokenizer = codellama_model, codellama_tokenizer | |
| else: | |
| model, tokenizer = mistral_model, mistral_tokenizer | |
| prompt = format_input_prompt(schema, question) | |
| # Tokenize input | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| # Generate response | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=0.95, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # Decode generated text | |
| full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract SQL query | |
| sql_match = re.search(r'```sql\s*(.*?)\s*```', full_response, re.DOTALL) | |
| sql_query = sql_match.group(1).strip() if sql_match else None | |
| # Extract explanation | |
| explanation_match = re.search(r'Explanation:\s*(.*?)($|\n\n)', full_response, re.DOTALL) | |
| explanation = explanation_match.group(1).strip() if explanation_match else None | |
| return { | |
| "query": sql_query or "SQL query extraction failed.", | |
| "explanation": explanation or "Explanation not found.", | |
| "full_response": full_response | |
| } | |
| # β Function to execute SQL query (Optional) | |
| def execute_sql_query(sql_query): | |
| """ | |
| Runs the generated SQL query on a sample SQLite database. | |
| (Replace with a real DB connection if needed) | |
| """ | |
| try: | |
| conn = sqlite3.connect(":memory:") # Temporary SQLite DB | |
| cursor = conn.cursor() | |
| cursor.execute(sql_query) | |
| result = cursor.fetchall() | |
| conn.close() | |
| return result if result else "Query executed successfully (No output rows)." | |
| except Exception as e: | |
| return f"Error executing SQL: {str(e)}" | |
| # β Gradio UI function | |
| def gradio_generate_sql(model_choice, schema, question, run_sql): | |
| """ | |
| Takes model selection, schema & question as input and returns SQL + explanation. | |
| Optionally executes the SQL if requested. | |
| """ | |
| result = generate_sql_with_explanation(model_choice, schema, question) | |
| sql_query = result["query"] | |
| if run_sql: | |
| execution_result = execute_sql_query(sql_query) | |
| return sql_query, result["explanation"], execution_result | |
| return sql_query, result["explanation"], "SQL execution not requested." | |
| # β Gradio UI | |
| iface = gr.Interface( | |
| fn=gradio_generate_sql, | |
| inputs=[ | |
| gr.Dropdown(["CodeLlama", "Mistral"], label="Choose Model"), | |
| gr.Textbox(label="Enter Database Schema", lines=10), | |
| gr.Textbox(label="Enter your Question"), | |
| gr.Checkbox(label="Run SQL Query?", value=False), | |
| ], | |
| outputs=[ | |
| gr.Code(label="Generated SQL Query", language="sql"), # SQL Syntax Highlighting | |
| gr.Textbox(label="Explanation", lines=5), | |
| gr.Textbox(label="SQL Execution Result", lines=5), | |
| ], | |
| title="SQL Query Generator with Execution", | |
| description="Select a model, enter your database schema and question. Optionally, execute the generated SQL query.", | |
| ) | |
| # β Launch Gradio | |
| iface.launch() |