File size: 6,836 Bytes
d773e1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from data_processor import DataProcessor
from chart_generator import ChartGenerator
from image_verifier import ImageVerifier
from huggingface_hub import login
import logging
import time
import os
from dotenv import load_dotenv
import ast
import requests

load_dotenv()

class LLM_Agent:
    def __init__(self, data_path=None):
        logging.info("Initializing LLM_Agent")
        self.data_processor = DataProcessor(data_path)
        self.chart_generator = ChartGenerator(self.data_processor.data)
        self.image_verifier = ImageVerifier()

        model_path = os.path.join(os.path.dirname(__file__), "fine-tuned-bart-large")
        self.query_tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.query_model = AutoModelForSeq2SeqLM.from_pretrained(model_path)

    def validate_plot_args(plot_args):
        required_keys = ['x', 'y', 'chart_type']
        if not all(key in plot_args for key in required_keys):
            return False
        if not isinstance(plot_args['y'], list):
            plot_args['y'] = [plot_args['y']]
        return True

    def process_request(self, data):
        start_time = time.time()
        logging.info(f"Processing request data: {data}")
        query = data['query']
        data_path = data.get('file_path')
        model_choice = data.get('model', 'bart')

        # Few-shot + persona prompt for Flan-UL2 (best model)
        flan_prompt = (
            "You are VizBot, an expert data visualization assistant. "
            "Given a user's natural language request about plotting data, output ONLY a valid Python dictionary with keys: x, y, chart_type, and color (if specified). "
            "Do not include any explanation or extra text.\n\n"
            "Example 1:\n"
            "User: plot the sales in the years with red line\n"
            "Output: {'x': 'Year', 'y': ['Sales'], 'chart_type': 'line', 'color': 'red'}\n\n"
            "Example 2:\n"
            "User: show employee expenses and net profit over the years\n"
            "Output: {'x': 'Year', 'y': ['Employee expense', 'Net profit'], 'chart_type': 'line'}\n\n"
            "Example 3:\n"
            "User: display the EBITDA for each year with a blue bar\n"
            "Output: {'x': 'Year', 'y': ['EBITDA'], 'chart_type': 'bar', 'color': 'blue'}\n\n"
            f"User: {query}\nOutput:"
        )

        # Re-initialize data processor and chart generator if a file is specified
        if data_path:
            self.data_processor = DataProcessor(data_path)
            self.chart_generator = ChartGenerator(self.data_processor.data)

        if model_choice == 'bart':
            # Use local fine-tuned BART model
            inputs = self.query_tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
            outputs = self.query_model.generate(**inputs, max_length=100, num_return_sequences=1)
            response_text = self.query_tokenizer.decode(outputs[0], skip_special_tokens=True)
        elif model_choice == 'flan-t5-base':
            # Use Hugging Face Inference API with Flan-T5-Base model
            api_url = "https://api-inference.huggingface.co/models/google/flan-t5-base"
            headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACEHUB_API_TOKEN')}", "Content-Type": "application/json"}
            response = requests.post(api_url, headers=headers, json={"inputs": flan_prompt})
            if response.status_code != 200:
                logging.error(f"Hugging Face API error: {response.status_code} {response.text}")
                response_text = "Error: Unable to get response from Flan-T5-Base API. Please try again later."
            else:
                try:
                    resp_json = response.json()
                    response_text = resp_json[0]['generated_text'] if isinstance(resp_json, list) else resp_json.get('generated_text', '')
                except Exception as e:
                    logging.error(f"Error parsing Hugging Face API response: {e}, raw: {response.text}")
                    response_text = f"Error: Unexpected response from Flan-T5-Base API."
        elif model_choice == 'flan-ul2':
            # Use Hugging Face Inference API with Flan-UL2 model
            api_url = "https://api-inference.huggingface.co/models/google/flan-ul2"
            # Corrected model name to "google/flan-ul2" does not exist, use "google/flan-t5-xxl" as best available
            api_url = "https://api-inference.huggingface.co/models/google/flan-t5-xxl"
            headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACEHUB_API_TOKEN')}", "Content-Type": "application/json"}
            response = requests.post(api_url, headers=headers, json={"inputs": flan_prompt})
            if response.status_code != 200:
                logging.error(f"Hugging Face API error: {response.status_code} {response.text}")
                response_text = "Error: Unable to get response from Flan-T5-XXL API. Please try again later."
            else:
                try:
                    resp_json = response.json()
                    response_text = resp_json[0]['generated_text'] if isinstance(resp_json, list) else resp_json.get('generated_text', '')
                except Exception as e:
                    logging.error(f"Error parsing Hugging Face API response: {e}, raw: {response.text}")
                    response_text = f"Error: Unexpected response from Flan-T5-XXL API."
        else:
            # Default fallback to local fine-tuned BART model
            inputs = self.query_tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
            outputs = self.query_model.generate(**inputs, max_length=100, num_return_sequences=1)
            response_text = self.query_tokenizer.decode(outputs[0], skip_special_tokens=True)

        logging.info(f"LLM response text: {response_text}")
        try:
            plot_args = ast.literal_eval(response_text)
        except (SyntaxError, ValueError):
            plot_args = {'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}
            logging.warning(f"Invalid LLM response. Using default plot args: {plot_args}")
        if LLM_Agent.validate_plot_args(plot_args):
            chart_path = self.chart_generator.generate_chart(plot_args)
        else:
            logging.warning("Invalid plot arguments. Using default.")
            chart_path = self.chart_generator.generate_chart({'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'})
        verified = self.image_verifier.verify(chart_path, query)
        end_time = time.time()
        logging.info(f"Processed request in {end_time - start_time} seconds")
        return {
            "response": response_text,
            "chart_path": chart_path,
            "verified": verified
        }