File size: 3,953 Bytes
8890d1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import gradio as gr
import pandas as pd
from transformers import TapexTokenizer, BartForConditionalGeneration, pipeline

# Initialize TAPEX (Microsoft) model and tokenizer
tokenizer_tapex = TapexTokenizer.from_pretrained("microsoft/tapex-large-finetuned-wtq")
model_tapex = BartForConditionalGeneration.from_pretrained("microsoft/tapex-large-finetuned-wtq")

# Initialize TAPAS (Google) models and pipelines
pipe_tapas = pipeline(task="table-question-answering", model="google/tapas-large-finetuned-wtq")
pipe_tapas2 = pipeline(task="table-question-answering", model="google/tapas-large-finetuned-wikisql-supervised")

def process_table_query(query, table_data):
    """

    Process a query and CSV data using TAPEX.

    """
    # Convert all columns in the table to strings for TAPEX compatibility
    table_data = table_data.astype(str)

    # Microsoft TAPEX model (using TAPEX tokenizer and model)
    encoding = tokenizer_tapex(table=table_data, query=query, return_tensors="pt", max_length=1024, truncation=True)
    outputs = model_tapex.generate(**encoding)
    result_tapex = tokenizer_tapex.batch_decode(outputs, skip_special_tokens=True)[0]
    
    return result_tapex

# Gradio interface
def answer_query_from_csv(query, file):
    """

    Function to handle file input and return model results.

    """
    # Read the file into a DataFrame
    table_data = pd.read_csv(file)

    # Convert object-type columns to lowercase (if they are valid strings)
    for column in table_data.columns:
        if table_data[column].dtype == 'object':
            table_data[column] = table_data[column].apply(lambda x: x.lower() if isinstance(x, str) else x)

    # Convert all table cells to strings for TAPEX compatibility
    table_data = table_data.astype(str)
    
    # Extract year, month, day, and time components for datetime columns
    for column in table_data.columns:
        if pd.api.types.is_datetime64_any_dtype(table_data[column]):
            table_data[f'{column}_year'] = table_data[column].dt.year
            table_data[f'{column}_month'] = table_data[column].dt.month
            table_data[f'{column}_day'] = table_data[column].dt.day
            table_data[f'{column}_time'] = table_data[column].dt.strftime('%H:%M:%S')

    # Process the CSV file and query
    result_tapex = process_table_query(query, table_data)
    
    # Process the query using TAPAS pipelines
    result_tapas = pipe_tapas(table=table_data, query=query)['cells'][0]
    result_tapas2 = pipe_tapas2(table=table_data, query=query)['cells'][0]
    
    return result_tapex, result_tapas, result_tapas2

# Create Gradio interface
with gr.Blocks() as interface:
    gr.Markdown("# Table Question Answering with TAPEX and TAPAS Models")
    
    # Add a notice about the token limit
    gr.Markdown("### Note: Only the first 1024 tokens (query + table data) will be considered. If your table is too large, it will be truncated to fit within this limit.")
    
    # Two-column layout (input on the left, output on the right)
    with gr.Row():
        with gr.Column():
            # Input fields for the query and file
            query_input = gr.Textbox(label="Enter your query:")
            csv_input = gr.File(label="Upload your CSV file")

        with gr.Column():
            # Output textboxes for the answers
            result_tapex = gr.Textbox(label="TAPEX Answer")
            result_tapas = gr.Textbox(label="TAPAS (WikiTableQuestions) Answer")
            result_tapas2 = gr.Textbox(label="TAPAS (WikiSQL) Answer")
    
    # Submit button
    submit_btn = gr.Button("Submit")
    
    # Action when submit button is clicked
    submit_btn.click(
        fn=answer_query_from_csv, 
        inputs=[query_input, csv_input], 
        outputs=[result_tapex, result_tapas, result_tapas2]
    )

# Launch the Gradio interface
interface.launch(share=True)