zoya23 commited on
Commit
f3a7e8e
·
verified ·
1 Parent(s): 0ae1060

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -14
app.py CHANGED
@@ -1,20 +1,37 @@
1
  import streamlit as st
 
2
  import joblib
3
- import numpy as np
4
 
5
- # Load your model
6
- model = joblib.load("log_reg_model.pkl") # or "log_reg_model.pkl"
7
 
8
- # Streamlit App UI
9
- st.title("AI Sleep State Detector")
10
- st.write("Predict sleep state (`onset` or `wakeup`) using step count and hour.")
11
 
12
- # Input Features
13
- step = st.number_input("Step count:", min_value=0, max_value=10000, value=0)
14
- hour = st.number_input("Hour of day (0–23):", min_value=0, max_value=23, value=0)
 
 
 
15
 
16
- # Predict Button
17
- if st.button("Predict Sleep State"):
18
- input_data = np.array([[step, hour]])
19
- prediction = model.predict(input_data)[0]
20
- return "Sleep Onset" if prediction == 1 else "Wakeup"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import pandas as pd
3
  import joblib
 
4
 
5
+ # Load model
6
+ model = joblib.load("log_reg_model.pkl") # or log_reg_model.pkl
7
 
8
+ st.title("Sleep State Prediction App")
 
 
9
 
10
+ # Upload test data
11
+ uploaded_file = st.file_uploader("Upload test_series.parquet", type=["parquet"])
12
+ if uploaded_file is not None:
13
+ test_df = pd.read_parquet(uploaded_file)
14
+ st.write("Sample of uploaded data:")
15
+ st.dataframe(test_df.head())
16
 
17
+ # Check if required columns exist
18
+ if {'series_id', 'step', 'hour'}.issubset(test_df.columns):
19
+ # Predict sleep state
20
+ features = test_df[['step', 'hour']]
21
+ predictions = model.predict(features)
22
+
23
+ # Build submission-like DataFrame
24
+ test_df = test_df.reset_index(drop=True)
25
+ test_df["event"] = predictions
26
+ test_df["row_id"] = test_df["series_id"] + "_" + test_df.index.astype(str)
27
+ submission = test_df[["row_id", "event"]]
28
+
29
+ st.success("Predictions completed.")
30
+ st.write(submission.head())
31
+
32
+ # Option to download
33
+ csv = submission.to_csv(index=False).encode('utf-8')
34
+ st.download_button("Download submission.csv", csv, "submission.csv", "text/csv")
35
+
36
+ else:
37
+ st.error("Uploaded file must contain 'series_id', 'step', and 'hour' columns.")