shukdevdattaEX's picture
Update app.py
c9a75cb verified
import os
import re
import json
import math
import gradio as gr
from typing import List, Dict, Any, Tuple
from together import Together
# -----------------------------
# Tolerant JSON loader (fixes your error)
# -----------------------------
def _remove_trailing_commas(s: str) -> str:
"""Remove trailing commas before ] or } when not inside strings."""
out = []
in_str = False
esc = False
for i, ch in enumerate(s):
if in_str:
out.append(ch)
if esc:
esc = False
elif ch == '\\':
esc = True
elif ch == '"':
in_str = False
continue
else:
if ch == '"':
in_str = True
out.append(ch)
continue
if ch == ',':
j = i + 1
while j < len(s) and s[j] in ' \t\r\n':
j += 1
if j < len(s) and s[j] in ']}':
# skip this comma
continue
out.append(ch)
return ''.join(out)
def _extract_json_objects(text: str) -> List[str]:
"""Extract top-level JSON objects by balancing curly braces, ignoring braces inside strings."""
objs = []
in_str = False
esc = False
brace_depth = 0
start = None
for i, ch in enumerate(text):
if in_str:
if esc:
esc = False
elif ch == '\\':
esc = True
elif ch == '"':
in_str = False
else:
if ch == '"':
in_str = True
elif ch == '{':
if brace_depth == 0:
start = i
brace_depth += 1
elif ch == '}':
if brace_depth > 0:
brace_depth -= 1
if brace_depth == 0 and start is not None:
objs.append(text[start:i+1])
start = None
return objs
def safe_load_phpmyadmin_like_json(raw_text: str) -> List[Dict[str, Any]]:
"""
Attempt strict JSON first; if it fails (e.g., trailing comma issues),
fall back to extracting individual objects and parsing them.
Returns a list of objects (header + tables, etc.).
"""
try:
return json.loads(raw_text)
except json.JSONDecodeError:
# Try removing trailing commas globally
cleaned = _remove_trailing_commas(raw_text)
try:
return json.loads(cleaned)
except json.JSONDecodeError:
# Last-resort: parse object-by-object and combine into an array
chunks = _extract_json_objects(raw_text)
objs = []
for ch in chunks:
s = _remove_trailing_commas(ch)
try:
objs.append(json.loads(s))
except json.JSONDecodeError:
# If a chunk is still bad, skip it rather than crashing
# (you can log or collect stats if you want)
continue
return objs
# -----------------------------
# Enhanced corpus building with better indexing
# -----------------------------
def flatten_json_to_corpus(docs: List[Dict[str, Any]], max_value_len: int = 1000) -> List[Dict[str, Any]]:
"""
Turn the exported structure into searchable text chunks with enhanced indexing.
Creates multiple representations of the same data for better retrieval.
"""
corpus = []
def extract_all_text_values(obj, prefix=""):
"""Recursively extract all text values from nested objects/arrays"""
texts = []
if isinstance(obj, dict):
for k, v in obj.items():
key_path = f"{prefix}.{k}" if prefix else k
if isinstance(v, (dict, list)):
texts.extend(extract_all_text_values(v, key_path))
else:
val_str = str(v).strip()
if val_str and val_str.lower() not in ['null', 'none', '']:
texts.append(f"{k}: {val_str}")
elif isinstance(obj, list):
for i, item in enumerate(obj):
texts.extend(extract_all_text_values(item, f"{prefix}[{i}]"))
else:
val_str = str(obj).strip()
if val_str and val_str.lower() not in ['null', 'none', '']:
texts.append(val_str)
return texts
for obj_idx, obj in enumerate(docs):
obj_type = obj.get("type", "unknown")
if obj_type == "table":
table_name = obj.get("name", f"table_{obj_idx}")
rows = obj.get("data", [])
if isinstance(rows, list):
# Create entries for individual rows
for row_idx, row in enumerate(rows):
if isinstance(row, dict):
# Standard row representation
parts = []
all_values = []
for k, v in row.items():
val = str(v).strip()
if len(val) > max_value_len:
val = val[:max_value_len] + "…"
if val and val.lower() not in ['null', 'none', '']:
parts.append(f"{k}={val}")
all_values.append(val)
# Main row text
row_text = f"[table={table_name} row={row_idx}] " + " | ".join(parts)
corpus.append({
"table": table_name,
"idx": row_idx,
"text": row_text,
"type": "row",
"raw_data": row
})
# Also create a searchable version with just values for name searches
if all_values:
value_text = f"[table={table_name} row={row_idx}] Contains: " + " ".join(all_values)
corpus.append({
"table": table_name,
"idx": row_idx,
"text": value_text,
"type": "values",
"raw_data": row
})
# Create table summary
if rows:
sample_keys = []
if rows and isinstance(rows[0], dict):
sample_keys = list(rows[0].keys())[:10]
table_summary = f"[table={table_name} summary] Table with {len(rows)} rows. Fields: {', '.join(sample_keys)}"
corpus.append({
"table": table_name,
"idx": -1,
"text": table_summary,
"type": "summary",
"raw_data": {"row_count": len(rows), "fields": sample_keys}
})
else:
# Non-table entries - extract all textual content
all_texts = extract_all_text_values(obj)
if all_texts:
text = f"[{obj_type}] " + " | ".join(all_texts[:20]) # Limit to prevent too long
if len(text) > 2000:
text = text[:2000] + "…"
corpus.append({
"table": obj_type,
"idx": obj_idx,
"text": text,
"type": "meta",
"raw_data": obj
})
return corpus
# -----------------------------
# Enhanced retrieval with multiple scoring methods
# -----------------------------
def _tokenize_enhanced(s: str) -> List[str]:
"""Enhanced tokenization that handles names and phrases better"""
# Keep original words, lowercase versions, and partial matches
tokens = []
# Get word tokens
words = re.findall(r"[A-Za-z0-9_]+", s)
for word in words:
tokens.append(word.lower())
if len(word) > 3:
# Add partial tokens for name matching
tokens.append(word[:4].lower())
# Also extract quoted phrases and camelCase splits
quoted = re.findall(r'"([^"]*)"', s)
for q in quoted:
tokens.extend(q.lower().split())
return tokens
def calculate_enhanced_score(query: str, doc_text: str, doc_data: Dict) -> float:
"""Enhanced scoring that considers multiple factors"""
q_lower = query.lower()
d_lower = doc_text.lower()
score = 0.0
# 1. Exact phrase matching (highest weight)
if q_lower in d_lower:
score += 10.0
# 2. Token-based matching
q_tokens = _tokenize_enhanced(query)
d_tokens = _tokenize_enhanced(doc_text)
if d_tokens:
q_set = set(q_tokens)
d_set = set(d_tokens)
# Exact token matches
exact_matches = len(q_set & d_set)
score += exact_matches * 2.0
# Partial matches for names
for q_tok in q_tokens:
if len(q_tok) > 2:
for d_tok in d_tokens:
if q_tok in d_tok or d_tok in q_tok:
score += 0.5
# Length normalization
score = score / math.log2(len(d_tokens) + 2)
# 3. Boost for certain types of content
if "instructor" in q_lower and "instructor" in d_lower:
score += 5.0
if "batch" in q_lower and "batch" in d_lower:
score += 3.0
# Boost for rows vs summaries when looking for specific info
if any(word in q_lower for word in ["who", "name", "person"]):
if doc_data.get("type") == "row":
score += 2.0
return score
def retrieve_top_k_enhanced(query: str, corpus: List[Dict[str, Any]], k: int = 15, per_table_cap: int = 8) -> List[Dict[str, Any]]:
"""Enhanced retrieval with better scoring and diversity"""
# Score every document
scored = []
for doc in corpus:
score = calculate_enhanced_score(query, doc["text"], doc)
if score > 0:
scored.append((score, doc))
# Sort by score
scored.sort(key=lambda x: x[0], reverse=True)
# Apply diversity constraints
table_counts = {}
type_counts = {}
result = []
for score, doc in scored:
table_name = doc.get("table", "unknown")
doc_type = doc.get("type", "unknown")
# Check table limit
if table_counts.get(table_name, 0) >= per_table_cap:
continue
# Prefer diverse content types
if type_counts.get(doc_type, 0) >= k // 3 and len(result) > k // 2:
continue
result.append(doc)
table_counts[table_name] = table_counts.get(table_name, 0) + 1
type_counts[doc_type] = type_counts.get(doc_type, 0) + 1
if len(result) >= k:
break
# If no good matches, return some diverse samples
if len(result) < 3:
fallback = [doc for _, doc in scored[:k]]
result.extend(fallback)
result = result[:k]
return result
# -----------------------------
# Enhanced prompt building
# -----------------------------
def build_enhanced_prompt(query: str, passages: List[Dict[str, Any]]) -> str:
"""Build a more comprehensive prompt with structured context"""
context_sections = []
table_summaries = []
for passage in passages:
if passage.get("type") == "summary":
table_summaries.append(passage["text"])
else:
context_sections.append(passage["text"])
# Combine contexts
table_context = "\n".join(table_summaries) if table_summaries else ""
detail_context = "\n\n".join(context_sections)
prompt = f"""You are a thorough JSON database assistant. Answer using ONLY the provided context from the JSON export.
# User Question
{query}
# Available Tables Summary
{table_context}
# Detailed Context (Most Relevant Entries)
{detail_context}
# Instructions
- Search through ALL provided context thoroughly
- For person names, look for partial matches and variations
- For roles like "instructor" or "teacher", check all relevant entries
- If asking about people, include their roles, associations, and related info
- Cite specific table names and row indices when possible
- If information exists in the context but seems incomplete, mention what you found
- Only say "not found" if you genuinely cannot locate relevant information after thorough checking
- Be comprehensive - don't just return the first match you find"""
return prompt
# -----------------------------
# Together client helper
# -----------------------------
def call_together(api_key: str, prompt: str) -> str:
if not api_key or not api_key.strip():
return "⚠️ Please enter your Together API key."
try:
# Set env and client to ensure the SDK picks it up everywhere
os.environ["TOGETHER_API_KEY"] = api_key.strip()
client = Together(api_key=api_key.strip())
resp = client.chat.completions.create(
model="lgai/exaone-3-5-32b-instruct",
messages=[{"role": "user", "content": prompt}],
temperature=0.1, # Lower temperature for more focused responses
max_tokens=1000,
)
return resp.choices[0].message.content
except Exception as e:
return f"❌ API Error: {str(e)}"
# -----------------------------
# Gradio App
# -----------------------------
with gr.Blocks(title="Enhanced JSON Chatbot") as demo:
gr.Markdown("## πŸ“š Enhanced JSON Chatbot (Together Exaone 3.5 32B)\nUpload your JSON export and ask questions. Enhanced retrieval system for better name and role matching.")
with gr.Row():
api_key_tb = gr.Textbox(label="Together API Key", type="password", placeholder="Paste your TOGETHER_API_KEY here")
topk_slider = gr.Slider(5, 30, value=15, step=1, label="Top-K JSON Passages")
with gr.Row():
json_file = gr.File(label="Upload JSON export (e.g., phpMyAdmin export)", file_count="single", file_types=[".json"])
fallback_path = gr.Textbox(label="Or fixed path on disk (optional)", placeholder="e.g., sultanbr_innovativeskills.json")
with gr.Accordion("Advanced Settings", open=False):
per_table_cap = gr.Slider(3, 15, value=8, step=1, label="Max passages per table")
max_val_len = gr.Slider(200, 2000, value=1000, step=100, label="Max value length per field")
status = gr.Markdown("πŸ”„ Ready. Upload JSON file to begin.")
chatbot = gr.Chatbot(height=500)
with gr.Row():
user_box = gr.Textbox(
label="Ask about your JSON data...",
placeholder="e.g., Who are the batch instructors? or Who is Shukdev Datta?",
lines=2,
scale=4
)
send_btn = gr.Button("Send", variant="primary", size="lg", scale=1)
with gr.Row():
clear_btn = gr.Button("Clear Chat", variant="secondary")
reload_btn = gr.Button("Reload JSON", variant="secondary")
# States
state_corpus = gr.State([])
state_docs = gr.State([])
def load_json_to_corpus(file_obj, path_text, max_value_len):
"""Load JSON and build enhanced corpus"""
try:
if file_obj is not None:
with open(file_obj.name, "r", encoding="utf-8", errors="replace") as f:
raw = f.read()
source = f"uploaded file: {file_obj.name}"
else:
p = (path_text or "").strip()
if not p:
return ("⚠️ Please upload a JSON file or provide a valid path.", [], [])
with open(p, "r", encoding="utf-8", errors="replace") as f:
raw = f.read()
source = f"file path: {p}"
docs = safe_load_phpmyadmin_like_json(raw)
if not isinstance(docs, list):
docs = [docs]
corpus = flatten_json_to_corpus(docs, max_value_len=int(max_value_len))
# Count tables vs other objects
tables = [d for d in docs if d.get("type") == "table"]
status_msg = f"βœ… Loaded from {source}\n"
status_msg += f"πŸ“Š {len(docs)} objects total, {len(tables)} tables\n"
status_msg += f"πŸ” Built {len(corpus)} searchable passages\n"
status_msg += f"πŸ’¬ Ready for questions!"
return (status_msg, corpus, docs)
except Exception as e:
return (f"❌ Load error: {str(e)}", [], [])
def ask_enhanced(api_key, query, history, corpus, k, cap):
if not corpus:
return history + [[query, "⚠️ Please upload and load the JSON file first."]]
if not query or not query.strip():
return history + [["", "⚠️ Please enter a question."]]
# Enhanced retrieval
top_passages = retrieve_top_k_enhanced(query.strip(), corpus, k=int(k), per_table_cap=int(cap))
# Build enhanced prompt
prompt = build_enhanced_prompt(query.strip(), top_passages)
try:
answer = call_together(api_key, prompt)
except Exception as e:
answer = f"❌ API error: {str(e)}"
history = history + [[query, answer]]
return history
# Event handlers
json_file.upload(
load_json_to_corpus,
inputs=[json_file, fallback_path, max_val_len],
outputs=[status, state_corpus, state_docs],
)
fallback_path.change(
load_json_to_corpus,
inputs=[json_file, fallback_path, max_val_len],
outputs=[status, state_corpus, state_docs],
)
user_box.submit(
ask_enhanced,
inputs=[api_key_tb, user_box, chatbot, state_corpus, topk_slider, per_table_cap],
outputs=[chatbot],
).then(lambda: "", outputs=[user_box]) # Clear input after submit
send_btn.click(
ask_enhanced,
inputs=[api_key_tb, user_box, chatbot, state_corpus, topk_slider, per_table_cap],
outputs=[chatbot],
).then(lambda: "", outputs=[user_box]) # Clear input after send
reload_btn.click(
load_json_to_corpus,
inputs=[json_file, fallback_path, max_val_len],
outputs=[status, state_corpus, state_docs],
)
clear_btn.click(
lambda: ([], "πŸ”„ Chat cleared. Ready for new questions."),
outputs=[chatbot, user_box]
)
if __name__ == "__main__":
demo.launch()