File size: 3,187 Bytes
1d91ffa
 
 
4d16da0
1d91ffa
 
 
 
 
4beb772
73ab43d
4beb772
a48a101
 
 
 
 
bc10f71
a48a101
 
 
 
 
 
9ed9be5
a48a101
1d91ffa
a48a101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d91ffa
 
4beb772
1d91ffa
a48a101
 
1d91ffa
 
 
 
 
8dfd657
1d91ffa
 
a48a101
1d91ffa
a48a101
1d91ffa
 
a48a101
 
 
 
 
 
 
 
 
 
 
9ed9be5
a48a101
 
 
9ed9be5
1d91ffa
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
import openai
from datasets import load_dataset
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize OpenAI API key
openai.api_key = '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'

# Load all RagBench datasets
datasets = {}
dataset_names = ['covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', 
                 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa', 
                 'tatqa', 'techqa']

for name in dataset_names:
    try:
        datasets[name] = load_dataset("rungalileo/ragbench", name, split='train')
        logger.info(f"Successfully loaded {name}")
    except Exception as e:
        logger.info(f"Skipping {name}: {str(e)}")

def process_query(query, dataset_choice="all"):
    try:
        relevant_contexts = []
        
        # Search through selected or all datasets
        search_datasets = [dataset_choice] if dataset_choice != "all" else datasets.keys()
        
        for dataset_name in search_datasets:
            if dataset_name in datasets:
                for doc in datasets[dataset_name]['documents']:
                    if any(keyword.lower() in doc.lower() for keyword in query.split()):
                        relevant_contexts.append((doc, dataset_name))
        
        # Use the most relevant context
        if relevant_contexts:
            context, source = relevant_contexts[0]
            context_info = f"From {source}: {context}"
        else:
            context_info = "Searching across all available datasets..."
        
        response = openai.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "system", "content": "You are a knowledgeable expert. Provide direct, informative answers based on the available data."},
                {"role": "user", "content": f"Context: {context_info}\nQuestion: {query}"}
            ],
            max_tokens=300,
            temperature=0.7,
        )
        
        return response.choices[0].message.content.strip()
        
    except Exception as e:
        return f"Currently searching through all available datasets for information about {query}."

# Enhanced Gradio interface with dataset selection
demo = gr.Interface(
    fn=process_query,
    inputs=[
        gr.Textbox(label="Question", placeholder="Ask any question..."),
        gr.Dropdown(
            choices=["all"] + dataset_names,
            label="Select Dataset",
            value="all"
        )
    ],
    outputs=gr.Textbox(label="Expert Response"),
    title="Multi-Dataset Knowledge Base",
    description="Search across all RagBench datasets for comprehensive information",
    examples=[
        ["What role does T-cell count play in severe human adenovirus type 55 (HAdV-55) infection?", "covidqa"],
        ["In what school district is Governor John R. Rogers High School located?", "hotpotqa"],
        ["What are the key financial metrics for Q3?", "finqa"]
    ]
)

if __name__ == "__main__":
    demo.launch(debug=True)