Spaces:
Sleeping
Sleeping
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from sklearn.linear_model import LinearRegression | |
from sklearn.model_selection import train_test_split | |
from sklearn.metrics import mean_squared_error, r2_score | |
from scipy import stats | |
import re | |
import json | |
import os | |
import sqlite3 | |
from datetime import datetime | |
import streamlit as st | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import os | |
import io | |
from datetime import datetime | |
import base64 | |
from PIL import Image | |
# Import the DataAnalysisChatbot class | |
#from paste import DataAnalysisChatbot | |
class DataAnalysisChatbot: | |
def __init__(self): | |
self.data = None | |
self.data_source = None | |
self.conversation_history = [] | |
self.available_commands = { | |
"load": self.load_data, | |
"info": self.get_data_info, | |
"describe": self.describe_data, | |
"missing": self.check_missing_values, | |
"correlate": self.correlation_analysis, | |
"visualize": self.visualize_data, | |
"analyze": self.analyze_column, | |
"trend": self.analyze_trend, | |
"outliers": self.detect_outliers, | |
"predict": self.predictive_analysis, | |
"test": self.hypothesis_testing, | |
"report": self.generate_report, | |
"help": self.get_help | |
} | |
def process_query(self, query): | |
"""Process user query and route to appropriate function""" | |
# Add the user query to conversation history | |
self.conversation_history.append({"role": "user", "message": query, "timestamp": datetime.now()}) | |
# Check if data is loaded (except for load command and help) | |
if self.data is None and not any(cmd in query.lower() for cmd in ["load", "help"]): | |
response = "Please load data first using the 'load' command. Example: load csv path/to/file.csv" | |
self._add_to_history(response) | |
return response | |
# Parse the command | |
command = self._extract_command(query) | |
if command in self.available_commands: | |
response = self.available_commands[command](query) | |
else: | |
# Natural language understanding would go here | |
# For now, use simple keyword matching | |
if "mean" in query.lower() or "average" in query.lower(): | |
response = self.analyze_column(query) | |
elif "correlate" in query.lower() or "relationship" in query.lower(): | |
response = self.correlation_analysis(query) | |
elif "visual" in query.lower() or "plot" in query.lower() or "chart" in query.lower() or "graph" in query.lower(): | |
response = self.visualize_data(query) | |
else: | |
response = "I'm not sure how to process that query. Type 'help' for available commands." | |
self._add_to_history(response) | |
return response | |
def _extract_command(self, query): | |
"""Extract the main command from the query""" | |
words = query.lower().split() | |
for word in words: | |
if word in self.available_commands: | |
return word | |
return None | |
def _add_to_history(self, response): | |
"""Add bot response to conversation history""" | |
self.conversation_history.append({"role": "bot", "message": response, "timestamp": datetime.now()}) | |
def _extract_column_names(self, query): | |
"""Extract column names mentioned in the query""" | |
if self.data is None: | |
return [] | |
columns = [] | |
for col in self.data.columns: | |
if col.lower() in query.lower(): | |
columns.append(col) | |
return columns | |
# DATA ACCESS AND RETRIEVAL | |
def load_data(self, query): | |
"""Load data from various sources""" | |
query_lower = query.lower() | |
# CSV Loading | |
if "csv" in query_lower: | |
match = re.search(r'load\s+csv\s+(.+?)(?:\s|$)', query) | |
if match: | |
file_path = match.group(1) | |
try: | |
self.data = pd.read_csv(file_path) | |
self.data_source = f"CSV: {file_path}" | |
return f"Successfully loaded data from {file_path}. {len(self.data)} rows and {len(self.data.columns)} columns found." | |
except Exception as e: | |
return f"Error loading CSV file: {str(e)}" | |
# Excel Loading | |
elif "excel" in query_lower or "xlsx" in query_lower: | |
match = re.search(r'load\s+(?:excel|xlsx)\s+(.+?)(?:\s|$)', query) | |
if match: | |
file_path = match.group(1) | |
try: | |
self.data = pd.read_excel(file_path) | |
self.data_source = f"Excel: {file_path}" | |
return f"Successfully loaded data from Excel file {file_path}. {len(self.data)} rows and {len(self.data.columns)} columns found." | |
except Exception as e: | |
return f"Error loading Excel file: {str(e)}" | |
# SQL Database Loading | |
elif "sql" in query_lower or "database" in query_lower: | |
# Extract database path and query using regex | |
db_match = re.search(r'load\s+(?:sql|database)\s+(.+?)\s+query\s+(.+?)(?:\s|$)', query, re.IGNORECASE | re.DOTALL) | |
if db_match: | |
db_path = db_match.group(1) | |
sql_query = db_match.group(2) | |
try: | |
conn = sqlite3.connect(db_path) | |
self.data = pd.read_sql_query(sql_query, conn) | |
conn.close() | |
self.data_source = f"SQL: {db_path}, Query: {sql_query}" | |
return f"Successfully loaded data from SQL query. {len(self.data)} rows and {len(self.data.columns)} columns found." | |
except Exception as e: | |
return f"Error executing SQL query: {str(e)}" | |
# JSON Loading | |
elif "json" in query_lower: | |
match = re.search(r'load\s+json\s+(.+?)(?:\s|$)', query) | |
if match: | |
file_path = match.group(1) | |
try: | |
with open(file_path, 'r') as f: | |
json_data = json.load(f) | |
self.data = pd.json_normalize(json_data) | |
self.data_source = f"JSON: {file_path}" | |
return f"Successfully loaded data from JSON file {file_path}. {len(self.data)} rows and {len(self.data.columns)} columns found." | |
except Exception as e: | |
return f"Error loading JSON file: {str(e)}" | |
return "Please specify the data source format and path. Example: 'load csv data.csv' or 'load sql database.db query SELECT * FROM table'" | |
def get_data_info(self, query): | |
"""Get basic information about the loaded data""" | |
if self.data is None: | |
return "No data loaded. Please load data first." | |
info = f"Data Source: {self.data_source}\n" | |
info += f"Rows: {len(self.data)}\n" | |
info += f"Columns: {len(self.data.columns)}\n" | |
info += f"Column Names: {', '.join(self.data.columns)}\n" | |
info += f"Data Types:\n{self.data.dtypes.to_string()}\n" | |
memory_usage = self.data.memory_usage(deep=True).sum() | |
if memory_usage < 1024: | |
memory_str = f"{memory_usage} bytes" | |
elif memory_usage < 1024 * 1024: | |
memory_str = f"{memory_usage / 1024:.2f} KB" | |
else: | |
memory_str = f"{memory_usage / (1024 * 1024):.2f} MB" | |
info += f"Memory Usage: {memory_str}" | |
return info | |
def describe_data(self, query): | |
"""Provide descriptive statistics for the data""" | |
if self.data is None: | |
return "No data loaded. Please load data first." | |
# Check if specific columns are mentioned | |
columns = self._extract_column_names(query) | |
if columns: | |
try: | |
desc = self.data[columns].describe().to_string() | |
return f"Descriptive statistics for columns {', '.join(columns)}:\n{desc}" | |
except Exception as e: | |
return f"Error generating descriptive statistics: {str(e)}" | |
else: | |
# If no specific columns mentioned, describe all numeric columns | |
numeric_cols = self.data.select_dtypes(include=['number']).columns.tolist() | |
if not numeric_cols: | |
return "No numeric columns found in the data for descriptive statistics." | |
desc = self.data[numeric_cols].describe().to_string() | |
return f"Descriptive statistics for all numeric columns:\n{desc}" | |
def check_missing_values(self, query): | |
"""Check for missing values in the data""" | |
if self.data is None: | |
return "No data loaded. Please load data first." | |
missing_values = self.data.isnull().sum() | |
missing_percentage = (missing_values / len(self.data) * 100).round(2) | |
result = "Missing Values Analysis:\n" | |
for col, count in missing_values.items(): | |
if count > 0: | |
result += f"{col}: {count} missing values ({missing_percentage[col]}%)\n" | |
if not any(missing_values > 0): | |
result += "No missing values found in the dataset." | |
else: | |
total_missing = missing_values.sum() | |
total_cells = self.data.size | |
overall_percentage = (total_missing / total_cells * 100).round(2) | |
result += f"\nOverall: {total_missing} missing values out of {total_cells} cells ({overall_percentage}%)" | |
return result | |
# DATA ANALYSIS AND INTERPRETATION | |
def analyze_column(self, query): | |
"""Analyze a specific column""" | |
if self.data is None: | |
return "No data loaded. Please load data first." | |
columns = self._extract_column_names(query) | |
if not columns: | |
return "Please specify a column name to analyze. Available columns: " + ", ".join(self.data.columns) | |
column = columns[0] # Take the first column mentioned | |
try: | |
col_data = self.data[column] | |
if pd.api.types.is_numeric_dtype(col_data): | |
# Numeric column analysis | |
stats = { | |
"Count": len(col_data), | |
"Missing": col_data.isnull().sum(), | |
"Mean": col_data.mean(), | |
"Median": col_data.median(), | |
"Mode": col_data.mode()[0] if not col_data.mode().empty else None, | |
"Std Dev": col_data.std(), | |
"Min": col_data.min(), | |
"Max": col_data.max(), | |
"25%": col_data.quantile(0.25), | |
"75%": col_data.quantile(0.75), | |
"Skewness": col_data.skew(), | |
"Kurtosis": col_data.kurt() | |
} | |
result = f"Analysis of column '{column}' (Numeric):\n" | |
for stat_name, stat_value in stats.items(): | |
if isinstance(stat_value, float): | |
result += f"{stat_name}: {stat_value:.4f}\n" | |
else: | |
result += f"{stat_name}: {stat_value}\n" | |
# Check for outliers using IQR method | |
Q1 = stats["25%"] | |
Q3 = stats["75%"] | |
IQR = Q3 - Q1 | |
lower_bound = Q1 - 1.5 * IQR | |
upper_bound = Q3 + 1.5 * IQR | |
outliers = col_data[(col_data < lower_bound) | (col_data > upper_bound)] | |
result += f"Outliers (IQR method): {len(outliers)} found\n" | |
# Add histogram data as ASCII art or description | |
hist_data = np.histogram(col_data.dropna(), bins=10) | |
result += "\nDistribution Summary:\n" | |
max_count = max(hist_data[0]) | |
for i, count in enumerate(hist_data[0]): | |
bin_start = f"{hist_data[1][i]:.2f}" | |
bin_end = f"{hist_data[1][i+1]:.2f}" | |
bar_length = int((count / max_count) * 20) | |
result += f"{bin_start} to {bin_end}: {'#' * bar_length} ({count})\n" | |
else: | |
# Categorical column analysis | |
value_counts = col_data.value_counts() | |
top_n = min(10, len(value_counts)) | |
result = f"Analysis of column '{column}' (Categorical):\n" | |
result += f"Count: {len(col_data)}\n" | |
result += f"Missing: {col_data.isnull().sum()}\n" | |
result += f"Unique Values: {col_data.nunique()}\n" | |
result += f"\nTop {top_n} values:\n" | |
for value, count in value_counts.head(top_n).items(): | |
percentage = (count / len(col_data)) * 100 | |
result += f"{value}: {count} ({percentage:.2f}%)\n" | |
return result | |
except Exception as e: | |
return f"Error analyzing column '{column}': {str(e)}" | |
def correlation_analysis(self, query): | |
"""Analyze correlations between columns""" | |
if self.data is None: | |
return "No data loaded. Please load data first." | |
# Extract specific columns if mentioned | |
columns = self._extract_column_names(query) | |
# If no specific columns or fewer than 2 columns mentioned, use all numeric columns | |
if len(columns) < 2: | |
numeric_columns = self.data.select_dtypes(include=['number']).columns.tolist() | |
if len(numeric_columns) < 2: | |
return "Not enough numeric columns for correlation analysis." | |
# If we found numeric columns but none were specified, use all numeric | |
if not columns: | |
columns = numeric_columns | |
# If one was specified, find its highest correlations | |
elif len(columns) == 1: | |
target_col = columns[0] | |
if target_col not in numeric_columns: | |
return f"Column '{target_col}' is not numeric and cannot be used for correlation analysis." | |
# Get correlations with target column | |
corr_matrix = self.data[numeric_columns].corr() | |
target_corr = corr_matrix[target_col].sort_values(ascending=False) | |
result = f"Correlation analysis for '{target_col}':\n" | |
for col, corr_val in target_corr.items(): | |
if col != target_col: | |
strength = "" | |
if abs(corr_val) > 0.7: | |
strength = "Strong" | |
elif abs(corr_val) > 0.3: | |
strength = "Moderate" | |
else: | |
strength = "Weak" | |
direction = "positive" if corr_val > 0 else "negative" | |
result += f"{col}: {corr_val:.4f} ({strength} {direction} correlation)\n" | |
return result | |
try: | |
# Calculate correlations between specified columns | |
corr_matrix = self.data[columns].corr() | |
result = "Correlation Matrix:\n" | |
result += corr_matrix.to_string() | |
# Find strongest correlations | |
corr_pairs = [] | |
for i in range(len(columns)): | |
for j in range(i+1, len(columns)): | |
col1, col2 = columns[i], columns[j] | |
corr_val = corr_matrix.loc[col1, col2] | |
corr_pairs.append((col1, col2, corr_val)) | |
# Sort by absolute correlation value | |
corr_pairs.sort(key=lambda x: abs(x[2]), reverse=True) | |
result += "\n\nStrongest Correlations:\n" | |
for col1, col2, corr_val in corr_pairs: | |
strength = "" | |
if abs(corr_val) > 0.7: | |
strength = "Strong" | |
elif abs(corr_val) > 0.3: | |
strength = "Moderate" | |
else: | |
strength = "Weak" | |
direction = "positive" if corr_val > 0 else "negative" | |
result += f"{col1} vs {col2}: {corr_val:.4f} ({strength} {direction} correlation)\n" | |
return result | |
except Exception as e: | |
return f"Error performing correlation analysis: {str(e)}" | |
def visualize_data(self, query): | |
"""Generate visualizations based on data""" | |
if self.data is None: | |
return "No data loaded. Please load data first." | |
# Extract columns from query | |
columns = self._extract_column_names(query) | |
# Determine visualization type from query | |
viz_type = None | |
if "scatter" in query.lower(): | |
viz_type = "scatter" | |
elif "histogram" in query.lower() or "distribution" in query.lower(): | |
viz_type = "histogram" | |
elif "box" in query.lower(): | |
viz_type = "box" | |
elif "bar" in query.lower(): | |
viz_type = "bar" | |
elif "pie" in query.lower(): | |
viz_type = "pie" | |
elif "heatmap" in query.lower() or "correlation" in query.lower(): | |
viz_type = "heatmap" | |
elif "line" in query.lower() or "trend" in query.lower(): | |
viz_type = "line" | |
else: | |
# Default to bar chart for one column, scatter for two | |
if len(columns) == 1: | |
viz_type = "bar" | |
elif len(columns) >= 2: | |
viz_type = "scatter" | |
else: | |
return "Please specify columns and visualization type (scatter, histogram, box, bar, pie, heatmap, line)" | |
try: | |
plt.figure(figsize=(10, 6)) | |
if viz_type == "scatter" and len(columns) >= 2: | |
plt.scatter(self.data[columns[0]], self.data[columns[1]]) | |
plt.xlabel(columns[0]) | |
plt.ylabel(columns[1]) | |
plt.title(f"Scatter Plot: {columns[0]} vs {columns[1]}") | |
# Add regression line | |
if len(self.data) > 2: # Need at least 3 points for meaningful regression | |
x = self.data[columns[0]].values.reshape(-1, 1) | |
y = self.data[columns[1]].values | |
model = LinearRegression() | |
model.fit(x, y) | |
plt.plot(x, model.predict(x), color='red', linewidth=2) | |
# Add correlation coefficient | |
corr = self.data[columns].corr().loc[columns[0], columns[1]] | |
plt.annotate(f"r = {corr:.4f}", xy=(0.05, 0.95), xycoords='axes fraction') | |
elif viz_type == "histogram" and columns: | |
sns.histplot(self.data[columns[0]], kde=True) | |
plt.xlabel(columns[0]) | |
plt.ylabel("Frequency") | |
plt.title(f"Histogram of {columns[0]}") | |
elif viz_type == "box" and columns: | |
if len(columns) == 1: | |
sns.boxplot(y=self.data[columns[0]]) | |
plt.ylabel(columns[0]) | |
else: | |
plt.boxplot([self.data[col].dropna() for col in columns]) | |
plt.xticks(range(1, len(columns) + 1), columns, rotation=45) | |
plt.title(f"Box Plot of {', '.join(columns)}") | |
elif viz_type == "bar" and columns: | |
if len(columns) == 1: | |
# For a single column, show value counts | |
value_counts = self.data[columns[0]].value_counts().nlargest(15) | |
value_counts.plot(kind='bar') | |
plt.xlabel(columns[0]) | |
plt.ylabel("Count") | |
plt.title(f"Bar Chart of {columns[0]} (Top 15 Categories)") | |
else: | |
# For multiple columns, show means | |
self.data[columns].mean().plot(kind='bar') | |
plt.ylabel("Mean Value") | |
plt.title(f"Mean Values of {', '.join(columns)}") | |
elif viz_type == "pie" and columns: | |
# Only use first column for pie chart | |
value_counts = self.data[columns[0]].value_counts().nlargest(10) | |
plt.pie(value_counts, labels=value_counts.index, autopct='%1.1f%%') | |
plt.title(f"Pie Chart of {columns[0]} (Top 10 Categories)") | |
elif viz_type == "heatmap": | |
# Use numeric columns for heatmap | |
if not columns: | |
columns = self.data.select_dtypes(include=['number']).columns.tolist() | |
if len(columns) < 2: | |
return "Need at least 2 numeric columns for heatmap." | |
corr_matrix = self.data[columns].corr() | |
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', vmin=-1, vmax=1) | |
plt.title("Correlation Heatmap") | |
elif viz_type == "line" and columns: | |
# Check if there's a datetime column to use as index | |
datetime_cols = [col for col in self.data.columns if pd.api.types.is_datetime64_dtype(self.data[col])] | |
if datetime_cols and len(columns) >= 1: | |
time_col = datetime_cols[0] | |
for col in columns: | |
if col != time_col: | |
plt.plot(self.data[time_col], self.data[col], label=col) | |
plt.xlabel(time_col) | |
plt.legend() | |
else: | |
# No datetime column, just plot the values | |
for col in columns: | |
plt.plot(self.data[col], label=col) | |
plt.legend() | |
plt.title(f"Line Plot of {', '.join(columns)}") | |
# Save figure to a temporary file | |
temp_file = f"temp_viz_{datetime.now().strftime('%Y%m%d%H%M%S')}.png" | |
plt.tight_layout() | |
plt.savefig(temp_file) | |
plt.close() | |
return f"Visualization created and saved as {temp_file}" | |
except Exception as e: | |
plt.close() # Close any open figures in case of error | |
return f"Error creating visualization: {str(e)}" | |
def analyze_trend(self, query): | |
"""Analyze trends over time or sequence""" | |
if self.data is None: | |
return "No data loaded. Please load data first." | |
# Extract columns from query | |
columns = self._extract_column_names(query) | |
if len(columns) < 1: | |
return "Please specify at least one column to analyze for trends." | |
try: | |
result = "Trend Analysis:\n" | |
# Look for a date/time column | |
date_columns = [] | |
for col in self.data.columns: | |
if pd.api.types.is_datetime64_dtype(self.data[col]): | |
date_columns.append(col) | |
elif any(date_term in col.lower() for date_term in ["date", "time", "year", "month", "day"]): | |
try: | |
# Try to convert to datetime | |
pd.to_datetime(self.data[col]) | |
date_columns.append(col) | |
except: | |
pass | |
# If we found date columns, use the first one | |
if date_columns: | |
time_col = date_columns[0] | |
result += f"Using {time_col} as the time variable.\n\n" | |
# Convert to datetime if not already | |
if not pd.api.types.is_datetime64_dtype(self.data[time_col]): | |
self.data[time_col] = pd.to_datetime(self.data[time_col], errors='coerce') | |
# Sort by time | |
data_sorted = self.data.sort_values(by=time_col) | |
for col in columns: | |
if col == time_col: | |
continue | |
if not pd.api.types.is_numeric_dtype(self.data[col]): | |
result += f"Skipping non-numeric column {col}\n" | |
continue | |
# Calculate trend statistics | |
result += f"Trend for {col}:\n" | |
# Calculate overall change | |
first_val = data_sorted[col].iloc[0] | |
last_val = data_sorted[col].iloc[-1] | |
total_change = last_val - first_val | |
pct_change = (total_change / first_val * 100) if first_val != 0 else float('inf') | |
result += f" Starting value: {first_val}\n" | |
result += f" Ending value: {last_val}\n" | |
result += f" Total change: {total_change} ({pct_change:.2f}%)\n" | |
# Perform trend analysis with linear regression | |
x = np.arange(len(data_sorted)).reshape(-1, 1) | |
y = data_sorted[col].values | |
# Handle missing values | |
mask = ~np.isnan(y) | |
x_clean = x[mask] | |
y_clean = y[mask] | |
if len(y_clean) >= 2: # Need at least 2 points for regression | |
model = LinearRegression() | |
model.fit(x_clean, y_clean) | |
slope = model.coef_[0] | |
avg_val = np.mean(y_clean) | |
result += f" Trend slope: {slope:.4f} per time unit\n" | |
result += f" Relative trend: {slope / avg_val * 100:.2f}% of mean per time unit\n" | |
# Determine if trend is significant | |
if abs(slope / avg_val) > 0.01: | |
direction = "increasing" if slope > 0 else "decreasing" | |
strength = "strongly" if abs(slope / avg_val) > 0.05 else "moderately" | |
result += f" The {col} is {strength} {direction} over time.\n" | |
else: | |
result += f" The {col} shows little change over time.\n" | |
# R-squared to show fit quality | |
y_pred = model.predict(x_clean) | |
r2 = r2_score(y_clean, y_pred) | |
result += f" R-squared: {r2:.4f} (higher means more consistent trend)\n" | |
# Calculate periodicity if enough data points | |
if len(y_clean) >= 4: | |
result += self._check_seasonality(y_clean) | |
result += "\n" | |
else: | |
# No date column found, use sequence order | |
result += "No date/time column found. Analyzing trends based on sequence order.\n\n" | |
for col in columns: | |
if not pd.api.types.is_numeric_dtype(self.data[col]): | |
result += f"Skipping non-numeric column {col}\n" | |
continue | |
# Get non-missing values | |
values = self.data[col].dropna().values | |
if len(values) < 2: | |
result += f"Not enough non-missing values in {col} for trend analysis.\n" | |
continue | |
# Calculate basic trend | |
result += f"Trend for {col}:\n" | |
# Linear regression for trend | |
x = np.arange(len(values)).reshape(-1, 1) | |
y = values | |
model = LinearRegression() | |
model.fit(x, y) | |
slope = model.coef_[0] | |
avg_val = np.mean(y) | |
result += f" Trend slope: {slope:.4f} per unit\n" | |
result += f" Relative trend: {slope / avg_val * 100:.2f}% of mean per unit\n" | |
# Determine trend direction and strength | |
if abs(slope / avg_val) > 0.01: | |
direction = "increasing" if slope > 0 else "decreasing" | |
strength = "strongly" if abs(slope / avg_val) > 0.05 else "moderately" | |
result += f" The {col} is {strength} {direction} over the sequence.\n" | |
else: | |
result += f" The {col} shows little change over the sequence.\n" | |
# R-squared | |
y_pred = model.predict(x) | |
r2 = r2_score(y, y_pred) | |
result += f" R-squared: {r2:.4f}\n" | |
# Check for simple patterns | |
if len(values) >= 4: | |
result += self._check_seasonality(values) | |
result += "\n" | |
return result | |
except Exception as e: | |
return f"Error analyzing trends: {str(e)}" | |
def _check_seasonality(self, values): | |
"""Helper function to check for seasonality in a time series""" | |
result = "" | |
# Compute autocorrelation | |
acf = [] | |
mean = np.mean(values) | |
variance = np.var(values) | |
if variance == 0: # All values are the same | |
return " No seasonality detected (constant values).\n" | |
# Compute autocorrelation up to 1/3 of series length | |
max_lag = min(len(values) // 3, 20) # Max 20 lags | |
for lag in range(1, max_lag + 1): | |
numerator = 0 | |
for i in range(len(values) - lag): | |
numerator += (values[i] - mean) * (values[i + lag] - mean) | |
acf.append(numerator / (len(values) - lag) / variance) | |
# Find potential seasonality by looking for peaks in autocorrelation | |
peaks = [] | |
for i in range(1, len(acf) - 1): | |
if acf[i] > acf[i-1] and acf[i] > acf[i+1] and acf[i] > 0.2: | |
peaks.append((i+1, acf[i])) | |
if peaks: | |
# Sort by correlation strength | |
peaks.sort(key=lambda x: x[1], reverse=True) | |
result += " Potential seasonality detected with periods: " | |
result += ", ".join([f"{p[0]} (r={p[1]:.2f})" for p in peaks[:3]]) | |
result += "\n" | |
else: | |
result += " No clear seasonality detected.\n" | |
return result | |
def detect_outliers(self, query): | |
"""Detect outliers in the data""" | |
if self.data is None: | |
return "No data loaded. Please load data first." | |
# Extract columns from query | |
columns = self._extract_column_names(query) | |
# If no columns specified, use all numeric columns | |
if not columns: | |
columns = self.data.select_dtypes(include=['number']).columns.tolist() | |
if not columns: | |
return "No numeric columns found for outlier detection." | |
try: | |
result = "Outlier Detection Results:\n" | |
for col in columns: | |
if not pd.api.types.is_numeric_dtype(self.data[col]): | |
result += f"Skipping non-numeric column: {col}\n" | |
continue | |
# Drop missing values | |
col_data = self.data[col].dropna() | |
if len(col_data) < 5: | |
result += f"Not enough data in {col} for outlier detection.\n" | |
continue | |
result += f"\nColumn: {col}\n" | |
# Method 1: IQR method | |
Q1 = col_data.quantile(0.25) | |
Q3 = col_data.quantile(0.75) | |
IQR = Q3 - Q1 | |
lower_bound = Q1 - 1.5 * IQR | |
upper_bound = Q3 + 1.5 * IQR | |
outliers_iqr = col_data[(col_data < lower_bound) | (col_data > upper_bound)] | |
result += f" IQR Method: {len(outliers_iqr)} outliers found\n" | |
result += f" Lower bound: {lower_bound:.4f}, Upper bound: {upper_bound:.4f}\n" | |
if len(outliers_iqr) > 0: | |
result += f" Outlier range: {outliers_iqr.min():.4f} to {outliers_iqr.max():.4f}\n" | |
if len(outliers_iqr) <= 10: | |
result += f" Outlier values: {', '.join(map(str, outliers_iqr.tolist()))}\n" | |
else: | |
result += f" First 5 outliers: {', '.join(map(str, outliers_iqr.iloc[:5].tolist()))}\n" | |
# Method 2: Z-score method | |
z_scores = stats.zscore(col_data) | |
outliers_zscore = col_data[abs(z_scores) > 3] | |
result += f" Z-score Method (|z| > 3): {len(outliers_zscore)} outliers found\n" | |
if len(outliers_zscore) > 0: | |
result += f" Outlier range: {outliers_zscore.min():.4f} to {outliers_zscore.max():.4f}\n" | |
if len(outliers_zscore) <= 10: | |
result += f" Outlier values: {', '.join(map(str, outliers_zscore.tolist()))}\n" | |
else: | |
result += f" First 5 outliers: {', '.join(map(str, outliers_zscore.iloc[:5].tolist()))}\n" | |
# Compare methods | |
common_outliers = set(outliers_iqr.index).intersection(set(outliers_zscore.index)) | |
result += f" {len(common_outliers)} outliers detected by both methods\n" | |
# Impact of outliers | |
mean_with_outliers = col_data.mean() | |
mean_without_outliers = col_data[~col_data.index.isin(outliers_iqr.index)].mean() | |
impact = abs((mean_without_outliers - mean_with_outliers) / mean_with_outliers * 100) | |
result += f" Impact on mean: {impact:.2f}% change if IQR outliers removed\n" | |
return result | |
except Exception as e: | |
return f"Error detecting outliers: {str(e)}" | |
def predictive_analysis(self, query): | |
"""Perform simple predictive analysis""" | |
if self.data is None: | |
return "No data loaded. Please load data first." | |
# Extract target and features from query | |
columns = self._extract_column_names(query) | |
if len(columns) < 2: | |
return "Please specify at least two columns: one target and one or more features." | |
# Last column is target, rest are features | |
target_col = columns[-1] | |
feature_cols = columns[:-1] | |
try: | |
# Check if columns are numeric | |
for col in columns: | |
if not pd.api.types.is_numeric_dtype(self.data[col]): | |
return f"Column '{col}' is not numeric. Simple predictive analysis requires numeric data." | |
# Prepare data | |
X = self.data[feature_cols].dropna() | |
y = self.data.loc[X.index, target_col] | |
if len(X) < 10: | |
return "Not enough complete data rows for predictive analysis (need at least 10)." | |
# Split data | |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42) | |
# Fit model | |
model = LinearRegression() | |
model.fit(X_train, y_train) | |
# Make predictions | |
y_train_pred = model.predict(X_train) | |
y_test_pred = model.predict(X_test) | |
# Calculate metrics | |
train_mse = mean_squared_error(y_train, y_train_pred) | |
test_mse = mean_squared_error(y_test, y_test_pred) | |
train_r2 = r2_score(y_train, y_train_pred) | |
test_r2 = r2_score(y_test, y_test_pred) | |
# Prepare results | |
result = f"Predictive Analysis: Predicting '{target_col}' using {', '.join(feature_cols)}\n\n" | |
result += "Model Information:\n" | |
result += f" Linear Regression with {len(feature_cols)} feature(s)\n" | |
result += f" Training data: {len(X_train)} rows\n" | |
result += f" Testing data: {len(X_test)} rows\n\n" | |
result += "Feature Importance:\n" | |
for i, feature in enumerate(feature_cols): | |
result += f" {feature}: coefficient = {model.coef_[i]:.4f}\n" | |
result += f" Intercept: {model.intercept_:.4f}\n\n" | |
result += "Model Equation:\n" | |
equation = f"{target_col} = {model.intercept_:.4f}" | |
for i, feature in enumerate(feature_cols): | |
coef = model.coef_[i] | |
sign = "+" if coef >= 0 else "" | |
equation += f" {sign} {coef:.4f} Γ {feature}" | |
result += f" {equation}\n\n" | |
result += "Model Performance:\n" | |
result += f" Training set:\n" | |
result += f" Mean Squared Error: {train_mse:.4f}\n" | |
result += f" RΒ² Score: {train_r2:.4f}\n\n" | |
result += f" Test set:\n" | |
result += f" Mean Squared Error: {test_mse:.4f}\n" | |
result += f" RΒ² Score: {test_r2:.4f}\n\n" | |
# Interpret the results | |
result += "Interpretation:\n" | |
# Interpret RΒ² score | |
if test_r2 >= 0.7: | |
result += " The model explains a high proportion of the variance in the target variable.\n" | |
elif test_r2 >= 0.4: | |
result += " The model explains a moderate proportion of the variance in the target variable.\n" | |
else: | |
result += " The model explains only a small proportion of the variance in the target variable.\n" | |
# Check for overfitting | |
if train_r2 - test_r2 > 0.2: | |
result += " The model shows signs of overfitting (performs much better on training than test data).\n" | |
# Feature importance interpretation | |
most_important_feature = feature_cols[abs(model.coef_).argmax()] | |
result += f" The most influential feature is '{most_important_feature}'.\n" | |
# Sample prediction | |
row_sample = X_test.iloc[0] | |
prediction = model.predict([row_sample])[0] | |
result += "\nSample Prediction:\n" | |
result += " For the values:\n" | |
for feature in feature_cols: | |
result += f" {feature} = {row_sample[feature]}\n" | |
result += f" Predicted {target_col} = {prediction:.4f}\n" | |
return result | |
except Exception as e: | |
return f"Error performing predictive analysis: {str(e)}" | |
def hypothesis_testing(self, query): | |
"""Perform hypothesis testing on the data""" | |
if self.data is None: | |
return "No data loaded. Please load data first." | |
# Extract columns from query | |
columns = self._extract_column_names(query) | |
if len(columns) == 0: | |
return "Please specify at least one column for hypothesis testing." | |
try: | |
result = "Hypothesis Testing Results:\n\n" | |
# Single column analysis (distribution tests) | |
if len(columns) == 1: | |
col = columns[0] | |
if not pd.api.types.is_numeric_dtype(self.data[col]): | |
return f"Column '{col}' is not numeric. Basic hypothesis testing requires numeric data." | |
data = self.data[col].dropna() | |
# Normality test | |
stat, p_value = stats.shapiro(data) if len(data) < 5000 else stats.normaltest(data) | |
result += f"Normality Test for '{col}':\n" | |
test_name = "Shapiro-Wilk" if len(data) < 5000 else "D'Agostino's KΒ²" | |
result += f" Test used: {test_name}\n" | |
result += f" Statistic: {stat:.4f}\n" | |
result += f" p-value: {p_value:.4f}\n" | |
result += f" Interpretation: The data is {'not ' if p_value < 0.05 else ''}normally distributed (95% confidence).\n\n" | |
# Basic statistics | |
mean = data.mean() | |
median = data.median() | |
std_dev = data.std() | |
# One-sample t-test (against 0 or population mean) | |
population_mean = 0 # Default null hypothesis mean | |
t_stat, p_value = stats.ttest_1samp(data, population_mean) | |
result += f"One-sample t-test for '{col}':\n" | |
result += f" Null Hypothesis: The mean of '{col}' is equal to {population_mean}\n" | |
result += f" Alternative Hypothesis: The mean of '{col}' is not equal to {population_mean}\n" | |
result += f" t-statistic: {t_stat:.4f}\n" | |
result += f" p-value: {p_value:.4f}\n" | |
result += f" Sample Mean: {mean:.4f}\n" | |
result += f" Interpretation: {'Reject' if p_value < 0.05 else 'Fail to reject'} the null hypothesis (95% confidence).\n" | |
result += f" In other words: The mean is {'statistically different from' if p_value < 0.05 else 'not statistically different from'} {population_mean}.\n" | |
# Two-column analysis | |
elif len(columns) == 2: | |
col1, col2 = columns | |
if not pd.api.types.is_numeric_dtype(self.data[col1]) or not pd.api.types.is_numeric_dtype(self.data[col2]): | |
return f"Both columns must be numeric for this hypothesis test." | |
data1 = self.data[col1].dropna() | |
data2 = self.data[col2].dropna() | |
# Check if the columns are independent or paired | |
are_paired = len(data1) == len(data2) and (self.data[columns].count().min() / self.data[columns].count().max() > 0.9) | |
test_type = "paired" if are_paired else "independent" | |
result += f"Two-sample {'Paired' if are_paired else 'Independent'} t-test:\n" | |
result += f" Comparing '{col1}' and '{col2}'\n" | |
result += f" Null Hypothesis: The means of the two columns are equal\n" | |
result += f" Alternative Hypothesis: The means of the two columns are not equal\n\n" | |
if are_paired: | |
# Use paired t-test for related samples | |
# Make sure we have pairs of non-NaN values | |
valid_rows = self.data[columns].dropna() | |
t_stat, p_value = stats.ttest_rel(valid_rows[col1], valid_rows[col2]) | |
else: | |
# Use independent t-test | |
t_stat, p_value = stats.ttest_ind(data1, data2, equal_var=False) # Use Welch's t-test | |
result += f" t-statistic: {t_stat:.4f}\n" | |
result += f" p-value: {p_value:.4f}\n" | |
result += f" Mean of '{col1}': {data1.mean():.4f}\n" | |
result += f" Mean of '{col2}': {data2.mean():.4f}\n" | |
result += f" Difference in means: {data1.mean() - data2.mean():.4f}\n" | |
result += f" Interpretation: {'Reject' if p_value < 0.05 else 'Fail to reject'} the null hypothesis (95% confidence).\n" | |
result += f" In other words: The means are {'statistically different' if p_value < 0.05 else 'not statistically different'} from each other.\n" | |
# Categorical vs. numeric analysis | |
elif len(columns) == 2: | |
col1, col2 = columns | |
# Check if one is categorical and one is numeric | |
if (pd.api.types.is_numeric_dtype(self.data[col1]) and | |
not pd.api.types.is_numeric_dtype(self.data[col2])): | |
numeric_col, cat_col = col1, col2 | |
elif (pd.api.types.is_numeric_dtype(self.data[col2]) and | |
not pd.api.types.is_numeric_dtype(self.data[col1])): | |
numeric_col, cat_col = col2, col1 | |
else: | |
return "For ANOVA, one column should be categorical and one should be numeric." | |
# Perform one-way ANOVA | |
groups = [] | |
labels = [] | |
for category, group in self.data.groupby(cat_col): | |
if len(group[numeric_col].dropna()) > 0: | |
groups.append(group[numeric_col].dropna()) | |
labels.append(str(category)) | |
if len(groups) < 2: | |
return "Not enough groups with data for ANOVA." | |
f_stat, p_value = stats.f_oneway(*groups) | |
result += "One-way ANOVA:\n" | |
result += f" Comparing '{numeric_col}' across groups of '{cat_col}'\n" | |
result += f" Null Hypothesis: The means of '{numeric_col}' are equal across all groups\n" | |
result += f" Alternative Hypothesis: At least one group has a different mean\n\n" | |
result += f" F-statistic: {f_stat:.4f}\n" | |
result += f" p-value: {p_value:.4f}\n" | |
result += f" Group means:\n" | |
for i, (label, group) in enumerate(zip(labels, groups)): | |
result += f" {label}: {group.mean():.4f} (n={len(group)})\n" | |
result += f" Interpretation: {'Reject' if p_value < 0.05 else 'Fail to reject'} the null hypothesis (95% confidence).\n" | |
result += f" In other words: There {'is' if p_value < 0.05 else 'is no'} statistically significant difference between groups.\n" | |
# Multiple column comparison | |
else: | |
result += "Correlation Analysis:\n" | |
numeric_cols = [col for col in columns if pd.api.types.is_numeric_dtype(self.data[col])] | |
if len(numeric_cols) < 2: | |
return "Need at least two numeric columns for correlation analysis." | |
corr_matrix = self.data[numeric_cols].corr() | |
result += " Pearson Correlation Matrix:\n" | |
result += f"{corr_matrix.to_string()}\n\n" | |
result += " Significance Tests (p-values):\n" | |
p_matrix = pd.DataFrame(index=corr_matrix.index, columns=corr_matrix.columns) | |
for i in range(len(numeric_cols)): | |
for j in range(i+1, len(numeric_cols)): | |
col_i, col_j = numeric_cols[i], numeric_cols[j] | |
valid_data = self.data[[col_i, col_j]].dropna() | |
_, p_value = stats.pearsonr(valid_data[col_i], valid_data[col_j]) | |
p_matrix.loc[col_i, col_j] = p_value | |
p_matrix.loc[col_j, col_i] = p_value | |
result += f"{p_matrix.to_string()}\n\n" | |
result += " Significant Correlations (p < 0.05):\n" | |
for i in range(len(numeric_cols)): | |
for j in range(i+1, len(numeric_cols)): | |
col_i, col_j = numeric_cols[i], numeric_cols[j] | |
if p_matrix.loc[col_i, col_j] < 0.05: | |
corr_val = corr_matrix.loc[col_i, col_j] | |
p_val = p_matrix.loc[col_i, col_j] | |
result += f" {col_i} vs {col_j}: r={corr_val:.4f}, p={p_val:.4f}\n" | |
return result | |
except Exception as e: | |
return f"Error performing hypothesis testing: {str(e)}" | |
def generate_report(self, query): | |
"""Generate a comprehensive report on the data""" | |
if self.data is None: | |
return "No data loaded. Please load data first." | |
try: | |
report = "# Data Analysis Report\n\n" | |
# 1. Dataset Overview | |
report += "## 1. Dataset Overview\n\n" | |
report += f"**Data Source:** {self.data_source}\n" | |
report += f"**Number of Rows:** {len(self.data)}\n" | |
report += f"**Number of Columns:** {len(self.data.columns)}\n\n" | |
# Column types summary | |
dtype_counts = {} | |
for dtype in self.data.dtypes: | |
dtype_name = str(dtype) | |
if dtype_name in dtype_counts: | |
dtype_counts[dtype_name] += 1 | |
else: | |
dtype_counts[dtype_name] = 1 | |
report += "**Column Data Types:**\n" | |
for dtype, count in dtype_counts.items(): | |
report += f"- {dtype}: {count} columns\n" | |
report += "\n" | |
# 2. Data Quality Assessment | |
report += "## 2. Data Quality Assessment\n\n" | |
# Missing values | |
missing_values = self.data.isnull().sum() | |
missing_percentage = (missing_values / len(self.data) * 100).round(2) | |
missing_cols = missing_values[missing_values > 0] | |
if len(missing_cols) > 0: | |
report += "**Missing Values:**\n" | |
for col, count in missing_cols.items(): | |
report += f"- {col}: {count} missing values ({missing_percentage[col]}%)\n" | |
else: | |
report += "**Missing Values:** None\n" | |
report += "\n" | |
# 3. Descriptive Statistics | |
report += "## 3. Descriptive Statistics\n\n" | |
# Numeric columns | |
numeric_cols = self.data.select_dtypes(include=['number']).columns.tolist() | |
if numeric_cols: | |
report += "**Numeric Columns:**\n" | |
report += "```\n" | |
report += self.data[numeric_cols].describe().to_string() | |
report += "\n```\n\n" | |
# Categorical columns | |
cat_cols = self.data.select_dtypes(exclude=['number']).columns.tolist() | |
if cat_cols: | |
report += "**Categorical Columns:**\n" | |
for col in cat_cols[:5]: # Limit to first 5 for brevity | |
value_counts = self.data[col].value_counts().head(5) | |
report += f"Top values for '{col}':\n" | |
report += "```\n" | |
report += value_counts.to_string() | |
report += "\n```\n" | |
report += f"Unique values: {self.data[col].nunique()}\n\n" | |
if len(cat_cols) > 5: | |
report += f"(Analysis limited to first 5 out of {len(cat_cols)} categorical columns)\n\n" | |
# 4. Correlation Analysis | |
report += "## 4. Correlation Analysis\n\n" | |
if len(numeric_cols) >= 2: | |
corr_matrix = self.data[numeric_cols].corr() | |
report += "**Correlation Matrix:**\n" | |
report += "```\n" | |
report += corr_matrix.round(2).to_string() | |
report += "\n```\n\n" | |
# Strongest correlations | |
corr_pairs = [] | |
for i in range(len(numeric_cols)): | |
for j in range(i+1, len(numeric_cols)): | |
col1, col2 = numeric_cols[i], numeric_cols[j] | |
corr_val = corr_matrix.loc[col1, col2] | |
if abs(corr_val) > 0.5: # Only report moderate to strong correlations | |
corr_pairs.append((col1, col2, corr_val)) | |
if corr_pairs: | |
# Sort by absolute correlation value | |
corr_pairs.sort(key=lambda x: abs(x[2]), reverse=True) | |
report += "**Strongest Correlations:**\n" | |
for col1, col2, corr_val in corr_pairs[:10]: # Top 10 | |
direction = "positive" if corr_val > 0 else "negative" | |
report += f"- {col1} vs {col2}: {corr_val:.4f} ({direction})\n" | |
report += "\n" | |
else: | |
report += "No moderate or strong correlations (|r| > 0.5) found between variables.\n\n" | |
else: | |
report += "Insufficient numeric columns for correlation analysis.\n\n" | |
# 5. Key Insights | |
report += "## 5. Key Insights\n\n" | |
insights = [] | |
# Data quality insights | |
total_missing = missing_values.sum() | |
if total_missing > 0: | |
total_cells = self.data.size | |
overall_percentage = (total_missing / total_cells * 100).round(2) | |
if overall_percentage > 10: | |
insights.append(f"The dataset has a high proportion of missing values ({overall_percentage}% overall), which may require imputation or handling.") | |
# Distribution insights for numeric columns | |
for col in numeric_cols[:5]: # Limit to first 5 for brevity | |
col_data = self.data[col].dropna() | |
if len(col_data) == 0: | |
continue | |
mean = col_data.mean() | |
median = col_data.median() | |
skew = col_data.skew() | |
# Check for skewed distributions | |
if abs(skew) > 1: | |
skew_direction = "positively" if skew > 0 else "negatively" | |
insights.append(f"'{col}' is {skew_direction} skewed (skew={skew:.2f}), with mean={mean:.2f} and median={median:.2f}.") | |
# Check for outliers | |
Q1 = col_data.quantile(0.25) | |
Q3 = col_data.quantile(0.75) | |
IQR = Q3 - Q1 | |
lower_bound = Q1 - 1.5 * IQR | |
upper_bound = Q3 + 1.5 * IQR | |
outliers = col_data[(col_data < lower_bound) | (col_data > upper_bound)] | |
outlier_percentage = (len(outliers) / len(col_data) * 100).round(2) | |
if outlier_percentage > 5: | |
insights.append(f"'{col}' has a high proportion of outliers ({outlier_percentage}% of values).") | |
# Correlation insights | |
if len(corr_pairs) > 0: | |
top_corr = corr_pairs[0] | |
direction = "positively" if top_corr[2] > 0 else "negatively" | |
insights.append(f"The strongest relationship is between '{top_corr[0]}' and '{top_corr[1]}' (r={top_corr[2]:.2f}), which are {direction} correlated.") | |
# Report insights | |
if insights: | |
for i, insight in enumerate(insights, 1): | |
report += f"{i}. {insight}\n" | |
else: | |
report += "No significant insights detected based on initial analysis.\n" | |
report += "\n" | |
# 6. Next Steps | |
report += "## 6. Recommendations for Further Analysis\n\n" | |
recommendations = [ | |
"Conduct more detailed analysis on columns with high missing value rates.", | |
"For skewed numeric distributions, consider transformations (e.g., log, sqrt) before analysis.", | |
"Investigate outliers to determine if they represent valid data points or errors.", | |
"For strongly correlated variables, explore causality or consider dimensionality reduction.", | |
"Consider predictive modeling using the identified relationships." | |
] | |
for i, rec in enumerate(recommendations, 1): | |
report += f"{i}. {rec}\n" | |
# Save the report to a file | |
report_filename = f"data_analysis_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md" | |
with open(report_filename, "w") as f: | |
f.write(report) | |
return f"Report generated and saved as {report_filename}" | |
except Exception as e: | |
return f"Error generating report: {str(e)}" | |
def get_help(self, query): | |
"""Display help information about available commands""" | |
help_text = "Available Commands:\n\n" | |
help_text += "DATA LOADING AND INSPECTION\n" | |
help_text += " load csv <path> - Load data from a CSV file\n" | |
help_text += " load excel <path> - Load data from an Excel file\n" | |
help_text += " load json <path> - Load data from a JSON file\n" | |
help_text += " load sql <db_path> query <sql> - Load data from a SQL database\n" | |
help_text += " info - Get basic information about the loaded data\n" | |
help_text += " describe [column1 column2...] - Get descriptive statistics\n" | |
help_text += " missing - Check for missing values in the data\n" | |
help_text += "\n" | |
help_text += "DATA ANALYSIS\n" | |
help_text += " analyze <column> - Analyze a specific column\n" | |
help_text += " correlate [column1 column2...] - Analyze correlations between columns\n" | |
help_text += " trend <column1 column2...> - Analyze trends over time or sequence\n" | |
help_text += " outliers [column1 column2...] - Detect outliers in the data\n" | |
help_text += " test <column1> [column2] - Perform hypothesis testing\n" | |
help_text += "\n" | |
help_text += "VISUALIZATION AND REPORTING\n" | |
help_text += " visualize <type> <column1 column2...> - Generate visualizations\n" | |
help_text += " Visualization types: scatter, histogram, box, bar, pie, heatmap, line\n" | |
help_text += " report - Generate a comprehensive report on the data\n" | |
help_text += "\n" | |
help_text += "EXAMPLES:\n" | |
help_text += " load csv data.csv\n" | |
help_text += " analyze temperature\n" | |
help_text += " correlate temperature humidity pressure\n" | |
help_text += " visualize scatter temperature humidity\n" | |
help_text += " trend sales date\n" | |
return help_text | |
# Page configuration | |
st.set_page_config( | |
page_title="Data Analysis Assistant", | |
page_icon="π", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# Initialize session state variables if they don't exist | |
if 'chatbot' not in st.session_state: | |
st.session_state.chatbot = DataAnalysisChatbot() | |
if 'conversation' not in st.session_state: | |
st.session_state.conversation = [] | |
if 'data_loaded' not in st.session_state: | |
st.session_state.data_loaded = False | |
if 'current_file' not in st.session_state: | |
st.session_state.current_file = None | |
if 'data_preview' not in st.session_state: | |
st.session_state.data_preview = None | |
# Function to get a download link for a file | |
def get_download_link(file_path, link_text): | |
with open(file_path, 'rb') as f: | |
data = f.read() | |
b64 = base64.b64encode(data).decode() | |
href = f'<a href="data:file/txt;base64,{b64}" download="{os.path.basename(file_path)}">{link_text}</a>' | |
return href | |
# Function to convert matplotlib figure to Streamlit-compatible format | |
def plt_to_streamlit(): | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png') | |
buf.seek(0) | |
return buf | |
# Custom CSS | |
st.markdown(""" | |
<style> | |
.main-header { | |
font-size: 2.5rem; | |
font-weight: 700; | |
color: #1E88E5; | |
margin-bottom: 1rem; | |
} | |
.sub-header { | |
font-size: 1.5rem; | |
font-weight: 600; | |
color: #333; | |
margin-bottom: 1rem; | |
} | |
.chat-user { | |
background-color: #E3F2FD; | |
padding: 10px 15px; | |
border-radius: 15px; | |
margin-bottom: 10px; | |
font-size: 1rem; | |
} | |
.chat-bot { | |
background-color: #F5F5F5; | |
padding: 10px 15px; | |
border-radius: 15px; | |
margin-bottom: 10px; | |
font-size: 1rem; | |
} | |
.file-info { | |
padding: 10px; | |
background-color: #E8F5E9; | |
border-radius: 5px; | |
margin-bottom: 10px; | |
} | |
.sidebar-content { | |
padding: 10px; | |
} | |
.highlight-text { | |
color: #1E88E5; | |
font-weight: bold; | |
} | |
.stButton>button { | |
width: 100%; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Sidebar for data loading and information | |
with st.sidebar: | |
st.markdown('<div class="sidebar-content">', unsafe_allow_html=True) | |
st.markdown('<p class="sub-header">π Data Loading</p>', unsafe_allow_html=True) | |
# File uploader | |
uploaded_file = st.file_uploader("Upload your data file", type=['csv', 'xlsx', 'json', 'db', 'sqlite']) | |
# Load data button (only show if file is uploaded) | |
if uploaded_file is not None: | |
file_type = uploaded_file.name.split('.')[-1].lower() | |
# Save the uploaded file to a temporary location | |
temp_file_path = f"temp_upload_{datetime.now().strftime('%Y%m%d%H%M%S')}.{file_type}" | |
with open(temp_file_path, "wb") as f: | |
f.write(uploaded_file.getbuffer()) | |
# Load data based on file type | |
if st.button("Load Data"): | |
try: | |
if file_type == 'csv': | |
response = st.session_state.chatbot.process_query(f"load csv {temp_file_path}") | |
elif file_type in ['xlsx', 'xls']: | |
response = st.session_state.chatbot.process_query(f"load excel {temp_file_path}") | |
elif file_type == 'json': | |
response = st.session_state.chatbot.process_query(f"load json {temp_file_path}") | |
elif file_type in ['db', 'sqlite']: | |
# For SQL databases, we need to prompt for a query | |
st.session_state.current_file = temp_file_path | |
st.session_state.data_loaded = False | |
response = "SQL database loaded. Please enter a query in the main chat." | |
else: | |
response = "Unsupported file format. Please upload CSV, Excel, JSON, or SQLite files." | |
st.session_state.conversation.append({"role": "user", "message": f"Loading {uploaded_file.name}"}) | |
st.session_state.conversation.append({"role": "bot", "message": response}) | |
if "Successfully loaded data" in response: | |
st.session_state.data_loaded = True | |
st.session_state.current_file = temp_file_path | |
# Get data preview | |
if st.session_state.chatbot.data is not None: | |
st.session_state.data_preview = st.session_state.chatbot.data.head() | |
except Exception as e: | |
st.error(f"Error loading data: {str(e)}") | |
# Display data information if data is loaded | |
if st.session_state.data_loaded and st.session_state.chatbot.data is not None: | |
st.markdown('<p class="sub-header">π Data Information</p>', unsafe_allow_html=True) | |
# Display basic info | |
st.markdown('<div class="file-info">', unsafe_allow_html=True) | |
st.write(f"**Rows:** {len(st.session_state.chatbot.data)}") | |
st.write(f"**Columns:** {len(st.session_state.chatbot.data.columns)}") | |
st.write(f"**Data Source:** {st.session_state.chatbot.data_source}") | |
st.markdown('</div>', unsafe_allow_html=True) | |
# Quick actions | |
st.markdown('<p class="sub-header">β‘ Quick Actions</p>', unsafe_allow_html=True) | |
col1, col2 = st.columns(2) | |
with col1: | |
if st.button("Describe Data"): | |
response = st.session_state.chatbot.process_query("describe") | |
st.session_state.conversation.append({"role": "user", "message": "Describe data"}) | |
st.session_state.conversation.append({"role": "bot", "message": response}) | |
with col2: | |
if st.button("Check Missing"): | |
response = st.session_state.chatbot.process_query("missing") | |
st.session_state.conversation.append({"role": "user", "message": "Check missing values"}) | |
st.session_state.conversation.append({"role": "bot", "message": response}) | |
col1, col2 = st.columns(2) | |
with col1: | |
if st.button("Correlations"): | |
response = st.session_state.chatbot.process_query("correlate") | |
st.session_state.conversation.append({"role": "user", "message": "Show correlations"}) | |
st.session_state.conversation.append({"role": "bot", "message": response}) | |
with col2: | |
if st.button("Generate Report"): | |
response = st.session_state.chatbot.process_query("report") | |
st.session_state.conversation.append({"role": "user", "message": "Generate report"}) | |
st.session_state.conversation.append({"role": "bot", "message": response}) | |
# If report was generated, provide download link | |
if "Report generated and saved as" in response: | |
report_filename = response.split("Report generated and saved as ")[-1].strip() | |
st.markdown( | |
get_download_link(report_filename, "π₯ Download Report"), | |
unsafe_allow_html=True | |
) | |
# Help section | |
st.markdown('<p class="sub-header">β Help</p>', unsafe_allow_html=True) | |
if st.button("Show Commands"): | |
response = st.session_state.chatbot.process_query("help") | |
st.session_state.conversation.append({"role": "user", "message": "Show available commands"}) | |
st.session_state.conversation.append({"role": "bot", "message": response}) | |
st.markdown('</div>', unsafe_allow_html=True) | |
# Main area | |
st.markdown('<h1 class="main-header">π Data Analysis Assistant</h1>', unsafe_allow_html=True) | |
# Show data preview if data is loaded | |
if st.session_state.data_loaded and st.session_state.data_preview is not None: | |
st.markdown('<p class="sub-header">Data Preview</p>', unsafe_allow_html=True) | |
st.dataframe(st.session_state.data_preview, use_container_width=True) | |
# Display conversation history | |
st.markdown('<p class="sub-header">Chat History</p>', unsafe_allow_html=True) | |
chat_container = st.container() | |
with chat_container: | |
for message in st.session_state.conversation: | |
if message["role"] == "user": | |
st.markdown(f'<div class="chat-user">π€ <b>You:</b> {message["message"]}</div>', unsafe_allow_html=True) | |
else: | |
# Process bot messages for special content | |
bot_message = message["message"] | |
# Check if it's a visualization result | |
if "Visualization created and saved as" in bot_message: | |
# Extract the filename and load the image | |
img_file = bot_message.split("Visualization created and saved as ")[-1].strip() | |
if os.path.exists(img_file): | |
st.markdown(f'<div class="chat-bot">π€ <b>Assistant:</b></div>', unsafe_allow_html=True) | |
try: | |
img = Image.open(img_file) | |
st.image(img, caption="Generated Visualization", use_column_width=True) | |
except Exception as e: | |
st.error(f"Error displaying visualization: {str(e)}") | |
st.markdown(f'<div class="chat-bot">π€ <b>Assistant:</b> {bot_message}</div>', unsafe_allow_html=True) | |
else: | |
st.markdown(f'<div class="chat-bot">π€ <b>Assistant:</b> {bot_message}</div>', unsafe_allow_html=True) | |
# Check if it's a report result | |
elif "Report generated and saved as" in bot_message: | |
report_filename = bot_message.split("Report generated and saved as ")[-1].strip() | |
st.markdown( | |
f'<div class="chat-bot">π€ <b>Assistant:</b> {bot_message}<br/>{get_download_link(report_filename, "π₯ Download Report")}</div>', | |
unsafe_allow_html=True | |
) | |
# Regular message | |
else: | |
# Format code blocks | |
if "```" in bot_message: | |
parts = bot_message.split("```") | |
formatted_message = "" | |
for i, part in enumerate(parts): | |
if i % 2 == 0: # Outside code block | |
formatted_message += part | |
else: # Inside code block | |
formatted_message += f"<pre style='background-color: #f0f0f0; padding: 10px; border-radius: 5px; overflow-x: auto;'>{part}</pre>" | |
st.markdown(f'<div class="chat-bot">π€ <b>Assistant:</b> {formatted_message}</div>', unsafe_allow_html=True) | |
else: | |
st.markdown(f'<div class="chat-bot">π€ <b>Assistant:</b> {bot_message}</div>', unsafe_allow_html=True) | |
# User input | |
st.markdown('<p class="sub-header">Ask a Question</p>', unsafe_allow_html=True) | |
user_input = st.text_area("Enter your query", height=100, key="user_query") | |
# Handle SQL query case | |
if st.session_state.current_file is not None and not st.session_state.data_loaded and st.session_state.current_file.endswith(('db', 'sqlite')): | |
sql_query = st.text_area("Enter SQL query", height=100, key="sql_query") | |
if st.button("Run SQL Query") and sql_query: | |
response = st.session_state.chatbot.process_query(f"load sql {st.session_state.current_file} query {sql_query}") | |
st.session_state.conversation.append({"role": "user", "message": f"SQL query: {sql_query}"}) | |
st.session_state.conversation.append({"role": "bot", "message": response}) | |
if "Successfully loaded data" in response: | |
st.session_state.data_loaded = True | |
if st.session_state.chatbot.data is not None: | |
st.session_state.data_preview = st.session_state.chatbot.data.head() | |
# Submit button for regular queries | |
if st.button("Submit") and user_input: | |
# Add user message to conversation | |
st.session_state.conversation.append({"role": "user", "message": user_input}) | |
# Process query | |
response = st.session_state.chatbot.process_query(user_input) | |
# Add bot response to conversation | |
st.session_state.conversation.append({"role": "bot", "message": response}) | |
# Clear input | |
st.session_state.user_query = "" | |
# Add warning for demo mode | |
st.markdown("---") | |
st.markdown("**Note:** File uploads and data processing are handled locally. Make sure you have the necessary dependencies installed.", unsafe_allow_html=True) | |
# Footer | |
st.markdown("---") | |
st.markdown("Β© 2025 Data Analysis Assistant | Built with Streamlit") | |