Ujeshhh commited on
Commit
3ec40df
Β·
verified Β·
1 Parent(s): 235d85c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -84
app.py CHANGED
@@ -1,107 +1,83 @@
1
  import gradio as gr
2
  import pandas as pd
3
- import joblib
4
  import matplotlib.pyplot as plt
5
  import seaborn as sns
6
- import io
 
 
7
 
8
- # Load trained model
9
  model = joblib.load("anomaly_detector_rf_model.pkl")
10
 
11
- # Features used during training
12
- feature_cols = [
13
- "amount", "hour", "day_of_week", "is_weekend", "merchant_avg_amount",
14
- "amount_zscore", "log_amount", "type_atm_withdrawal", "type_credit",
15
- "type_debit", "merchant_encoded"
16
- ]
17
 
18
- # Function to detect anomalies
19
  def detect_anomalies(df):
20
- original_df = df.copy()
21
-
22
- for col in ["transaction_id", "merchant", "location", "amount"]:
23
- if col not in original_df.columns:
24
- original_df[col] = "N/A" if col != "amount" else 0.0
25
-
26
- model_input = df.reindex(columns=feature_cols, fill_value=0)
27
- preds = model.predict(model_input)
28
- original_df["is_anomalous"] = preds
29
-
30
- anomalies = original_df[original_df["is_anomalous"] == 1]
31
- return original_df, anomalies[["transaction_id", "merchant", "location", "amount", "is_anomalous"]]
32
-
33
- # Function to plot charts
34
- def plot_charts(df):
35
- fig, axes = plt.subplots(2, 2, figsize=(12, 10))
36
-
37
- if "amount" in df.columns:
38
- sns.histplot(df["amount"], bins=30, kde=True, ax=axes[0, 0])
39
- axes[0, 0].set_title("Amount Distribution")
40
- sns.boxplot(x=df["amount"], ax=axes[0, 1])
41
- axes[0, 1].set_title("Amount Box Plot")
42
- else:
43
- axes[0, 0].text(0.5, 0.5, "No 'amount' column", ha='center')
44
- axes[0, 1].text(0.5, 0.5, "No 'amount' column", ha='center')
45
-
46
- if "day_of_week" in df.columns:
47
- sns.countplot(x=df["day_of_week"], ax=axes[1, 0])
48
- axes[1, 0].set_title("Transactions by Day of Week")
49
- else:
50
- axes[1, 0].text(0.5, 0.5, "No 'day_of_week' column", ha='center')
51
-
52
- if "merchant" in df.columns:
53
- top_merchants = df.groupby("merchant")["amount"].sum().nlargest(5).reset_index()
54
- sns.barplot(data=top_merchants, x="merchant", y="amount", ax=axes[1, 1])
55
- axes[1, 1].set_title("Top 5 Merchants by Amount")
56
- else:
57
- axes[1, 1].text(0.5, 0.5, "No 'merchant' column", ha='center')
58
-
59
- plt.tight_layout()
60
- return fig
61
-
62
- # Function to generate summary + charts + file
63
- def app_interface(csv_file):
64
- df = pd.read_csv(csv_file)
65
- full_df, anomalies = detect_anomalies(df)
66
 
67
- total = len(full_df)
68
- anom_count = len(anomalies)
69
- percent = (anom_count / total) * 100 if total > 0 else 0
70
 
71
- summary = (
72
- f"πŸ”’ **Total Transactions**: {total}\n"
73
- f"⚠️ **Anomalies Detected**: {anom_count}\n"
74
- f"πŸ“Š **Anomaly Percentage**: {percent:.2f}%"
75
  )
76
 
77
- # Convert anomalies to CSV bytes for download
78
- csv_bytes = anomalies.to_csv(index=False).encode()
79
- download = io.BytesIO(csv_bytes)
 
 
80
 
81
- fig = plot_charts(full_df)
 
82
 
83
- return summary, anomalies, fig, download
 
 
84
 
85
- # Gradio App with UI
86
- with gr.Blocks(theme=gr.themes.Soft()) as interface:
87
- gr.Markdown("# πŸ›‘οΈ Financial Abuse & Anomaly Detection App")
88
- gr.Markdown("Upload your **transaction CSV** to detect anomalies and view insights.")
89
 
90
- with gr.Row():
91
- file_input = gr.File(label="πŸ“ Upload CSV File", file_types=[".csv"])
92
- detect_button = gr.Button("🚨 Run Detection", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  with gr.Row():
95
- summary_box = gr.Markdown("")
 
96
 
97
- with gr.Tab("πŸ“‹ Anomalies Detected"):
98
- result_table = gr.Dataframe(label="πŸ”΄ Anomalies")
99
- download_btn = gr.File(label="⬇️ Download Detected Anomalies")
100
 
101
- with gr.Tab("πŸ“Š Transaction Charts"):
102
- chart_output = gr.Plot()
103
 
104
- detect_button.click(fn=app_interface, inputs=file_input,
105
- outputs=[summary_box, result_table, chart_output, download_btn])
 
 
 
106
 
107
- interface.launch(share=True)
 
1
  import gradio as gr
2
  import pandas as pd
 
3
  import matplotlib.pyplot as plt
4
  import seaborn as sns
5
+ import os
6
+ import uuid
7
+ import joblib
8
 
9
+ # Load the model
10
  model = joblib.load("anomaly_detector_rf_model.pkl")
11
 
12
+ # Define the features expected by the model
13
+ expected_features = ["amount"] # Update this list as per your trained model
 
 
 
 
14
 
 
15
  def detect_anomalies(df):
16
+ df = df.copy()
17
+ df['is_anomalous'] = model.predict(df[expected_features])
18
+ anomalies = df[df['is_anomalous'] == 1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ # Save anomalies to temporary CSV file
21
+ csv_filename = f"/tmp/anomalies_{uuid.uuid4().hex}.csv"
22
+ anomalies.to_csv(csv_filename, index=False)
23
 
24
+ return (
25
+ anomalies[["transaction_id", "merchant", "location", "amount", "is_anomalous"]],
26
+ csv_filename
 
27
  )
28
 
29
+ def generate_summary(df):
30
+ total_transactions = len(df)
31
+ total_anomalies = df['is_anomalous'].sum()
32
+ percent_anomalies = round((total_anomalies / total_transactions) * 100, 2)
33
+ return f"Total Transactions: {total_transactions}\nTotal Anomalies: {total_anomalies}\nAnomaly Rate: {percent_anomalies}%"
34
 
35
+ def generate_charts(df):
36
+ fig, ax = plt.subplots(1, 2, figsize=(12, 5))
37
 
38
+ # Distribution of Amounts
39
+ sns.histplot(df['amount'], bins=30, ax=ax[0], kde=True)
40
+ ax[0].set_title('Transaction Amount Distribution')
41
 
42
+ # Anomalies by Merchant
43
+ anomaly_counts = df[df['is_anomalous'] == 1]['merchant'].value_counts().nlargest(10)
44
+ sns.barplot(x=anomaly_counts.values, y=anomaly_counts.index, ax=ax[1])
45
+ ax[1].set_title('Top 10 Merchants with Anomalies')
46
 
47
+ plt.tight_layout()
48
+ chart_path = f"/tmp/chart_{uuid.uuid4().hex}.png"
49
+ plt.savefig(chart_path)
50
+ plt.close()
51
+ return chart_path
52
+
53
+ def app_interface(file):
54
+ df = pd.read_csv(file.name)
55
+ anomalies, csv_path = detect_anomalies(df)
56
+ summary = generate_summary(df)
57
+ chart_path = generate_charts(df)
58
+ return anomalies, summary, chart_path, csv_path
59
+
60
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
61
+ gr.Markdown("""
62
+ # πŸ” Elder Financial Abuse Detection Tool
63
+ Upload a transaction dataset to identify potential financial abuse patterns in elderly individuals.
64
+ """)
65
 
66
  with gr.Row():
67
+ file_input = gr.File(label="πŸ“‚ Upload Transaction CSV")
68
+ analyze_btn = gr.Button("πŸš€ Analyze")
69
 
70
+ with gr.Row():
71
+ anomalies_output = gr.Dataframe(label="⚠️ Detected Anomalies", wrap=True)
72
+ summary_output = gr.Textbox(label="πŸ“Š Summary")
73
 
74
+ chart_output = gr.Image(label="πŸ“ˆ Analysis Charts")
75
+ csv_download = gr.File(label="πŸ“ Download Anomalies CSV")
76
 
77
+ analyze_btn.click(
78
+ fn=app_interface,
79
+ inputs=[file_input],
80
+ outputs=[anomalies_output, summary_output, chart_output, csv_download]
81
+ )
82
 
83
+ demo.launch()