Update app.py
Browse files
app.py
CHANGED
@@ -1,99 +1,27 @@
|
|
1 |
-
# Import necessary libraries
|
2 |
-
import pandas as pd
|
3 |
import yfinance as yf
|
4 |
-
from sklearn.model_selection import train_test_split
|
5 |
-
from sklearn.ensemble import RandomForestRegressor
|
6 |
-
from sklearn.metrics import mean_squared_error
|
7 |
import streamlit as st
|
8 |
-
|
9 |
|
10 |
# Function to fetch historical stock data
|
11 |
def fetch_stock_data(symbol, start_date, end_date):
|
12 |
data = yf.download(symbol, start=start_date, end=end_date)
|
13 |
return data
|
14 |
|
15 |
-
# Function to create features for the model
|
16 |
-
def create_features(data):
|
17 |
-
data['Datetime'] = pd.to_datetime(data['Datetime']) # Replace 'Date' with the actual column name
|
18 |
-
|
19 |
-
data['Year'] = data['Datetime'].dt.year
|
20 |
-
data['Month'] = data['Datetime'].dt.month
|
21 |
-
data['Day'] = data['Datetime'].dt.day
|
22 |
-
data['Hour'] = data['Datetime'].dt.hour
|
23 |
-
data['Minute'] = data['Datetime'].dt.minute
|
24 |
-
|
25 |
-
return data
|
26 |
-
|
27 |
-
# Function to train a machine learning model
|
28 |
-
def train_model(data):
|
29 |
-
features = ['Year', 'Month', 'Day', 'Hour', 'Minute']
|
30 |
-
target = 'Close'
|
31 |
-
|
32 |
-
if len(data) == 0:
|
33 |
-
st.write("Not enough data for training.")
|
34 |
-
return None
|
35 |
-
|
36 |
-
X = data[features]
|
37 |
-
y = data[target]
|
38 |
-
|
39 |
-
if len(data) <= 1:
|
40 |
-
st.write("Not enough data for splitting.")
|
41 |
-
return None
|
42 |
-
|
43 |
-
try:
|
44 |
-
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
45 |
-
|
46 |
-
if len(X_train) == 0 or len(X_test) == 0:
|
47 |
-
st.write("Not enough data after splitting.")
|
48 |
-
return None
|
49 |
-
|
50 |
-
model = RandomForestRegressor()
|
51 |
-
model.fit(X_train, y_train)
|
52 |
-
|
53 |
-
# Evaluate the model
|
54 |
-
predictions = model.predict(X_test)
|
55 |
-
mse = mean_squared_error(y_test, predictions)
|
56 |
-
st.write(f"Mean Squared Error: {mse}")
|
57 |
-
|
58 |
-
return model
|
59 |
-
|
60 |
-
except ValueError as e:
|
61 |
-
st.write(f"Error during train-test split: {e}")
|
62 |
-
return None
|
63 |
-
|
64 |
# Streamlit UI
|
65 |
def main():
|
66 |
-
st.title("Stock
|
67 |
|
68 |
symbol = st.text_input("Enter Stock Symbol (e.g., AAPL):")
|
69 |
start_date = st.date_input("Select Start Date:")
|
70 |
end_date = st.date_input("Select End Date:")
|
71 |
|
72 |
-
if st.button("
|
73 |
# Fetch stock data
|
74 |
stock_data = fetch_stock_data(symbol, start_date, end_date)
|
75 |
|
76 |
-
#
|
77 |
-
|
78 |
-
|
79 |
-
# Train the model
|
80 |
-
model = train_model(stock_data)
|
81 |
-
|
82 |
-
if model:
|
83 |
-
# Predict the stock price for a specific date (e.g., the last date in the dataset)
|
84 |
-
prediction_date = stock_data['Date'].iloc[-1]
|
85 |
-
prediction_features = [[
|
86 |
-
prediction_date.year,
|
87 |
-
prediction_date.month,
|
88 |
-
prediction_date.day,
|
89 |
-
prediction_date.hour,
|
90 |
-
prediction_date.minute
|
91 |
-
]]
|
92 |
-
predicted_price = model.predict(prediction_features)[0]
|
93 |
-
|
94 |
-
st.subheader(f"Predicted Stock Price on {prediction_date} (UTC):")
|
95 |
-
st.write(f"${predicted_price:.2f}")
|
96 |
|
97 |
-
# Run the Streamlit app
|
98 |
if __name__ == '__main__':
|
99 |
main()
|
|
|
|
|
|
|
1 |
import yfinance as yf
|
|
|
|
|
|
|
2 |
import streamlit as st
|
3 |
+
import pandas as pd
|
4 |
|
5 |
# Function to fetch historical stock data
|
6 |
def fetch_stock_data(symbol, start_date, end_date):
|
7 |
data = yf.download(symbol, start=start_date, end=end_date)
|
8 |
return data
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
# Streamlit UI
|
11 |
def main():
|
12 |
+
st.title("Stock Data Viewer")
|
13 |
|
14 |
symbol = st.text_input("Enter Stock Symbol (e.g., AAPL):")
|
15 |
start_date = st.date_input("Select Start Date:")
|
16 |
end_date = st.date_input("Select End Date:")
|
17 |
|
18 |
+
if st.button("Fetch Stock Data"):
|
19 |
# Fetch stock data
|
20 |
stock_data = fetch_stock_data(symbol, start_date, end_date)
|
21 |
|
22 |
+
# Display the raw stock data
|
23 |
+
st.subheader("Raw Stock Data:")
|
24 |
+
st.write(stock_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
|
|
26 |
if __name__ == '__main__':
|
27 |
main()
|