Ujeshhh commited on
Commit
155de20
·
verified ·
1 Parent(s): 14bcb8f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -27
app.py CHANGED
@@ -1,46 +1,58 @@
1
  import pandas as pd
2
  import joblib
3
  import gradio as gr
 
 
4
 
5
  # Load the trained model
6
  model = joblib.load("anomaly_detector_rf_model.pkl")
7
 
8
- # Define feature columns (include all used during training)
9
- feature_cols = ['hour', 'day_of_week', 'is_weekend', 'amount', 'merchant_avg_amount',
10
- 'amount_zscore', 'log_amount', 'type_atm_withdrawal', 'type_credit',
11
- 'type_debit', 'merchant_encoded']
12
 
13
- def detect_anomalies(file_path):
14
- # Read the dataset
15
- df = pd.read_csv(file_path)
 
16
 
17
- # Ensure all features exist in the dataframe
18
- missing_cols = [col for col in feature_cols if col not in df.columns]
19
- if missing_cols:
20
- return f"Missing columns in dataset: {missing_cols}"
 
 
 
21
 
22
- # Align feature order with model training
23
- df = df[feature_cols]
24
 
25
- # Make predictions
26
- df['is_anomalous'] = model.predict(df)
27
 
28
- # Filter anomalous transactions
29
- anomalies = df[df['is_anomalous'] == 1][['transaction_id', 'merchant', 'location', 'amount']]
30
 
31
- # Save to a new CSV
32
- anomalies.to_csv("predicted_anomalies.csv", index=False)
33
 
34
- return anomalies
 
35
 
36
  # Gradio Interface
 
 
 
 
 
 
37
  interface = gr.Interface(
38
- fn=detect_anomalies,
39
- inputs=gr.File(label="Upload CSV File"),
40
- outputs=gr.Dataframe(label="Predicted Anomalies"),
41
- title="Anomaly Detection System",
42
- description="Upload a transaction dataset to detect anomalies."
43
  )
44
 
45
- if __name__ == "__main__":
46
- interface.launch(share=True)
 
1
  import pandas as pd
2
  import joblib
3
  import gradio as gr
4
+ import seaborn as sns
5
+ import matplotlib.pyplot as plt
6
 
7
  # Load the trained model
8
  model = joblib.load("anomaly_detector_rf_model.pkl")
9
 
10
+ # Define feature order
11
+ feature_order = ['hour', 'day_of_week', 'is_weekend', 'amount', 'merchant_avg_amount',
12
+ 'amount_zscore', 'log_amount', 'type_atm_withdrawal', 'type_credit',
13
+ 'type_debit', 'merchant_encoded']
14
 
15
+ def detect_anomalies(data):
16
+ df = pd.DataFrame(data)
17
+ df = df[feature_order] # Ensure correct feature order
18
+ df['is_anomalous'] = model.predict(df)
19
 
20
+ # Filter anomalies and display relevant details
21
+ anomalies = df[df['is_anomalous'] == 1][['transaction_id', 'merchant', 'location', 'amount']]
22
+ return anomalies
23
+
24
+ # Function to generate plots
25
+ def generate_plots(df):
26
+ fig, axes = plt.subplots(2, 2, figsize=(12, 10))
27
 
28
+ sns.countplot(data=df, x='is_anomalous', palette='Set2', ax=axes[0, 0])
29
+ axes[0, 0].set_title("Anomaly Distribution")
30
 
31
+ sns.countplot(data=df, y='merchant', order=df['merchant'].value_counts().index, palette='viridis', ax=axes[0, 1])
32
+ axes[0, 1].set_title("Transactions by Merchant")
33
 
34
+ sns.histplot(df['amount'], bins=30, kde=True, color='blue', ax=axes[1, 0])
35
+ axes[1, 0].set_title("Transaction Amount Distribution")
36
 
37
+ sns.scatterplot(data=df, x='amount', y='merchant_avg_amount', hue='is_anomalous', palette='coolwarm', ax=axes[1, 1])
38
+ axes[1, 1].set_title("Amount vs. Merchant Average Amount")
39
 
40
+ plt.tight_layout()
41
+ return fig
42
 
43
  # Gradio Interface
44
+ def app_interface(file):
45
+ df = pd.read_csv(file.name)
46
+ anomalies = detect_anomalies(df)
47
+ plot = generate_plots(df)
48
+ return anomalies, plot
49
+
50
  interface = gr.Interface(
51
+ fn=app_interface,
52
+ inputs=[gr.File(label="Upload Transaction Data (CSV)")],
53
+ outputs=[gr.Dataframe(label="Detected Anomalies"), gr.Plot(label="Transaction Analysis Charts")],
54
+ title="Financial Anomaly Detection",
55
+ description="Upload a transaction dataset to detect financial anomalies and visualize transaction patterns."
56
  )
57
 
58
+ interface.launch(share=True)