Spaces:
Sleeping
Sleeping
import os | |
import re | |
import json | |
import math | |
import gradio as gr | |
from typing import List, Dict, Any, Tuple | |
from together import Together | |
# ----------------------------- | |
# Tolerant JSON loader (fixes your error) | |
# ----------------------------- | |
def _remove_trailing_commas(s: str) -> str: | |
"""Remove trailing commas before ] or } when not inside strings.""" | |
out = [] | |
in_str = False | |
esc = False | |
for i, ch in enumerate(s): | |
if in_str: | |
out.append(ch) | |
if esc: | |
esc = False | |
elif ch == '\\': | |
esc = True | |
elif ch == '"': | |
in_str = False | |
continue | |
else: | |
if ch == '"': | |
in_str = True | |
out.append(ch) | |
continue | |
if ch == ',': | |
j = i + 1 | |
while j < len(s) and s[j] in ' \t\r\n': | |
j += 1 | |
if j < len(s) and s[j] in ']}': | |
# skip this comma | |
continue | |
out.append(ch) | |
return ''.join(out) | |
def _extract_json_objects(text: str) -> List[str]: | |
"""Extract top-level JSON objects by balancing curly braces, ignoring braces inside strings.""" | |
objs = [] | |
in_str = False | |
esc = False | |
brace_depth = 0 | |
start = None | |
for i, ch in enumerate(text): | |
if in_str: | |
if esc: | |
esc = False | |
elif ch == '\\': | |
esc = True | |
elif ch == '"': | |
in_str = False | |
else: | |
if ch == '"': | |
in_str = True | |
elif ch == '{': | |
if brace_depth == 0: | |
start = i | |
brace_depth += 1 | |
elif ch == '}': | |
if brace_depth > 0: | |
brace_depth -= 1 | |
if brace_depth == 0 and start is not None: | |
objs.append(text[start:i+1]) | |
start = None | |
return objs | |
def safe_load_phpmyadmin_like_json(raw_text: str) -> List[Dict[str, Any]]: | |
""" | |
Attempt strict JSON first; if it fails (e.g., trailing comma issues), | |
fall back to extracting individual objects and parsing them. | |
Returns a list of objects (header + tables, etc.). | |
""" | |
try: | |
return json.loads(raw_text) | |
except json.JSONDecodeError: | |
# Try removing trailing commas globally | |
cleaned = _remove_trailing_commas(raw_text) | |
try: | |
return json.loads(cleaned) | |
except json.JSONDecodeError: | |
# Last-resort: parse object-by-object and combine into an array | |
chunks = _extract_json_objects(raw_text) | |
objs = [] | |
for ch in chunks: | |
s = _remove_trailing_commas(ch) | |
try: | |
objs.append(json.loads(s)) | |
except json.JSONDecodeError: | |
# If a chunk is still bad, skip it rather than crashing | |
# (you can log or collect stats if you want) | |
continue | |
return objs | |
# ----------------------------- | |
# Enhanced corpus building with better indexing | |
# ----------------------------- | |
def flatten_json_to_corpus(docs: List[Dict[str, Any]], max_value_len: int = 1000) -> List[Dict[str, Any]]: | |
""" | |
Turn the exported structure into searchable text chunks with enhanced indexing. | |
Creates multiple representations of the same data for better retrieval. | |
""" | |
corpus = [] | |
def extract_all_text_values(obj, prefix=""): | |
"""Recursively extract all text values from nested objects/arrays""" | |
texts = [] | |
if isinstance(obj, dict): | |
for k, v in obj.items(): | |
key_path = f"{prefix}.{k}" if prefix else k | |
if isinstance(v, (dict, list)): | |
texts.extend(extract_all_text_values(v, key_path)) | |
else: | |
val_str = str(v).strip() | |
if val_str and val_str.lower() not in ['null', 'none', '']: | |
texts.append(f"{k}: {val_str}") | |
elif isinstance(obj, list): | |
for i, item in enumerate(obj): | |
texts.extend(extract_all_text_values(item, f"{prefix}[{i}]")) | |
else: | |
val_str = str(obj).strip() | |
if val_str and val_str.lower() not in ['null', 'none', '']: | |
texts.append(val_str) | |
return texts | |
for obj_idx, obj in enumerate(docs): | |
obj_type = obj.get("type", "unknown") | |
if obj_type == "table": | |
table_name = obj.get("name", f"table_{obj_idx}") | |
rows = obj.get("data", []) | |
if isinstance(rows, list): | |
# Create entries for individual rows | |
for row_idx, row in enumerate(rows): | |
if isinstance(row, dict): | |
# Standard row representation | |
parts = [] | |
all_values = [] | |
for k, v in row.items(): | |
val = str(v).strip() | |
if len(val) > max_value_len: | |
val = val[:max_value_len] + "β¦" | |
if val and val.lower() not in ['null', 'none', '']: | |
parts.append(f"{k}={val}") | |
all_values.append(val) | |
# Main row text | |
row_text = f"[table={table_name} row={row_idx}] " + " | ".join(parts) | |
corpus.append({ | |
"table": table_name, | |
"idx": row_idx, | |
"text": row_text, | |
"type": "row", | |
"raw_data": row | |
}) | |
# Also create a searchable version with just values for name searches | |
if all_values: | |
value_text = f"[table={table_name} row={row_idx}] Contains: " + " ".join(all_values) | |
corpus.append({ | |
"table": table_name, | |
"idx": row_idx, | |
"text": value_text, | |
"type": "values", | |
"raw_data": row | |
}) | |
# Create table summary | |
if rows: | |
sample_keys = [] | |
if rows and isinstance(rows[0], dict): | |
sample_keys = list(rows[0].keys())[:10] | |
table_summary = f"[table={table_name} summary] Table with {len(rows)} rows. Fields: {', '.join(sample_keys)}" | |
corpus.append({ | |
"table": table_name, | |
"idx": -1, | |
"text": table_summary, | |
"type": "summary", | |
"raw_data": {"row_count": len(rows), "fields": sample_keys} | |
}) | |
else: | |
# Non-table entries - extract all textual content | |
all_texts = extract_all_text_values(obj) | |
if all_texts: | |
text = f"[{obj_type}] " + " | ".join(all_texts[:20]) # Limit to prevent too long | |
if len(text) > 2000: | |
text = text[:2000] + "β¦" | |
corpus.append({ | |
"table": obj_type, | |
"idx": obj_idx, | |
"text": text, | |
"type": "meta", | |
"raw_data": obj | |
}) | |
return corpus | |
# ----------------------------- | |
# Enhanced retrieval with multiple scoring methods | |
# ----------------------------- | |
def _tokenize_enhanced(s: str) -> List[str]: | |
"""Enhanced tokenization that handles names and phrases better""" | |
# Keep original words, lowercase versions, and partial matches | |
tokens = [] | |
# Get word tokens | |
words = re.findall(r"[A-Za-z0-9_]+", s) | |
for word in words: | |
tokens.append(word.lower()) | |
if len(word) > 3: | |
# Add partial tokens for name matching | |
tokens.append(word[:4].lower()) | |
# Also extract quoted phrases and camelCase splits | |
quoted = re.findall(r'"([^"]*)"', s) | |
for q in quoted: | |
tokens.extend(q.lower().split()) | |
return tokens | |
def calculate_enhanced_score(query: str, doc_text: str, doc_data: Dict) -> float: | |
"""Enhanced scoring that considers multiple factors""" | |
q_lower = query.lower() | |
d_lower = doc_text.lower() | |
score = 0.0 | |
# 1. Exact phrase matching (highest weight) | |
if q_lower in d_lower: | |
score += 10.0 | |
# 2. Token-based matching | |
q_tokens = _tokenize_enhanced(query) | |
d_tokens = _tokenize_enhanced(doc_text) | |
if d_tokens: | |
q_set = set(q_tokens) | |
d_set = set(d_tokens) | |
# Exact token matches | |
exact_matches = len(q_set & d_set) | |
score += exact_matches * 2.0 | |
# Partial matches for names | |
for q_tok in q_tokens: | |
if len(q_tok) > 2: | |
for d_tok in d_tokens: | |
if q_tok in d_tok or d_tok in q_tok: | |
score += 0.5 | |
# Length normalization | |
score = score / math.log2(len(d_tokens) + 2) | |
# 3. Boost for certain types of content | |
if "instructor" in q_lower and "instructor" in d_lower: | |
score += 5.0 | |
if "batch" in q_lower and "batch" in d_lower: | |
score += 3.0 | |
# Boost for rows vs summaries when looking for specific info | |
if any(word in q_lower for word in ["who", "name", "person"]): | |
if doc_data.get("type") == "row": | |
score += 2.0 | |
return score | |
def retrieve_top_k_enhanced(query: str, corpus: List[Dict[str, Any]], k: int = 15, per_table_cap: int = 8) -> List[Dict[str, Any]]: | |
"""Enhanced retrieval with better scoring and diversity""" | |
# Score every document | |
scored = [] | |
for doc in corpus: | |
score = calculate_enhanced_score(query, doc["text"], doc) | |
if score > 0: | |
scored.append((score, doc)) | |
# Sort by score | |
scored.sort(key=lambda x: x[0], reverse=True) | |
# Apply diversity constraints | |
table_counts = {} | |
type_counts = {} | |
result = [] | |
for score, doc in scored: | |
table_name = doc.get("table", "unknown") | |
doc_type = doc.get("type", "unknown") | |
# Check table limit | |
if table_counts.get(table_name, 0) >= per_table_cap: | |
continue | |
# Prefer diverse content types | |
if type_counts.get(doc_type, 0) >= k // 3 and len(result) > k // 2: | |
continue | |
result.append(doc) | |
table_counts[table_name] = table_counts.get(table_name, 0) + 1 | |
type_counts[doc_type] = type_counts.get(doc_type, 0) + 1 | |
if len(result) >= k: | |
break | |
# If no good matches, return some diverse samples | |
if len(result) < 3: | |
fallback = [doc for _, doc in scored[:k]] | |
result.extend(fallback) | |
result = result[:k] | |
return result | |
# ----------------------------- | |
# Enhanced prompt building | |
# ----------------------------- | |
def build_enhanced_prompt(query: str, passages: List[Dict[str, Any]]) -> str: | |
"""Build a more comprehensive prompt with structured context""" | |
context_sections = [] | |
table_summaries = [] | |
for passage in passages: | |
if passage.get("type") == "summary": | |
table_summaries.append(passage["text"]) | |
else: | |
context_sections.append(passage["text"]) | |
# Combine contexts | |
table_context = "\n".join(table_summaries) if table_summaries else "" | |
detail_context = "\n\n".join(context_sections) | |
prompt = f"""You are a thorough JSON database assistant. Answer using ONLY the provided context from the JSON export. | |
# User Question | |
{query} | |
# Available Tables Summary | |
{table_context} | |
# Detailed Context (Most Relevant Entries) | |
{detail_context} | |
# Instructions | |
- Search through ALL provided context thoroughly | |
- For person names, look for partial matches and variations | |
- For roles like "instructor" or "teacher", check all relevant entries | |
- If asking about people, include their roles, associations, and related info | |
- Cite specific table names and row indices when possible | |
- If information exists in the context but seems incomplete, mention what you found | |
- Only say "not found" if you genuinely cannot locate relevant information after thorough checking | |
- Be comprehensive - don't just return the first match you find""" | |
return prompt | |
# ----------------------------- | |
# Together client helper | |
# ----------------------------- | |
def call_together(api_key: str, prompt: str) -> str: | |
if not api_key or not api_key.strip(): | |
return "β οΈ Please enter your Together API key." | |
try: | |
# Set env and client to ensure the SDK picks it up everywhere | |
os.environ["TOGETHER_API_KEY"] = api_key.strip() | |
client = Together(api_key=api_key.strip()) | |
resp = client.chat.completions.create( | |
model="lgai/exaone-3-5-32b-instruct", | |
messages=[{"role": "user", "content": prompt}], | |
temperature=0.1, # Lower temperature for more focused responses | |
max_tokens=1000, | |
) | |
return resp.choices[0].message.content | |
except Exception as e: | |
return f"β API Error: {str(e)}" | |
# ----------------------------- | |
# Gradio App | |
# ----------------------------- | |
with gr.Blocks(title="Enhanced JSON Chatbot") as demo: | |
gr.Markdown("## π Enhanced JSON Chatbot (Together Exaone 3.5 32B)\nUpload your JSON export and ask questions. Enhanced retrieval system for better name and role matching.") | |
with gr.Row(): | |
api_key_tb = gr.Textbox(label="Together API Key", type="password", placeholder="Paste your TOGETHER_API_KEY here") | |
topk_slider = gr.Slider(5, 30, value=15, step=1, label="Top-K JSON Passages") | |
with gr.Row(): | |
json_file = gr.File(label="Upload JSON export (e.g., phpMyAdmin export)", file_count="single", file_types=[".json"]) | |
fallback_path = gr.Textbox(label="Or fixed path on disk (optional)", placeholder="e.g., sultanbr_innovativeskills.json") | |
with gr.Accordion("Advanced Settings", open=False): | |
per_table_cap = gr.Slider(3, 15, value=8, step=1, label="Max passages per table") | |
max_val_len = gr.Slider(200, 2000, value=1000, step=100, label="Max value length per field") | |
status = gr.Markdown("π Ready. Upload JSON file to begin.") | |
chatbot = gr.Chatbot(height=500) | |
with gr.Row(): | |
user_box = gr.Textbox( | |
label="Ask about your JSON data...", | |
placeholder="e.g., Who are the batch instructors? or Who is Shukdev Datta?", | |
lines=2, | |
scale=4 | |
) | |
send_btn = gr.Button("Send", variant="primary", size="lg", scale=1) | |
with gr.Row(): | |
clear_btn = gr.Button("Clear Chat", variant="secondary") | |
reload_btn = gr.Button("Reload JSON", variant="secondary") | |
# States | |
state_corpus = gr.State([]) | |
state_docs = gr.State([]) | |
def load_json_to_corpus(file_obj, path_text, max_value_len): | |
"""Load JSON and build enhanced corpus""" | |
try: | |
if file_obj is not None: | |
with open(file_obj.name, "r", encoding="utf-8", errors="replace") as f: | |
raw = f.read() | |
source = f"uploaded file: {file_obj.name}" | |
else: | |
p = (path_text or "").strip() | |
if not p: | |
return ("β οΈ Please upload a JSON file or provide a valid path.", [], []) | |
with open(p, "r", encoding="utf-8", errors="replace") as f: | |
raw = f.read() | |
source = f"file path: {p}" | |
docs = safe_load_phpmyadmin_like_json(raw) | |
if not isinstance(docs, list): | |
docs = [docs] | |
corpus = flatten_json_to_corpus(docs, max_value_len=int(max_value_len)) | |
# Count tables vs other objects | |
tables = [d for d in docs if d.get("type") == "table"] | |
status_msg = f"β Loaded from {source}\n" | |
status_msg += f"π {len(docs)} objects total, {len(tables)} tables\n" | |
status_msg += f"π Built {len(corpus)} searchable passages\n" | |
status_msg += f"π¬ Ready for questions!" | |
return (status_msg, corpus, docs) | |
except Exception as e: | |
return (f"β Load error: {str(e)}", [], []) | |
def ask_enhanced(api_key, query, history, corpus, k, cap): | |
if not corpus: | |
return history + [[query, "β οΈ Please upload and load the JSON file first."]] | |
if not query or not query.strip(): | |
return history + [["", "β οΈ Please enter a question."]] | |
# Enhanced retrieval | |
top_passages = retrieve_top_k_enhanced(query.strip(), corpus, k=int(k), per_table_cap=int(cap)) | |
# Build enhanced prompt | |
prompt = build_enhanced_prompt(query.strip(), top_passages) | |
try: | |
answer = call_together(api_key, prompt) | |
except Exception as e: | |
answer = f"β API error: {str(e)}" | |
history = history + [[query, answer]] | |
return history | |
# Event handlers | |
json_file.upload( | |
load_json_to_corpus, | |
inputs=[json_file, fallback_path, max_val_len], | |
outputs=[status, state_corpus, state_docs], | |
) | |
fallback_path.change( | |
load_json_to_corpus, | |
inputs=[json_file, fallback_path, max_val_len], | |
outputs=[status, state_corpus, state_docs], | |
) | |
user_box.submit( | |
ask_enhanced, | |
inputs=[api_key_tb, user_box, chatbot, state_corpus, topk_slider, per_table_cap], | |
outputs=[chatbot], | |
).then(lambda: "", outputs=[user_box]) # Clear input after submit | |
send_btn.click( | |
ask_enhanced, | |
inputs=[api_key_tb, user_box, chatbot, state_corpus, topk_slider, per_table_cap], | |
outputs=[chatbot], | |
).then(lambda: "", outputs=[user_box]) # Clear input after send | |
reload_btn.click( | |
load_json_to_corpus, | |
inputs=[json_file, fallback_path, max_val_len], | |
outputs=[status, state_corpus, state_docs], | |
) | |
clear_btn.click( | |
lambda: ([], "π Chat cleared. Ready for new questions."), | |
outputs=[chatbot, user_box] | |
) | |
if __name__ == "__main__": | |
demo.launch() |