File size: 18,808 Bytes
fd142b1
 
31be862
fd142b1
31be862
fd142b1
31be862
 
fd142b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31be862
fd142b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31be862
fd142b1
 
 
 
 
 
31be862
fd142b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31be862
fd142b1
46c2b45
fd142b1
46c2b45
fd142b1
46c2b45
 
fd142b1
 
46c2b45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd142b1
46c2b45
fd142b1
46c2b45
 
fd142b1
46c2b45
fd142b1
46c2b45
 
fd142b1
46c2b45
fd142b1
 
46c2b45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd142b1
46c2b45
 
 
 
 
 
 
 
 
 
 
 
 
 
fd142b1
 
 
46c2b45
fd142b1
46c2b45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd142b1
46c2b45
 
fd142b1
46c2b45
 
 
 
 
 
 
 
 
fd142b1
46c2b45
 
 
fd142b1
46c2b45
 
 
 
 
 
fd142b1
46c2b45
 
 
 
 
 
 
 
fd142b1
 
46c2b45
fd142b1
46c2b45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd142b1
 
46c2b45
 
 
 
 
31be862
fd142b1
46c2b45
 
 
 
 
 
 
 
31be862
fd142b1
31be862
fd142b1
 
 
 
 
 
46c2b45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd142b1
 
 
 
46c2b45
 
fd142b1
 
 
46c2b45
fd142b1
 
 
 
 
46c2b45
 
 
31be862
46c2b45
 
c9a75cb
 
 
 
 
 
 
 
 
 
 
46c2b45
c9a75cb
 
31be862
fd142b1
46c2b45
 
31be862
fd142b1
46c2b45
fd142b1
 
 
 
46c2b45
fd142b1
 
 
 
 
 
46c2b45
31be862
fd142b1
 
 
 
 
 
 
46c2b45
 
 
 
 
 
 
 
 
fd142b1
 
46c2b45
fd142b1
46c2b45
fd142b1
46c2b45
fd142b1
 
 
46c2b45
 
 
 
 
fd142b1
 
 
 
46c2b45
fd142b1
 
 
 
46c2b45
fd142b1
 
 
 
 
46c2b45
fd142b1
 
 
 
 
31be862
fd142b1
46c2b45
fd142b1
 
46c2b45
 
c9a75cb
 
 
 
 
 
46c2b45
 
 
 
31be862
 
46c2b45
 
 
 
31be862
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
import os
import re
import json
import math
import gradio as gr
from typing import List, Dict, Any, Tuple
from together import Together

# -----------------------------
# Tolerant JSON loader (fixes your error)
# -----------------------------
def _remove_trailing_commas(s: str) -> str:
    """Remove trailing commas before ] or } when not inside strings."""
    out = []
    in_str = False
    esc = False
    for i, ch in enumerate(s):
        if in_str:
            out.append(ch)
            if esc:
                esc = False
            elif ch == '\\':
                esc = True
            elif ch == '"':
                in_str = False
            continue
        else:
            if ch == '"':
                in_str = True
                out.append(ch)
                continue
            if ch == ',':
                j = i + 1
                while j < len(s) and s[j] in ' \t\r\n':
                    j += 1
                if j < len(s) and s[j] in ']}':
                    # skip this comma
                    continue
            out.append(ch)
    return ''.join(out)

def _extract_json_objects(text: str) -> List[str]:
    """Extract top-level JSON objects by balancing curly braces, ignoring braces inside strings."""
    objs = []
    in_str = False
    esc = False
    brace_depth = 0
    start = None
    for i, ch in enumerate(text):
        if in_str:
            if esc:
                esc = False
            elif ch == '\\':
                esc = True
            elif ch == '"':
                in_str = False
        else:
            if ch == '"':
                in_str = True
            elif ch == '{':
                if brace_depth == 0:
                    start = i
                brace_depth += 1
            elif ch == '}':
                if brace_depth > 0:
                    brace_depth -= 1
                    if brace_depth == 0 and start is not None:
                        objs.append(text[start:i+1])
                        start = None
    return objs

def safe_load_phpmyadmin_like_json(raw_text: str) -> List[Dict[str, Any]]:
    """
    Attempt strict JSON first; if it fails (e.g., trailing comma issues),
    fall back to extracting individual objects and parsing them.
    Returns a list of objects (header + tables, etc.).
    """
    try:
        return json.loads(raw_text)
    except json.JSONDecodeError:
        # Try removing trailing commas globally
        cleaned = _remove_trailing_commas(raw_text)
        try:
            return json.loads(cleaned)
        except json.JSONDecodeError:
            # Last-resort: parse object-by-object and combine into an array
            chunks = _extract_json_objects(raw_text)
            objs = []
            for ch in chunks:
                s = _remove_trailing_commas(ch)
                try:
                    objs.append(json.loads(s))
                except json.JSONDecodeError:
                    # If a chunk is still bad, skip it rather than crashing
                    # (you can log or collect stats if you want)
                    continue
            return objs

# -----------------------------
# Enhanced corpus building with better indexing
# -----------------------------
def flatten_json_to_corpus(docs: List[Dict[str, Any]], max_value_len: int = 1000) -> List[Dict[str, Any]]:
    """
    Turn the exported structure into searchable text chunks with enhanced indexing.
    Creates multiple representations of the same data for better retrieval.
    """
    corpus = []
    
    def extract_all_text_values(obj, prefix=""):
        """Recursively extract all text values from nested objects/arrays"""
        texts = []
        if isinstance(obj, dict):
            for k, v in obj.items():
                key_path = f"{prefix}.{k}" if prefix else k
                if isinstance(v, (dict, list)):
                    texts.extend(extract_all_text_values(v, key_path))
                else:
                    val_str = str(v).strip()
                    if val_str and val_str.lower() not in ['null', 'none', '']:
                        texts.append(f"{k}: {val_str}")
        elif isinstance(obj, list):
            for i, item in enumerate(obj):
                texts.extend(extract_all_text_values(item, f"{prefix}[{i}]"))
        else:
            val_str = str(obj).strip()
            if val_str and val_str.lower() not in ['null', 'none', '']:
                texts.append(val_str)
        return texts
    
    for obj_idx, obj in enumerate(docs):
        obj_type = obj.get("type", "unknown")
        
        if obj_type == "table":
            table_name = obj.get("name", f"table_{obj_idx}")
            rows = obj.get("data", [])
            
            if isinstance(rows, list):
                # Create entries for individual rows
                for row_idx, row in enumerate(rows):
                    if isinstance(row, dict):
                        # Standard row representation
                        parts = []
                        all_values = []
                        
                        for k, v in row.items():
                            val = str(v).strip()
                            if len(val) > max_value_len:
                                val = val[:max_value_len] + "…"
                            if val and val.lower() not in ['null', 'none', '']:
                                parts.append(f"{k}={val}")
                                all_values.append(val)
                        
                        # Main row text
                        row_text = f"[table={table_name} row={row_idx}] " + " | ".join(parts)
                        corpus.append({
                            "table": table_name,
                            "idx": row_idx,
                            "text": row_text,
                            "type": "row",
                            "raw_data": row
                        })
                        
                        # Also create a searchable version with just values for name searches
                        if all_values:
                            value_text = f"[table={table_name} row={row_idx}] Contains: " + " ".join(all_values)
                            corpus.append({
                                "table": table_name,
                                "idx": row_idx,
                                "text": value_text,
                                "type": "values",
                                "raw_data": row
                            })
                
                # Create table summary
                if rows:
                    sample_keys = []
                    if rows and isinstance(rows[0], dict):
                        sample_keys = list(rows[0].keys())[:10]
                    
                    table_summary = f"[table={table_name} summary] Table with {len(rows)} rows. Fields: {', '.join(sample_keys)}"
                    corpus.append({
                        "table": table_name,
                        "idx": -1,
                        "text": table_summary,
                        "type": "summary",
                        "raw_data": {"row_count": len(rows), "fields": sample_keys}
                    })
        else:
            # Non-table entries - extract all textual content
            all_texts = extract_all_text_values(obj)
            if all_texts:
                text = f"[{obj_type}] " + " | ".join(all_texts[:20])  # Limit to prevent too long
                if len(text) > 2000:
                    text = text[:2000] + "…"
                corpus.append({
                    "table": obj_type,
                    "idx": obj_idx,
                    "text": text,
                    "type": "meta",
                    "raw_data": obj
                })
    
    return corpus

# -----------------------------
# Enhanced retrieval with multiple scoring methods
# -----------------------------
def _tokenize_enhanced(s: str) -> List[str]:
    """Enhanced tokenization that handles names and phrases better"""
    # Keep original words, lowercase versions, and partial matches
    tokens = []
    
    # Get word tokens
    words = re.findall(r"[A-Za-z0-9_]+", s)
    for word in words:
        tokens.append(word.lower())
        if len(word) > 3:
            # Add partial tokens for name matching
            tokens.append(word[:4].lower())
    
    # Also extract quoted phrases and camelCase splits
    quoted = re.findall(r'"([^"]*)"', s)
    for q in quoted:
        tokens.extend(q.lower().split())
    
    return tokens

def calculate_enhanced_score(query: str, doc_text: str, doc_data: Dict) -> float:
    """Enhanced scoring that considers multiple factors"""
    q_lower = query.lower()
    d_lower = doc_text.lower()
    
    score = 0.0
    
    # 1. Exact phrase matching (highest weight)
    if q_lower in d_lower:
        score += 10.0
    
    # 2. Token-based matching
    q_tokens = _tokenize_enhanced(query)
    d_tokens = _tokenize_enhanced(doc_text)
    
    if d_tokens:
        q_set = set(q_tokens)
        d_set = set(d_tokens)
        
        # Exact token matches
        exact_matches = len(q_set & d_set)
        score += exact_matches * 2.0
        
        # Partial matches for names
        for q_tok in q_tokens:
            if len(q_tok) > 2:
                for d_tok in d_tokens:
                    if q_tok in d_tok or d_tok in q_tok:
                        score += 0.5
        
        # Length normalization
        score = score / math.log2(len(d_tokens) + 2)
    
    # 3. Boost for certain types of content
    if "instructor" in q_lower and "instructor" in d_lower:
        score += 5.0
    
    if "batch" in q_lower and "batch" in d_lower:
        score += 3.0
        
    # Boost for rows vs summaries when looking for specific info
    if any(word in q_lower for word in ["who", "name", "person"]):
        if doc_data.get("type") == "row":
            score += 2.0
    
    return score

def retrieve_top_k_enhanced(query: str, corpus: List[Dict[str, Any]], k: int = 15, per_table_cap: int = 8) -> List[Dict[str, Any]]:
    """Enhanced retrieval with better scoring and diversity"""
    
    # Score every document
    scored = []
    for doc in corpus:
        score = calculate_enhanced_score(query, doc["text"], doc)
        if score > 0:
            scored.append((score, doc))
    
    # Sort by score
    scored.sort(key=lambda x: x[0], reverse=True)
    
    # Apply diversity constraints
    table_counts = {}
    type_counts = {}
    result = []
    
    for score, doc in scored:
        table_name = doc.get("table", "unknown")
        doc_type = doc.get("type", "unknown")
        
        # Check table limit
        if table_counts.get(table_name, 0) >= per_table_cap:
            continue
        
        # Prefer diverse content types
        if type_counts.get(doc_type, 0) >= k // 3 and len(result) > k // 2:
            continue
            
        result.append(doc)
        table_counts[table_name] = table_counts.get(table_name, 0) + 1
        type_counts[doc_type] = type_counts.get(doc_type, 0) + 1
        
        if len(result) >= k:
            break
    
    # If no good matches, return some diverse samples
    if len(result) < 3:
        fallback = [doc for _, doc in scored[:k]]
        result.extend(fallback)
        result = result[:k]
    
    return result

# -----------------------------
# Enhanced prompt building
# -----------------------------
def build_enhanced_prompt(query: str, passages: List[Dict[str, Any]]) -> str:
    """Build a more comprehensive prompt with structured context"""
    
    context_sections = []
    table_summaries = []
    
    for passage in passages:
        if passage.get("type") == "summary":
            table_summaries.append(passage["text"])
        else:
            context_sections.append(passage["text"])
    
    # Combine contexts
    table_context = "\n".join(table_summaries) if table_summaries else ""
    detail_context = "\n\n".join(context_sections)
    
    prompt = f"""You are a thorough JSON database assistant. Answer using ONLY the provided context from the JSON export.

# User Question
{query}

# Available Tables Summary
{table_context}

# Detailed Context (Most Relevant Entries)
{detail_context}

# Instructions
- Search through ALL provided context thoroughly
- For person names, look for partial matches and variations
- For roles like "instructor" or "teacher", check all relevant entries
- If asking about people, include their roles, associations, and related info
- Cite specific table names and row indices when possible
- If information exists in the context but seems incomplete, mention what you found
- Only say "not found" if you genuinely cannot locate relevant information after thorough checking
- Be comprehensive - don't just return the first match you find"""

    return prompt

# -----------------------------
# Together client helper
# -----------------------------
def call_together(api_key: str, prompt: str) -> str:
    if not api_key or not api_key.strip():
        return "⚠️ Please enter your Together API key."
    
    try:
        # Set env and client to ensure the SDK picks it up everywhere
        os.environ["TOGETHER_API_KEY"] = api_key.strip()
        client = Together(api_key=api_key.strip())
        
        resp = client.chat.completions.create(
            model="lgai/exaone-3-5-32b-instruct",
            messages=[{"role": "user", "content": prompt}],
            temperature=0.1,  # Lower temperature for more focused responses
            max_tokens=1000,
        )
        return resp.choices[0].message.content
    except Exception as e:
        return f"❌ API Error: {str(e)}"

# -----------------------------
# Gradio App
# -----------------------------
with gr.Blocks(title="Enhanced JSON Chatbot") as demo:
    gr.Markdown("## πŸ“š Enhanced JSON Chatbot (Together Exaone 3.5 32B)\nUpload your JSON export and ask questions. Enhanced retrieval system for better name and role matching.")

    with gr.Row():
        api_key_tb = gr.Textbox(label="Together API Key", type="password", placeholder="Paste your TOGETHER_API_KEY here")
        topk_slider = gr.Slider(5, 30, value=15, step=1, label="Top-K JSON Passages")

    with gr.Row():
        json_file = gr.File(label="Upload JSON export (e.g., phpMyAdmin export)", file_count="single", file_types=[".json"])
        fallback_path = gr.Textbox(label="Or fixed path on disk (optional)", placeholder="e.g., sultanbr_innovativeskills.json")

    with gr.Accordion("Advanced Settings", open=False):
        per_table_cap = gr.Slider(3, 15, value=8, step=1, label="Max passages per table")
        max_val_len = gr.Slider(200, 2000, value=1000, step=100, label="Max value length per field")

    status = gr.Markdown("πŸ”„ Ready. Upload JSON file to begin.")
    
    chatbot = gr.Chatbot(height=500)
    
    with gr.Row():
        user_box = gr.Textbox(
            label="Ask about your JSON data...", 
            placeholder="e.g., Who are the batch instructors? or Who is Shukdev Datta?",
            lines=2,
            scale=4
        )
        send_btn = gr.Button("Send", variant="primary", size="lg", scale=1)
    
    with gr.Row():
        clear_btn = gr.Button("Clear Chat", variant="secondary")
        reload_btn = gr.Button("Reload JSON", variant="secondary")

    # States
    state_corpus = gr.State([])
    state_docs = gr.State([])

    def load_json_to_corpus(file_obj, path_text, max_value_len):
        """Load JSON and build enhanced corpus"""
        try:
            if file_obj is not None:
                with open(file_obj.name, "r", encoding="utf-8", errors="replace") as f:
                    raw = f.read()
                source = f"uploaded file: {file_obj.name}"
            else:
                p = (path_text or "").strip()
                if not p:
                    return ("⚠️ Please upload a JSON file or provide a valid path.", [], [])
                with open(p, "r", encoding="utf-8", errors="replace") as f:
                    raw = f.read()
                source = f"file path: {p}"

            docs = safe_load_phpmyadmin_like_json(raw)

            if not isinstance(docs, list):
                docs = [docs]

            corpus = flatten_json_to_corpus(docs, max_value_len=int(max_value_len))

            # Count tables vs other objects
            tables = [d for d in docs if d.get("type") == "table"]
            
            status_msg = f"βœ… Loaded from {source}\n"
            status_msg += f"πŸ“Š {len(docs)} objects total, {len(tables)} tables\n"
            status_msg += f"πŸ” Built {len(corpus)} searchable passages\n"
            status_msg += f"πŸ’¬ Ready for questions!"

            return (status_msg, corpus, docs)

        except Exception as e:
            return (f"❌ Load error: {str(e)}", [], [])

    def ask_enhanced(api_key, query, history, corpus, k, cap):
        if not corpus:
            return history + [[query, "⚠️ Please upload and load the JSON file first."]]
        if not query or not query.strip():
            return history + [["", "⚠️ Please enter a question."]]

        # Enhanced retrieval
        top_passages = retrieve_top_k_enhanced(query.strip(), corpus, k=int(k), per_table_cap=int(cap))
        
        # Build enhanced prompt
        prompt = build_enhanced_prompt(query.strip(), top_passages)

        try:
            answer = call_together(api_key, prompt)
        except Exception as e:
            answer = f"❌ API error: {str(e)}"

        history = history + [[query, answer]]
        return history

    # Event handlers
    json_file.upload(
        load_json_to_corpus,
        inputs=[json_file, fallback_path, max_val_len],
        outputs=[status, state_corpus, state_docs],
    )
    
    fallback_path.change(
        load_json_to_corpus,
        inputs=[json_file, fallback_path, max_val_len],
        outputs=[status, state_corpus, state_docs],
    )

    user_box.submit(
        ask_enhanced,
        inputs=[api_key_tb, user_box, chatbot, state_corpus, topk_slider, per_table_cap],
        outputs=[chatbot],
    ).then(lambda: "", outputs=[user_box])  # Clear input after submit

    send_btn.click(
        ask_enhanced,
        inputs=[api_key_tb, user_box, chatbot, state_corpus, topk_slider, per_table_cap],
        outputs=[chatbot],
    ).then(lambda: "", outputs=[user_box])  # Clear input after send

    reload_btn.click(
        load_json_to_corpus,
        inputs=[json_file, fallback_path, max_val_len],
        outputs=[status, state_corpus, state_docs],
    )

    clear_btn.click(
        lambda: ([], "πŸ”„ Chat cleared. Ready for new questions."),
        outputs=[chatbot, user_box]
    )

if __name__ == "__main__":
    demo.launch()