File size: 6,836 Bytes
d773e1b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from data_processor import DataProcessor
from chart_generator import ChartGenerator
from image_verifier import ImageVerifier
from huggingface_hub import login
import logging
import time
import os
from dotenv import load_dotenv
import ast
import requests
load_dotenv()
class LLM_Agent:
def __init__(self, data_path=None):
logging.info("Initializing LLM_Agent")
self.data_processor = DataProcessor(data_path)
self.chart_generator = ChartGenerator(self.data_processor.data)
self.image_verifier = ImageVerifier()
model_path = os.path.join(os.path.dirname(__file__), "fine-tuned-bart-large")
self.query_tokenizer = AutoTokenizer.from_pretrained(model_path)
self.query_model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
def validate_plot_args(plot_args):
required_keys = ['x', 'y', 'chart_type']
if not all(key in plot_args for key in required_keys):
return False
if not isinstance(plot_args['y'], list):
plot_args['y'] = [plot_args['y']]
return True
def process_request(self, data):
start_time = time.time()
logging.info(f"Processing request data: {data}")
query = data['query']
data_path = data.get('file_path')
model_choice = data.get('model', 'bart')
# Few-shot + persona prompt for Flan-UL2 (best model)
flan_prompt = (
"You are VizBot, an expert data visualization assistant. "
"Given a user's natural language request about plotting data, output ONLY a valid Python dictionary with keys: x, y, chart_type, and color (if specified). "
"Do not include any explanation or extra text.\n\n"
"Example 1:\n"
"User: plot the sales in the years with red line\n"
"Output: {'x': 'Year', 'y': ['Sales'], 'chart_type': 'line', 'color': 'red'}\n\n"
"Example 2:\n"
"User: show employee expenses and net profit over the years\n"
"Output: {'x': 'Year', 'y': ['Employee expense', 'Net profit'], 'chart_type': 'line'}\n\n"
"Example 3:\n"
"User: display the EBITDA for each year with a blue bar\n"
"Output: {'x': 'Year', 'y': ['EBITDA'], 'chart_type': 'bar', 'color': 'blue'}\n\n"
f"User: {query}\nOutput:"
)
# Re-initialize data processor and chart generator if a file is specified
if data_path:
self.data_processor = DataProcessor(data_path)
self.chart_generator = ChartGenerator(self.data_processor.data)
if model_choice == 'bart':
# Use local fine-tuned BART model
inputs = self.query_tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
outputs = self.query_model.generate(**inputs, max_length=100, num_return_sequences=1)
response_text = self.query_tokenizer.decode(outputs[0], skip_special_tokens=True)
elif model_choice == 'flan-t5-base':
# Use Hugging Face Inference API with Flan-T5-Base model
api_url = "https://api-inference.huggingface.co/models/google/flan-t5-base"
headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACEHUB_API_TOKEN')}", "Content-Type": "application/json"}
response = requests.post(api_url, headers=headers, json={"inputs": flan_prompt})
if response.status_code != 200:
logging.error(f"Hugging Face API error: {response.status_code} {response.text}")
response_text = "Error: Unable to get response from Flan-T5-Base API. Please try again later."
else:
try:
resp_json = response.json()
response_text = resp_json[0]['generated_text'] if isinstance(resp_json, list) else resp_json.get('generated_text', '')
except Exception as e:
logging.error(f"Error parsing Hugging Face API response: {e}, raw: {response.text}")
response_text = f"Error: Unexpected response from Flan-T5-Base API."
elif model_choice == 'flan-ul2':
# Use Hugging Face Inference API with Flan-UL2 model
api_url = "https://api-inference.huggingface.co/models/google/flan-ul2"
# Corrected model name to "google/flan-ul2" does not exist, use "google/flan-t5-xxl" as best available
api_url = "https://api-inference.huggingface.co/models/google/flan-t5-xxl"
headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACEHUB_API_TOKEN')}", "Content-Type": "application/json"}
response = requests.post(api_url, headers=headers, json={"inputs": flan_prompt})
if response.status_code != 200:
logging.error(f"Hugging Face API error: {response.status_code} {response.text}")
response_text = "Error: Unable to get response from Flan-T5-XXL API. Please try again later."
else:
try:
resp_json = response.json()
response_text = resp_json[0]['generated_text'] if isinstance(resp_json, list) else resp_json.get('generated_text', '')
except Exception as e:
logging.error(f"Error parsing Hugging Face API response: {e}, raw: {response.text}")
response_text = f"Error: Unexpected response from Flan-T5-XXL API."
else:
# Default fallback to local fine-tuned BART model
inputs = self.query_tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
outputs = self.query_model.generate(**inputs, max_length=100, num_return_sequences=1)
response_text = self.query_tokenizer.decode(outputs[0], skip_special_tokens=True)
logging.info(f"LLM response text: {response_text}")
try:
plot_args = ast.literal_eval(response_text)
except (SyntaxError, ValueError):
plot_args = {'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}
logging.warning(f"Invalid LLM response. Using default plot args: {plot_args}")
if LLM_Agent.validate_plot_args(plot_args):
chart_path = self.chart_generator.generate_chart(plot_args)
else:
logging.warning("Invalid plot arguments. Using default.")
chart_path = self.chart_generator.generate_chart({'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'})
verified = self.image_verifier.verify(chart_path, query)
end_time = time.time()
logging.info(f"Processed request in {end_time - start_time} seconds")
return {
"response": response_text,
"chart_path": chart_path,
"verified": verified
}
|