| import datetime | |
| import requests | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| from mplfinance.original_flavor import candlestick_ohlc | |
| import numpy as np | |
| from sklearn.linear_model import LinearRegression | |
| import os | |
| from pathlib import Path | |
| import streamlit as st | |
| PLOT_DIR = Path("./Plots") | |
| if not os.path.exists(PLOT_DIR): | |
| os.mkdir(PLOT_DIR) | |
| host = "https://api.gateio.ws" | |
| prefix = "/api/v4" | |
| headers = {'Accept': 'application/json', 'Content-Type': 'application/json'} | |
| endpoint = '/spot/candlesticks' | |
| url = host + prefix + endpoint | |
| max_API_request_allowed = 900 | |
| def lin_reg(data, threshold_channel_len): | |
| list_f = [] | |
| X = [] | |
| y = [] | |
| for i in range(0, len(data)): | |
| X.append(data[i][0]) | |
| avg = (data[i][2] + data[i][3]) / 2 | |
| y.append(avg) | |
| X = np.array(X).reshape(-1, 1) | |
| y = np.array(y).reshape(-1, 1) | |
| l = 0 | |
| j = threshold_channel_len | |
| while l < j and j <= len(data): | |
| score = [] | |
| list_pf = [] | |
| while j <= len(data): | |
| reg = LinearRegression().fit(X[l:j], y[l:j]) | |
| temp_coeff = list(reg.coef_) | |
| temp_intercept = list(reg.intercept_) | |
| list_pf.append([temp_coeff[0][0], temp_intercept[0], l, j - 1]) | |
| score.append([reg.score(X[l:j], y[l:j]), j]) | |
| j = j + 1 | |
| req_score = float("-inf") | |
| ind = -1 | |
| temp_ind = -1 | |
| for i in range(len(score)): | |
| if req_score < score[i][0]: | |
| ind = score[i][1] | |
| req_score = score[i][0] | |
| temp_ind = i | |
| list_f.append(list_pf[temp_ind]) | |
| l = ind | |
| j = ind + threshold_channel_len | |
| return list_f | |
| def binary_search(data, line_type, m, b, epsilon): | |
| right = float("-inf") | |
| left = float("inf") | |
| get_y_intercept = lambda x, y: y - m * x | |
| for i in range(len(data)): | |
| d = data[i] | |
| curr_y = d[2] | |
| if line_type == "bottom": | |
| curr_y = d[3] | |
| curr = get_y_intercept(d[0], curr_y) | |
| right = max(right, curr) | |
| left = min(left, curr) | |
| sign = -1 | |
| if line_type == "bottom": | |
| left, right = right, left | |
| sign = 1 | |
| ans = right | |
| while left <= right: | |
| mid = left + (right - left) // 2 | |
| intersection_count = 0 | |
| for i in range(len(data)): | |
| d = data[i] | |
| curr_y = m * d[0] + mid | |
| candle_y = d[2] | |
| if line_type == "bottom": | |
| candle_y = d[3] | |
| if line_type == "bottom" and (curr_y > candle_y and (curr_y - candle_y > epsilon)): | |
| intersection_count += 1 | |
| if line_type == "top" and (curr_y < candle_y and (candle_y - curr_y > epsilon)): | |
| intersection_count += 1 | |
| if intersection_count == 0: | |
| right = mid + 1 * sign | |
| ans = mid | |
| else: | |
| left = mid - 1 * sign | |
| return ans | |
| def plot_lines(lines, plt, converted_data): | |
| for m, b, start, end in lines: | |
| x_data = list(np.linspace(converted_data[start][0], converted_data[end][0], 10)) | |
| y_data = [m * x + b for x in x_data] | |
| plt.plot(x_data, y_data) | |
| def get_API_data(currency, interval_timedelta, interval, start_datetime, end_datetime): | |
| curr_datetime = start_datetime | |
| total_dates = 0 | |
| while curr_datetime <= end_datetime: | |
| total_dates += 1 | |
| curr_datetime += interval_timedelta | |
| data = [] | |
| for i in range(0, total_dates, max_API_request_allowed): | |
| query_param = { | |
| "currency_pair": "{}_USDT".format(currency), | |
| "from": int((start_datetime + i * interval_timedelta).timestamp()), | |
| "to": int((start_datetime + (i + max_API_request_allowed - 1) * interval_timedelta).timestamp()), | |
| "interval": interval, | |
| } | |
| r = requests.get(url=url, headers=headers, params=query_param) | |
| if r.status_code != 200: | |
| st.error("Very Large Duration Selected. Please reduce Duration or increase Interval") | |
| return [] | |
| data += r.json() | |
| return data | |
| def testcasecase(currency, interval, startdate, enddate, threshold_channel_len, testcasecase_id): | |
| start_date_month, start_date_day, start_date_year = [int(x) for x in startdate.strip().split("/")] | |
| end_date_month, end_date_day, end_date_year = [int(x) for x in enddate.strip().split("/")] | |
| if interval == "1h": | |
| interval_timedelta = datetime.timedelta(hours=1) | |
| elif interval == "4h": | |
| interval_timedelta = datetime.timedelta(hours=4) | |
| elif interval == "1d": | |
| interval_timedelta = datetime.timedelta(days=1) | |
| else: | |
| interval_timedelta = datetime.timedelta(weeks=1) | |
| start_datetime = datetime.datetime(year=start_date_year, month=start_date_month, day=start_date_day) | |
| end_datetime = datetime.datetime(year=end_date_year, month=end_date_month, day=end_date_day) | |
| data = get_API_data(currency, interval_timedelta, interval, start_datetime, end_datetime) | |
| if len(data) == 0: | |
| return | |
| converted_data = [] | |
| for d in data: | |
| converted_data.append([matplotlib.dates.date2num(datetime.datetime.utcfromtimestamp(float(d[0]))), float(d[5]), float(d[3]), float(d[4]), float(d[2])]) | |
| fig, ax = plt.subplots() | |
| candlestick_ohlc(ax, converted_data, width=0.4, colorup='#77d879', colordown='#db3f3f') | |
| fitting_lines_data = lin_reg(converted_data, threshold_channel_len) | |
| top_fitting_lines_data = [] | |
| bottom_fitting_lines_data = [] | |
| epsilon = 0 | |
| for i in range(len(fitting_lines_data)): | |
| m, b, start, end = fitting_lines_data[i] | |
| top_b = binary_search(converted_data[start:end + 1], "top", m, b, epsilon) | |
| bottom_b = binary_search(converted_data[start:end + 1], "bottom", m, b, epsilon) | |
| top_fitting_lines_data.append([m, top_b, start, end]) | |
| bottom_fitting_lines_data.append([m, bottom_b, start, end]) | |
| plot_lines(top_fitting_lines_data, plt, converted_data) | |
| plot_lines(bottom_fitting_lines_data, plt, converted_data) | |
| plt.title("{}_USDT".format(currency)) | |
| file_name = "figure_{}_{}_USDT.png".format(testcasecase_id, currency) | |
| file_location = os.path.join(PLOT_DIR, file_name) | |
| plt.savefig(file_location) | |
| st.pyplot(fig) | |
| def main(): | |
| st.title("Cryptocurrency Regression Analysis") | |
| st.write("Enter details to generate regression lines on cryptocurrency candlesticks.") | |
| currency = st.text_input("Currency", "BTC") | |
| interval = st.selectbox("Interval", ["1h", "4h", "1d", "1w"]) | |
| startdate = st.text_input("Start Date (MM/DD/YYYY)", "01/01/2023") | |
| enddate = st.text_input("End Date (MM/DD/YYYY)", "4/31/2023") | |
| threshold_channel_len = st.number_input("Threshold Channel Length", min_value=1, max_value=1000, value=10) | |
| if st.button("Generate Plot"): | |
| testcasecase(currency, interval, startdate, enddate, threshold_channel_len, 1) | |
| if __name__ == "__main__": | |
| main() | |