Transcendental-Programmer commited on
Commit
e7b58b1
·
1 Parent(s): f0b383a

feat: add streamlit app

Browse files
Files changed (4) hide show
  1. README.md +39 -0
  2. requirements.txt +9 -9
  3. src/api/server.py +36 -0
  4. webapp/streamlit_app.py +57 -0
README.md CHANGED
@@ -30,3 +30,42 @@ MIT
30
 
31
  ## Contributing
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  ## Contributing
32
 
33
+
34
+ ## Federated Credit Scoring Demo (with Web App)
35
+
36
+ This project includes a demo where multiple banks (clients) collaboratively train a credit scoring model using federated learning. A Streamlit web app allows you to enter customer features and get a credit score prediction from the federated model.
37
+
38
+ ### Quick Start
39
+
40
+ 1. **Install dependencies**
41
+
42
+ ```bash
43
+ pip install -r requirements.txt
44
+ ```
45
+
46
+ 2. **Start the Federated Server**
47
+
48
+ ```bash
49
+ python -m src.main --mode server --config config/server_config.yaml
50
+ ```
51
+
52
+ 3. **Start at least two Clients (in separate terminals)**
53
+
54
+ ```bash
55
+ python -m src.main --mode client --config config/client_config.yaml
56
+ ```
57
+
58
+ 4. **Run the Web App**
59
+
60
+ ```bash
61
+ streamlit run webapp/streamlit_app.py
62
+ ```
63
+
64
+ 5. **Use the Web App**
65
+ - Enter 32 features (dummy values are fine for demo)
66
+ - Click "Predict Credit Score" to get a prediction from the federated model
67
+ - View training progress in the app
68
+
69
+ *For best results, keep the server and at least two clients running in parallel.*
70
+
71
+ ---
requirements.txt CHANGED
@@ -1,12 +1,12 @@
1
  # Core ML frameworks
2
- tensorflow>=2.6.0
3
  tensorflow-federated
4
- torch>=1.9.0
5
  transformers
6
 
7
  # Data processing
8
- pandas>=1.3.0
9
- numpy>=1.19.0
10
  scikit-learn
11
 
12
  # RAG components
@@ -18,14 +18,14 @@ tensorflow-privacy
18
  pysyft
19
 
20
  # API and web
21
- flask>=2.0.0
22
  fastapi
23
  uvicorn
24
- requests>=2.25.0
 
25
 
26
  # Configuration and utilities
27
- pyyaml>=5.4.0
28
-
29
  # Testing and development
30
  pytest
31
  black
@@ -37,4 +37,4 @@ sphinx
37
  sphinx-rtd-theme
38
 
39
  # Additional requirements
40
- pyyaml>=5.4.1
 
1
  # Core ML frameworks
2
+ tensorflow
3
  tensorflow-federated
4
+ torch
5
  transformers
6
 
7
  # Data processing
8
+ pandas
9
+ numpy
10
  scikit-learn
11
 
12
  # RAG components
 
18
  pysyft
19
 
20
  # API and web
21
+ flask
22
  fastapi
23
  uvicorn
24
+ requests
25
+ streamlit
26
 
27
  # Configuration and utilities
28
+ pyyaml
 
29
  # Testing and development
30
  pytest
31
  black
 
37
  sphinx-rtd-theme
38
 
39
  # Additional requirements
40
+ pyyaml
src/api/server.py CHANGED
@@ -149,6 +149,42 @@ class FederatedAPI:
149
  logger.error(f"Error processing RAG query: {str(e)}")
150
  return jsonify({'error': str(e)}), 500
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  def run(self, debug: bool = False):
153
  """Run the API server"""
154
  logger.info(f"Starting Federated API server on {self.host}:{self.port}")
 
149
  logger.error(f"Error processing RAG query: {str(e)}")
150
  return jsonify({'error': str(e)}), 500
151
 
152
+ @self.app.route('/predict', methods=['POST'])
153
+ def predict():
154
+ """Predict using the current global model."""
155
+ try:
156
+ data = request.get_json()
157
+ features = data.get('features')
158
+ if features is None or not isinstance(features, list) or len(features) != 32:
159
+ return jsonify({'error': 'features must be a list of 32 floats'}), 400
160
+
161
+ # Get global model weights
162
+ model_weights = self.coordinator.get_global_model()
163
+ if model_weights is None:
164
+ return jsonify({'error': 'Global model not available yet'}), 503
165
+
166
+ # Build the model (same as client)
167
+ import tensorflow as tf
168
+ import numpy as np
169
+ input_dim = 32
170
+ model = tf.keras.Sequential([
171
+ tf.keras.layers.Input(shape=(input_dim,)),
172
+ tf.keras.layers.Dense(128, activation='relu'),
173
+ tf.keras.layers.Dense(64, activation='relu'),
174
+ tf.keras.layers.Dense(1)
175
+ ])
176
+ model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss='mse')
177
+ model.set_weights([np.array(w) for w in model_weights])
178
+
179
+ # Prepare input and predict
180
+ x = np.array(features, dtype=np.float32).reshape(1, -1)
181
+ pred = model.predict(x)
182
+ prediction = float(pred[0, 0])
183
+ return jsonify({'prediction': prediction})
184
+ except Exception as e:
185
+ logger.error(f"Error in prediction endpoint: {str(e)}")
186
+ return jsonify({'error': str(e)}), 500
187
+
188
  def run(self, debug: bool = False):
189
  """Run the API server"""
190
  logger.info(f"Starting Federated API server on {self.host}:{self.port}")
webapp/streamlit_app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import requests
3
+ import numpy as np
4
+
5
+ st.set_page_config(page_title="Federated Credit Scoring Demo", layout="centered")
6
+ st.title("Federated Credit Scoring Demo (Federated Learning)")
7
+
8
+ SERVER_URL = st.sidebar.text_input("Server URL", value="http://localhost:8080")
9
+
10
+ st.markdown("""
11
+ This demo shows how multiple banks can collaboratively train a credit scoring model using federated learning, without sharing raw data.
12
+ Enter customer features below to get a credit score prediction from the federated model.
13
+ """)
14
+
15
+ # --- Feature Input Form ---
16
+ st.header("Enter Customer Features")
17
+ with st.form("feature_form"):
18
+ features = []
19
+ cols = st.columns(4)
20
+ for i in range(32):
21
+ with cols[i % 4]:
22
+ val = st.number_input(f"Feature {i+1}", value=0.0, format="%.4f", key=f"f_{i}")
23
+ features.append(val)
24
+ submitted = st.form_submit_button("Predict Credit Score")
25
+
26
+ # --- Prediction ---
27
+ prediction = None
28
+ if submitted:
29
+ try:
30
+ resp = requests.post(f"{SERVER_URL}/predict", json={"features": features}, timeout=10)
31
+ if resp.status_code == 200:
32
+ prediction = resp.json().get("prediction")
33
+ st.success(f"Predicted Credit Score: {prediction:.2f}")
34
+ else:
35
+ st.error(f"Prediction failed: {resp.json().get('error', 'Unknown error')}")
36
+ except Exception as e:
37
+ st.error(f"Error connecting to server: {e}")
38
+
39
+ # --- Training Progress ---
40
+ st.header("Federated Training Progress")
41
+ try:
42
+ status = requests.get(f"{SERVER_URL}/training_status", timeout=5)
43
+ if status.status_code == 200:
44
+ data = status.json()
45
+ st.write(f"Current Round: {data.get('current_round', 0)} / {data.get('total_rounds', 10)}")
46
+ st.write(f"Active Clients: {data.get('active_clients', 0)}")
47
+ st.write(f"Clients Ready: {data.get('clients_ready', 0)}")
48
+ st.write(f"Training Active: {data.get('training_active', False)}")
49
+ else:
50
+ st.warning("Could not fetch training status.")
51
+ except Exception as e:
52
+ st.warning(f"Could not connect to server for training status: {e}")
53
+
54
+ st.markdown("---")
55
+ st.markdown("""
56
+ *This is a demo. All data is synthetic. For best results, run the federated server and at least two clients in parallel.*
57
+ """)