jarif commited on
Commit
69b6195
·
verified ·
1 Parent(s): a12ca3f

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +18 -12
src/streamlit_app.py CHANGED
@@ -29,33 +29,39 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29
 
30
  # Load model and threshold
31
  try:
32
- # Adjust paths for Hugging Face Spaces (relative to repository root)
33
- model_path = os.path.join(os.path.dirname(__file__), '..', 'fraud_gnn_model.pth')
34
- threshold_path = os.path.join(os.path.dirname(__file__), '..', 'optimal_threshold.txt')
35
-
36
- # Alternative: If model files are in a 'models' folder
 
 
 
 
 
 
37
  # model_path = os.path.join(os.path.dirname(__file__), '..', 'models', 'fraud_gnn_model.pth')
38
  # threshold_path = os.path.join(os.path.dirname(__file__), '..', 'models', 'optimal_threshold.txt')
39
-
40
  if not os.path.exists(model_path):
41
- raise FileNotFoundError(f"Model file not found at {model_path}")
42
  if not os.path.exists(threshold_path):
43
- raise FileNotFoundError(f"Threshold file not found at {threshold_path}")
44
-
45
  model = FraudGNN(input_dim=7, hidden_dim=16, output_dim=1).to(device)
46
  model.load_state_dict(torch.load(model_path, map_location=device))
47
  model.eval()
48
-
49
  with open(threshold_path, 'r') as f:
50
  threshold = float(f.read())
51
  except FileNotFoundError as e:
52
- st.error(f"Error: {e}. Please ensure model and threshold files are uploaded to the repository root.")
53
  st.stop()
54
  except Exception as e:
55
  st.error(f"Error loading model or threshold: {e}")
56
  st.stop()
57
 
58
- # City and state mappings (unchanged)
59
  city_mapping = {
60
  'Atlanta': 0, 'Bronx': 1, 'Brooklyn': 2, 'Chicago': 3, 'Dallas': 4, 'Houston': 5,
61
  'Indianapolis': 6, 'Las Vegas': 7, 'Los Angeles': 8, 'Louisville': 9, 'Miami': 10,
 
29
 
30
  # Load model and threshold
31
  try:
32
+ # Try root directory first (Hugging Face Spaces working directory)
33
+ model_path = 'fraud_gnn_model.pth'
34
+ threshold_path = 'optimal_threshold.txt'
35
+
36
+ # Fallback: Try relative to src/ (if files are misplaced)
37
+ if not os.path.exists(model_path):
38
+ model_path = os.path.join(os.path.dirname(__file__), 'fraud_gnn_model.pth')
39
+ if not os.path.exists(threshold_path):
40
+ threshold_path = os.path.join(os.path.dirname(__file__), 'optimal_threshold.txt')
41
+
42
+ # Alternative: If files are in a 'models/' folder (uncomment if applicable)
43
  # model_path = os.path.join(os.path.dirname(__file__), '..', 'models', 'fraud_gnn_model.pth')
44
  # threshold_path = os.path.join(os.path.dirname(__file__), '..', 'models', 'optimal_threshold.txt')
45
+
46
  if not os.path.exists(model_path):
47
+ raise FileNotFoundError(f"Model file not found at {model_path}. Please upload fraud_gnn_model.pth to the repository root.")
48
  if not os.path.exists(threshold_path):
49
+ raise FileNotFoundError(f"Threshold file not found at {threshold_path}. Please upload optimal_threshold.txt to the repository root.")
50
+
51
  model = FraudGNN(input_dim=7, hidden_dim=16, output_dim=1).to(device)
52
  model.load_state_dict(torch.load(model_path, map_location=device))
53
  model.eval()
54
+
55
  with open(threshold_path, 'r') as f:
56
  threshold = float(f.read())
57
  except FileNotFoundError as e:
58
+ st.error(f"Error: {e}")
59
  st.stop()
60
  except Exception as e:
61
  st.error(f"Error loading model or threshold: {e}")
62
  st.stop()
63
 
64
+ # City and state mappings
65
  city_mapping = {
66
  'Atlanta': 0, 'Bronx': 1, 'Brooklyn': 2, 'Chicago': 3, 'Dallas': 4, 'Houston': 5,
67
  'Indianapolis': 6, 'Las Vegas': 7, 'Los Angeles': 8, 'Louisville': 9, 'Miami': 10,