Avinash109 commited on
Commit
9c65345
·
verified ·
1 Parent(s): 2452368

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -69
app.py CHANGED
@@ -8,27 +8,6 @@ from sklearn.preprocessing import StandardScaler
8
  from sklearn.model_selection import train_test_split
9
  import gradio as gr
10
  import os
11
- import time
12
- from fastapi import FastAPI, BackgroundTasks
13
- from fastapi.middleware.cors import CORSMiddleware
14
- import asyncio
15
-
16
- # FastAPI app
17
- app = FastAPI()
18
-
19
- # Add CORS middleware
20
- app.add_middleware(
21
- CORSMiddleware,
22
- allow_origins=["*"],
23
- allow_credentials=True,
24
- allow_methods=["*"],
25
- allow_headers=["*"],
26
- )
27
-
28
- # Global variables
29
- model = None
30
- scaler = None
31
- latest_report = "Initializing..."
32
 
33
  # Define the Dataset class
34
  class BankNiftyDataset(Dataset):
@@ -64,11 +43,13 @@ class LSTMModel(nn.Module):
64
  return out
65
 
66
  # Function to train the model
67
- def train_model(train_loader, val_loader, num_epochs=10):
68
- global model
69
  criterion = nn.MSELoss()
70
  optimizer = optim.Adam(model.parameters(), lr=0.001)
71
 
 
 
 
72
  for epoch in range(num_epochs):
73
  model.train()
74
  for features, labels in train_loader:
@@ -87,6 +68,13 @@ def train_model(train_loader, val_loader, num_epochs=10):
87
  val_loss /= len(val_loader)
88
 
89
  print(f"Epoch {epoch+1}/{num_epochs}, Validation Loss: {val_loss:.4f}")
 
 
 
 
 
 
 
90
 
91
  # Function to generate trading signals
92
  def generate_signals(predictions, actual_values, stop_loss_threshold=0.05):
@@ -101,7 +89,7 @@ def generate_signals(predictions, actual_values, stop_loss_threshold=0.05):
101
  return signals
102
 
103
  # Function to generate a report
104
- def generate_report(predictions, actual_values, signals):
105
  report = []
106
  cumulative_profit = 0
107
  for i in range(len(signals)):
@@ -115,12 +103,17 @@ def generate_report(predictions, actual_values, signals):
115
 
116
  total_profit = cumulative_profit
117
  report.append(f"Total Profit: {total_profit:.2f}")
 
118
  return "\n".join(report)
119
 
 
 
 
 
120
  # Function to process data and make predictions
121
  def predict():
122
- global model, scaler, latest_report
123
-
124
  # Load the pre-existing CSV file
125
  csv_path = 'BANKNIFTY_OPTION_CHAIN_data.csv'
126
  if not os.path.exists(csv_path):
@@ -128,11 +121,13 @@ def predict():
128
 
129
  # Load and preprocess data
130
  data = pd.read_csv(csv_path)
131
- if scaler is None:
132
- scaler = StandardScaler()
133
- scaled_data = scaler.fit_transform(data[['open', 'high', 'low', 'close', 'volume', 'oi']])
 
134
  else:
135
- scaled_data = scaler.transform(data[['open', 'high', 'low', 'close', 'volume', 'oi']])
 
136
  data[['open', 'high', 'low', 'close', 'volume', 'oi']] = scaled_data
137
 
138
  # Split data
@@ -147,60 +142,39 @@ def predict():
147
  val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
148
 
149
  # Initialize and train the model
150
- if model is None:
151
- input_dim = 6
152
- hidden_dim = 64
153
- output_dim = len(target_cols)
154
- model = LSTMModel(input_dim, hidden_dim, output_dim)
155
-
156
- train_model(train_loader, val_loader)
 
157
 
158
  # Make predictions
159
- model.eval()
160
  predictions = []
161
  actual_values = val_data['close'].values[seq_len-1:]
162
  with torch.no_grad():
163
  for i in range(len(val_dataset)):
164
  features, _ = val_dataset[i]
165
- pred = model(features.unsqueeze(0)).item()
166
  predictions.append(pred)
167
 
168
  # Generate signals and report
169
  signals = generate_signals(predictions, actual_values)
170
- latest_report = generate_report(predictions, actual_values, signals)
171
 
172
- return latest_report
173
-
174
- # Background task to update the model and report
175
- async def update_model_and_report():
176
- global latest_report
177
- while True:
178
- latest_report = predict()
179
- await asyncio.sleep(3600) # Update every hour
180
-
181
- # Startup event to begin the background task
182
- @app.on_event("startup")
183
- async def startup_event():
184
- background_tasks = BackgroundTasks()
185
- background_tasks.add_task(update_model_and_report)
186
- await background_tasks()
187
-
188
- # Gradio interface
189
- def gradio_interface():
190
- return latest_report
191
 
 
192
  iface = gr.Interface(
193
- fn=gradio_interface,
194
  inputs=None,
195
- outputs=gr.Textbox(label="Latest Prediction Report"),
196
  title="BankNifty Option Chain Predictor",
197
- description="This app automatically generates and updates predictions and trading signals based on the latest BankNifty option chain data."
198
  )
199
 
200
- # Combine FastAPI and Gradio
201
- app = gr.mount_gradio_app(app, iface, path="/")
202
-
203
- # Run the FastAPI app
204
- if __name__ == "__main__":
205
- import uvicorn
206
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
8
  from sklearn.model_selection import train_test_split
9
  import gradio as gr
10
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  # Define the Dataset class
13
  class BankNiftyDataset(Dataset):
 
43
  return out
44
 
45
  # Function to train the model
46
+ def train_model(model, train_loader, val_loader, num_epochs=10):
 
47
  criterion = nn.MSELoss()
48
  optimizer = optim.Adam(model.parameters(), lr=0.001)
49
 
50
+ best_val_loss = float('inf')
51
+ best_model = None
52
+
53
  for epoch in range(num_epochs):
54
  model.train()
55
  for features, labels in train_loader:
 
68
  val_loss /= len(val_loader)
69
 
70
  print(f"Epoch {epoch+1}/{num_epochs}, Validation Loss: {val_loss:.4f}")
71
+
72
+ if val_loss < best_val_loss:
73
+ best_val_loss = val_loss
74
+ best_model = model.state_dict().copy()
75
+
76
+ model.load_state_dict(best_model)
77
+ return model, best_val_loss
78
 
79
  # Function to generate trading signals
80
  def generate_signals(predictions, actual_values, stop_loss_threshold=0.05):
 
89
  return signals
90
 
91
  # Function to generate a report
92
+ def generate_report(predictions, actual_values, signals, val_loss):
93
  report = []
94
  cumulative_profit = 0
95
  for i in range(len(signals)):
 
103
 
104
  total_profit = cumulative_profit
105
  report.append(f"Total Profit: {total_profit:.2f}")
106
+ report.append(f"Model Validation Loss: {val_loss:.4f}")
107
  return "\n".join(report)
108
 
109
+ # Global variables to store the model and scaler
110
+ global_model = None
111
+ global_scaler = None
112
+
113
  # Function to process data and make predictions
114
  def predict():
115
+ global global_model, global_scaler
116
+
117
  # Load the pre-existing CSV file
118
  csv_path = 'BANKNIFTY_OPTION_CHAIN_data.csv'
119
  if not os.path.exists(csv_path):
 
121
 
122
  # Load and preprocess data
123
  data = pd.read_csv(csv_path)
124
+
125
+ if global_scaler is None:
126
+ global_scaler = StandardScaler()
127
+ scaled_data = global_scaler.fit_transform(data[['open', 'high', 'low', 'close', 'volume', 'oi']])
128
  else:
129
+ scaled_data = global_scaler.transform(data[['open', 'high', 'low', 'close', 'volume', 'oi']])
130
+
131
  data[['open', 'high', 'low', 'close', 'volume', 'oi']] = scaled_data
132
 
133
  # Split data
 
142
  val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
143
 
144
  # Initialize and train the model
145
+ input_dim = 6
146
+ hidden_dim = 64
147
+ output_dim = len(target_cols)
148
+
149
+ if global_model is None:
150
+ global_model = LSTMModel(input_dim, hidden_dim, output_dim)
151
+
152
+ global_model, val_loss = train_model(global_model, train_loader, val_loader)
153
 
154
  # Make predictions
155
+ global_model.eval()
156
  predictions = []
157
  actual_values = val_data['close'].values[seq_len-1:]
158
  with torch.no_grad():
159
  for i in range(len(val_dataset)):
160
  features, _ = val_dataset[i]
161
+ pred = global_model(features.unsqueeze(0)).item()
162
  predictions.append(pred)
163
 
164
  # Generate signals and report
165
  signals = generate_signals(predictions, actual_values)
166
+ report = generate_report(predictions, actual_values, signals, val_loss)
167
 
168
+ return report
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
+ # Set up the Gradio interface
171
  iface = gr.Interface(
172
+ fn=predict,
173
  inputs=None,
174
+ outputs=gr.Textbox(label="Prediction Report"),
175
  title="BankNifty Option Chain Predictor",
176
+ description="Click 'Submit' to generate predictions and trading signals based on the latest BankNifty option chain data. The model is automatically trained and improved with each run."
177
  )
178
 
179
+ # Launch the app
180
+ iface.launch()