File size: 8,530 Bytes
d773e1b
 
 
 
 
 
 
 
 
 
 
 
4ac0bf8
d773e1b
 
 
 
 
 
 
 
 
 
2fa6289
 
d773e1b
 
 
 
 
 
 
 
 
 
 
 
 
 
4ac0bf8
d773e1b
 
 
d5422a7
 
 
 
 
 
 
 
 
4ac0bf8
 
 
 
 
 
 
 
 
 
d773e1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ac0bf8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d773e1b
4ac0bf8
 
 
 
d773e1b
4ac0bf8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d773e1b
4ac0bf8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from data_processor import DataProcessor
from chart_generator import ChartGenerator
from image_verifier import ImageVerifier
from huggingface_hub import login
import logging
import time
import os
from dotenv import load_dotenv
import ast
import requests
import json

load_dotenv()

class LLM_Agent:
    def __init__(self, data_path=None):
        logging.info("Initializing LLM_Agent")
        self.data_processor = DataProcessor(data_path)
        self.chart_generator = ChartGenerator(self.data_processor.data)
        self.image_verifier = ImageVerifier()

        # Use Hugging Face Hub model path for fine-tuned model
        model_path = "ArchCoder/fine-tuned-bart-large"
        self.query_tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.query_model = AutoModelForSeq2SeqLM.from_pretrained(model_path)

    def validate_plot_args(plot_args):
        required_keys = ['x', 'y', 'chart_type']
        if not all(key in plot_args for key in required_keys):
            return False
        if not isinstance(plot_args['y'], list):
            plot_args['y'] = [plot_args['y']]
        return True

    def process_request(self, data):
        start_time = time.time()
        logging.info(f"Processing request data: {data}")
        query = data.get('query', '')
        data_path = data.get('file_path')
        model_choice = data.get('model', 'bart')

        # Log file path and check existence
        if data_path:
            logging.info(f"Data path received: {data_path}")
            import os
            if not os.path.exists(data_path):
                logging.error(f"File does not exist at path: {data_path}")
            else:
                logging.info(f"File exists at path: {data_path}")

        # Re-initialize data processor and chart generator if a file is specified
        if data_path:
            self.data_processor = DataProcessor(data_path)
            # Log loaded columns
            loaded_columns = self.data_processor.get_columns()
            logging.info(f"Loaded columns from data: {loaded_columns}")
            self.chart_generator = ChartGenerator(self.data_processor.data)

        # Enhanced prompt for better model responses
        enhanced_prompt = (
            "You are VizBot, an expert data visualization assistant. "
            "Given a user's natural language request about plotting data, output ONLY a valid Python dictionary with keys: x, y, chart_type, and color (if specified). "
            "Do not include any explanation or extra text.\n\n"
            "Example 1:\n"
            "User: plot the sales in the years with red line\n"
            "Output: {'x': 'Year', 'y': ['Sales'], 'chart_type': 'line', 'color': 'red'}\n\n"
            "Example 2:\n"
            "User: show employee expenses and net profit over the years\n"
            "Output: {'x': 'Year', 'y': ['Employee expense', 'Net profit'], 'chart_type': 'line'}\n\n"
            "Example 3:\n"
            "User: display the EBITDA for each year with a blue bar\n"
            "Output: {'x': 'Year', 'y': ['EBITDA'], 'chart_type': 'bar', 'color': 'blue'}\n\n"
            f"User: {query}\nOutput:"
        )

        try:
            if model_choice == 'bart':
                # Use local fine-tuned BART model
                inputs = self.query_tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
                outputs = self.query_model.generate(**inputs, max_length=100, num_return_sequences=1)
                response_text = self.query_tokenizer.decode(outputs[0], skip_special_tokens=True)
            elif model_choice == 'flan-t5-base':
                # Use Hugging Face Inference API with Flan-T5-Base model
                api_url = "https://api-inference.huggingface.co/models/google/flan-t5-base"
                headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACEHUB_API_TOKEN')}"}
                payload = {"inputs": enhanced_prompt}
                
                response = requests.post(api_url, headers=headers, json=payload, timeout=30)
                if response.status_code != 200:
                    logging.error(f"Hugging Face API error: {response.status_code} {response.text}")
                    response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
                else:
                    try:
                        resp_json = response.json()
                        response_text = resp_json[0]['generated_text'] if isinstance(resp_json, list) else resp_json.get('generated_text', '')
                        if not response_text:
                            response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
                    except Exception as e:
                        logging.error(f"Error parsing Hugging Face API response: {e}, raw: {response.text}")
                        response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
            elif model_choice == 'flan-ul2':
                # Use Hugging Face Inference API with Flan-T5-XXL model (best available)
                api_url = "https://api-inference.huggingface.co/models/google/flan-t5-xxl"
                headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACEHUB_API_TOKEN')}"}
                payload = {"inputs": enhanced_prompt}
                
                response = requests.post(api_url, headers=headers, json=payload, timeout=30)
                if response.status_code != 200:
                    logging.error(f"Hugging Face API error: {response.status_code} {response.text}")
                    response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
                else:
                    try:
                        resp_json = response.json()
                        response_text = resp_json[0]['generated_text'] if isinstance(resp_json, list) else resp_json.get('generated_text', '')
                        if not response_text:
                            response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
                    except Exception as e:
                        logging.error(f"Error parsing Hugging Face API response: {e}, raw: {response.text}")
                        response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
            else:
                # Default fallback to local fine-tuned BART model
                inputs = self.query_tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
                outputs = self.query_model.generate(**inputs, max_length=100, num_return_sequences=1)
                response_text = self.query_tokenizer.decode(outputs[0], skip_special_tokens=True)

            logging.info(f"LLM response text: {response_text}")
            
            # Clean and parse the response
            response_text = response_text.strip()
            if response_text.startswith("```") and response_text.endswith("```"):
                response_text = response_text[3:-3].strip()
            if response_text.startswith("python"):
                response_text = response_text[6:].strip()
            
            try:
                plot_args = ast.literal_eval(response_text)
            except (SyntaxError, ValueError) as e:
                logging.warning(f"Invalid LLM response: {e}. Response: {response_text}")
                plot_args = {'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}
            
            if not LLM_Agent.validate_plot_args(plot_args):
                logging.warning("Invalid plot arguments. Using default.")
                plot_args = {'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}
            
            chart_path = self.chart_generator.generate_chart(plot_args)
            verified = self.image_verifier.verify(chart_path, query)
            
            end_time = time.time()
            logging.info(f"Processed request in {end_time - start_time} seconds")
            
            return {
                "response": response_text,
                "chart_path": chart_path,
                "verified": verified
            }
            
        except Exception as e:
            logging.error(f"Error processing request: {e}")
            end_time = time.time()
            logging.info(f"Processed request in {end_time - start_time} seconds")
            
            return {
                "response": f"Error: {str(e)}",
                "chart_path": "",
                "verified": False
            }