File size: 6,817 Bytes
c399543
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import gradio as gr
import requests
import json
import time

def generate_sql(question, table_headers):
    """Generate SQL using the RAG API."""
    try:
        # Prepare the request
        data = {
            "question": question,
            "table_headers": [h.strip() for h in table_headers.split(",") if h.strip()]
        }
        
        # Make API call to the RAG system
        response = requests.post("http://localhost:8000/predict", json=data)
        
        if response.status_code == 200:
            result = response.json()
            return f"""
**Generated SQL:**
```sql
{result['sql_query']}
```

**Model Used:** {result['model_used']}
**Processing Time:** {result['processing_time']:.2f}s
**Status:** {result['status']}
**Retrieved Examples:** {len(result['retrieved_examples'])} examples used for RAG
"""
        else:
            return f"❌ Error: {response.status_code} - {response.text}"
            
    except Exception as e:
        return f"❌ Error: {str(e)}"

def batch_generate_sql(questions_text, table_headers):
    """Generate SQL for multiple questions."""
    try:
        # Parse questions
        questions = [q.strip() for q in questions_text.split("\n") if q.strip()]
        
        # Prepare batch request
        data = {
            "queries": [
                {
                    "question": q,
                    "table_headers": [h.strip() for h in table_headers.split(",") if h.strip()]
                }
                for q in questions
            ]
        }
        
        # Make API call
        response = requests.post("http://localhost:8000/batch", json=data)
        
        if response.status_code == 200:
            result = response.json()
            output = f"**Batch Results:**\n"
            output += f"Total Queries: {result['total_queries']}\n"
            output += f"Successful: {result['successful_queries']}\n\n"
            
            for i, res in enumerate(result['results']):
                output += f"**Query {i+1}:** {res['question']}\n"
                output += f"```sql\n{res['sql_query']}\n```\n"
                output += f"Model: {res['model_used']} | Time: {res['processing_time']:.2f}s\n\n"
            
            return output
        else:
            return f"❌ Error: {response.status_code} - {response.text}"
            
    except Exception as e:
        return f"❌ Error: {str(e)}"

def check_system_health():
    """Check the health of the RAG system."""
    try:
        response = requests.get("http://localhost:8000/health")
        if response.status_code == 200:
            health_data = response.json()
            return f"""
**System Health:**
- **Status:** {health_data['status']}
- **System Loaded:** {health_data['system_loaded']}
- **System Loading:** {health_data['system_loading']}
- **Error:** {health_data['system_error'] or 'None'}
- **Timestamp:** {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(health_data['timestamp']))}

**Model Info:**
{json.dumps(health_data.get('model_info', {}), indent=2) if health_data.get('model_info') else 'Not available'}
"""
        else:
            return f"❌ Health check failed: {response.status_code}"
    except Exception as e:
        return f"❌ Health check error: {str(e)}"

# Create Gradio interface
with gr.Blocks(title="Text-to-SQL RAG with CodeLlama", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# πŸš€ Text-to-SQL RAG with CodeLlama")
    gr.Markdown("Generate SQL queries from natural language using **RAG (Retrieval-Augmented Generation)** and **CodeLlama** models.")
    gr.Markdown("**Features:** RAG-enhanced generation, CodeLlama integration, Vector-based retrieval, Advanced prompt engineering")
    
    with gr.Tab("Single Query"):
        with gr.Row():
            with gr.Column(scale=1):
                question_input = gr.Textbox(
                    label="Question",
                    placeholder="e.g., Show me all employees with salary greater than 50000",
                    lines=3
                )
                table_headers_input = gr.Textbox(
                    label="Table Headers (comma-separated)",
                    placeholder="e.g., id, name, salary, department",
                    value="id, name, salary, department"
                )
                generate_btn = gr.Button("πŸš€ Generate SQL", variant="primary", size="lg")
            
            with gr.Column(scale=1):
                output = gr.Markdown(label="Result")
    
    with gr.Tab("Batch Queries"):
        with gr.Row():
            with gr.Column(scale=1):
                batch_questions = gr.Textbox(
                    label="Questions (one per line)",
                    placeholder="Show me all employees\nCount total employees\nAverage salary by department",
                    lines=5
                )
                batch_headers = gr.Textbox(
                    label="Table Headers (comma-separated)",
                    placeholder="e.g., id, name, salary, department",
                    value="id, name, salary, department"
                )
                batch_btn = gr.Button("πŸš€ Generate Batch SQL", variant="primary", size="lg")
            
            with gr.Column(scale=1):
                batch_output = gr.Markdown(label="Batch Results")
    
    with gr.Tab("System Health"):
        with gr.Row():
            health_btn = gr.Button("πŸ” Check System Health", variant="secondary", size="lg")
            health_output = gr.Markdown(label="Health Status")
    
    # Event handlers
    generate_btn.click(
        generate_sql,
        inputs=[question_input, table_headers_input],
        outputs=output
    )
    
    batch_btn.click(
        batch_generate_sql,
        inputs=[batch_questions, batch_headers],
        outputs=batch_output
    )
    
    health_btn.click(
        check_system_health,
        outputs=health_output
    )
    
    gr.Markdown("---")
    gr.Markdown("""
    ## 🎯 How It Works
    
    1. **RAG System**: Retrieves relevant SQL examples from vector database
    2. **CodeLlama**: Generates SQL using retrieved examples as context
    3. **Vector Search**: Finds similar questions and their SQL solutions
    4. **Enhanced Generation**: Combines retrieval + generation for better accuracy
    
    ## πŸ› οΈ Technology Stack
    
    - **Backend**: FastAPI + Python
    - **LLM**: CodeLlama-7B-Python-GGUF (primary)
    - **Vector DB**: ChromaDB with sentence transformers
    - **Frontend**: Gradio interface
    - **Hosting**: Hugging Face Spaces
    
    ## πŸ“Š Performance
    
    - **Model**: CodeLlama-7B-Python-GGUF
    - **Response Time**: < 5 seconds
    - **Accuracy**: High (RAG-enhanced)
    - **Cost**: Free (local inference)
    """)

# Launch the interface
if __name__ == "__main__":
    demo.launch()