Ashoka74 commited on
Commit
6d2b558
·
verified ·
1 Parent(s): bb019ba

Update magnetic.py

Browse files
Files changed (1) hide show
  1. magnetic.py +951 -907
magnetic.py CHANGED
@@ -1,907 +1,951 @@
1
-
2
- import math
3
- import pandas as pd
4
- import numpy as np
5
- import json
6
- import requests
7
- import datetime
8
- from datetime import timedelta
9
- from PIL import Image
10
- # alternative to PIL
11
- import matplotlib.pyplot as plt
12
- import matplotlib.image as mpimg
13
- import os
14
- import matplotlib.dates as mdates
15
- import seaborn as sns
16
- from IPython.display import Image as image_display
17
- path = os.getcwd()
18
- from fastdtw import fastdtw
19
- from scipy.spatial.distance import euclidean
20
- from IPython.display import display
21
- from dateutil import parser
22
- from Levenshtein import distance
23
- from sklearn.model_selection import train_test_split
24
- from sklearn.metrics import confusion_matrix
25
- from stqdm import stqdm
26
- stqdm.pandas()
27
- import streamlit.components.v1 as components
28
- from dateutil import parser
29
- from sentence_transformers import SentenceTransformer
30
- import torch
31
- import squarify
32
- import matplotlib.colors as mcolors
33
- import textwrap
34
- import datamapplot
35
- import streamlit as st
36
-
37
-
38
- st.title('Magnetic Correlations Dashboard')
39
-
40
- st.set_option('deprecation.showPyplotGlobalUse', False)
41
-
42
-
43
- from pandas.api.types import (
44
- is_categorical_dtype,
45
- is_datetime64_any_dtype,
46
- is_numeric_dtype,
47
- is_object_dtype,
48
- )
49
-
50
-
51
- def plot_treemap(df, column, top_n=32):
52
- # Get the value counts and the top N labels
53
- value_counts = df[column].value_counts()
54
- top_labels = value_counts.iloc[:top_n].index
55
-
56
- # Use np.where to replace all values not in the top N with 'Other'
57
- revised_column = f'{column}_revised'
58
- df[revised_column] = np.where(df[column].isin(top_labels), df[column], 'Other')
59
-
60
- # Get the value counts including the 'Other' category
61
- sizes = df[revised_column].value_counts().values
62
- labels = df[revised_column].value_counts().index
63
-
64
- # Get a gradient of colors
65
- # colors = list(mcolors.TABLEAU_COLORS.values())
66
-
67
- n_colors = len(sizes)
68
- colors = plt.cm.Oranges(np.linspace(0.3, 0.9, n_colors))[::-1]
69
-
70
-
71
- # Get % of each category
72
- percents = sizes / sizes.sum()
73
-
74
- # Prepare labels with percentages
75
- labels = [f'{label}\n {percent:.1%}' for label, percent in zip(labels, percents)]
76
-
77
- fig, ax = plt.subplots(figsize=(20, 12))
78
-
79
- # Plot the treemap
80
- squarify.plot(sizes=sizes, label=labels, alpha=0.7, pad=True, color=colors, text_kwargs={'fontsize': 10})
81
-
82
- ax = plt.gca()
83
- # Iterate over text elements and rectangles (patches) in the axes for color adjustment
84
- for text, rect in zip(ax.texts, ax.patches):
85
- background_color = rect.get_facecolor()
86
- r, g, b, _ = mcolors.to_rgba(background_color)
87
- brightness = np.average([r, g, b])
88
- text.set_color('white' if brightness < 0.5 else 'black')
89
-
90
-
91
- def plot_hist(df, column, bins=10, kde=True):
92
- fig, ax = plt.subplots(figsize=(12, 6))
93
- sns.histplot(data=df, x=column, kde=True, bins=bins,color='orange')
94
- # set the ticks and frame in orange
95
- ax.spines['bottom'].set_color('orange')
96
- ax.spines['top'].set_color('orange')
97
- ax.spines['right'].set_color('orange')
98
- ax.spines['left'].set_color('orange')
99
- ax.xaxis.label.set_color('orange')
100
- ax.yaxis.label.set_color('orange')
101
- ax.tick_params(axis='x', colors='orange')
102
- ax.tick_params(axis='y', colors='orange')
103
- ax.title.set_color('orange')
104
-
105
- # Set transparent background
106
- fig.patch.set_alpha(0)
107
- ax.patch.set_alpha(0)
108
- return fig
109
-
110
-
111
-
112
-
113
- def plot_line(df, x_column, y_columns, figsize=(12, 10), color='orange', title=None, rolling_mean_value=2):
114
- import matplotlib.cm as cm
115
- # Sort the dataframe by the date column
116
- df = df.sort_values(by=x_column)
117
-
118
- # Calculate rolling mean for each y_column
119
- if rolling_mean_value:
120
- df[y_columns] = df[y_columns].rolling(len(df) // rolling_mean_value).mean()
121
-
122
- # Create the plot
123
- fig, ax = plt.subplots(figsize=figsize)
124
-
125
- colors = cm.Oranges(np.linspace(0.2, 1, len(y_columns)))
126
-
127
- # Plot each y_column as a separate line with a different color
128
- for i, y_column in enumerate(y_columns):
129
- df.plot(x=x_column, y=y_column, ax=ax, color=colors[i], label=y_column, linewidth=.5)
130
-
131
- # Rotate x-axis labels
132
- ax.set_xticklabels(ax.get_xticklabels(), rotation=30, ha='right')
133
-
134
- # Format x_column as date if it is
135
- if np.issubdtype(df[x_column].dtype, np.datetime64) or np.issubdtype(df[x_column].dtype, np.timedelta64):
136
- df[x_column] = pd.to_datetime(df[x_column]).dt.date
137
-
138
- # Set title, labels, and legend
139
- ax.set_title(title or f'{", ".join(y_columns)} over {x_column}', color=color, fontweight='bold')
140
- ax.set_xlabel(x_column, color=color)
141
- ax.set_ylabel(', '.join(y_columns), color=color)
142
- ax.spines['bottom'].set_color('orange')
143
- ax.spines['top'].set_color('orange')
144
- ax.spines['right'].set_color('orange')
145
- ax.spines['left'].set_color('orange')
146
- ax.xaxis.label.set_color('orange')
147
- ax.yaxis.label.set_color('orange')
148
- ax.tick_params(axis='x', colors='orange')
149
- ax.tick_params(axis='y', colors='orange')
150
- ax.title.set_color('orange')
151
-
152
- ax.legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.4, labelcolor='orange', edgecolor='orange')
153
-
154
- # Remove background
155
- fig.patch.set_alpha(0)
156
- ax.patch.set_alpha(0)
157
-
158
- return fig
159
-
160
- def plot_bar(df, x_column, y_column, figsize=(12, 10), color='orange', title=None):
161
- fig, ax = plt.subplots(figsize=figsize)
162
-
163
- sns.barplot(data=df, x=x_column, y=y_column, color=color, ax=ax)
164
-
165
- ax.set_title(title if title else f'{y_column} by {x_column}', color=color, fontweight='bold')
166
- ax.set_xlabel(x_column, color=color)
167
- ax.set_ylabel(y_column, color=color)
168
-
169
- ax.tick_params(axis='x', colors=color)
170
- ax.tick_params(axis='y', colors=color)
171
-
172
- # Remove background
173
- fig.patch.set_alpha(0)
174
- ax.patch.set_alpha(0)
175
- ax.spines['bottom'].set_color('orange')
176
- ax.spines['top'].set_color('orange')
177
- ax.spines['right'].set_color('orange')
178
- ax.spines['left'].set_color('orange')
179
- ax.xaxis.label.set_color('orange')
180
- ax.yaxis.label.set_color('orange')
181
- ax.tick_params(axis='x', colors='orange')
182
- ax.tick_params(axis='y', colors='orange')
183
- ax.title.set_color('orange')
184
- ax.legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.4, labelcolor='orange', edgecolor='orange')
185
-
186
- return fig
187
-
188
- def plot_grouped_bar(df, x_columns, y_column, figsize=(12, 10), colors=None, title=None):
189
- fig, ax = plt.subplots(figsize=figsize)
190
-
191
- width = 0.8 / len(x_columns) # the width of the bars
192
- x = np.arange(len(df)) # the label locations
193
-
194
- for i, x_column in enumerate(x_columns):
195
- sns.barplot(data=df, x=x, y=y_column, color=colors[i] if colors else None, ax=ax, width=width, label=x_column)
196
- x += width # add the width of the bar to the x position for the next bar
197
-
198
- ax.set_title(title if title else f'{y_column} by {", ".join(x_columns)}', color='orange', fontweight='bold')
199
- ax.set_xlabel('Groups', color='orange')
200
- ax.set_ylabel(y_column, color='orange')
201
-
202
- ax.set_xticks(x - width * len(x_columns) / 2)
203
- ax.set_xticklabels(df.index)
204
-
205
- ax.tick_params(axis='x', colors='orange')
206
- ax.tick_params(axis='y', colors='orange')
207
-
208
- # Remove background
209
- fig.patch.set_alpha(0)
210
- ax.patch.set_alpha(0)
211
- ax.spines['bottom'].set_color('orange')
212
- ax.spines['top'].set_color('orange')
213
- ax.spines['right'].set_color('orange')
214
- ax.spines['left'].set_color('orange')
215
- ax.xaxis.label.set_color('orange')
216
- ax.yaxis.label.set_color('orange')
217
- ax.title.set_color('orange')
218
- ax.legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.4, labelcolor='orange', edgecolor='orange')
219
-
220
- return fig
221
-
222
-
223
- def filter_dataframe(df: pd.DataFrame) -> pd.DataFrame:
224
- """
225
- Adds a UI on top of a dataframe to let viewers filter columns
226
-
227
- Args:
228
- df (pd.DataFrame): Original dataframe
229
-
230
- Returns:
231
- pd.DataFrame: Filtered dataframe
232
- """
233
-
234
- title_font = "Arial"
235
- body_font = "Arial"
236
- title_size = 32
237
- colors = ["red", "green", "blue"]
238
- interpretation = False
239
- extract_docx = False
240
- title = "My Chart"
241
- regex = ".*"
242
- img_path = 'default_image.png'
243
-
244
-
245
- #try:
246
- # modify = st.checkbox("Add filters on raw data")
247
- #except:
248
- # try:
249
- # modify = st.checkbox("Add filters on processed data")
250
- # except:
251
- # try:
252
- # modify = st.checkbox("Add filters on parsed data")
253
- # except:
254
- # pass
255
-
256
- #if not modify:
257
- # return df
258
-
259
- df_ = df.copy()
260
- # Try to convert datetimes into a standard format (datetime, no timezone)
261
-
262
- #modification_container = st.container()
263
-
264
- #with modification_container:
265
- to_filter_columns = st.multiselect("Filter dataframe on", df_.columns)
266
-
267
- date_column = None
268
- filtered_columns = []
269
-
270
- for column in to_filter_columns:
271
- left, right = st.columns((1, 20))
272
- # Treat columns with < 200 unique values as categorical if not date or numeric
273
- if is_categorical_dtype(df_[column]) or (df_[column].nunique() < 120 and not is_datetime64_any_dtype(df_[column]) and not is_numeric_dtype(df_[column])):
274
- user_cat_input = right.multiselect(
275
- f"Values for {column}",
276
- df_[column].value_counts().index.tolist(),
277
- default=list(df_[column].value_counts().index)
278
- )
279
- df_ = df_[df_[column].isin(user_cat_input)]
280
- filtered_columns.append(column)
281
-
282
- with st.status(f"Category Distribution: {column}", expanded=False) as stat:
283
- st.pyplot(plot_treemap(df_, column))
284
-
285
- elif is_numeric_dtype(df_[column]):
286
- _min = float(df_[column].min())
287
- _max = float(df_[column].max())
288
- step = (_max - _min) / 100
289
- user_num_input = right.slider(
290
- f"Values for {column}",
291
- min_value=_min,
292
- max_value=_max,
293
- value=(_min, _max),
294
- step=step,
295
- )
296
- df_ = df_[df_[column].between(*user_num_input)]
297
- filtered_columns.append(column)
298
-
299
- # Chart_GPT = ChartGPT(df_, title_font, body_font, title_size,
300
- # colors, interpretation, extract_docx, img_path)
301
-
302
- with st.status(f"Numerical Distribution: {column}", expanded=False) as stat_:
303
- st.pyplot(plot_hist(df_, column, bins=int(round(len(df_[column].unique())-1)/2)))
304
-
305
- elif is_object_dtype(df_[column]):
306
- try:
307
- df_[column] = pd.to_datetime(df_[column], infer_datetime_format=True, errors='coerce')
308
- except Exception:
309
- try:
310
- df_[column] = df_[column].apply(parser.parse)
311
- except Exception:
312
- pass
313
-
314
- if is_datetime64_any_dtype(df_[column]):
315
- df_[column] = df_[column].dt.tz_localize(None)
316
- min_date = df_[column].min().date()
317
- max_date = df_[column].max().date()
318
- user_date_input = right.date_input(
319
- f"Values for {column}",
320
- value=(min_date, max_date),
321
- min_value=min_date,
322
- max_value=max_date,
323
- )
324
- # if len(user_date_input) == 2:
325
- # start_date, end_date = user_date_input
326
- # df_ = df_.loc[df_[column].dt.date.between(start_date, end_date)]
327
- if len(user_date_input) == 2:
328
- user_date_input = tuple(map(pd.to_datetime, user_date_input))
329
- start_date, end_date = user_date_input
330
- df_ = df_.loc[df_[column].between(start_date, end_date)]
331
-
332
- date_column = column
333
-
334
- if date_column and filtered_columns:
335
- numeric_columns = [col for col in filtered_columns if is_numeric_dtype(df_[col])]
336
- if numeric_columns:
337
- fig = plot_line(df_, date_column, numeric_columns)
338
- #st.pyplot(fig)
339
- # now to deal with categorical columns
340
- categorical_columns = [col for col in filtered_columns if is_categorical_dtype(df_[col])]
341
- if categorical_columns:
342
- fig2 = plot_bar(df_, date_column, categorical_columns[0])
343
- #st.pyplot(fig2)
344
- with st.status(f"Date Distribution: {column}", expanded=False) as stat:
345
- try:
346
- st.pyplot(fig)
347
- except Exception as e:
348
- st.error(f"Error plotting line chart: {e}")
349
- pass
350
- try:
351
- st.pyplot(fig2)
352
- except Exception as e:
353
- st.error(f"Error plotting bar chart: {e}")
354
-
355
-
356
- else:
357
- user_text_input = right.text_input(
358
- f"Substring or regex in {column}",
359
- )
360
- if user_text_input:
361
- df_ = df_[df_[column].astype(str).str.contains(user_text_input)]
362
- # write len of df after filtering with % of original
363
- st.write(f"{len(df_)} rows ({len(df_) / len(df) * 100:.2f}%)")
364
- return df_
365
-
366
-
367
- def get_stations():
368
- base_url = 'https://imag-data.bgs.ac.uk:/GIN_V1/GINServices?Request=GetCapabilities&format=json'
369
- response = requests.get(base_url)
370
- data = response.json()
371
- dataframe_stations = pd.DataFrame.from_dict(data['ObservatoryList'])
372
- return dataframe_stations
373
-
374
- def get_haversine_distance(lat1, lon1, lat2, lon2):
375
- R = 6371
376
- dlat = math.radians(lat2 - lat1)
377
- dlon = math.radians(lon2 - lon1)
378
- a = math.sin(dlat/2) * math.sin(dlat/2) + math.cos(math.radians(lat1)) * math.cos(math.radians(lat2)) * math.sin(dlon/2) * math.sin(dlon/2)
379
- c = 2 * math.atan2(math.sqrt(a), math.sqrt(1-a))
380
- d = R * c
381
- return d
382
-
383
- def compare_stations(test_lat_lon, data_table, distance=1000, closest=False):
384
- table_updated = pd.DataFrame()
385
- distances = dict()
386
- for lat,lon,names in data_table[['Latitude', 'Longitude', 'Name']].values:
387
- harv_distance = get_haversine_distance(test_lat_lon[0], test_lat_lon[1], lat, lon)
388
- if harv_distance < distance:
389
- #print(f"Station {names} is at {round(harv_distance,2)} km from the test point")
390
- table_updated = pd.concat([table_updated, data_table[data_table['Name'] == names]])
391
- distances[names] = harv_distance
392
- if closest:
393
- closest_station = min(distances, key=distances.get)
394
- #print(f"The closest station is {closest_station} at {round(distances[closest_station],2)} km")
395
- table_updated = data_table[data_table['Name'] == closest_station]
396
- table_updated['Distance'] = distances[closest_station]
397
- return table_updated
398
-
399
- def get_data(IagaCode, start_date, end_date):
400
- try:
401
- start_date_ = datetime.datetime.strptime(start_date, '%Y-%m-%d')
402
- except ValueError as e:
403
- print(f"Error: {e}")
404
- start_date_ = pd.to_datetime(start_date)
405
- try:
406
- end_date_ = datetime.datetime.strptime(end_date, '%Y-%m-%d')
407
- except ValueError as e:
408
- print(f"Error: {e}")
409
- end_date_ = pd.to_datetime(end_date)
410
-
411
- duration = end_date_ - start_date_
412
- # Define the parameters for the request
413
- params = {
414
- 'Request': 'GetData',
415
- 'format': 'PNG',
416
- 'testObsys': '0',
417
- 'observatoryIagaCode': IagaCode,
418
- 'samplesPerDay': 'minute',
419
- 'publicationState': 'Best available',
420
- 'dataStartDate': start_date,
421
- # make substraction
422
- 'dataDuration': duration.days,
423
- 'traceList': '1234',
424
- 'colourTraces': 'true',
425
- 'pictureSize': 'Automatic',
426
- 'dataScale': 'Automatic',
427
- 'pdfSize': '21,29.7',
428
- }
429
-
430
- base_url_json = 'https://imag-data.bgs.ac.uk:/GIN_V1/GINServices?Request=GetData&format=json'
431
- #base_url_img = 'https://imag-data.bgs.ac.uk:/GIN_V1/GINServices?Request=GetData&format=png'
432
-
433
- for base_url in [base_url_json]:#, base_url_img]:
434
- response = requests.get(base_url, params=params)
435
- if response.status_code == 200:
436
- content_type = response.headers.get('Content-Type')
437
- if 'image' in content_type:
438
- # f"custom_plot_{new_dataset.iloc[0]['IagaCode']}_{str_date.replace(':', '_')}.png"
439
- # output_image_path = "plot_image.png"
440
- # with open(output_image_path, 'wb') as file:
441
- # file.write(response.content)
442
- # print(f"Image successfully saved as {output_image_path}")
443
-
444
- # # Display the image
445
- # img = mpimg.imread(output_image_path)
446
- # plt.imshow(img)
447
- # plt.axis('off') # Hide axes
448
- # plt.show()
449
- # img_answer = Image.open(output_image_path)
450
- img_answer = None
451
- else:
452
- print(f"Unexpected content type: {content_type}")
453
- #print("Response content:")
454
- #print(response.content.decode('utf-8')) # Attempt to print response as text
455
- # return json
456
- answer = response.json()
457
- else:
458
- print(f"Failed to retrieve data. HTTP Status code: {response.status_code}")
459
- print("Response content:")
460
- print(response.content.decode('utf-8'))
461
- return answer#, img_answer
462
-
463
-
464
- # def get_data(IagaCode, start_date, end_date):
465
- # # Convert dates to datetime
466
- # try:
467
- # start_date_ = pd.to_datetime(start_date)
468
- # end_date_ = pd.to_datetime(end_date)
469
- # except ValueError as e:
470
- # print(f"Error: {e}")
471
- # return None, None
472
-
473
- # duration = (end_date_ - start_date_).days
474
-
475
- # # Define the parameters for the request
476
- # params = {
477
- # 'Request': 'GetData',
478
- # 'format': 'json',
479
- # 'testObsys': '0',
480
- # 'observatoryIagaCode': IagaCode,
481
- # 'samplesPerDay': 'minute',
482
- # 'publicationState': 'Best available',
483
- # 'dataStartDate': start_date_.strftime('%Y-%m-%d'),
484
- # 'dataDuration': duration,
485
- # 'traceList': '1234',
486
- # 'colourTraces': 'true',
487
- # 'pictureSize': 'Automatic',
488
- # 'dataScale': 'Automatic',
489
- # 'pdfSize': '21,29.7',
490
- # }
491
-
492
- # base_url_json = 'https://imag-data.bgs.ac.uk:/GIN_V1/GINServices?Request=GetData&format=json'
493
- # base_url_img = 'https://imag-data.bgs.ac.uk:/GIN_V1/GINServices?Request=GetData&format=png'
494
-
495
- # try:
496
- # # Request JSON data
497
- # response_json = requests.get(base_url_json, params=params)
498
- # response_json.raise_for_status() # Raises an error for bad status codes
499
- # data = response_json.json()
500
-
501
- # # Request Image
502
- # params['format'] = 'png'
503
- # response_img = requests.get(base_url_img, params=params)
504
- # response_img.raise_for_status()
505
-
506
- # # Save and display image if response is successful
507
- # if 'image' in response_img.headers.get('Content-Type'):
508
- # output_image_path = "plot_image.png"
509
- # with open(output_image_path, 'wb') as file:
510
- # file.write(response_img.content)
511
- # print(f"Image successfully saved as {output_image_path}")
512
-
513
- # img = mpimg.imread(output_image_path)
514
- # plt.imshow(img)
515
- # plt.axis('off')
516
- # plt.show()
517
- # img_answer = Image.open(output_image_path)
518
- # else:
519
- # img_answer = None
520
-
521
- # return data, img_answer
522
-
523
- # except requests.RequestException as e:
524
- # print(f"Request failed: {e}")
525
- # return None, None
526
- # except ValueError as e:
527
- # print(f"JSON decode error: {e}")
528
- # return None, None
529
-
530
- def clean_uap_data(dataset, lat, lon, date):
531
- # Assuming 'nuforc' is already defined
532
- processed = dataset[dataset[[lat, lon, date]].notnull().all(axis=1)]
533
- # Converting 'Lat' and 'Long' columns to floats, handling errors
534
- processed[lat] = pd.to_numeric(processed[lat], errors='coerce')
535
- processed[lon] = pd.to_numeric(processed[lon], errors='coerce')
536
-
537
- # if processed[date].min() < pd.to_datetime('1677-09-22'):
538
- # processed.loc[processed[date] < pd.to_datetime('1677-09-22'), 'corrected_date'] = pd.to_datetime('1677-09-22 00:00:00')
539
-
540
- procesed = processed[processed[date] >= '1677-09-22']
541
-
542
- # convert date to str
543
- #processed[date] = processed[date].astype(str)
544
- # Dropping rows where 'Lat' or 'Long' conversion failed (i.e., became NaN)
545
- processed = processed.dropna(subset=[lat, lon])
546
- return processed
547
-
548
-
549
- def plot_overlapped_timeseries(data_list, event_times, window_hours=12, save_path=None):
550
- fig, axs = plt.subplots(4, 1, figsize=(12, 16), sharex=True)
551
- fig.patch.set_alpha(0) # Make figure background transparent
552
-
553
- components = ['X', 'Y', 'Z', 'S']
554
- colors = ['red', 'green', 'blue', 'black']
555
-
556
- for i, component in enumerate(components):
557
- axs[i].patch.set_alpha(0) # Make subplot background transparent
558
- axs[i].set_ylabel(component, color='orange')
559
- axs[i].grid(True, color='orange', alpha=0.3)
560
-
561
- for spine in axs[i].spines.values():
562
- spine.set_color('orange')
563
-
564
- axs[i].tick_params(axis='both', colors='orange') # Change tick color
565
- axs[i].set_title(f'{component}', color='orange')
566
- axs[i].set_xlabel('Time Difference from Event (hours)', color='orange')
567
-
568
- for j, (df, event_time) in enumerate(zip(data_list, event_times)):
569
- # Convert datetime column to UTC if it has timezone info, otherwise assume it's UTC
570
- df['datetime'] = pd.to_datetime(df['datetime']).dt.tz_localize(None)
571
-
572
- # Convert event_time to UTC if it has timezone info, otherwise assume it's UTC
573
- event_time = pd.to_datetime(event_time).tz_localize(None)
574
-
575
- # Calculate time difference from event
576
- df['time_diff'] = (df['datetime'] - event_time).dt.total_seconds() / 3600 # Convert to hours
577
-
578
- # Filter data within the specified window
579
- df_window = df[(df['time_diff'] >= -window_hours) & (df['time_diff'] <= window_hours)]
580
-
581
- # normalize component data
582
- df_window[component] = (df_window[component] - df_window[component].mean()) / df_window[component].std()
583
-
584
- axs[i].plot(df_window['time_diff'], df_window[component], color=colors[i], alpha=0.7, label=f'Event {j+1}', linewidth=1)
585
-
586
- axs[i].axvline(x=0, color='red', linewidth=2, linestyle='--', label='Event Time')
587
- axs[i].set_xlim(-window_hours, window_hours)
588
- #axs[i].legend(loc='upper left', bbox_to_anchor=(1, 1))
589
-
590
- axs[-1].set_xlabel('Hours from Event', color='orange')
591
- fig.suptitle('Overlapped Time Series of Components', fontsize=16, color='orange')
592
-
593
- plt.tight_layout()
594
- plt.subplots_adjust(top=0.95, right=0.85)
595
-
596
- if save_path:
597
- fig.savefig(save_path, transparent=True, bbox_inches='tight')
598
- plt.close(fig)
599
- return save_path
600
- else:
601
- return fig
602
-
603
- def plot_average_timeseries(data_list, event_times, window_hours=12, save_path=None):
604
- fig, axs = plt.subplots(4, 1, figsize=(12, 16), sharex=True)
605
- fig.patch.set_alpha(0) # Make figure background transparent
606
-
607
- components = ['X', 'Y', 'Z', 'S']
608
- colors = ['red', 'green', 'blue', 'black']
609
-
610
- for i, component in enumerate(components):
611
- axs[i].patch.set_alpha(0)
612
- axs[i].set_ylabel(component, color='orange')
613
- axs[i].grid(True, color='orange', alpha=0.3)
614
-
615
- for spine in axs[i].spines.values():
616
- spine.set_color('orange')
617
-
618
- axs[i].tick_params(axis='both', colors='orange')
619
-
620
- all_data = []
621
- time_diffs = []
622
-
623
- for j, (df, event_time) in enumerate(zip(data_list, event_times)):
624
- # Convert datetime column to UTC if it has timezone info, otherwise assume it's UTC
625
- df['datetime'] = pd.to_datetime(df['datetime']).dt.tz_localize(None)
626
-
627
- # Convert event_time to UTC if it has timezone info, otherwise assume it's UTC
628
- event_time = pd.to_datetime(event_time).tz_localize(None)
629
-
630
- # Calculate time difference from event
631
- df['time_diff'] = (df['datetime'] - event_time).dt.total_seconds() / 3600 # Convert to hours
632
-
633
- # Filter data within the specified window
634
- df_window = df[(df['time_diff'] >= -window_hours) & (df['time_diff'] <= window_hours)]
635
-
636
- # Normalize component data
637
- df_window[component] = (df_window[component] - df_window[component].mean())# / df_window[component].std()
638
-
639
- all_data.append(df_window[component].values)
640
- time_diffs.append(df_window['time_diff'].values)
641
-
642
- # Calculate average and standard deviation
643
- try:
644
- avg_data = np.mean(all_data, axis=0)
645
- except:
646
- avg_data = np.zeros_like(all_data[0])
647
- try:
648
- std_data = np.std(all_data, axis=0)
649
- except:
650
- std_data = np.zeros_like(avg_data)
651
-
652
- axs[-1].set_xlabel('Hours from Event', color='orange')
653
- fig.suptitle('Average Time Series of Components', fontsize=16, color='orange')
654
-
655
- # Plot average line
656
- axs[i].plot(time_diffs[0], avg_data, color=colors[i], label='Average')
657
-
658
- # Plot standard deviation as shaded region
659
- try:
660
- axs[i].fill_between(time_diffs[0], avg_data - std_data, avg_data + std_data, color=colors[i], alpha=0.2)
661
- except:
662
- pass
663
-
664
- axs[i].axvline(x=0, color='red', linewidth=2, linestyle='--', label='Event Time')
665
- axs[i].set_xlim(-window_hours, window_hours)
666
- # orange frame, orange label legend
667
- axs[i].legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.4, labelcolor='orange', edgecolor='orange')
668
-
669
- plt.tight_layout()
670
- plt.subplots_adjust(top=0.95, right=0.85)
671
-
672
- if save_path:
673
- fig.savefig(save_path, transparent=True, bbox_inches='tight')
674
- plt.close(fig)
675
- return save_path
676
- else:
677
- return fig
678
-
679
- def align_series(reference, series):
680
- reference = reference.flatten()
681
- series = series.flatten()
682
- _, path = fastdtw(reference, series, dist=euclidean)
683
- aligned = np.zeros(len(reference))
684
- for ref_idx, series_idx in path:
685
- aligned[ref_idx] = series[series_idx]
686
- return aligned
687
-
688
- def plot_average_timeseries_with_dtw(data_list, event_times, window_hours=12, save_path=None):
689
- fig, axs = plt.subplots(4, 1, figsize=(12, 16), sharex=True)
690
- fig.patch.set_alpha(0) # Make figure background transparent
691
-
692
- components = ['X', 'Y', 'Z', 'S']
693
- colors = ['red', 'green', 'blue', 'black']
694
- fig.text(0.02, 0.5, 'Geomagnetic Variation (nT)', va='center', rotation='vertical', color='orange')
695
-
696
-
697
- for i, component in enumerate(components):
698
- axs[i].patch.set_alpha(0)
699
- axs[i].set_ylabel(component, color='orange', rotation=90)
700
- axs[i].grid(True, color='orange', alpha=0.3)
701
-
702
- for spine in axs[i].spines.values():
703
- spine.set_color('orange')
704
-
705
- axs[i].tick_params(axis='both', colors='orange')
706
-
707
- all_aligned_data = []
708
- reference_df = None
709
-
710
- for j, (df, event_time) in enumerate(zip(data_list, event_times)):
711
- df['datetime'] = pd.to_datetime(df['datetime']).dt.tz_localize(None)
712
- event_time = pd.to_datetime(event_time).tz_localize(None)
713
- df['time_diff'] = (df['datetime'] - event_time).dt.total_seconds() / 3600
714
- df_window = df[(df['time_diff'] >= -window_hours) & (df['time_diff'] <= window_hours)]
715
- df_window[component] = (df_window[component] - df_window[component].mean())# / df_window[component].std()
716
-
717
- if reference_df is None:
718
- reference_df = df_window
719
- all_aligned_data.append(reference_df[component].values)
720
- else:
721
- try:
722
- aligned_series = align_series(reference_df[component].values, df_window[component].values)
723
- all_aligned_data.append(aligned_series)
724
- except:
725
- pass
726
-
727
- # Calculate average and standard deviation of aligned data
728
- all_aligned_data = np.array(all_aligned_data)
729
- avg_data = np.mean(all_aligned_data, axis=0)
730
-
731
- # round float to avoid sqrt errors
732
- def calculate_std(data):
733
- if data is not None and len(data) > 0:
734
- data = np.array(data)
735
- std_data = np.std(data)
736
- return std_data
737
- else:
738
- return "Data is empty or not a list"
739
-
740
- std_data = calculate_std(all_aligned_data)
741
-
742
- # Plot average line
743
- axs[i].plot(reference_df['time_diff'], avg_data, color=colors[i], label='Average')
744
-
745
- # Plot standard deviation as shaded region
746
- try:
747
- axs[i].fill_between(reference_df['time_diff'], avg_data - std_data, avg_data + std_data, color=colors[i], alpha=0.2)
748
- except TypeError as e:
749
- #print(f"Error: {e}")
750
- pass
751
-
752
-
753
- axs[i].axvline(x=0, color='red', linewidth=2, linestyle='--', label='Event Time')
754
- axs[i].set_xlim(-window_hours, window_hours)
755
- axs[i].legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.2, labelcolor='orange', edgecolor='orange')
756
-
757
-
758
- axs[-1].set_xlabel('Hours from Event', color='orange')
759
- fig.suptitle('Average Time Series of Components (FastDTW Aligned)', fontsize=16, color='orange')
760
-
761
- plt.tight_layout()
762
- plt.subplots_adjust(top=0.85, right=0.85, left=0.1)
763
-
764
- if save_path:
765
- fig.savefig(save_path, transparent=True, bbox_inches='tight')
766
- plt.close(fig)
767
- return save_path
768
- else:
769
- return fig
770
-
771
- def plot_data_custom(df, date, save_path=None, subtitle=None):
772
- df['datetime'] = pd.to_datetime(df['datetime'])
773
- event = pd.to_datetime(date)
774
- window = timedelta(hours=12)
775
- x_min = event - window
776
- x_max = event + window
777
-
778
- fig, axs = plt.subplots(4, 1, figsize=(12, 12), sharex=True)
779
- fig.patch.set_alpha(0) # Make figure background transparent
780
-
781
- components = ['X', 'Y', 'Z', 'S']
782
- colors = ['red', 'green', 'blue', 'black']
783
-
784
- fig.text(0.02, 0.5, 'Geomagnetic Variation (nT)', va='center', rotation='vertical', color='orange')
785
-
786
- # if df[component].isnull().all().all():
787
- # return None
788
-
789
- for i, component in enumerate(components):
790
- axs[i].plot(df['datetime'], df[component], label=component, color=colors[i])
791
- axs[i].axvline(x=event, color='red', linewidth=2, label='Event', linestyle='--')
792
- axs[i].set_ylabel(component, color='orange', rotation=90)
793
- axs[i].set_xlim(x_min, x_max)
794
- axs[i].legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.2, labelcolor='orange', edgecolor='orange')
795
- axs[i].grid(True, color='orange', alpha=0.3)
796
- axs[i].patch.set_alpha(0) # Make subplot background transparent
797
-
798
- for spine in axs[i].spines.values():
799
- spine.set_color('orange')
800
-
801
- axs[i].xaxis.set_major_formatter(mdates.DateFormatter('%H:%M'))
802
- axs[i].xaxis.set_major_locator(mdates.HourLocator(interval=1))
803
- axs[i].tick_params(axis='both', colors='orange')
804
-
805
- plt.setp(axs[-1].xaxis.get_majorticklabels(), rotation=45)
806
- axs[-1].set_xlabel('Hours', color='orange')
807
- fig.suptitle(f'Time Series of Components with Event Marks\n{subtitle}', fontsize=12, color='orange')
808
-
809
- plt.tight_layout()
810
- #plt.subplots_adjust(top=0.85)
811
- plt.subplots_adjust(top=0.85, right=0.85, left=0.1)
812
-
813
-
814
- if save_path:
815
- fig.savefig(save_path, transparent=True)
816
- plt.close(fig)
817
- return save_path
818
- else:
819
- return fig
820
-
821
-
822
- def batch_requests(stations, dataset, lon, lat, date, distance=100):
823
- results = {"station": [], "data": [], "image": [], "custom_image": []}
824
- all_data = []
825
- all_event_times = []
826
-
827
- for lon_, lat_, date_ in dataset[[lon, lat, date]].values:
828
- test_lat_lon = (lat_, lon_)
829
- try:
830
- str_date = pd.to_datetime(date_).strftime('%Y-%m-%dT%H:%M:%S')
831
- except:
832
- str_date = date_
833
- twelve_hours = pd.Timedelta(hours=12)
834
- forty_eight_hours = pd.Timedelta(hours=48)
835
- try:
836
- str_date_start = (pd.to_datetime(str_date) - twelve_hours).strftime('%Y-%m-%dT%H:%M:%S')
837
- str_date_end = (pd.to_datetime(str_date) + forty_eight_hours).strftime('%Y-%m-%dT%H:%M:%S')
838
- except Exception as e:
839
- print(f"Error: {e}")
840
- pass
841
-
842
- try:
843
- new_dataset = compare_stations(test_lat_lon, stations, distance=distance, closest=True)
844
- station_name = new_dataset['Name']
845
- station_distance = new_dataset['Distance']
846
- test_ = get_data(new_dataset.iloc[0]['IagaCode'], str_date_start, str_date_end)
847
-
848
- if test_:
849
- results["station"].append(new_dataset.iloc[0]['IagaCode'])
850
- results["data"].append(test_)
851
- plotted = pd.DataFrame({
852
- 'datetime': test_['datetime'],
853
- 'X': test_['X'],
854
- 'Y': test_['Y'],
855
- 'Z': test_['Z'],
856
- 'S': test_['S'],
857
- })
858
- all_data.append(plotted)
859
- all_event_times.append(pd.to_datetime(date_))
860
- # print(date_)
861
- additional_data = f"Date: {date_}\nLat/Lon: {lat_}, {lon_}\nClosest station: {station_name.values[0]}\n Distance:{round(station_distance.values[0],2)} km"
862
- fig = plot_data_custom(plotted, date=pd.to_datetime(date_), save_path=None, subtitle =additional_data)
863
- with st.status(f'Magnetic Data: {date_}', expanded=False) as status:
864
- st.pyplot(fig)
865
- status.update(f'Magnetic Data: {date_} - Finished!')
866
- except Exception as e:
867
- #print(f"An error occurred: {e}")
868
- pass
869
-
870
- if all_data:
871
- fig_overlapped = plot_overlapped_timeseries(all_data, all_event_times)
872
- display(fig_overlapped)
873
- plt.close(fig_overlapped)
874
- # fig_average = plot_average_timeseries(all_data, all_event_times)
875
- # st.pyplot(fig_average)
876
- fig_average_aligned = plot_average_timeseries_with_dtw(all_data, all_event_times)
877
- with st.status(f'Dynamic Time Warping Data', expanded=False) as stts:
878
- st.pyplot(fig_average_aligned)
879
- return results
880
-
881
-
882
- df = pd.DataFrame()
883
-
884
-
885
- # Upload dataset
886
- uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"])
887
-
888
- if uploaded_file is not None:
889
- if uploaded_file.name.endswith('.csv'):
890
- df = pd.read_csv(uploaded_file)
891
- else:
892
- df = pd.read_excel(uploaded_file)
893
- stations = get_stations()
894
- st.write("Dataset Loaded:")
895
- df = filter_dataframe(df)
896
- st.dataframe(df)
897
-
898
- # Select columns
899
- lon_col = st.selectbox("Select Longitude Column", df.columns)
900
- lat_col = st.selectbox("Select Latitude Column", df.columns)
901
- date_col = st.selectbox("Select Date Column", df.columns)
902
- distance = st.number_input("Enter Distance", min_value=0, value=100)
903
-
904
- # Process data
905
- if st.button("Process Data"):
906
- cases = clean_uap_data(df, lat_col, lon_col, date_col)
907
- results = batch_requests(stations, cases, lon_col, lat_col, date_col, distance=distance)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ import pandas as pd
4
+ import numpy as np
5
+ import json
6
+ import requests
7
+ import datetime
8
+ from datetime import timedelta
9
+ from PIL import Image
10
+ # alternative to PIL
11
+ import matplotlib.pyplot as plt
12
+ import matplotlib.image as mpimg
13
+ import os
14
+ import matplotlib.dates as mdates
15
+ import seaborn as sns
16
+ from IPython.display import Image as image_display
17
+ path = os.getcwd()
18
+ from fastdtw import fastdtw
19
+ from scipy.spatial.distance import euclidean
20
+ from IPython.display import display
21
+ from dateutil import parser
22
+ from Levenshtein import distance
23
+ from sklearn.model_selection import train_test_split
24
+ from sklearn.metrics import confusion_matrix
25
+ from stqdm import stqdm
26
+ stqdm.pandas()
27
+ import streamlit.components.v1 as components
28
+ from dateutil import parser
29
+ from sentence_transformers import SentenceTransformer
30
+ import torch
31
+ import squarify
32
+ import matplotlib.colors as mcolors
33
+ import textwrap
34
+ import datamapplot
35
+ import streamlit as st
36
+
37
+
38
+ if 'form_submitted' not in st.session_state:
39
+ st.session_state['form_submitted'] = False
40
+
41
+
42
+ st.title('Magnetic Correlations Dashboard')
43
+
44
+ st.set_option('deprecation.showPyplotGlobalUse', False)
45
+
46
+
47
+ from pandas.api.types import (
48
+ is_categorical_dtype,
49
+ is_datetime64_any_dtype,
50
+ is_numeric_dtype,
51
+ is_object_dtype,
52
+ )
53
+
54
+
55
+ def plot_treemap(df, column, top_n=32):
56
+ # Get the value counts and the top N labels
57
+ value_counts = df[column].value_counts()
58
+ top_labels = value_counts.iloc[:top_n].index
59
+
60
+ # Use np.where to replace all values not in the top N with 'Other'
61
+ revised_column = f'{column}_revised'
62
+ df[revised_column] = np.where(df[column].isin(top_labels), df[column], 'Other')
63
+
64
+ # Get the value counts including the 'Other' category
65
+ sizes = df[revised_column].value_counts().values
66
+ labels = df[revised_column].value_counts().index
67
+
68
+ # Get a gradient of colors
69
+ # colors = list(mcolors.TABLEAU_COLORS.values())
70
+
71
+ n_colors = len(sizes)
72
+ colors = plt.cm.Oranges(np.linspace(0.3, 0.9, n_colors))[::-1]
73
+
74
+
75
+ # Get % of each category
76
+ percents = sizes / sizes.sum()
77
+
78
+ # Prepare labels with percentages
79
+ labels = [f'{label}\n {percent:.1%}' for label, percent in zip(labels, percents)]
80
+
81
+ fig, ax = plt.subplots(figsize=(20, 12))
82
+
83
+ # Plot the treemap
84
+ squarify.plot(sizes=sizes, label=labels, alpha=0.7, pad=True, color=colors, text_kwargs={'fontsize': 10})
85
+
86
+ ax = plt.gca()
87
+ # Iterate over text elements and rectangles (patches) in the axes for color adjustment
88
+ for text, rect in zip(ax.texts, ax.patches):
89
+ background_color = rect.get_facecolor()
90
+ r, g, b, _ = mcolors.to_rgba(background_color)
91
+ brightness = np.average([r, g, b])
92
+ text.set_color('white' if brightness < 0.5 else 'black')
93
+
94
+
95
+ def plot_hist(df, column, bins=10, kde=True):
96
+ fig, ax = plt.subplots(figsize=(12, 6))
97
+ sns.histplot(data=df, x=column, kde=True, bins=bins,color='orange')
98
+ # set the ticks and frame in orange
99
+ ax.spines['bottom'].set_color('orange')
100
+ ax.spines['top'].set_color('orange')
101
+ ax.spines['right'].set_color('orange')
102
+ ax.spines['left'].set_color('orange')
103
+ ax.xaxis.label.set_color('orange')
104
+ ax.yaxis.label.set_color('orange')
105
+ ax.tick_params(axis='x', colors='orange')
106
+ ax.tick_params(axis='y', colors='orange')
107
+ ax.title.set_color('orange')
108
+
109
+ # Set transparent background
110
+ fig.patch.set_alpha(0)
111
+ ax.patch.set_alpha(0)
112
+ return fig
113
+
114
+
115
+
116
+
117
+ def plot_line(df, x_column, y_columns, figsize=(12, 10), color='orange', title=None, rolling_mean_value=2):
118
+ import matplotlib.cm as cm
119
+ # Sort the dataframe by the date column
120
+ df = df.sort_values(by=x_column)
121
+
122
+ # Calculate rolling mean for each y_column
123
+ if rolling_mean_value:
124
+ df[y_columns] = df[y_columns].rolling(len(df) // rolling_mean_value).mean()
125
+
126
+ # Create the plot
127
+ fig, ax = plt.subplots(figsize=figsize)
128
+
129
+ colors = cm.Oranges(np.linspace(0.2, 1, len(y_columns)))
130
+
131
+ # Plot each y_column as a separate line with a different color
132
+ for i, y_column in enumerate(y_columns):
133
+ df.plot(x=x_column, y=y_column, ax=ax, color=colors[i], label=y_column, linewidth=.5)
134
+
135
+ # Rotate x-axis labels
136
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=30, ha='right')
137
+
138
+ # Format x_column as date if it is
139
+ if np.issubdtype(df[x_column].dtype, np.datetime64) or np.issubdtype(df[x_column].dtype, np.timedelta64):
140
+ df[x_column] = pd.to_datetime(df[x_column]).dt.date
141
+
142
+ # Set title, labels, and legend
143
+ ax.set_title(title or f'{", ".join(y_columns)} over {x_column}', color=color, fontweight='bold')
144
+ ax.set_xlabel(x_column, color=color)
145
+ ax.set_ylabel(', '.join(y_columns), color=color)
146
+ ax.spines['bottom'].set_color('orange')
147
+ ax.spines['top'].set_color('orange')
148
+ ax.spines['right'].set_color('orange')
149
+ ax.spines['left'].set_color('orange')
150
+ ax.xaxis.label.set_color('orange')
151
+ ax.yaxis.label.set_color('orange')
152
+ ax.tick_params(axis='x', colors='orange')
153
+ ax.tick_params(axis='y', colors='orange')
154
+ ax.title.set_color('orange')
155
+
156
+ ax.legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.4, labelcolor='orange', edgecolor='orange')
157
+
158
+ # Remove background
159
+ fig.patch.set_alpha(0)
160
+ ax.patch.set_alpha(0)
161
+
162
+ return fig
163
+
164
+ def plot_bar(df, x_column, y_column, figsize=(12, 10), color='orange', title=None, rotation=45):
165
+ fig, ax = plt.subplots(figsize=figsize)
166
+
167
+ sns.barplot(data=df, x=x_column, y=y_column, color=color, ax=ax)
168
+
169
+ ax.set_title(title if title else f'{y_column} by {x_column}', color=color, fontweight='bold')
170
+ ax.set_xlabel(x_column, color=color)
171
+ ax.set_ylabel(y_column, color=color)
172
+
173
+ ax.tick_params(axis='x', colors=color)
174
+ ax.tick_params(axis='y', colors=color)
175
+
176
+ plt.xticks(rotation=rotation)
177
+
178
+ # Remove background
179
+ fig.patch.set_alpha(0)
180
+ ax.patch.set_alpha(0)
181
+ ax.spines['bottom'].set_color('orange')
182
+ ax.spines['top'].set_color('orange')
183
+ ax.spines['right'].set_color('orange')
184
+ ax.spines['left'].set_color('orange')
185
+ ax.xaxis.label.set_color('orange')
186
+ ax.yaxis.label.set_color('orange')
187
+ ax.tick_params(axis='x', colors='orange')
188
+ ax.tick_params(axis='y', colors='orange')
189
+ ax.title.set_color('orange')
190
+ ax.legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.4, labelcolor='orange', edgecolor='orange')
191
+
192
+ return fig
193
+
194
+ def plot_grouped_bar(df, x_columns, y_column, figsize=(12, 10), colors=None, title=None):
195
+ fig, ax = plt.subplots(figsize=figsize)
196
+
197
+ width = 0.8 / len(x_columns) # the width of the bars
198
+ x = np.arange(len(df)) # the label locations
199
+
200
+ for i, x_column in enumerate(x_columns):
201
+ sns.barplot(data=df, x=x, y=y_column, color=colors[i] if colors else None, ax=ax, width=width, label=x_column)
202
+ x += width # add the width of the bar to the x position for the next bar
203
+
204
+ ax.set_title(title if title else f'{y_column} by {", ".join(x_columns)}', color='orange', fontweight='bold')
205
+ ax.set_xlabel('Groups', color='orange')
206
+ ax.set_ylabel(y_column, color='orange')
207
+
208
+ ax.set_xticks(x - width * len(x_columns) / 2)
209
+ ax.set_xticklabels(df.index)
210
+
211
+ ax.tick_params(axis='x', colors='orange')
212
+ ax.tick_params(axis='y', colors='orange')
213
+
214
+ # Remove background
215
+ fig.patch.set_alpha(0)
216
+ ax.patch.set_alpha(0)
217
+ ax.spines['bottom'].set_color('orange')
218
+ ax.spines['top'].set_color('orange')
219
+ ax.spines['right'].set_color('orange')
220
+ ax.spines['left'].set_color('orange')
221
+ ax.xaxis.label.set_color('orange')
222
+ ax.yaxis.label.set_color('orange')
223
+ ax.title.set_color('orange')
224
+ ax.legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.4, labelcolor='orange', edgecolor='orange')
225
+
226
+ return fig
227
+
228
+
229
+ def filter_dataframe(df: pd.DataFrame) -> pd.DataFrame:
230
+ """
231
+ Adds a UI on top of a dataframe to let viewers filter columns
232
+
233
+ Args:
234
+ df (pd.DataFrame): Original dataframe
235
+
236
+ Returns:
237
+ pd.DataFrame: Filtered dataframe
238
+ """
239
+
240
+ title_font = "Arial"
241
+ body_font = "Arial"
242
+ title_size = 32
243
+ colors = ["red", "green", "blue"]
244
+ interpretation = False
245
+ extract_docx = False
246
+ title = "My Chart"
247
+ regex = ".*"
248
+ img_path = 'default_image.png'
249
+
250
+
251
+ #try:
252
+ # modify = st.checkbox("Add filters on raw data")
253
+ #except:
254
+ # try:
255
+ # modify = st.checkbox("Add filters on processed data")
256
+ # except:
257
+ # try:
258
+ # modify = st.checkbox("Add filters on parsed data")
259
+ # except:
260
+ # pass
261
+
262
+ #if not modify:
263
+ # return df
264
+
265
+ df_ = df.copy()
266
+ # Try to convert datetimes into a standard format (datetime, no timezone)
267
+
268
+ #modification_container = st.container()
269
+
270
+ #with modification_container:
271
+ to_filter_columns = st.multiselect("Filter dataframe on", df_.columns)
272
+
273
+ date_column = None
274
+ filtered_columns = []
275
+
276
+ for column in to_filter_columns:
277
+ left, right = st.columns((1, 20))
278
+ # Treat columns with < 200 unique values as categorical if not date or numeric
279
+ if is_categorical_dtype(df_[column]) or (df_[column].nunique() < 120 and not is_datetime64_any_dtype(df_[column]) and not is_numeric_dtype(df_[column])):
280
+ user_cat_input = right.multiselect(
281
+ f"Values for {column}",
282
+ df_[column].value_counts().index.tolist(),
283
+ default=list(df_[column].value_counts().index)
284
+ )
285
+ df_ = df_[df_[column].isin(user_cat_input)]
286
+ filtered_columns.append(column)
287
+
288
+ with st.status(f"Category Distribution: {column}", expanded=False) as stat:
289
+ st.pyplot(plot_treemap(df_, column))
290
+
291
+ elif is_numeric_dtype(df_[column]):
292
+ _min = float(df_[column].min())
293
+ _max = float(df_[column].max())
294
+ step = (_max - _min) / 100
295
+ user_num_input = right.slider(
296
+ f"Values for {column}",
297
+ min_value=_min,
298
+ max_value=_max,
299
+ value=(_min, _max),
300
+ step=step,
301
+ )
302
+ df_ = df_[df_[column].between(*user_num_input)]
303
+ filtered_columns.append(column)
304
+
305
+ # Chart_GPT = ChartGPT(df_, title_font, body_font, title_size,
306
+ # colors, interpretation, extract_docx, img_path)
307
+
308
+ with st.status(f"Numerical Distribution: {column}", expanded=False) as stat_:
309
+ st.pyplot(plot_hist(df_, column, bins=int(round(len(df_[column].unique())-1)/2)))
310
+
311
+ elif is_object_dtype(df_[column]):
312
+ try:
313
+ df_[column] = pd.to_datetime(df_[column], infer_datetime_format=True, errors='coerce')
314
+ except Exception:
315
+ try:
316
+ df_[column] = df_[column].apply(parser.parse)
317
+ except Exception:
318
+ pass
319
+
320
+ if is_datetime64_any_dtype(df_[column]):
321
+ df_[column] = df_[column].dt.tz_localize(None)
322
+ min_date = df_[column].min().date()
323
+ max_date = df_[column].max().date()
324
+ user_date_input = right.date_input(
325
+ f"Values for {column}",
326
+ value=(min_date, max_date),
327
+ min_value=min_date,
328
+ max_value=max_date,
329
+ )
330
+
331
+
332
+ if len(user_date_input) == 2:
333
+ user_date_input = tuple(map(pd.to_datetime, user_date_input))
334
+ start_date, end_date = user_date_input
335
+
336
+ # Determine the most appropriate time unit for plot
337
+ time_units = {
338
+ 'year': df_[column].dt.year,
339
+ 'month': df_[column].dt.to_period('M'),
340
+ 'day': df_[column].dt.date
341
+ }
342
+ unique_counts = {unit: col.nunique() for unit, col in time_units.items()}
343
+ closest_to_36 = min(unique_counts, key=lambda k: abs(unique_counts[k] - 36))
344
+
345
+ # Group by the most appropriate time unit and count occurrences
346
+ grouped = df_.groupby(time_units[closest_to_36]).size().reset_index(name='count')
347
+ grouped.columns = [column, 'count']
348
+
349
+ # Create a complete date range
350
+ if closest_to_36 == 'year':
351
+ date_range = pd.date_range(start=f"{start_date.year}-01-01", end=f"{end_date.year}-12-31", freq='YS')
352
+ elif closest_to_36 == 'month':
353
+ date_range = pd.date_range(start=start_date.replace(day=1), end=end_date + pd.offsets.MonthEnd(0), freq='MS')
354
+ else: # day
355
+ date_range = pd.date_range(start=start_date, end=end_date, freq='D')
356
+
357
+ # Create a DataFrame with the complete date range
358
+ complete_range = pd.DataFrame({column: date_range})
359
+
360
+ # Convert the date column to the appropriate format based on closest_to_36
361
+ if closest_to_36 == 'year':
362
+ complete_range[column] = complete_range[column].dt.year
363
+ elif closest_to_36 == 'month':
364
+ complete_range[column] = complete_range[column].dt.to_period('M')
365
+
366
+ # Merge the complete range with the grouped data
367
+ final_data = pd.merge(complete_range, grouped, on=column, how='left').fillna(0)
368
+
369
+ with st.status(f"Date Distributions: {column}", expanded=False) as stat:
370
+ try:
371
+ st.pyplot(plot_bar(final_data, column, 'count'))
372
+ except Exception as e:
373
+ st.error(f"Error plotting bar chart: {e}")
374
+
375
+ df_ = df_.loc[df_[column].between(start_date, end_date)]
376
+
377
+ date_column = column
378
+
379
+ if date_column and filtered_columns:
380
+ numeric_columns = [col for col in filtered_columns if is_numeric_dtype(df_[col])]
381
+ if numeric_columns:
382
+ fig = plot_line(df_, date_column, numeric_columns)
383
+ #st.pyplot(fig)
384
+ # now to deal with categorical columns
385
+ categorical_columns = [col for col in filtered_columns if is_categorical_dtype(df_[col])]
386
+ if categorical_columns:
387
+ fig2 = plot_bar(df_, date_column, categorical_columns[0])
388
+ #st.pyplot(fig2)
389
+ with st.status(f"Date Distribution: {column}", expanded=False) as stat:
390
+ try:
391
+ st.pyplot(fig)
392
+ except Exception as e:
393
+ st.error(f"Error plotting line chart: {e}")
394
+ pass
395
+ try:
396
+ st.pyplot(fig2)
397
+ except Exception as e:
398
+ st.error(f"Error plotting bar chart: {e}")
399
+
400
+
401
+ else:
402
+ user_text_input = right.text_input(
403
+ f"Substring or regex in {column}",
404
+ )
405
+ if user_text_input:
406
+ df_ = df_[df_[column].astype(str).str.contains(user_text_input)]
407
+ # write len of df after filtering with % of original
408
+ st.write(f"{len(df_)} rows ({len(df_) / len(df) * 100:.2f}%)")
409
+ return df_
410
+
411
+
412
+ def get_stations():
413
+ base_url = 'https://imag-data.bgs.ac.uk:/GIN_V1/GINServices?Request=GetCapabilities&format=json'
414
+ response = requests.get(base_url)
415
+ data = response.json()
416
+ dataframe_stations = pd.DataFrame.from_dict(data['ObservatoryList'])
417
+ return dataframe_stations
418
+
419
+ def get_haversine_distance(lat1, lon1, lat2, lon2):
420
+ R = 6371
421
+ dlat = math.radians(lat2 - lat1)
422
+ dlon = math.radians(lon2 - lon1)
423
+ a = math.sin(dlat/2) * math.sin(dlat/2) + math.cos(math.radians(lat1)) * math.cos(math.radians(lat2)) * math.sin(dlon/2) * math.sin(dlon/2)
424
+ c = 2 * math.atan2(math.sqrt(a), math.sqrt(1-a))
425
+ d = R * c
426
+ return d
427
+
428
+ def compare_stations(test_lat_lon, data_table, distance=1000, closest=False):
429
+ table_updated = pd.DataFrame()
430
+ distances = dict()
431
+ for lat,lon,names in data_table[['Latitude', 'Longitude', 'Name']].values:
432
+ harv_distance = get_haversine_distance(test_lat_lon[0], test_lat_lon[1], lat, lon)
433
+ if harv_distance < distance:
434
+ #print(f"Station {names} is at {round(harv_distance,2)} km from the test point")
435
+ table_updated = pd.concat([table_updated, data_table[data_table['Name'] == names]])
436
+ distances[names] = harv_distance
437
+ if closest:
438
+ closest_station = min(distances, key=distances.get)
439
+ #print(f"The closest station is {closest_station} at {round(distances[closest_station],2)} km")
440
+ table_updated = data_table[data_table['Name'] == closest_station]
441
+ table_updated['Distance'] = distances[closest_station]
442
+ return table_updated
443
+
444
+ def get_data(IagaCode, start_date, end_date):
445
+ try:
446
+ start_date_ = datetime.datetime.strptime(start_date, '%Y-%m-%d')
447
+ except ValueError as e:
448
+ print(f"Error: {e}")
449
+ start_date_ = pd.to_datetime(start_date)
450
+ try:
451
+ end_date_ = datetime.datetime.strptime(end_date, '%Y-%m-%d')
452
+ except ValueError as e:
453
+ print(f"Error: {e}")
454
+ end_date_ = pd.to_datetime(end_date)
455
+
456
+ duration = end_date_ - start_date_
457
+ # Define the parameters for the request
458
+ params = {
459
+ 'Request': 'GetData',
460
+ 'format': 'PNG',
461
+ 'testObsys': '0',
462
+ 'observatoryIagaCode': IagaCode,
463
+ 'samplesPerDay': 'minute',
464
+ 'publicationState': 'Best available',
465
+ 'dataStartDate': start_date,
466
+ # make substraction
467
+ 'dataDuration': duration.days,
468
+ 'traceList': '1234',
469
+ 'colourTraces': 'true',
470
+ 'pictureSize': 'Automatic',
471
+ 'dataScale': 'Automatic',
472
+ 'pdfSize': '21,29.7',
473
+ }
474
+
475
+ base_url_json = 'https://imag-data.bgs.ac.uk:/GIN_V1/GINServices?Request=GetData&format=json'
476
+ #base_url_img = 'https://imag-data.bgs.ac.uk:/GIN_V1/GINServices?Request=GetData&format=png'
477
+
478
+ for base_url in [base_url_json]:#, base_url_img]:
479
+ response = requests.get(base_url, params=params)
480
+ if response.status_code == 200:
481
+ content_type = response.headers.get('Content-Type')
482
+ if 'image' in content_type:
483
+ # f"custom_plot_{new_dataset.iloc[0]['IagaCode']}_{str_date.replace(':', '_')}.png"
484
+ # output_image_path = "plot_image.png"
485
+ # with open(output_image_path, 'wb') as file:
486
+ # file.write(response.content)
487
+ # print(f"Image successfully saved as {output_image_path}")
488
+
489
+ # # Display the image
490
+ # img = mpimg.imread(output_image_path)
491
+ # plt.imshow(img)
492
+ # plt.axis('off') # Hide axes
493
+ # plt.show()
494
+ # img_answer = Image.open(output_image_path)
495
+ img_answer = None
496
+ else:
497
+ print(f"Unexpected content type: {content_type}")
498
+ #print("Response content:")
499
+ #print(response.content.decode('utf-8')) # Attempt to print response as text
500
+ # return json
501
+ answer = response.json()
502
+ else:
503
+ print(f"Failed to retrieve data. HTTP Status code: {response.status_code}")
504
+ print("Response content:")
505
+ print(response.content.decode('utf-8'))
506
+ return answer#, img_answer
507
+
508
+
509
+ # def get_data(IagaCode, start_date, end_date):
510
+ # # Convert dates to datetime
511
+ # try:
512
+ # start_date_ = pd.to_datetime(start_date)
513
+ # end_date_ = pd.to_datetime(end_date)
514
+ # except ValueError as e:
515
+ # print(f"Error: {e}")
516
+ # return None, None
517
+
518
+ # duration = (end_date_ - start_date_).days
519
+
520
+ # # Define the parameters for the request
521
+ # params = {
522
+ # 'Request': 'GetData',
523
+ # 'format': 'json',
524
+ # 'testObsys': '0',
525
+ # 'observatoryIagaCode': IagaCode,
526
+ # 'samplesPerDay': 'minute',
527
+ # 'publicationState': 'Best available',
528
+ # 'dataStartDate': start_date_.strftime('%Y-%m-%d'),
529
+ # 'dataDuration': duration,
530
+ # 'traceList': '1234',
531
+ # 'colourTraces': 'true',
532
+ # 'pictureSize': 'Automatic',
533
+ # 'dataScale': 'Automatic',
534
+ # 'pdfSize': '21,29.7',
535
+ # }
536
+
537
+ # base_url_json = 'https://imag-data.bgs.ac.uk:/GIN_V1/GINServices?Request=GetData&format=json'
538
+ # base_url_img = 'https://imag-data.bgs.ac.uk:/GIN_V1/GINServices?Request=GetData&format=png'
539
+
540
+ # try:
541
+ # # Request JSON data
542
+ # response_json = requests.get(base_url_json, params=params)
543
+ # response_json.raise_for_status() # Raises an error for bad status codes
544
+ # data = response_json.json()
545
+
546
+ # # Request Image
547
+ # params['format'] = 'png'
548
+ # response_img = requests.get(base_url_img, params=params)
549
+ # response_img.raise_for_status()
550
+
551
+ # # Save and display image if response is successful
552
+ # if 'image' in response_img.headers.get('Content-Type'):
553
+ # output_image_path = "plot_image.png"
554
+ # with open(output_image_path, 'wb') as file:
555
+ # file.write(response_img.content)
556
+ # print(f"Image successfully saved as {output_image_path}")
557
+
558
+ # img = mpimg.imread(output_image_path)
559
+ # plt.imshow(img)
560
+ # plt.axis('off')
561
+ # plt.show()
562
+ # img_answer = Image.open(output_image_path)
563
+ # else:
564
+ # img_answer = None
565
+
566
+ # return data, img_answer
567
+
568
+ # except requests.RequestException as e:
569
+ # print(f"Request failed: {e}")
570
+ # return None, None
571
+ # except ValueError as e:
572
+ # print(f"JSON decode error: {e}")
573
+ # return None, None
574
+
575
+ def clean_uap_data(dataset, lat, lon, date):
576
+ # Assuming 'nuforc' is already defined
577
+ processed = dataset[dataset[[lat, lon, date]].notnull().all(axis=1)]
578
+ # Converting 'Lat' and 'Long' columns to floats, handling errors
579
+ processed[lat] = pd.to_numeric(processed[lat], errors='coerce')
580
+ processed[lon] = pd.to_numeric(processed[lon], errors='coerce')
581
+
582
+ # if processed[date].min() < pd.to_datetime('1677-09-22'):
583
+ # processed.loc[processed[date] < pd.to_datetime('1677-09-22'), 'corrected_date'] = pd.to_datetime('1677-09-22 00:00:00')
584
+
585
+ procesed = processed[processed[date] >= '1677-09-22']
586
+
587
+ # convert date to str
588
+ #processed[date] = processed[date].astype(str)
589
+ # Dropping rows where 'Lat' or 'Long' conversion failed (i.e., became NaN)
590
+ processed = processed.dropna(subset=[lat, lon])
591
+ return processed
592
+
593
+
594
+ def plot_overlapped_timeseries(data_list, event_times, window_hours=12, save_path=None):
595
+ fig, axs = plt.subplots(4, 1, figsize=(12, 16), sharex=True)
596
+ fig.patch.set_alpha(0) # Make figure background transparent
597
+
598
+ components = ['X', 'Y', 'Z', 'S']
599
+ colors = ['red', 'green', 'blue', 'black']
600
+
601
+ for i, component in enumerate(components):
602
+ axs[i].patch.set_alpha(0) # Make subplot background transparent
603
+ axs[i].set_ylabel(component, color='orange')
604
+ axs[i].grid(True, color='orange', alpha=0.3)
605
+
606
+ for spine in axs[i].spines.values():
607
+ spine.set_color('orange')
608
+
609
+ axs[i].tick_params(axis='both', colors='orange') # Change tick color
610
+ axs[i].set_title(f'{component}', color='orange')
611
+ axs[i].set_xlabel('Time Difference from Event (hours)', color='orange')
612
+
613
+ for j, (df, event_time) in enumerate(zip(data_list, event_times)):
614
+ # Convert datetime column to UTC if it has timezone info, otherwise assume it's UTC
615
+ df['datetime'] = pd.to_datetime(df['datetime']).dt.tz_localize(None)
616
+
617
+ # Convert event_time to UTC if it has timezone info, otherwise assume it's UTC
618
+ event_time = pd.to_datetime(event_time).tz_localize(None)
619
+
620
+ # Calculate time difference from event
621
+ df['time_diff'] = (df['datetime'] - event_time).dt.total_seconds() / 3600 # Convert to hours
622
+
623
+ # Filter data within the specified window
624
+ df_window = df[(df['time_diff'] >= -window_hours) & (df['time_diff'] <= window_hours)]
625
+
626
+ # normalize component data
627
+ df_window[component] = (df_window[component] - df_window[component].mean()) / df_window[component].std()
628
+
629
+ axs[i].plot(df_window['time_diff'], df_window[component], color=colors[i], alpha=0.7, label=f'Event {j+1}', linewidth=1)
630
+
631
+ axs[i].axvline(x=0, color='red', linewidth=2, linestyle='--', label='Event Time')
632
+ axs[i].set_xlim(-window_hours, window_hours)
633
+ #axs[i].legend(loc='upper left', bbox_to_anchor=(1, 1))
634
+
635
+ axs[-1].set_xlabel('Hours from Event', color='orange')
636
+ fig.suptitle('Overlapped Time Series of Components', fontsize=16, color='orange')
637
+
638
+ plt.tight_layout()
639
+ plt.subplots_adjust(top=0.95, right=0.85)
640
+
641
+ if save_path:
642
+ fig.savefig(save_path, transparent=True, bbox_inches='tight')
643
+ plt.close(fig)
644
+ return save_path
645
+ else:
646
+ return fig
647
+
648
+ def plot_average_timeseries(data_list, event_times, window_hours=12, save_path=None):
649
+ fig, axs = plt.subplots(4, 1, figsize=(12, 16), sharex=True)
650
+ fig.patch.set_alpha(0) # Make figure background transparent
651
+
652
+ components = ['X', 'Y', 'Z', 'S']
653
+ colors = ['red', 'green', 'blue', 'black']
654
+
655
+ for i, component in enumerate(components):
656
+ axs[i].patch.set_alpha(0)
657
+ axs[i].set_ylabel(component, color='orange')
658
+ axs[i].grid(True, color='orange', alpha=0.3)
659
+
660
+ for spine in axs[i].spines.values():
661
+ spine.set_color('orange')
662
+
663
+ axs[i].tick_params(axis='both', colors='orange')
664
+
665
+ all_data = []
666
+ time_diffs = []
667
+
668
+ for j, (df, event_time) in enumerate(zip(data_list, event_times)):
669
+ # Convert datetime column to UTC if it has timezone info, otherwise assume it's UTC
670
+ df['datetime'] = pd.to_datetime(df['datetime']).dt.tz_localize(None)
671
+
672
+ # Convert event_time to UTC if it has timezone info, otherwise assume it's UTC
673
+ event_time = pd.to_datetime(event_time).tz_localize(None)
674
+
675
+ # Calculate time difference from event
676
+ df['time_diff'] = (df['datetime'] - event_time).dt.total_seconds() / 3600 # Convert to hours
677
+
678
+ # Filter data within the specified window
679
+ df_window = df[(df['time_diff'] >= -window_hours) & (df['time_diff'] <= window_hours)]
680
+
681
+ # Normalize component data
682
+ df_window[component] = (df_window[component] - df_window[component].mean())# / df_window[component].std()
683
+
684
+ all_data.append(df_window[component].values)
685
+ time_diffs.append(df_window['time_diff'].values)
686
+
687
+ # Calculate average and standard deviation
688
+ try:
689
+ avg_data = np.mean(all_data, axis=0)
690
+ except:
691
+ avg_data = np.zeros_like(all_data[0])
692
+ try:
693
+ std_data = np.std(all_data, axis=0)
694
+ except:
695
+ std_data = np.zeros_like(avg_data)
696
+
697
+ axs[-1].set_xlabel('Hours from Event', color='orange')
698
+ fig.suptitle('Average Time Series of Components', fontsize=16, color='orange')
699
+
700
+ # Plot average line
701
+ axs[i].plot(time_diffs[0], avg_data, color=colors[i], label='Average')
702
+
703
+ # Plot standard deviation as shaded region
704
+ try:
705
+ axs[i].fill_between(time_diffs[0], avg_data - std_data, avg_data + std_data, color=colors[i], alpha=0.2)
706
+ except:
707
+ pass
708
+
709
+ axs[i].axvline(x=0, color='red', linewidth=2, linestyle='--', label='Event Time')
710
+ axs[i].set_xlim(-window_hours, window_hours)
711
+ # orange frame, orange label legend
712
+ axs[i].legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.4, labelcolor='orange', edgecolor='orange')
713
+
714
+ plt.tight_layout()
715
+ plt.subplots_adjust(top=0.95, right=0.85)
716
+
717
+ if save_path:
718
+ fig.savefig(save_path, transparent=True, bbox_inches='tight')
719
+ plt.close(fig)
720
+ return save_path
721
+ else:
722
+ return fig
723
+
724
+ def align_series(reference, series):
725
+ reference = reference.flatten()
726
+ series = series.flatten()
727
+ _, path = fastdtw(reference, series, dist=euclidean)
728
+ aligned = np.zeros(len(reference))
729
+ for ref_idx, series_idx in path:
730
+ aligned[ref_idx] = series[series_idx]
731
+ return aligned
732
+
733
+ def plot_average_timeseries_with_dtw(data_list, event_times, window_hours=12, save_path=None):
734
+ fig, axs = plt.subplots(4, 1, figsize=(12, 16), sharex=True)
735
+ fig.patch.set_alpha(0) # Make figure background transparent
736
+
737
+ components = ['X', 'Y', 'Z', 'S']
738
+ colors = ['red', 'green', 'blue', 'black']
739
+ fig.text(0.02, 0.5, 'Geomagnetic Variation (nT)', va='center', rotation='vertical', color='orange')
740
+
741
+
742
+ for i, component in enumerate(components):
743
+ axs[i].patch.set_alpha(0)
744
+ axs[i].set_ylabel(component, color='orange', rotation=90)
745
+ axs[i].grid(True, color='orange', alpha=0.3)
746
+
747
+ for spine in axs[i].spines.values():
748
+ spine.set_color('orange')
749
+
750
+ axs[i].tick_params(axis='both', colors='orange')
751
+
752
+ all_aligned_data = []
753
+ reference_df = None
754
+
755
+ for j, (df, event_time) in enumerate(zip(data_list, event_times)):
756
+ df['datetime'] = pd.to_datetime(df['datetime']).dt.tz_localize(None)
757
+ event_time = pd.to_datetime(event_time).tz_localize(None)
758
+ df['time_diff'] = (df['datetime'] - event_time).dt.total_seconds() / 3600
759
+ df_window = df[(df['time_diff'] >= -window_hours) & (df['time_diff'] <= window_hours)]
760
+ df_window[component] = (df_window[component] - df_window[component].mean())# / df_window[component].std()
761
+
762
+ if reference_df is None:
763
+ reference_df = df_window
764
+ all_aligned_data.append(reference_df[component].values)
765
+ else:
766
+ try:
767
+ aligned_series = align_series(reference_df[component].values, df_window[component].values)
768
+ all_aligned_data.append(aligned_series)
769
+ except:
770
+ pass
771
+
772
+ # Calculate average and standard deviation of aligned data
773
+ all_aligned_data = np.array(all_aligned_data)
774
+ avg_data = np.mean(all_aligned_data, axis=0)
775
+
776
+ # round float to avoid sqrt errors
777
+ def calculate_std(data):
778
+ if data is not None and len(data) > 0:
779
+ data = np.array(data)
780
+ std_data = np.std(data)
781
+ return std_data
782
+ else:
783
+ return "Data is empty or not a list"
784
+
785
+ std_data = calculate_std(all_aligned_data)
786
+
787
+ # Plot average line
788
+ axs[i].plot(reference_df['time_diff'], avg_data, color=colors[i], label='Average')
789
+
790
+ # Plot standard deviation as shaded region
791
+ try:
792
+ axs[i].fill_between(reference_df['time_diff'], avg_data - std_data, avg_data + std_data, color=colors[i], alpha=0.2)
793
+ except TypeError as e:
794
+ #print(f"Error: {e}")
795
+ pass
796
+
797
+
798
+ axs[i].axvline(x=0, color='red', linewidth=2, linestyle='--', label='Event Time')
799
+ axs[i].set_xlim(-window_hours, window_hours)
800
+ axs[i].legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.2, labelcolor='orange', edgecolor='orange')
801
+
802
+
803
+ axs[-1].set_xlabel('Hours from Event', color='orange')
804
+ fig.suptitle('Average Time Series of Components (FastDTW Aligned)', fontsize=16, color='orange')
805
+
806
+ plt.tight_layout()
807
+ plt.subplots_adjust(top=0.85, right=0.85, left=0.1)
808
+
809
+ if save_path:
810
+ fig.savefig(save_path, transparent=True, bbox_inches='tight')
811
+ plt.close(fig)
812
+ return save_path
813
+ else:
814
+ return fig
815
+
816
+ def plot_data_custom(df, date, save_path=None, subtitle=None):
817
+ df['datetime'] = pd.to_datetime(df['datetime'])
818
+ event = pd.to_datetime(date)
819
+ window = timedelta(hours=12)
820
+ x_min = event - window
821
+ x_max = event + window
822
+
823
+ fig, axs = plt.subplots(4, 1, figsize=(12, 12), sharex=True)
824
+ fig.patch.set_alpha(0) # Make figure background transparent
825
+
826
+ components = ['X', 'Y', 'Z', 'S']
827
+ colors = ['red', 'green', 'blue', 'black']
828
+
829
+ fig.text(0.02, 0.5, 'Geomagnetic Variation (nT)', va='center', rotation='vertical', color='orange')
830
+
831
+ # if df[component].isnull().all().all():
832
+ # return None
833
+
834
+ for i, component in enumerate(components):
835
+ axs[i].plot(df['datetime'], df[component], label=component, color=colors[i])
836
+ axs[i].axvline(x=event, color='red', linewidth=2, label='Event', linestyle='--')
837
+ axs[i].set_ylabel(component, color='orange', rotation=90)
838
+ axs[i].set_xlim(x_min, x_max)
839
+ axs[i].legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.2, labelcolor='orange', edgecolor='orange')
840
+ axs[i].grid(True, color='orange', alpha=0.3)
841
+ axs[i].patch.set_alpha(0) # Make subplot background transparent
842
+
843
+ for spine in axs[i].spines.values():
844
+ spine.set_color('orange')
845
+
846
+ axs[i].xaxis.set_major_formatter(mdates.DateFormatter('%H:%M'))
847
+ axs[i].xaxis.set_major_locator(mdates.HourLocator(interval=1))
848
+ axs[i].tick_params(axis='both', colors='orange')
849
+
850
+ plt.setp(axs[-1].xaxis.get_majorticklabels(), rotation=45)
851
+ axs[-1].set_xlabel('Hours', color='orange')
852
+ fig.suptitle(f'Time Series of Components with Event Marks\n{subtitle}', fontsize=12, color='orange')
853
+
854
+ plt.tight_layout()
855
+ #plt.subplots_adjust(top=0.85)
856
+ plt.subplots_adjust(top=0.85, right=0.85, left=0.1)
857
+
858
+
859
+ if save_path:
860
+ fig.savefig(save_path, transparent=True)
861
+ plt.close(fig)
862
+ return save_path
863
+ else:
864
+ return fig
865
+
866
+
867
+ def batch_requests(stations, dataset, lon, lat, date, distance=100):
868
+ results = {"station": [], "data": [], "image": [], "custom_image": []}
869
+ all_data = []
870
+ all_event_times = []
871
+
872
+ for lon_, lat_, date_ in dataset[[lon, lat, date]].values:
873
+ test_lat_lon = (lat_, lon_)
874
+ try:
875
+ str_date = pd.to_datetime(date_).strftime('%Y-%m-%dT%H:%M:%S')
876
+ except:
877
+ str_date = date_
878
+ twelve_hours = pd.Timedelta(hours=12)
879
+ forty_eight_hours = pd.Timedelta(hours=48)
880
+ try:
881
+ str_date_start = (pd.to_datetime(str_date) - twelve_hours).strftime('%Y-%m-%dT%H:%M:%S')
882
+ str_date_end = (pd.to_datetime(str_date) + forty_eight_hours).strftime('%Y-%m-%dT%H:%M:%S')
883
+ except Exception as e:
884
+ print(f"Error: {e}")
885
+ pass
886
+
887
+ try:
888
+ new_dataset = compare_stations(test_lat_lon, stations, distance=distance, closest=True)
889
+ station_name = new_dataset['Name']
890
+ station_distance = new_dataset['Distance']
891
+ test_ = get_data(new_dataset.iloc[0]['IagaCode'], str_date_start, str_date_end)
892
+
893
+ if test_:
894
+ results["station"].append(new_dataset.iloc[0]['IagaCode'])
895
+ results["data"].append(test_)
896
+ plotted = pd.DataFrame({
897
+ 'datetime': test_['datetime'],
898
+ 'X': test_['X'],
899
+ 'Y': test_['Y'],
900
+ 'Z': test_['Z'],
901
+ 'S': test_['S'],
902
+ })
903
+ all_data.append(plotted)
904
+ all_event_times.append(pd.to_datetime(date_))
905
+ # print(date_)
906
+ additional_data = f"Date: {date_}\nLat/Lon: {lat_}, {lon_}\nClosest station: {station_name.values[0]}\n Distance:{round(station_distance.values[0],2)} km"
907
+ fig = plot_data_custom(plotted, date=pd.to_datetime(date_), save_path=None, subtitle =additional_data)
908
+ with st.status(f'Magnetic Data: {date_}', expanded=False) as status:
909
+ st.pyplot(fig)
910
+ status.update(f'Magnetic Data: {date_} - Finished!')
911
+ except Exception as e:
912
+ #print(f"An error occurred: {e}")
913
+ pass
914
+
915
+ if all_data:
916
+ fig_overlapped = plot_overlapped_timeseries(all_data, all_event_times)
917
+ display(fig_overlapped)
918
+ plt.close(fig_overlapped)
919
+ # fig_average = plot_average_timeseries(all_data, all_event_times)
920
+ # st.pyplot(fig_average)
921
+ fig_average_aligned = plot_average_timeseries_with_dtw(all_data, all_event_times)
922
+ with st.status(f'Dynamic Time Warping Data', expanded=False) as stts:
923
+ st.pyplot(fig_average_aligned)
924
+ return results
925
+
926
+
927
+ df = pd.DataFrame()
928
+
929
+
930
+ # Upload dataset
931
+ uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"])
932
+
933
+ if uploaded_file is not None:
934
+ if uploaded_file.name.endswith('.csv'):
935
+ df = pd.read_csv(uploaded_file)
936
+ else:
937
+ df = pd.read_excel(uploaded_file)
938
+ stations = get_stations()
939
+ st.write("Dataset Loaded:")
940
+ df = filter_dataframe(df)
941
+ st.dataframe(df)
942
+
943
+ # Select columns
944
+ with st.form(border=True, key='Select Columns for Analysis'):
945
+ lon_col = st.selectbox("Select Longitude Column", df.columns)
946
+ lat_col = st.selectbox("Select Latitude Column", df.columns)
947
+ date_col = st.selectbox("Select Date Column", df.columns)
948
+ distance = st.number_input("Enter Distance", min_value=0, value=100)
949
+ if st.form_submit_button("Process Data"):
950
+ cases = clean_uap_data(df, lat_col, lon_col, date_col)
951
+ results = batch_requests(stations, cases, lon_col, lat_col, date_col, distance=distance)