suryanshs16103 commited on
Commit
5354994
·
verified ·
1 Parent(s): 6f69e3c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -182
app.py CHANGED
@@ -1,182 +1,183 @@
1
- import datetime
2
- import requests
3
- import matplotlib.pyplot as plt
4
- from mplfinance.original_flavor import candlestick_ohlc
5
- import numpy as np
6
- from sklearn.linear_model import LinearRegression
7
- import os
8
- from pathlib import Path
9
- import streamlit as st
10
-
11
- PLOT_DIR = Path("./Plots")
12
-
13
- if not os.path.exists(PLOT_DIR):
14
- os.mkdir(PLOT_DIR)
15
-
16
- host = "https://api.gateio.ws"
17
- prefix = "/api/v4"
18
- headers = {'Accept': 'application/json', 'Content-Type': 'application/json'}
19
- endpoint = '/spot/candlesticks'
20
- url = host + prefix + endpoint
21
- max_API_request_allowed = 900
22
-
23
- def lin_reg(data, threshold_channel_len):
24
- list_f = []
25
- X = []
26
- y = []
27
- for i in range(0, len(data)):
28
- X.append(data[i][0])
29
- avg = (data[i][2] + data[i][3]) / 2
30
- y.append(avg)
31
- X = np.array(X).reshape(-1, 1)
32
- y = np.array(y).reshape(-1, 1)
33
- l = 0
34
- j = threshold_channel_len
35
- while l < j and j <= len(data):
36
- score = []
37
- list_pf = []
38
- while j <= len(data):
39
- reg = LinearRegression().fit(X[l:j], y[l:j])
40
- temp_coeff = list(reg.coef_)
41
- temp_intercept = list(reg.intercept_)
42
- list_pf.append([temp_coeff[0][0], temp_intercept[0], l, j - 1])
43
- score.append([reg.score(X[l:j], y[l:j]), j])
44
- j = j + 1
45
- req_score = float("-inf")
46
- ind = -1
47
- temp_ind = -1
48
- for i in range(len(score)):
49
- if req_score < score[i][0]:
50
- ind = score[i][1]
51
- req_score = score[i][0]
52
- temp_ind = i
53
- list_f.append(list_pf[temp_ind])
54
- l = ind
55
- j = ind + threshold_channel_len
56
- return list_f
57
-
58
- def binary_search(data, line_type, m, b, epsilon):
59
- right = float("-inf")
60
- left = float("inf")
61
- get_y_intercept = lambda x, y: y - m * x
62
- for i in range(len(data)):
63
- d = data[i]
64
- curr_y = d[2]
65
- if line_type == "bottom":
66
- curr_y = d[3]
67
- curr = get_y_intercept(d[0], curr_y)
68
- right = max(right, curr)
69
- left = min(left, curr)
70
-
71
- sign = -1
72
- if line_type == "bottom":
73
- left, right = right, left
74
- sign = 1
75
- ans = right
76
- while left <= right:
77
- mid = left + (right - left) // 2
78
- intersection_count = 0
79
- for i in range(len(data)):
80
- d = data[i]
81
- curr_y = m * d[0] + mid
82
- candle_y = d[2]
83
- if line_type == "bottom":
84
- candle_y = d[3]
85
- if line_type == "bottom" and (curr_y > candle_y and (curr_y - candle_y > epsilon)):
86
- intersection_count += 1
87
- if line_type == "top" and (curr_y < candle_y and (candle_y - curr_y > epsilon)):
88
- intersection_count += 1
89
- if intersection_count == 0:
90
- right = mid + 1 * sign
91
- ans = mid
92
- else:
93
- left = mid - 1 * sign
94
- return ans
95
-
96
- def plot_lines(lines, plt, converted_data):
97
- for m, b, start, end in lines:
98
- x_data = list(np.linspace(converted_data[start][0], converted_data[end][0], 10))
99
- y_data = [m * x + b for x in x_data]
100
- plt.plot(x_data, y_data)
101
-
102
- def get_API_data(currency, interval_timedelta, interval, start_datetime, end_datetime):
103
- curr_datetime = start_datetime
104
- total_dates = 0
105
- while curr_datetime <= end_datetime:
106
- total_dates += 1
107
- curr_datetime += interval_timedelta
108
- data = []
109
- for i in range(0, total_dates, max_API_request_allowed):
110
- query_param = {
111
- "currency_pair": "{}_USDT".format(currency),
112
- "from": int((start_datetime + i * interval_timedelta).timestamp()),
113
- "to": int((start_datetime + (i + max_API_request_allowed - 1) * interval_timedelta).timestamp()),
114
- "interval": interval,
115
- }
116
- r = requests.get(url=url, headers=headers, params=query_param)
117
- if r.status_code != 200:
118
- st.error("Invalid API Request")
119
- return []
120
- data += r.json()
121
- return data
122
-
123
- def testcasecase(currency, interval, startdate, enddate, threshold_channel_len, testcasecase_id):
124
- start_date_month, start_date_day, start_date_year = [int(x) for x in startdate.strip().split("/")]
125
- end_date_month, end_date_day, end_date_year = [int(x) for x in enddate.strip().split("/")]
126
-
127
- if interval == "1h":
128
- interval_timedelta = datetime.timedelta(hours=1)
129
- elif interval == "4h":
130
- interval_timedelta = datetime.timedelta(hours=4)
131
- elif interval == "1d":
132
- interval_timedelta = datetime.timedelta(days=1)
133
- else:
134
- interval_timedelta = datetime.timedelta(weeks=1)
135
-
136
- start_datetime = datetime.datetime(year=start_date_year, month=start_date_month, day=start_date_day)
137
- end_datetime = datetime.datetime(year=end_date_year, month=end_date_month, day=end_date_day)
138
-
139
- data = get_API_data(currency, interval_timedelta, interval, start_datetime, end_datetime)
140
- if len(data) == 0:
141
- return
142
- converted_data = []
143
- for d in data:
144
- converted_data.append([matplotlib.dates.date2num(datetime.datetime.utcfromtimestamp(float(d[0]))), float(d[5]), float(d[3]), float(d[4]), float(d[2])])
145
-
146
- fig, ax = plt.subplots()
147
- candlestick_ohlc(ax, converted_data, width=0.4, colorup='#77d879', colordown='#db3f3f')
148
-
149
- fitting_lines_data = lin_reg(converted_data, threshold_channel_len)
150
- top_fitting_lines_data = []
151
- bottom_fitting_lines_data = []
152
- epsilon = 0
153
- for i in range(len(fitting_lines_data)):
154
- m, b, start, end = fitting_lines_data[i]
155
- top_b = binary_search(converted_data[start:end + 1], "top", m, b, epsilon)
156
- bottom_b = binary_search(converted_data[start:end + 1], "bottom", m, b, epsilon)
157
- top_fitting_lines_data.append([m, top_b, start, end])
158
- bottom_fitting_lines_data.append([m, bottom_b, start, end])
159
-
160
- plot_lines(top_fitting_lines_data, plt, converted_data)
161
- plot_lines(bottom_fitting_lines_data, plt, converted_data)
162
- plt.title("{}_USDT".format(currency))
163
- file_name = "figure_{}_{}_USDT.png".format(testcasecase_id, currency)
164
- file_location = os.path.join(PLOT_DIR, file_name)
165
- plt.savefig(file_location)
166
- st.pyplot(fig)
167
-
168
- def main():
169
- st.title("Cryptocurrency Regression Analysis")
170
- st.write("Enter details to generate regression lines on cryptocurrency candlesticks.")
171
-
172
- currency = st.text_input("Currency", "BTC")
173
- interval = st.selectbox("Interval", ["1h", "4h", "1d", "1w"])
174
- startdate = st.text_input("Start Date (MM/DD/YYYY)", "01/01/2022")
175
- enddate = st.text_input("End Date (MM/DD/YYYY)", "12/31/2022")
176
- threshold_channel_len = st.number_input("Threshold Channel Length", min_value=1, max_value=1000, value=10)
177
-
178
- if st.button("Generate Plot"):
179
- testcasecase(currency, interval, startdate, enddate, threshold_channel_len, 1)
180
-
181
- if __name__ == "__main__":
182
- main()
 
 
1
+ import datetime
2
+ import requests
3
+ import matplotlib
4
+ import matplotlib.pyplot as plt
5
+ from mplfinance.original_flavor import candlestick_ohlc
6
+ import numpy as np
7
+ from sklearn.linear_model import LinearRegression
8
+ import os
9
+ from pathlib import Path
10
+ import streamlit as st
11
+
12
+ PLOT_DIR = Path("./Plots")
13
+
14
+ if not os.path.exists(PLOT_DIR):
15
+ os.mkdir(PLOT_DIR)
16
+
17
+ host = "https://api.gateio.ws"
18
+ prefix = "/api/v4"
19
+ headers = {'Accept': 'application/json', 'Content-Type': 'application/json'}
20
+ endpoint = '/spot/candlesticks'
21
+ url = host + prefix + endpoint
22
+ max_API_request_allowed = 900
23
+
24
+ def lin_reg(data, threshold_channel_len):
25
+ list_f = []
26
+ X = []
27
+ y = []
28
+ for i in range(0, len(data)):
29
+ X.append(data[i][0])
30
+ avg = (data[i][2] + data[i][3]) / 2
31
+ y.append(avg)
32
+ X = np.array(X).reshape(-1, 1)
33
+ y = np.array(y).reshape(-1, 1)
34
+ l = 0
35
+ j = threshold_channel_len
36
+ while l < j and j <= len(data):
37
+ score = []
38
+ list_pf = []
39
+ while j <= len(data):
40
+ reg = LinearRegression().fit(X[l:j], y[l:j])
41
+ temp_coeff = list(reg.coef_)
42
+ temp_intercept = list(reg.intercept_)
43
+ list_pf.append([temp_coeff[0][0], temp_intercept[0], l, j - 1])
44
+ score.append([reg.score(X[l:j], y[l:j]), j])
45
+ j = j + 1
46
+ req_score = float("-inf")
47
+ ind = -1
48
+ temp_ind = -1
49
+ for i in range(len(score)):
50
+ if req_score < score[i][0]:
51
+ ind = score[i][1]
52
+ req_score = score[i][0]
53
+ temp_ind = i
54
+ list_f.append(list_pf[temp_ind])
55
+ l = ind
56
+ j = ind + threshold_channel_len
57
+ return list_f
58
+
59
+ def binary_search(data, line_type, m, b, epsilon):
60
+ right = float("-inf")
61
+ left = float("inf")
62
+ get_y_intercept = lambda x, y: y - m * x
63
+ for i in range(len(data)):
64
+ d = data[i]
65
+ curr_y = d[2]
66
+ if line_type == "bottom":
67
+ curr_y = d[3]
68
+ curr = get_y_intercept(d[0], curr_y)
69
+ right = max(right, curr)
70
+ left = min(left, curr)
71
+
72
+ sign = -1
73
+ if line_type == "bottom":
74
+ left, right = right, left
75
+ sign = 1
76
+ ans = right
77
+ while left <= right:
78
+ mid = left + (right - left) // 2
79
+ intersection_count = 0
80
+ for i in range(len(data)):
81
+ d = data[i]
82
+ curr_y = m * d[0] + mid
83
+ candle_y = d[2]
84
+ if line_type == "bottom":
85
+ candle_y = d[3]
86
+ if line_type == "bottom" and (curr_y > candle_y and (curr_y - candle_y > epsilon)):
87
+ intersection_count += 1
88
+ if line_type == "top" and (curr_y < candle_y and (candle_y - curr_y > epsilon)):
89
+ intersection_count += 1
90
+ if intersection_count == 0:
91
+ right = mid + 1 * sign
92
+ ans = mid
93
+ else:
94
+ left = mid - 1 * sign
95
+ return ans
96
+
97
+ def plot_lines(lines, plt, converted_data):
98
+ for m, b, start, end in lines:
99
+ x_data = list(np.linspace(converted_data[start][0], converted_data[end][0], 10))
100
+ y_data = [m * x + b for x in x_data]
101
+ plt.plot(x_data, y_data)
102
+
103
+ def get_API_data(currency, interval_timedelta, interval, start_datetime, end_datetime):
104
+ curr_datetime = start_datetime
105
+ total_dates = 0
106
+ while curr_datetime <= end_datetime:
107
+ total_dates += 1
108
+ curr_datetime += interval_timedelta
109
+ data = []
110
+ for i in range(0, total_dates, max_API_request_allowed):
111
+ query_param = {
112
+ "currency_pair": "{}_USDT".format(currency),
113
+ "from": int((start_datetime + i * interval_timedelta).timestamp()),
114
+ "to": int((start_datetime + (i + max_API_request_allowed - 1) * interval_timedelta).timestamp()),
115
+ "interval": interval,
116
+ }
117
+ r = requests.get(url=url, headers=headers, params=query_param)
118
+ if r.status_code != 200:
119
+ st.error("Invalid API Request")
120
+ return []
121
+ data += r.json()
122
+ return data
123
+
124
+ def testcasecase(currency, interval, startdate, enddate, threshold_channel_len, testcasecase_id):
125
+ start_date_month, start_date_day, start_date_year = [int(x) for x in startdate.strip().split("/")]
126
+ end_date_month, end_date_day, end_date_year = [int(x) for x in enddate.strip().split("/")]
127
+
128
+ if interval == "1h":
129
+ interval_timedelta = datetime.timedelta(hours=1)
130
+ elif interval == "4h":
131
+ interval_timedelta = datetime.timedelta(hours=4)
132
+ elif interval == "1d":
133
+ interval_timedelta = datetime.timedelta(days=1)
134
+ else:
135
+ interval_timedelta = datetime.timedelta(weeks=1)
136
+
137
+ start_datetime = datetime.datetime(year=start_date_year, month=start_date_month, day=start_date_day)
138
+ end_datetime = datetime.datetime(year=end_date_year, month=end_date_month, day=end_date_day)
139
+
140
+ data = get_API_data(currency, interval_timedelta, interval, start_datetime, end_datetime)
141
+ if len(data) == 0:
142
+ return
143
+ converted_data = []
144
+ for d in data:
145
+ converted_data.append([matplotlib.dates.date2num(datetime.datetime.utcfromtimestamp(float(d[0]))), float(d[5]), float(d[3]), float(d[4]), float(d[2])])
146
+
147
+ fig, ax = plt.subplots()
148
+ candlestick_ohlc(ax, converted_data, width=0.4, colorup='#77d879', colordown='#db3f3f')
149
+
150
+ fitting_lines_data = lin_reg(converted_data, threshold_channel_len)
151
+ top_fitting_lines_data = []
152
+ bottom_fitting_lines_data = []
153
+ epsilon = 0
154
+ for i in range(len(fitting_lines_data)):
155
+ m, b, start, end = fitting_lines_data[i]
156
+ top_b = binary_search(converted_data[start:end + 1], "top", m, b, epsilon)
157
+ bottom_b = binary_search(converted_data[start:end + 1], "bottom", m, b, epsilon)
158
+ top_fitting_lines_data.append([m, top_b, start, end])
159
+ bottom_fitting_lines_data.append([m, bottom_b, start, end])
160
+
161
+ plot_lines(top_fitting_lines_data, plt, converted_data)
162
+ plot_lines(bottom_fitting_lines_data, plt, converted_data)
163
+ plt.title("{}_USDT".format(currency))
164
+ file_name = "figure_{}_{}_USDT.png".format(testcasecase_id, currency)
165
+ file_location = os.path.join(PLOT_DIR, file_name)
166
+ plt.savefig(file_location)
167
+ st.pyplot(fig)
168
+
169
+ def main():
170
+ st.title("Cryptocurrency Regression Analysis")
171
+ st.write("Enter details to generate regression lines on cryptocurrency candlesticks.")
172
+
173
+ currency = st.text_input("Currency", "BTC")
174
+ interval = st.selectbox("Interval", ["1h", "4h", "1d", "1w"])
175
+ startdate = st.text_input("Start Date (MM/DD/YYYY)", "01/01/2022")
176
+ enddate = st.text_input("End Date (MM/DD/YYYY)", "12/31/2022")
177
+ threshold_channel_len = st.number_input("Threshold Channel Length", min_value=1, max_value=1000, value=10)
178
+
179
+ if st.button("Generate Plot"):
180
+ testcasecase(currency, interval, startdate, enddate, threshold_channel_len, 1)
181
+
182
+ if __name__ == "__main__":
183
+ main()