|
import gradio as gr |
|
import time |
|
import json |
|
import traceback |
|
|
|
|
|
print("Starting RAG system initialization...") |
|
try: |
|
from rag_system.vector_store import VectorStore |
|
print("β VectorStore imported successfully") |
|
except Exception as e: |
|
print(f"β VectorStore import failed: {e}") |
|
traceback.print_exc() |
|
|
|
try: |
|
from rag_system.retriever import SQLRetriever |
|
print("β SQLRetriever imported successfully") |
|
except Exception as e: |
|
print(f"β SQLRetriever import failed: {e}") |
|
traceback.print_exc() |
|
|
|
try: |
|
from rag_system.prompt_engine import PromptEngine |
|
print("β PromptEngine imported successfully") |
|
except Exception as e: |
|
print(f"β PromptEngine import failed: {e}") |
|
traceback.print_exc() |
|
|
|
try: |
|
from rag_system.sql_generator import SQLGenerator |
|
print("β SQLGenerator imported successfully") |
|
except Exception as e: |
|
print(f"β SQLGenerator import failed: {e}") |
|
traceback.print_exc() |
|
|
|
|
|
print("Initializing RAG system components...") |
|
sql_generator = None |
|
try: |
|
vector_store = VectorStore() |
|
print("β VectorStore initialized") |
|
|
|
retriever = SQLRetriever(vector_store) |
|
print("β SQLRetriever initialized") |
|
|
|
prompt_engine = PromptEngine() |
|
print("β PromptEngine initialized") |
|
|
|
sql_generator = SQLGenerator(retriever, prompt_engine) |
|
print("β SQLGenerator initialized") |
|
|
|
print("π RAG system initialized successfully!") |
|
except Exception as e: |
|
print(f"β Error initializing RAG system: {e}") |
|
traceback.print_exc() |
|
sql_generator = None |
|
|
|
def generate_sql(question, table_headers): |
|
"""Generate SQL using the RAG system directly.""" |
|
print(f"generate_sql called with: {question}, {table_headers}") |
|
|
|
if sql_generator is None: |
|
return "β Error: RAG system not initialized. Check the logs for initialization errors." |
|
|
|
if not question or not question.strip(): |
|
return "β Error: Please enter a question." |
|
|
|
if not table_headers or not table_headers.strip(): |
|
return "β Error: Please enter table headers." |
|
|
|
try: |
|
print(f"Generating SQL for: {question}") |
|
print(f"Table headers: {table_headers}") |
|
|
|
start_time = time.time() |
|
|
|
|
|
result = sql_generator.generate_sql(question, table_headers) |
|
|
|
processing_time = time.time() - start_time |
|
print(f"SQL generation completed in {processing_time:.2f}s") |
|
print(f"Result: {result}") |
|
|
|
return f""" |
|
**Generated SQL:** |
|
```sql |
|
{result['sql_query']} |
|
``` |
|
|
|
**Model Used:** {result['model_used']} |
|
**Processing Time:** {processing_time:.2f}s |
|
**Status:** {result['status']} |
|
**Retrieved Examples:** {len(result['retrieved_examples'])} examples used for RAG |
|
""" |
|
except Exception as e: |
|
error_msg = f"β Error: {str(e)}\n\nFull traceback:\n{traceback.format_exc()}" |
|
print(error_msg) |
|
return error_msg |
|
|
|
def batch_generate_sql(questions_text, table_headers): |
|
"""Generate SQL for multiple questions.""" |
|
print(f"batch_generate_sql called with: {questions_text}, {table_headers}") |
|
|
|
if sql_generator is None: |
|
return "β Error: RAG system not initialized. Check the logs for initialization errors." |
|
|
|
if not questions_text or not questions_text.strip(): |
|
return "β Error: Please enter questions." |
|
|
|
if not table_headers or not table_headers.strip(): |
|
return "β Error: Please enter table headers." |
|
|
|
try: |
|
|
|
questions = [q.strip() for q in questions_text.split("\n") if q.strip()] |
|
total_questions = len(questions) |
|
|
|
output = f"**Batch Results:**\n" |
|
output += f"Total Queries: {total_questions}\n\n" |
|
successful_count = 0 |
|
|
|
for i, question in enumerate(questions): |
|
print(f"Processing query {i+1}/{total_questions}: {question}") |
|
|
|
try: |
|
start_time = time.time() |
|
result = sql_generator.generate_sql(question, table_headers) |
|
processing_time = time.time() - start_time |
|
|
|
output += f"**Query {i+1}:** {question}\n" |
|
output += f"```sql\n{result['sql_query']}\n```\n" |
|
output += f"Model: {result['model_used']} | Time: {processing_time:.2f}s\n\n" |
|
|
|
if result['status'] == 'success': |
|
successful_count += 1 |
|
|
|
except Exception as e: |
|
output += f"**Query {i+1}:** {question}\n" |
|
output += f"β Error: {str(e)}\n\n" |
|
|
|
output += f"**Summary:** {successful_count}/{total_questions} queries successful" |
|
return output |
|
|
|
except Exception as e: |
|
return f"β Error: {str(e)}\n\nFull traceback:\n{traceback.format_exc()}" |
|
|
|
def check_system_health(): |
|
"""Check the health of the RAG system.""" |
|
print("check_system_health called") |
|
|
|
try: |
|
if sql_generator is None: |
|
return "β System Status: RAG system not initialized\n\nCheck the logs above for initialization errors." |
|
|
|
|
|
try: |
|
model_info = sql_generator.get_model_info() |
|
model_status = "Available" |
|
except Exception as e: |
|
model_info = {"error": str(e)} |
|
model_status = f"Error: {e}" |
|
|
|
return f""" |
|
**System Health:** |
|
- **Status:** β
Healthy |
|
- **System Loaded:** β
Yes |
|
- **System Loading:** β No |
|
- **Error:** None |
|
- **Model Status:** {model_status} |
|
- **Timestamp:** {time.strftime('%Y-%m-%d %H:%M:%S')} |
|
|
|
**Model Info:** |
|
{json.dumps(model_info, indent=2) if model_info else 'Not available'} |
|
|
|
**Initialization Logs:** |
|
Check the console/logs above for detailed initialization information. |
|
""" |
|
except Exception as e: |
|
return f"β Health check error: {str(e)}\n\nFull traceback:\n{traceback.format_exc()}" |
|
|
|
|
|
with gr.Blocks(title="Text-to-SQL RAG with CodeLlama", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("#Text-to-SQL RAG with CodeLlama") |
|
gr.Markdown("Generate SQL queries from natural language using **RAG (Retrieval-Augmented Generation)** and **CodeLlama** models.") |
|
gr.Markdown("**Features:** RAG-enhanced generation, CodeLlama integration, Vector-based retrieval, Advanced prompt engineering") |
|
|
|
|
|
if sql_generator is None: |
|
gr.Markdown("β οΈ **Warning:** RAG system failed to initialize. Check the logs for errors.") |
|
else: |
|
gr.Markdown("β
**Status:** RAG system initialized successfully!") |
|
|
|
with gr.Tab("Single Query"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
question_input = gr.Textbox( |
|
label="Question", |
|
placeholder="e.g., Show me all employees with salary greater than 50000", |
|
lines=3 |
|
) |
|
table_headers_input = gr.Textbox( |
|
label="Table Headers (comma-separated)", |
|
placeholder="e.g., id, name, salary, department", |
|
value="id, name, salary, department" |
|
) |
|
generate_btn = gr.Button("Generate SQL", variant="primary", size="lg") |
|
|
|
with gr.Column(scale=1): |
|
output = gr.Markdown(label="Result") |
|
|
|
with gr.Tab("Batch Queries"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
batch_questions = gr.Textbox( |
|
label="Questions (one per line)", |
|
placeholder="Show me all employees\nCount total employees\nAverage salary by department", |
|
lines=5 |
|
) |
|
batch_headers = gr.Textbox( |
|
label="Table Headers (comma-separated)", |
|
placeholder="e.g., id, name, salary, department", |
|
value="id, name, salary, department" |
|
) |
|
batch_btn = gr.Button("Generate Batch SQL", variant="primary", size="lg") |
|
|
|
with gr.Column(scale=1): |
|
batch_output = gr.Markdown(label="Batch Results") |
|
|
|
with gr.Tab("System Health"): |
|
with gr.Row(): |
|
health_btn = gr.Button("π Check System Health", variant="secondary", size="lg") |
|
health_output = gr.Markdown(label="Health Status") |
|
|
|
|
|
generate_btn.click( |
|
generate_sql, |
|
inputs=[question_input, table_headers_input], |
|
outputs=output |
|
) |
|
|
|
batch_btn.click( |
|
batch_generate_sql, |
|
inputs=[batch_questions, batch_headers], |
|
outputs=batch_output |
|
) |
|
|
|
health_btn.click( |
|
check_system_health, |
|
outputs=health_output |
|
) |
|
|
|
gr.Markdown("---") |
|
gr.Markdown(""" |
|
## How It Works |
|
|
|
1. **RAG System**: Retrieves relevant SQL examples from vector database |
|
2. **CodeLlama**: Generates SQL using retrieved examples as context |
|
3. **Vector Search**: Finds similar questions and their SQL solutions |
|
4. **Enhanced Generation**: Combines retrieval + generation for better accuracy |
|
|
|
## Technology Stack |
|
|
|
- **Backend**: Direct RAG system integration |
|
- **LLM**: CodeLlama-7B-Python-GGUF (primary) |
|
- **Vector DB**: ChromaDB with sentence transformers |
|
- **Frontend**: Gradio interface |
|
- **Hosting**: Hugging Face Spaces |
|
|
|
## π Performance |
|
|
|
- **Model**: CodeLlama-7B-Python-GGUF |
|
- **Response Time**: < 5 seconds |
|
- **Accuracy**: High (RAG-enhanced) |
|
- **Cost**: Free (local inference) |
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |