sborhade commited on
Commit
a00aae6
·
verified ·
1 Parent(s): 8cf7cdb

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +23 -7
inference.py CHANGED
@@ -1,18 +1,34 @@
1
  import pickle
2
  import pandas as pd
 
3
 
4
  def load_model():
5
- with open("model/expense_forecaster_model.pkl", "rb") as f:
6
- model = pickle.load(f)
7
- return model
 
 
 
 
8
 
9
  def predict(data):
10
  model = load_model()
11
- df = pd.DataFrame([data])
12
- prediction = model.predict(df)
13
- return prediction.tolist()
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  if __name__ == "__main__":
16
- example_input = {"income": 5000, "previous_expenses": 3000, "month": 12} #example data, change this to match your feature names.
17
  prediction = predict(example_input)
18
  print(f"Prediction: {prediction}")
 
1
  import pickle
2
  import pandas as pd
3
+ import json
4
 
5
  def load_model():
6
+ try:
7
+ with open("model/expense_forecaster_model.pkl", "rb") as f:
8
+ model = pickle.load(f)
9
+ return model
10
+ except Exception as e:
11
+ print(f"Error loading model: {e}")
12
+ return None
13
 
14
  def predict(data):
15
  model = load_model()
16
+ if model is None:
17
+ return {"error": "Model loading failed"}
18
+
19
+ try:
20
+ # Ensure data is a dictionary
21
+ if not isinstance(data, dict):
22
+ return {"error": "Input data must be a dictionary"}
23
+
24
+ df = pd.DataFrame([data])
25
+ prediction = model.predict(df)
26
+ return prediction.tolist()
27
+
28
+ except Exception as e:
29
+ return {"error": f"Prediction error: {e}"}
30
 
31
  if __name__ == "__main__":
32
+ example_input = {"income": 5000, "previous_expenses": 3000, "month": 12}
33
  prediction = predict(example_input)
34
  print(f"Prediction: {prediction}")