File size: 8,530 Bytes
d773e1b 4ac0bf8 d773e1b 2fa6289 d773e1b 4ac0bf8 d773e1b d5422a7 4ac0bf8 d773e1b 4ac0bf8 d773e1b 4ac0bf8 d773e1b 4ac0bf8 d773e1b 4ac0bf8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from data_processor import DataProcessor
from chart_generator import ChartGenerator
from image_verifier import ImageVerifier
from huggingface_hub import login
import logging
import time
import os
from dotenv import load_dotenv
import ast
import requests
import json
load_dotenv()
class LLM_Agent:
def __init__(self, data_path=None):
logging.info("Initializing LLM_Agent")
self.data_processor = DataProcessor(data_path)
self.chart_generator = ChartGenerator(self.data_processor.data)
self.image_verifier = ImageVerifier()
# Use Hugging Face Hub model path for fine-tuned model
model_path = "ArchCoder/fine-tuned-bart-large"
self.query_tokenizer = AutoTokenizer.from_pretrained(model_path)
self.query_model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
def validate_plot_args(plot_args):
required_keys = ['x', 'y', 'chart_type']
if not all(key in plot_args for key in required_keys):
return False
if not isinstance(plot_args['y'], list):
plot_args['y'] = [plot_args['y']]
return True
def process_request(self, data):
start_time = time.time()
logging.info(f"Processing request data: {data}")
query = data.get('query', '')
data_path = data.get('file_path')
model_choice = data.get('model', 'bart')
# Log file path and check existence
if data_path:
logging.info(f"Data path received: {data_path}")
import os
if not os.path.exists(data_path):
logging.error(f"File does not exist at path: {data_path}")
else:
logging.info(f"File exists at path: {data_path}")
# Re-initialize data processor and chart generator if a file is specified
if data_path:
self.data_processor = DataProcessor(data_path)
# Log loaded columns
loaded_columns = self.data_processor.get_columns()
logging.info(f"Loaded columns from data: {loaded_columns}")
self.chart_generator = ChartGenerator(self.data_processor.data)
# Enhanced prompt for better model responses
enhanced_prompt = (
"You are VizBot, an expert data visualization assistant. "
"Given a user's natural language request about plotting data, output ONLY a valid Python dictionary with keys: x, y, chart_type, and color (if specified). "
"Do not include any explanation or extra text.\n\n"
"Example 1:\n"
"User: plot the sales in the years with red line\n"
"Output: {'x': 'Year', 'y': ['Sales'], 'chart_type': 'line', 'color': 'red'}\n\n"
"Example 2:\n"
"User: show employee expenses and net profit over the years\n"
"Output: {'x': 'Year', 'y': ['Employee expense', 'Net profit'], 'chart_type': 'line'}\n\n"
"Example 3:\n"
"User: display the EBITDA for each year with a blue bar\n"
"Output: {'x': 'Year', 'y': ['EBITDA'], 'chart_type': 'bar', 'color': 'blue'}\n\n"
f"User: {query}\nOutput:"
)
try:
if model_choice == 'bart':
# Use local fine-tuned BART model
inputs = self.query_tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
outputs = self.query_model.generate(**inputs, max_length=100, num_return_sequences=1)
response_text = self.query_tokenizer.decode(outputs[0], skip_special_tokens=True)
elif model_choice == 'flan-t5-base':
# Use Hugging Face Inference API with Flan-T5-Base model
api_url = "https://api-inference.huggingface.co/models/google/flan-t5-base"
headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACEHUB_API_TOKEN')}"}
payload = {"inputs": enhanced_prompt}
response = requests.post(api_url, headers=headers, json=payload, timeout=30)
if response.status_code != 200:
logging.error(f"Hugging Face API error: {response.status_code} {response.text}")
response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
else:
try:
resp_json = response.json()
response_text = resp_json[0]['generated_text'] if isinstance(resp_json, list) else resp_json.get('generated_text', '')
if not response_text:
response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
except Exception as e:
logging.error(f"Error parsing Hugging Face API response: {e}, raw: {response.text}")
response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
elif model_choice == 'flan-ul2':
# Use Hugging Face Inference API with Flan-T5-XXL model (best available)
api_url = "https://api-inference.huggingface.co/models/google/flan-t5-xxl"
headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACEHUB_API_TOKEN')}"}
payload = {"inputs": enhanced_prompt}
response = requests.post(api_url, headers=headers, json=payload, timeout=30)
if response.status_code != 200:
logging.error(f"Hugging Face API error: {response.status_code} {response.text}")
response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
else:
try:
resp_json = response.json()
response_text = resp_json[0]['generated_text'] if isinstance(resp_json, list) else resp_json.get('generated_text', '')
if not response_text:
response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
except Exception as e:
logging.error(f"Error parsing Hugging Face API response: {e}, raw: {response.text}")
response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
else:
# Default fallback to local fine-tuned BART model
inputs = self.query_tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
outputs = self.query_model.generate(**inputs, max_length=100, num_return_sequences=1)
response_text = self.query_tokenizer.decode(outputs[0], skip_special_tokens=True)
logging.info(f"LLM response text: {response_text}")
# Clean and parse the response
response_text = response_text.strip()
if response_text.startswith("```") and response_text.endswith("```"):
response_text = response_text[3:-3].strip()
if response_text.startswith("python"):
response_text = response_text[6:].strip()
try:
plot_args = ast.literal_eval(response_text)
except (SyntaxError, ValueError) as e:
logging.warning(f"Invalid LLM response: {e}. Response: {response_text}")
plot_args = {'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}
if not LLM_Agent.validate_plot_args(plot_args):
logging.warning("Invalid plot arguments. Using default.")
plot_args = {'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}
chart_path = self.chart_generator.generate_chart(plot_args)
verified = self.image_verifier.verify(chart_path, query)
end_time = time.time()
logging.info(f"Processed request in {end_time - start_time} seconds")
return {
"response": response_text,
"chart_path": chart_path,
"verified": verified
}
except Exception as e:
logging.error(f"Error processing request: {e}")
end_time = time.time()
logging.info(f"Processed request in {end_time - start_time} seconds")
return {
"response": f"Error: {str(e)}",
"chart_path": "",
"verified": False
}
|