“Transcendental-Programmer”
commited on
Commit
·
4ac0bf8
1
Parent(s):
b0741bf
fix: Update chart generation and LLM agent functionality
Browse files- app.py +9 -2
- chart_generator.py +26 -4
- llm_agent.py +100 -70
app.py
CHANGED
@@ -16,7 +16,9 @@ logging.getLogger('PIL').setLevel(logging.WARNING)
|
|
16 |
|
17 |
app = Flask(__name__, static_folder=os.path.join(os.path.dirname(__file__), '..', 'static'))
|
18 |
|
19 |
-
CORS
|
|
|
|
|
20 |
agent = LLM_Agent()
|
21 |
|
22 |
UPLOAD_FOLDER = os.path.join(os.path.dirname(__file__), '..', 'data', 'uploads')
|
@@ -53,7 +55,12 @@ def plot():
|
|
53 |
@app.route('/static/<path:filename>')
|
54 |
def serve_static(filename):
|
55 |
logging.info(f"Serving static file: {filename}")
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
@app.route('/upload', methods=['POST'])
|
59 |
def upload_file():
|
|
|
16 |
|
17 |
app = Flask(__name__, static_folder=os.path.join(os.path.dirname(__file__), '..', 'static'))
|
18 |
|
19 |
+
# Configure CORS to allow all origins for development
|
20 |
+
CORS(app, origins=["*"], supports_credentials=True)
|
21 |
+
|
22 |
agent = LLM_Agent()
|
23 |
|
24 |
UPLOAD_FOLDER = os.path.join(os.path.dirname(__file__), '..', 'data', 'uploads')
|
|
|
55 |
@app.route('/static/<path:filename>')
|
56 |
def serve_static(filename):
|
57 |
logging.info(f"Serving static file: {filename}")
|
58 |
+
response = send_from_directory(app.static_folder, filename)
|
59 |
+
# Add CORS headers for images
|
60 |
+
response.headers.add('Access-Control-Allow-Origin', '*')
|
61 |
+
response.headers.add('Access-Control-Allow-Headers', 'Content-Type')
|
62 |
+
response.headers.add('Access-Control-Allow-Methods', 'GET')
|
63 |
+
return response
|
64 |
|
65 |
@app.route('/upload', methods=['POST'])
|
66 |
def upload_file():
|
chart_generator.py
CHANGED
@@ -27,32 +27,54 @@ class ChartGenerator:
|
|
27 |
missing_cols.append(y)
|
28 |
if missing_cols:
|
29 |
logging.error(f"Missing columns in data: {missing_cols}")
|
|
|
30 |
raise ValueError(f"Missing columns in data: {missing_cols}")
|
31 |
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
33 |
for y in y_cols:
|
34 |
color = plot_args.get('color', None)
|
35 |
if plot_args.get('chart_type', 'line') == 'bar':
|
36 |
ax.bar(self.data[x_col], self.data[y], label=y, color=color)
|
37 |
else:
|
38 |
-
ax.plot(self.data[x_col], self.data[y], label=y, color=color)
|
39 |
|
40 |
ax.set_xlabel(x_col)
|
|
|
|
|
41 |
ax.legend()
|
|
|
42 |
|
|
|
|
|
|
|
43 |
|
44 |
chart_filename = 'chart.png'
|
45 |
output_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'static', 'images')
|
46 |
if not os.path.exists(output_dir):
|
47 |
os.makedirs(output_dir)
|
|
|
48 |
|
49 |
full_path = os.path.join(output_dir, chart_filename)
|
50 |
|
51 |
if os.path.exists(full_path):
|
52 |
os.remove(full_path)
|
|
|
53 |
|
54 |
-
|
|
|
|
|
55 |
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
return os.path.join('static', 'images', chart_filename)
|
|
|
27 |
missing_cols.append(y)
|
28 |
if missing_cols:
|
29 |
logging.error(f"Missing columns in data: {missing_cols}")
|
30 |
+
logging.info(f"Available columns: {list(self.data.columns)}")
|
31 |
raise ValueError(f"Missing columns in data: {missing_cols}")
|
32 |
|
33 |
+
# Clear any existing plots
|
34 |
+
plt.clf()
|
35 |
+
plt.close('all')
|
36 |
+
|
37 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
38 |
+
|
39 |
for y in y_cols:
|
40 |
color = plot_args.get('color', None)
|
41 |
if plot_args.get('chart_type', 'line') == 'bar':
|
42 |
ax.bar(self.data[x_col], self.data[y], label=y, color=color)
|
43 |
else:
|
44 |
+
ax.plot(self.data[x_col], self.data[y], label=y, color=color, marker='o')
|
45 |
|
46 |
ax.set_xlabel(x_col)
|
47 |
+
ax.set_ylabel('Value')
|
48 |
+
ax.set_title(f'{plot_args.get("chart_type", "line").title()} Chart')
|
49 |
ax.legend()
|
50 |
+
ax.grid(True, alpha=0.3)
|
51 |
|
52 |
+
# Rotate x-axis labels if needed
|
53 |
+
if len(self.data[x_col]) > 5:
|
54 |
+
plt.xticks(rotation=45)
|
55 |
|
56 |
chart_filename = 'chart.png'
|
57 |
output_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'static', 'images')
|
58 |
if not os.path.exists(output_dir):
|
59 |
os.makedirs(output_dir)
|
60 |
+
logging.info(f"Created output directory: {output_dir}")
|
61 |
|
62 |
full_path = os.path.join(output_dir, chart_filename)
|
63 |
|
64 |
if os.path.exists(full_path):
|
65 |
os.remove(full_path)
|
66 |
+
logging.info(f"Removed existing chart file: {full_path}")
|
67 |
|
68 |
+
# Save with high DPI for better quality
|
69 |
+
plt.savefig(full_path, dpi=300, bbox_inches='tight', facecolor='white')
|
70 |
+
plt.close(fig)
|
71 |
|
72 |
+
# Verify file was created
|
73 |
+
if os.path.exists(full_path):
|
74 |
+
file_size = os.path.getsize(full_path)
|
75 |
+
logging.info(f"Chart generated and saved to {full_path} (size: {file_size} bytes)")
|
76 |
+
else:
|
77 |
+
logging.error(f"Failed to create chart file at {full_path}")
|
78 |
+
raise FileNotFoundError(f"Chart file was not created at {full_path}")
|
79 |
|
80 |
return os.path.join('static', 'images', chart_filename)
|
llm_agent.py
CHANGED
@@ -10,6 +10,7 @@ import os
|
|
10 |
from dotenv import load_dotenv
|
11 |
import ast
|
12 |
import requests
|
|
|
13 |
|
14 |
load_dotenv()
|
15 |
|
@@ -36,7 +37,7 @@ class LLM_Agent:
|
|
36 |
def process_request(self, data):
|
37 |
start_time = time.time()
|
38 |
logging.info(f"Processing request data: {data}")
|
39 |
-
query = data
|
40 |
data_path = data.get('file_path')
|
41 |
model_choice = data.get('model', 'bart')
|
42 |
|
@@ -49,8 +50,16 @@ class LLM_Agent:
|
|
49 |
else:
|
50 |
logging.info(f"File exists at path: {data_path}")
|
51 |
|
52 |
-
#
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
"You are VizBot, an expert data visualization assistant. "
|
55 |
"Given a user's natural language request about plotting data, output ONLY a valid Python dictionary with keys: x, y, chart_type, and color (if specified). "
|
56 |
"Do not include any explanation or extra text.\n\n"
|
@@ -66,73 +75,94 @@ class LLM_Agent:
|
|
66 |
f"User: {query}\nOutput:"
|
67 |
)
|
68 |
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
else:
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
logging.error(f"Error parsing Hugging Face API response: {e}, raw: {response.text}")
|
113 |
-
response_text = f"Error: Unexpected response from Flan-T5-XXL API."
|
114 |
-
else:
|
115 |
-
# Default fallback to local fine-tuned BART model
|
116 |
-
inputs = self.query_tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
|
117 |
-
outputs = self.query_model.generate(**inputs, max_length=100, num_return_sequences=1)
|
118 |
-
response_text = self.query_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
119 |
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
chart_path = self.chart_generator.generate_chart(plot_args)
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
from dotenv import load_dotenv
|
11 |
import ast
|
12 |
import requests
|
13 |
+
import json
|
14 |
|
15 |
load_dotenv()
|
16 |
|
|
|
37 |
def process_request(self, data):
|
38 |
start_time = time.time()
|
39 |
logging.info(f"Processing request data: {data}")
|
40 |
+
query = data.get('query', '')
|
41 |
data_path = data.get('file_path')
|
42 |
model_choice = data.get('model', 'bart')
|
43 |
|
|
|
50 |
else:
|
51 |
logging.info(f"File exists at path: {data_path}")
|
52 |
|
53 |
+
# Re-initialize data processor and chart generator if a file is specified
|
54 |
+
if data_path:
|
55 |
+
self.data_processor = DataProcessor(data_path)
|
56 |
+
# Log loaded columns
|
57 |
+
loaded_columns = self.data_processor.get_columns()
|
58 |
+
logging.info(f"Loaded columns from data: {loaded_columns}")
|
59 |
+
self.chart_generator = ChartGenerator(self.data_processor.data)
|
60 |
+
|
61 |
+
# Enhanced prompt for better model responses
|
62 |
+
enhanced_prompt = (
|
63 |
"You are VizBot, an expert data visualization assistant. "
|
64 |
"Given a user's natural language request about plotting data, output ONLY a valid Python dictionary with keys: x, y, chart_type, and color (if specified). "
|
65 |
"Do not include any explanation or extra text.\n\n"
|
|
|
75 |
f"User: {query}\nOutput:"
|
76 |
)
|
77 |
|
78 |
+
try:
|
79 |
+
if model_choice == 'bart':
|
80 |
+
# Use local fine-tuned BART model
|
81 |
+
inputs = self.query_tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
|
82 |
+
outputs = self.query_model.generate(**inputs, max_length=100, num_return_sequences=1)
|
83 |
+
response_text = self.query_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
84 |
+
elif model_choice == 'flan-t5-base':
|
85 |
+
# Use Hugging Face Inference API with Flan-T5-Base model
|
86 |
+
api_url = "https://api-inference.huggingface.co/models/google/flan-t5-base"
|
87 |
+
headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACEHUB_API_TOKEN')}"}
|
88 |
+
payload = {"inputs": enhanced_prompt}
|
89 |
+
|
90 |
+
response = requests.post(api_url, headers=headers, json=payload, timeout=30)
|
91 |
+
if response.status_code != 200:
|
92 |
+
logging.error(f"Hugging Face API error: {response.status_code} {response.text}")
|
93 |
+
response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
|
94 |
+
else:
|
95 |
+
try:
|
96 |
+
resp_json = response.json()
|
97 |
+
response_text = resp_json[0]['generated_text'] if isinstance(resp_json, list) else resp_json.get('generated_text', '')
|
98 |
+
if not response_text:
|
99 |
+
response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
|
100 |
+
except Exception as e:
|
101 |
+
logging.error(f"Error parsing Hugging Face API response: {e}, raw: {response.text}")
|
102 |
+
response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
|
103 |
+
elif model_choice == 'flan-ul2':
|
104 |
+
# Use Hugging Face Inference API with Flan-T5-XXL model (best available)
|
105 |
+
api_url = "https://api-inference.huggingface.co/models/google/flan-t5-xxl"
|
106 |
+
headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACEHUB_API_TOKEN')}"}
|
107 |
+
payload = {"inputs": enhanced_prompt}
|
108 |
+
|
109 |
+
response = requests.post(api_url, headers=headers, json=payload, timeout=30)
|
110 |
+
if response.status_code != 200:
|
111 |
+
logging.error(f"Hugging Face API error: {response.status_code} {response.text}")
|
112 |
+
response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
|
113 |
+
else:
|
114 |
+
try:
|
115 |
+
resp_json = response.json()
|
116 |
+
response_text = resp_json[0]['generated_text'] if isinstance(resp_json, list) else resp_json.get('generated_text', '')
|
117 |
+
if not response_text:
|
118 |
+
response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
|
119 |
+
except Exception as e:
|
120 |
+
logging.error(f"Error parsing Hugging Face API response: {e}, raw: {response.text}")
|
121 |
+
response_text = "{'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}"
|
122 |
else:
|
123 |
+
# Default fallback to local fine-tuned BART model
|
124 |
+
inputs = self.query_tokenizer(query, return_tensors="pt", max_length=512, truncation=True)
|
125 |
+
outputs = self.query_model.generate(**inputs, max_length=100, num_return_sequences=1)
|
126 |
+
response_text = self.query_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
+
logging.info(f"LLM response text: {response_text}")
|
129 |
+
|
130 |
+
# Clean and parse the response
|
131 |
+
response_text = response_text.strip()
|
132 |
+
if response_text.startswith("```") and response_text.endswith("```"):
|
133 |
+
response_text = response_text[3:-3].strip()
|
134 |
+
if response_text.startswith("python"):
|
135 |
+
response_text = response_text[6:].strip()
|
136 |
+
|
137 |
+
try:
|
138 |
+
plot_args = ast.literal_eval(response_text)
|
139 |
+
except (SyntaxError, ValueError) as e:
|
140 |
+
logging.warning(f"Invalid LLM response: {e}. Response: {response_text}")
|
141 |
+
plot_args = {'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}
|
142 |
+
|
143 |
+
if not LLM_Agent.validate_plot_args(plot_args):
|
144 |
+
logging.warning("Invalid plot arguments. Using default.")
|
145 |
+
plot_args = {'x': 'Year', 'y': ['Sales'], 'chart_type': 'line'}
|
146 |
+
|
147 |
chart_path = self.chart_generator.generate_chart(plot_args)
|
148 |
+
verified = self.image_verifier.verify(chart_path, query)
|
149 |
+
|
150 |
+
end_time = time.time()
|
151 |
+
logging.info(f"Processed request in {end_time - start_time} seconds")
|
152 |
+
|
153 |
+
return {
|
154 |
+
"response": response_text,
|
155 |
+
"chart_path": chart_path,
|
156 |
+
"verified": verified
|
157 |
+
}
|
158 |
+
|
159 |
+
except Exception as e:
|
160 |
+
logging.error(f"Error processing request: {e}")
|
161 |
+
end_time = time.time()
|
162 |
+
logging.info(f"Processed request in {end_time - start_time} seconds")
|
163 |
+
|
164 |
+
return {
|
165 |
+
"response": f"Error: {str(e)}",
|
166 |
+
"chart_path": "",
|
167 |
+
"verified": False
|
168 |
+
}
|