File size: 7,391 Bytes
d773e1b 2fa6289 d773e1b d5422a7 d773e1b d5422a7 d773e1b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from data_processor import DataProcessor
from chart_generator import ChartGenerator
from image_verifier import ImageVerifier
from huggingface_hub import login
import logging
import time
import os
from dotenv import load_dotenv
import ast
import requests
load_dotenv()
class LLM_Agent:
def __init__(self, data_path=None):
logging.info("Initializing LLM_Agent")
self.data_processor = DataProcessor(data_path)
self.chart_generator = ChartGenerator(self.data_processor.data)
self.image_verifier = ImageVerifier()
# Use Hugging Face Hub model path for fine-tuned model
model_path = "ArchCoder/fine-tuned-bart-large"
self.query_tokenizer = AutoTokenizer.from_pretrained(model_path)
self.query_model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
def validate_plot_args(plot_args):
required_keys = ['x', 'y', 'chart_type']
if not all(key in plot_args for key in required_keys):
return False
if not isinstance(plot_args['y'], list):
plot_args['y'] = [plot_args['y']]
return True
def process_request(self, data):
start_time = time.time()
logging.info(f"Processing request data: {data}")
query = data['query']
data_path = data.get('file_path')
model_choice = data.get('model', 'bart')
# Log file path and check existence
if data_path:
logging.info(f"Data path received: {data_path}")
import os
if not os.path.exists(data_path):
logging.error(f"File does not exist at path: {data_path}")
else:
logging.info(f"File exists at path: {data_path}")
# Few-shot + persona prompt for Flan-UL2 (best model)
flan_prompt = (
"You are VizBot, an expert data visualization assistant. "
"Given a user's natural language request about plotting data, output ONLY a valid Python dictionary with keys: x, y, chart_type, and color (if specified). "
"Do not include any explanation or extra text.\n\n"
"Example 1:\n"
"User: plot the sales in the years with red line\n"
"Output: {'x': 'Year', 'y': ['Sales'], 'chart_type': 'line', 'color': 'red'}\n\n"
"Example 2:\n"
"User: show employee expenses and net profit over the years\n"
"Output: {'x': 'Year', 'y': ['Employee expense', 'Net profit'], 'chart_type': 'line'}\n\n"
"Example 3:\n"
"User: display the EBITDA for each year with a blue bar\n"
"Output: {'x': 'Year', 'y': ['EBITDA'], 'chart_type': 'bar', 'color': 'blue'}\n\n"
f"User: {query}\nOutput:"
)
# Re-initialize data processor and chart generator if a file is specified
if data_path:
self.data_processor = DataProcessor(data_path)
# Log loaded columns
loaded_columns = self.data_processor.get_columns()
logging.info(f"Loaded columns from data: {loaded_columns}")
self.chart_generator = ChartGenerator(self.data_processor.data)
if model_choice == 'bart':
# Use local fine-tuned BART model
inputs = self.query_tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
outputs = self.query_model.generate(**inputs, max_length=100, num_return_sequences=1)
response_text = self.query_tokenizer.decode(outputs[0], skip_special_tokens=True)
elif model_choice == 'flan-t5-base':
# Use Hugging Face Inference API with Flan-T5-Base model
api_url = "https://api-inference.huggingface.co/models/google/flan-t5-base"
headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACEHUB_API_TOKEN')}", "Content-Type": "application/json"}
response = requests.post(api_url, headers=headers, json={"inputs": flan_prompt})
if response.status_code != 200:
logging.error(f"Hugging Face API error: {response.status_code} {response.text}")
response_text = "Error: Unable to get response from Flan-T5-Base API. Please try again later."
else:
try:
resp_json = response.json()
response_text = resp_json[0]['generated_text'] if isinstance(resp_json, list) else resp_json.get('generated_text', '')
except Exception as e:
logging.error(f"Error parsing Hugging Face API response: {e}, raw: {response.text}")
response_text = f"Error: Unexpected response from Flan-T5-Base API."
elif model_choice == 'flan-ul2':
# Use Hugging Face Inference API with Flan-UL2 model
api_url = "https://api-inference.huggingface.co/models/google/flan-ul2"
# Corrected model name to "google/flan-ul2" does not exist, use "google/flan-t5-xxl" as best available
api_url = "https://api-inference.huggingface.co/models/google/flan-t5-xxl"
headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACEHUB_API_TOKEN')}", "Content-Type": "application/json"}
response = requests.post(api_url, headers=headers, json={"inputs": flan_prompt})
if response.status_code != 200:
logging.error(f"Hugging Face API error: {response.status_code} {response.text}")
response_text = "Error: Unable to get response from Flan-T5-XXL API. Please try again later."
else:
try:
resp_json = response.json()
response_text = resp_json[0]['generated_text'] if isinstance(resp_json, list) else resp_json.get('generated_text', '')
except Exception as e:
logging.error(f"Error parsing Hugging Face API response: {e}, raw: {response.text}")
response_text = f"Error: Unexpected response from Flan-T5-XXL API."
else:
# Default fallback to local fine-tuned BART model
inputs = self.query_tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
outputs = self.query_model.generate(**inputs, max_length=100, num_return_sequences=1)
response_text = self.query_tokenizer.decode(outputs[0], skip_special_tokens=True)
logging.info(f"LLM response text: {response_text}")
try:
plot_args = ast.literal_eval(response_text)
except (SyntaxError, ValueError):
plot_args = {'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}
logging.warning(f"Invalid LLM response. Using default plot args: {plot_args}")
if LLM_Agent.validate_plot_args(plot_args):
chart_path = self.chart_generator.generate_chart(plot_args)
else:
logging.warning("Invalid plot arguments. Using default.")
chart_path = self.chart_generator.generate_chart({'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'})
verified = self.image_verifier.verify(chart_path, query)
end_time = time.time()
logging.info(f"Processed request in {end_time - start_time} seconds")
return {
"response": response_text,
"chart_path": chart_path,
"verified": verified
}
|