Manju080's picture
Initial commit
c399543
raw
history blame
6.82 kB
import gradio as gr
import requests
import json
import time
def generate_sql(question, table_headers):
"""Generate SQL using the RAG API."""
try:
# Prepare the request
data = {
"question": question,
"table_headers": [h.strip() for h in table_headers.split(",") if h.strip()]
}
# Make API call to the RAG system
response = requests.post("http://localhost:8000/predict", json=data)
if response.status_code == 200:
result = response.json()
return f"""
**Generated SQL:**
```sql
{result['sql_query']}
```
**Model Used:** {result['model_used']}
**Processing Time:** {result['processing_time']:.2f}s
**Status:** {result['status']}
**Retrieved Examples:** {len(result['retrieved_examples'])} examples used for RAG
"""
else:
return f"❌ Error: {response.status_code} - {response.text}"
except Exception as e:
return f"❌ Error: {str(e)}"
def batch_generate_sql(questions_text, table_headers):
"""Generate SQL for multiple questions."""
try:
# Parse questions
questions = [q.strip() for q in questions_text.split("\n") if q.strip()]
# Prepare batch request
data = {
"queries": [
{
"question": q,
"table_headers": [h.strip() for h in table_headers.split(",") if h.strip()]
}
for q in questions
]
}
# Make API call
response = requests.post("http://localhost:8000/batch", json=data)
if response.status_code == 200:
result = response.json()
output = f"**Batch Results:**\n"
output += f"Total Queries: {result['total_queries']}\n"
output += f"Successful: {result['successful_queries']}\n\n"
for i, res in enumerate(result['results']):
output += f"**Query {i+1}:** {res['question']}\n"
output += f"```sql\n{res['sql_query']}\n```\n"
output += f"Model: {res['model_used']} | Time: {res['processing_time']:.2f}s\n\n"
return output
else:
return f"❌ Error: {response.status_code} - {response.text}"
except Exception as e:
return f"❌ Error: {str(e)}"
def check_system_health():
"""Check the health of the RAG system."""
try:
response = requests.get("http://localhost:8000/health")
if response.status_code == 200:
health_data = response.json()
return f"""
**System Health:**
- **Status:** {health_data['status']}
- **System Loaded:** {health_data['system_loaded']}
- **System Loading:** {health_data['system_loading']}
- **Error:** {health_data['system_error'] or 'None'}
- **Timestamp:** {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(health_data['timestamp']))}
**Model Info:**
{json.dumps(health_data.get('model_info', {}), indent=2) if health_data.get('model_info') else 'Not available'}
"""
else:
return f"❌ Health check failed: {response.status_code}"
except Exception as e:
return f"❌ Health check error: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="Text-to-SQL RAG with CodeLlama", theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸš€ Text-to-SQL RAG with CodeLlama")
gr.Markdown("Generate SQL queries from natural language using **RAG (Retrieval-Augmented Generation)** and **CodeLlama** models.")
gr.Markdown("**Features:** RAG-enhanced generation, CodeLlama integration, Vector-based retrieval, Advanced prompt engineering")
with gr.Tab("Single Query"):
with gr.Row():
with gr.Column(scale=1):
question_input = gr.Textbox(
label="Question",
placeholder="e.g., Show me all employees with salary greater than 50000",
lines=3
)
table_headers_input = gr.Textbox(
label="Table Headers (comma-separated)",
placeholder="e.g., id, name, salary, department",
value="id, name, salary, department"
)
generate_btn = gr.Button("πŸš€ Generate SQL", variant="primary", size="lg")
with gr.Column(scale=1):
output = gr.Markdown(label="Result")
with gr.Tab("Batch Queries"):
with gr.Row():
with gr.Column(scale=1):
batch_questions = gr.Textbox(
label="Questions (one per line)",
placeholder="Show me all employees\nCount total employees\nAverage salary by department",
lines=5
)
batch_headers = gr.Textbox(
label="Table Headers (comma-separated)",
placeholder="e.g., id, name, salary, department",
value="id, name, salary, department"
)
batch_btn = gr.Button("πŸš€ Generate Batch SQL", variant="primary", size="lg")
with gr.Column(scale=1):
batch_output = gr.Markdown(label="Batch Results")
with gr.Tab("System Health"):
with gr.Row():
health_btn = gr.Button("πŸ” Check System Health", variant="secondary", size="lg")
health_output = gr.Markdown(label="Health Status")
# Event handlers
generate_btn.click(
generate_sql,
inputs=[question_input, table_headers_input],
outputs=output
)
batch_btn.click(
batch_generate_sql,
inputs=[batch_questions, batch_headers],
outputs=batch_output
)
health_btn.click(
check_system_health,
outputs=health_output
)
gr.Markdown("---")
gr.Markdown("""
## 🎯 How It Works
1. **RAG System**: Retrieves relevant SQL examples from vector database
2. **CodeLlama**: Generates SQL using retrieved examples as context
3. **Vector Search**: Finds similar questions and their SQL solutions
4. **Enhanced Generation**: Combines retrieval + generation for better accuracy
## πŸ› οΈ Technology Stack
- **Backend**: FastAPI + Python
- **LLM**: CodeLlama-7B-Python-GGUF (primary)
- **Vector DB**: ChromaDB with sentence transformers
- **Frontend**: Gradio interface
- **Hosting**: Hugging Face Spaces
## πŸ“Š Performance
- **Model**: CodeLlama-7B-Python-GGUF
- **Response Time**: < 5 seconds
- **Accuracy**: High (RAG-enhanced)
- **Cost**: Free (local inference)
""")
# Launch the interface
if __name__ == "__main__":
demo.launch()