Ujeshhh commited on
Commit
5d337dd
Β·
verified Β·
1 Parent(s): 0daaeca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -29
app.py CHANGED
@@ -7,7 +7,7 @@ import seaborn as sns
7
  # Load trained model
8
  model = joblib.load("anomaly_detector_rf_model.pkl")
9
 
10
- # Define feature columns used during training
11
  feature_cols = [
12
  "amount", "hour", "day_of_week", "is_weekend", "merchant_avg_amount",
13
  "amount_zscore", "log_amount", "type_atm_withdrawal", "type_credit",
@@ -16,49 +16,70 @@ feature_cols = [
16
 
17
  # Function to detect anomalies
18
  def detect_anomalies(df):
19
- # Ensure 'amount' column is present
20
- if "amount" not in df.columns:
21
- return "Error: 'amount' column is missing from the uploaded CSV file."
22
 
23
- # Select only required features and maintain correct order
24
- df = df.reindex(columns=feature_cols, fill_value=0)
 
25
 
26
- # Make predictions
27
- df["is_anomalous"] = model.predict(df)
 
28
 
29
- # Filter anomalies
30
- anomalies = df[df["is_anomalous"] == 1]
31
 
32
- # Keep only available columns
33
- available_cols = [col for col in ["transaction_id", "merchant", "location", "amount"] if col in anomalies.columns]
34
- return anomalies[available_cols] if available_cols else "No relevant columns found in the dataset."
35
-
36
- # Function to visualize anomalies
37
  def plot_charts(df):
38
  fig, axes = plt.subplots(2, 2, figsize=(12, 10))
39
- sns.histplot(df["amount"], bins=30, kde=True, ax=axes[0, 0])
40
- sns.boxplot(x=df["amount"], ax=axes[0, 1])
41
- sns.countplot(x=df["day_of_week"], ax=axes[1, 0])
42
- sns.barplot(x=df["merchant"], y=df["amount"], ax=axes[1, 1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  plt.tight_layout()
45
  return fig
46
 
47
- # Gradio Interface
48
  def app_interface(csv_file):
49
  df = pd.read_csv(csv_file)
50
  anomalies = detect_anomalies(df)
51
  fig = plot_charts(df)
52
-
53
  return anomalies, fig
54
 
55
- interface = gr.Interface(
56
- fn=app_interface,
57
- inputs="file",
58
- outputs=["dataframe", "plot"],
59
- title="Financial Anomaly Detector",
60
- description="Upload a transaction CSV file to detect fraudulent transactions."
61
- )
 
 
 
 
 
 
 
 
 
62
 
63
- # Launch the Gradio app with public access
64
  interface.launch(share=True)
 
7
  # Load trained model
8
  model = joblib.load("anomaly_detector_rf_model.pkl")
9
 
10
+ # Features used during training
11
  feature_cols = [
12
  "amount", "hour", "day_of_week", "is_weekend", "merchant_avg_amount",
13
  "amount_zscore", "log_amount", "type_atm_withdrawal", "type_credit",
 
16
 
17
  # Function to detect anomalies
18
  def detect_anomalies(df):
19
+ original_df = df.copy()
 
 
20
 
21
+ for col in ["transaction_id", "merchant", "location", "amount"]:
22
+ if col not in original_df.columns:
23
+ original_df[col] = "N/A" if col != "amount" else 0.0
24
 
25
+ model_input = df.reindex(columns=feature_cols, fill_value=0)
26
+ preds = model.predict(model_input)
27
+ original_df["is_anomalous"] = preds
28
 
29
+ anomalies = original_df[original_df["is_anomalous"] == 1]
30
+ return anomalies[["transaction_id", "merchant", "location", "amount", "is_anomalous"]]
31
 
32
+ # Function to generate charts
 
 
 
 
33
  def plot_charts(df):
34
  fig, axes = plt.subplots(2, 2, figsize=(12, 10))
35
+
36
+ if "amount" in df.columns:
37
+ sns.histplot(df["amount"], bins=30, kde=True, ax=axes[0, 0])
38
+ axes[0, 0].set_title("Amount Distribution")
39
+ sns.boxplot(x=df["amount"], ax=axes[0, 1])
40
+ axes[0, 1].set_title("Amount Box Plot")
41
+ else:
42
+ axes[0, 0].text(0.5, 0.5, "No 'amount' column", ha='center')
43
+ axes[0, 1].text(0.5, 0.5, "No 'amount' column", ha='center')
44
+
45
+ if "day_of_week" in df.columns:
46
+ sns.countplot(x=df["day_of_week"], ax=axes[1, 0])
47
+ axes[1, 0].set_title("Transactions by Day of Week")
48
+ else:
49
+ axes[1, 0].text(0.5, 0.5, "No 'day_of_week' column", ha='center')
50
+
51
+ if "merchant" in df.columns:
52
+ top_merchants = df.groupby("merchant")["amount"].sum().nlargest(5).reset_index()
53
+ sns.barplot(data=top_merchants, x="merchant", y="amount", ax=axes[1, 1])
54
+ axes[1, 1].set_title("Top 5 Merchants by Amount")
55
+ else:
56
+ axes[1, 1].text(0.5, 0.5, "No 'merchant' column", ha='center')
57
 
58
  plt.tight_layout()
59
  return fig
60
 
61
+ # Gradio Interface logic
62
  def app_interface(csv_file):
63
  df = pd.read_csv(csv_file)
64
  anomalies = detect_anomalies(df)
65
  fig = plot_charts(df)
 
66
  return anomalies, fig
67
 
68
+ # Launching with UI
69
+ with gr.Blocks(theme=gr.themes.Soft()) as interface:
70
+ gr.Markdown("# πŸ›‘οΈ Financial Abuse & Anomaly Detection App")
71
+ gr.Markdown("Upload your **financial transaction CSV** to detect suspicious activity and view insightful visualizations.")
72
+
73
+ with gr.Row():
74
+ file_input = gr.File(label="πŸ“ Upload Transaction CSV", file_types=[".csv"])
75
+ submit_btn = gr.Button("πŸ” Detect Anomalies", variant="primary")
76
+
77
+ with gr.Tab("πŸ“‹ Anomalies Detected"):
78
+ result_df = gr.Dataframe(label="πŸ”΄ Detected Suspicious Transactions")
79
+
80
+ with gr.Tab("πŸ“Š Transaction Insights"):
81
+ chart_output = gr.Plot(label="πŸ“ˆ Transaction Summary Charts")
82
+
83
+ submit_btn.click(fn=app_interface, inputs=file_input, outputs=[result_df, chart_output])
84
 
 
85
  interface.launch(share=True)