Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from uap_analyzer import UAPParser, UAPAnalyzer, UAPVisualizer | |
| # import ChartGen | |
| # from ChartGen import ChartGPT | |
| from Levenshtein import distance | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import confusion_matrix | |
| from stqdm import stqdm | |
| stqdm.pandas() | |
| import streamlit.components.v1 as components | |
| from dateutil import parser | |
| from sentence_transformers import SentenceTransformer | |
| import torch | |
| import squarify | |
| import matplotlib.colors as mcolors | |
| import textwrap | |
| import datamapplot | |
| import openai | |
| from openai import OpenAI | |
| import os | |
| import json | |
| # this is a test comment | |
| import plotly.graph_objects as go | |
| st.set_option('deprecation.showPyplotGlobalUse', False) | |
| from pandas.api.types import ( | |
| is_categorical_dtype, | |
| is_datetime64_any_dtype, | |
| is_numeric_dtype, | |
| is_object_dtype, | |
| ) | |
| def load_data(file_path, key='df'): | |
| return pd.read_hdf(file_path, key=key) | |
| def gemini_query(question, selected_data, gemini_key): | |
| if question == "": | |
| question = "Summarize the following data in relevant bullet points" | |
| import pathlib | |
| import textwrap | |
| import google.generativeai as genai | |
| from IPython.display import display | |
| from IPython.display import Markdown | |
| def to_markdown(text): | |
| text = text.replace('•', ' *') | |
| return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True)) | |
| # selected_data is a list | |
| # remove empty | |
| filtered = [str(x) for x in selected_data if str(x) != '' and x is not None] | |
| # make a string | |
| context = '\n'.join(filtered) | |
| genai.configure(api_key=gemini_key) | |
| query_model = genai.GenerativeModel('models/gemini-1.5-pro-latest') | |
| response = query_model.generate_content([f"{question}\n Answer based on this context: {context}\n\n"]) | |
| return(response.text) | |
| def plot_treemap(df, column, top_n=32): | |
| # Get the value counts and the top N labels | |
| value_counts = df[column].value_counts() | |
| top_labels = value_counts.iloc[:top_n].index | |
| # Use np.where to replace all values not in the top N with 'Other' | |
| revised_column = f'{column}_revised' | |
| df[revised_column] = np.where(df[column].isin(top_labels), df[column], 'Other') | |
| # Get the value counts including the 'Other' category | |
| sizes = df[revised_column].value_counts().values | |
| labels = df[revised_column].value_counts().index | |
| # Get a gradient of colors | |
| # colors = list(mcolors.TABLEAU_COLORS.values()) | |
| n_colors = len(sizes) | |
| colors = plt.cm.Oranges(np.linspace(0.3, 0.9, n_colors))[::-1] | |
| # Get % of each category | |
| percents = sizes / sizes.sum() | |
| # Prepare labels with percentages | |
| labels = [f'{label}\n {percent:.1%}' for label, percent in zip(labels, percents)] | |
| fig, ax = plt.subplots(figsize=(20, 12)) | |
| # Plot the treemap | |
| squarify.plot(sizes=sizes, label=labels, alpha=0.7, pad=True, color=colors, text_kwargs={'fontsize': 10}) | |
| ax = plt.gca() | |
| # Iterate over text elements and rectangles (patches) in the axes for color adjustment | |
| for text, rect in zip(ax.texts, ax.patches): | |
| background_color = rect.get_facecolor() | |
| r, g, b, _ = mcolors.to_rgba(background_color) | |
| brightness = np.average([r, g, b]) | |
| text.set_color('white' if brightness < 0.5 else 'black') | |
| # Adjust font size based on rectangle's area and wrap long text | |
| st.set_option('deprecation.showPyplotGlobalUse', False) | |
| from pandas.api.types import ( | |
| is_categorical_dtype, | |
| is_datetime64_any_dtype, | |
| is_numeric_dtype, | |
| is_object_dtype, | |
| ) | |
| class CachedUAPParser(UAPParser): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| if 'parsed_responses' not in st.session_state: | |
| st.session_state['parsed_responses'] = {} | |
| def parse_responses(self): | |
| parsed_responses = {} | |
| not_parsed = 0 | |
| try: | |
| for k, v in self.responses.items(): | |
| try: | |
| parsed_responses[k] = json.loads(v) | |
| except: | |
| try: | |
| parsed_responses[k] = json.loads(v.replace("'", '"')) | |
| except: | |
| not_parsed += 1 | |
| # Update the cached responses | |
| st.session_state['parsed_responses'] = parsed_responses | |
| except Exception as e: | |
| st.error(f"Error parsing responses: {e}") | |
| st.write(f"Number of unparsed responses: {not_parsed}") | |
| st.write(f"Number of parsed responses: {len(parsed_responses)}") | |
| return st.session_state['parsed_responses'] | |
| def responses_to_df(self, col, parsed_responses): | |
| try: | |
| parsed_df = pd.DataFrame(parsed_responses).T | |
| if col is not None: | |
| parsed_df2 = pd.json_normalize(parsed_df[col]) | |
| parsed_df2.index = parsed_df.index | |
| else: | |
| parsed_df2 = pd.json_normalize(parsed_df) | |
| parsed_df2.index = parsed_df.index | |
| # Convert problematic columns to string | |
| for column in parsed_df2.columns: | |
| if parsed_df2[column].dtype == 'object': | |
| parsed_df2[column] = parsed_df2[column].astype(str) | |
| return parsed_df2 | |
| except Exception as e: | |
| st.error(f"Error converting responses to DataFrame: {e}") | |
| return pd.DataFrame() # Return an empty DataFrame if conversion fails | |
| def load_data(file_path, key='df'): | |
| return pd.read_hdf(file_path, key=key) | |
| def gemini_query(question, selected_data, gemini_key): | |
| if question == "": | |
| question = "Summarize the following data in relevant bullet points" | |
| import pathlib | |
| import textwrap | |
| import google.generativeai as genai | |
| from IPython.display import display | |
| from IPython.display import Markdown | |
| def to_markdown(text): | |
| text = text.replace('•', ' *') | |
| return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True)) | |
| # selected_data is a list | |
| # remove empty | |
| filtered = [str(x) for x in selected_data if str(x) != '' and x is not None] | |
| # make a string | |
| context = '\n'.join(filtered) | |
| genai.configure(api_key=gemini_key) | |
| query_model = genai.GenerativeModel('models/gemini-1.5-pro-latest') | |
| response = query_model.generate_content([f"{question}\n Answer based on this context: {context}\n\n"]) | |
| return(response.text) | |
| def plot_hist(df, column, bins=10, kde=True): | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| sns.histplot(data=df, x=column, kde=True, bins=bins,color='orange') | |
| # set the ticks and frame in orange | |
| ax.spines['bottom'].set_color('orange') | |
| ax.spines['top'].set_color('orange') | |
| ax.spines['right'].set_color('orange') | |
| ax.spines['left'].set_color('orange') | |
| ax.xaxis.label.set_color('orange') | |
| ax.yaxis.label.set_color('orange') | |
| ax.tick_params(axis='x', colors='orange') | |
| ax.tick_params(axis='y', colors='orange') | |
| ax.title.set_color('orange') | |
| # Set transparent background | |
| fig.patch.set_alpha(0) | |
| ax.patch.set_alpha(0) | |
| return fig | |
| def is_api_key_valid(api_key, model='gpt-3.5-turbo'): | |
| try: | |
| os.environ['OPENAI_API_KEY'] = api_key | |
| client = OpenAI() | |
| response = client.chat.completions.create( | |
| model=model, | |
| messages=[{"role": "user", "content": 'Say Hello World!'}]) | |
| text = response.choices[0].message.content | |
| if len(text) >= 0: | |
| return True | |
| except Exception as e: | |
| st.error(f'Error with the API key :{e}') | |
| return False | |
| def download_json(data): | |
| json_str = json.dumps(data, indent=2) | |
| return json_str | |
| def convert_cached_data_to_df(parser): | |
| if 'parsed_responses' in st.session_state: | |
| #parser = CachedUAPParser(api_key=API_KEY, model='gpt-3.5-turbo-0125') | |
| try: | |
| responses_df = parser.responses_to_df('sightingDetails', st.session_state['parsed_responses']) | |
| except Exception as e: | |
| st.warning(f"Error parsing with 'sightingDetails': {e}") | |
| responses_df = parser.responses_to_df(None, st.session_state['parsed_responses']) | |
| if not responses_df.empty: | |
| st.dataframe(responses_df) | |
| st.session_state['parsed_responses_df'] = responses_df.copy() | |
| st.success("Successfully converted cached data to DataFrame.") | |
| else: | |
| st.error("Failed to create DataFrame from cached responses.") | |
| else: | |
| st.warning("No cached data available. Please parse the dataset first.") | |
| def plot_line(df, x_column, y_columns, figsize=(12, 10), color='orange', title=None, rolling_mean_value=2): | |
| import matplotlib.cm as cm | |
| # Sort the dataframe by the date column | |
| df = df.sort_values(by=x_column) | |
| # Calculate rolling mean for each y_column | |
| if rolling_mean_value: | |
| df[y_columns] = df[y_columns].rolling(len(df) // rolling_mean_value).mean() | |
| # Create the plot | |
| fig, ax = plt.subplots(figsize=figsize) | |
| colors = cm.Oranges(np.linspace(0.2, 1, len(y_columns))) | |
| # Plot each y_column as a separate line with a different color | |
| for i, y_column in enumerate(y_columns): | |
| df.plot(x=x_column, y=y_column, ax=ax, color=colors[i], label=y_column, linewidth=.5) | |
| # Rotate x-axis labels | |
| ax.set_xticklabels(ax.get_xticklabels(), rotation=30, ha='right') | |
| # Format x_column as date if it is | |
| if np.issubdtype(df[x_column].dtype, np.datetime64) or np.issubdtype(df[x_column].dtype, np.timedelta64): | |
| df[x_column] = pd.to_datetime(df[x_column]).dt.date | |
| # Set title, labels, and legend | |
| ax.set_title(title or f'{", ".join(y_columns)} over {x_column}', color=color, fontweight='bold') | |
| ax.set_xlabel(x_column, color=color) | |
| ax.set_ylabel(', '.join(y_columns), color=color) | |
| ax.spines['bottom'].set_color('orange') | |
| ax.spines['top'].set_color('orange') | |
| ax.spines['right'].set_color('orange') | |
| ax.spines['left'].set_color('orange') | |
| ax.xaxis.label.set_color('orange') | |
| ax.yaxis.label.set_color('orange') | |
| ax.tick_params(axis='x', colors='orange') | |
| ax.tick_params(axis='y', colors='orange') | |
| ax.title.set_color('orange') | |
| ax.legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.4, labelcolor='orange', edgecolor='orange') | |
| # Remove background | |
| fig.patch.set_alpha(0) | |
| ax.patch.set_alpha(0) | |
| return fig | |
| def plot_bar(df, x_column, y_column, figsize=(12, 10), color='orange', title=None): | |
| fig, ax = plt.subplots(figsize=figsize) | |
| sns.barplot(data=df, x=x_column, y=y_column, color=color, ax=ax) | |
| # Rotate x-axis labels | |
| ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right') | |
| ax.set_title(title if title else f'{y_column} by {x_column}', color=color, fontweight='bold') | |
| ax.set_xlabel(x_column, color=color) | |
| ax.set_ylabel(y_column, color=color) | |
| ax.tick_params(axis='x', colors=color) | |
| ax.tick_params(axis='y', colors=color) | |
| # Remove background | |
| fig.patch.set_alpha(0) | |
| ax.patch.set_alpha(0) | |
| ax.spines['bottom'].set_color('orange') | |
| ax.spines['top'].set_color('orange') | |
| ax.spines['right'].set_color('orange') | |
| ax.spines['left'].set_color('orange') | |
| ax.xaxis.label.set_color('orange') | |
| ax.yaxis.label.set_color('orange') | |
| ax.tick_params(axis='x', colors='orange') | |
| ax.tick_params(axis='y', colors='orange') | |
| ax.title.set_color('orange') | |
| ax.legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.4, labelcolor='orange', edgecolor='orange') | |
| return fig | |
| def plot_grouped_bar(df, x_columns, y_column, figsize=(12, 10), colors=None, title=None): | |
| fig, ax = plt.subplots(figsize=figsize) | |
| width = 0.8 / len(x_columns) # the width of the bars | |
| x = np.arange(len(df)) # the label locations | |
| for i, x_column in enumerate(x_columns): | |
| sns.barplot(data=df, x=x, y=y_column, color=colors[i] if colors else None, ax=ax, width=width, label=x_column) | |
| x += width # add the width of the bar to the x position for the next bar | |
| ax.set_title(title if title else f'{y_column} by {", ".join(x_columns)}', color='orange', fontweight='bold') | |
| ax.set_xlabel('Groups', color='orange') | |
| ax.set_ylabel(y_column, color='orange') | |
| ax.set_xticks(x - width * len(x_columns) / 2) | |
| ax.set_xticklabels(df.index) | |
| ax.tick_params(axis='x', colors='orange') | |
| ax.tick_params(axis='y', colors='orange') | |
| # Remove background | |
| fig.patch.set_alpha(0) | |
| ax.patch.set_alpha(0) | |
| ax.spines['bottom'].set_color('orange') | |
| ax.spines['top'].set_color('orange') | |
| ax.spines['right'].set_color('orange') | |
| ax.spines['left'].set_color('orange') | |
| ax.xaxis.label.set_color('orange') | |
| ax.yaxis.label.set_color('orange') | |
| ax.title.set_color('orange') | |
| ax.legend(loc='upper right', bbox_to_anchor=(1, 1), facecolor='black', framealpha=.4, labelcolor='orange', edgecolor='orange') | |
| return fig | |
| def convert_df(df): | |
| # IMPORTANT: Cache the conversion to prevent computation on every rerun | |
| try: | |
| csv = df.to_csv().encode("utf-8") | |
| except: | |
| csv = df.to_csv().encode("utf-8-sig") | |
| return csv | |
| def filter_dataframe(df: pd.DataFrame) -> pd.DataFrame: | |
| """ | |
| Adds a UI on top of a dataframe to let viewers filter columns | |
| Args: | |
| df (pd.DataFrame): Original dataframe | |
| Returns: | |
| pd.DataFrame: Filtered dataframe | |
| """ | |
| title_font = "Arial" | |
| body_font = "Arial" | |
| title_size = 32 | |
| colors = ["red", "green", "blue"] | |
| interpretation = False | |
| extract_docx = False | |
| title = "My Chart" | |
| regex = ".*" | |
| img_path = 'default_image.png' | |
| #try: | |
| # modify = st.checkbox("Add filters on raw data") | |
| #except: | |
| # try: | |
| # modify = st.checkbox("Add filters on processed data") | |
| # except: | |
| # try: | |
| # modify = st.checkbox("Add filters on parsed data") | |
| # except: | |
| # pass | |
| #if not modify: | |
| # return df | |
| df_ = df.copy() | |
| # Try to convert datetimes into a standard format (datetime, no timezone) | |
| #modification_container = st.container() | |
| #with modification_container: | |
| to_filter_columns = st.multiselect("Filter dataframe on", df_.columns) | |
| date_column = None | |
| filtered_columns = [] | |
| for column in to_filter_columns: | |
| left, right = st.columns((1, 20)) | |
| # Treat columns with < 200 unique values as categorical if not date or numeric | |
| if is_categorical_dtype(df_[column]) or (df_[column].nunique() < 120 and not is_datetime64_any_dtype(df_[column]) and not is_numeric_dtype(df_[column])): | |
| user_cat_input = right.multiselect( | |
| f"Values for {column}", | |
| df_[column].value_counts().index.tolist(), | |
| default=list(df_[column].value_counts().index) | |
| ) | |
| df_ = df_[df_[column].isin(user_cat_input)] | |
| filtered_columns.append(column) | |
| with st.status(f"Category Distribution: {column}", expanded=False) as stat: | |
| st.pyplot(plot_treemap(df_, column)) | |
| elif is_numeric_dtype(df_[column]): | |
| _min = float(df_[column].min()) | |
| _max = float(df_[column].max()) | |
| step = (_max - _min) / 100 | |
| user_num_input = right.slider( | |
| f"Values for {column}", | |
| min_value=_min, | |
| max_value=_max, | |
| value=(_min, _max), | |
| step=step, | |
| ) | |
| df_ = df_[df_[column].between(*user_num_input)] | |
| filtered_columns.append(column) | |
| # Chart_GPT = ChartGPT(df_, title_font, body_font, title_size, | |
| # colors, interpretation, extract_docx, img_path) | |
| with st.status(f"Numerical Distribution: {column}", expanded=False) as stat_: | |
| st.pyplot(plot_hist(df_, column, bins=int(round(len(df_[column].unique())-1)/2))) | |
| elif is_object_dtype(df_[column]): | |
| try: | |
| df_[column] = pd.to_datetime(df_[column], infer_datetime_format=True, errors='coerce') | |
| except Exception: | |
| try: | |
| df_[column] = df_[column].apply(parser.parse) | |
| except Exception: | |
| pass | |
| if is_datetime64_any_dtype(df_[column]): | |
| df_[column] = df_[column].dt.tz_localize(None) | |
| min_date = df_[column].min().date() | |
| max_date = df_[column].max().date() | |
| user_date_input = right.date_input( | |
| f"Values for {column}", | |
| value=(min_date, max_date), | |
| min_value=min_date, | |
| max_value=max_date, | |
| ) | |
| # if len(user_date_input) == 2: | |
| # start_date, end_date = user_date_input | |
| # df_ = df_.loc[df_[column].dt.date.between(start_date, end_date)] | |
| if len(user_date_input) == 2: | |
| user_date_input = tuple(map(pd.to_datetime, user_date_input)) | |
| start_date, end_date = user_date_input | |
| # Determine the most appropriate time unit for plot | |
| time_units = { | |
| 'year': df_[column].dt.year, | |
| 'month': df_[column].dt.to_period('M'), | |
| 'day': df_[column].dt.date | |
| } | |
| unique_counts = {unit: col.nunique() for unit, col in time_units.items()} | |
| closest_to_36 = min(unique_counts, key=lambda k: abs(unique_counts[k] - 36)) | |
| # Group by the most appropriate time unit and count occurrences | |
| grouped = df_.groupby(time_units[closest_to_36]).size().reset_index(name='count') | |
| grouped.columns = [column, 'count'] | |
| # Create a complete date range | |
| if closest_to_36 == 'year': | |
| date_range = pd.date_range(start=f"{start_date.year}-01-01", end=f"{end_date.year}-12-31", freq='YS') | |
| elif closest_to_36 == 'month': | |
| date_range = pd.date_range(start=start_date.replace(day=1), end=end_date + pd.offsets.MonthEnd(0), freq='MS') | |
| else: # day | |
| date_range = pd.date_range(start=start_date, end=end_date, freq='D') | |
| # Create a DataFrame with the complete date range | |
| complete_range = pd.DataFrame({column: date_range}) | |
| # Convert the date column to the appropriate format based on closest_to_36 | |
| if closest_to_36 == 'year': | |
| complete_range[column] = complete_range[column].dt.year | |
| elif closest_to_36 == 'month': | |
| complete_range[column] = complete_range[column].dt.to_period('M') | |
| # Merge the complete range with the grouped data | |
| final_data = pd.merge(complete_range, grouped, on=column, how='left').fillna(0) | |
| with st.status(f"Date Distributions: {column}", expanded=False) as stat: | |
| try: | |
| st.pyplot(plot_bar(final_data, column, 'count')) | |
| except Exception as e: | |
| st.error(f"Error plotting bar chart: {e}") | |
| df_ = df_.loc[df_[column].between(start_date, end_date)] | |
| date_column = column | |
| if date_column and filtered_columns: | |
| numeric_columns = [col for col in filtered_columns if is_numeric_dtype(df_[col])] | |
| if numeric_columns: | |
| fig = plot_line(df_, date_column, numeric_columns) | |
| #st.pyplot(fig) | |
| # now to deal with categorical columns | |
| categorical_columns = [col for col in filtered_columns if is_categorical_dtype(df_[col])] | |
| if categorical_columns: | |
| fig2 = plot_bar(df_, date_column, categorical_columns[0]) | |
| #st.pyplot(fig2) | |
| with st.status(f"Date Distribution: {column}", expanded=False) as stat: | |
| try: | |
| st.pyplot(fig) | |
| except Exception as e: | |
| st.error(f"Error plotting line chart: {e}") | |
| pass | |
| try: | |
| st.pyplot(fig2) | |
| except Exception as e: | |
| st.error(f"Error plotting bar chart: {e}") | |
| else: | |
| user_text_input = right.text_input( | |
| f"Substring or regex in {column}", | |
| ) | |
| if user_text_input: | |
| df_ = df_[df_[column].astype(str).str.contains(user_text_input)] | |
| # write len of df after filtering with % of original | |
| st.write(f"{len(df_)} rows ({len(df_) / len(df) * 100:.2f}%)") | |
| return df_ | |
| from config import API_KEY, GEMINI_KEY, FORMAT_LONG | |
| with torch.no_grad(): | |
| torch.cuda.empty_cache() | |
| #st.set_page_config( | |
| # page_title="UAP ANALYSIS", | |
| # page_icon=":alien:", | |
| # layout="wide", | |
| # initial_sidebar_state="expanded", | |
| #) | |
| st.title('UAP Feature Extraction') | |
| # Initialize session state | |
| if 'analyzers' not in st.session_state: | |
| st.session_state['analyzers'] = [] | |
| if 'col_names' not in st.session_state: | |
| st.session_state['col_names'] = [] | |
| if 'clusters' not in st.session_state: | |
| st.session_state['clusters'] = {} | |
| if 'new_data' not in st.session_state: | |
| st.session_state['new_data'] = pd.DataFrame() | |
| if 'dataset' not in st.session_state: | |
| st.session_state['dataset'] = pd.DataFrame() | |
| if 'data_processed' not in st.session_state: | |
| st.session_state['data_processed'] = False | |
| if 'stage' not in st.session_state: | |
| st.session_state['stage'] = 0 | |
| if 'filtered_data' not in st.session_state: | |
| st.session_state['filtered_data'] = None | |
| if 'gemini_answer' not in st.session_state: | |
| st.session_state['gemini_answer'] = None | |
| if 'parsed_responses' not in st.session_state: | |
| st.session_state['parsed_responses'] = None | |
| if 'parsed_responses_df' not in st.session_state: | |
| st.session_state['parsed_responses_df'] = None | |
| if 'json_format' not in st.session_state: | |
| st.session_state['json_format'] = None | |
| if 'api_key_valid' not in st.session_state: | |
| st.session_state['api_key_valid'] = False | |
| if 'previous_api_key' not in st.session_state: | |
| st.session_state['previous_api_key'] = None | |
| # Unparsed data | |
| #unparsed_tickbox = st.checkbox('Data Parsing') | |
| #if unparsed_tickbox: | |
| unparsed = st.file_uploader("Upload Raw DataFrame", type=["csv", "xlsx"]) | |
| if unparsed is not None: | |
| try: | |
| data = pd.read_csv(unparsed) if unparsed.type == "text/csv" else pd.read_excel(unparsed) | |
| filtered_data = filter_dataframe(data) | |
| st.dataframe(filtered_data) | |
| except Exception as e: | |
| st.error(f"An error occurred while reading the file: {e}") | |
| modify_json = st.checkbox('Custom JSON') | |
| API_KEY = st.text_input('OpenAI API Key', API_KEY, type='password', help="Enter your OpenAI API key") | |
| if modify_json: | |
| FORMAT_LONG = st.text_area('Custom JSON', FORMAT_LONG, height=500) | |
| st.download_button("Save Format", FORMAT_LONG) | |
| try: | |
| json.loads(FORMAT_LONG) | |
| st.session_state['json_format'] = True | |
| except json.JSONDecodeError as e: | |
| st.error(f"Invalid JSON format: {str(e)}") | |
| st.session_state['json_format'] = False | |
| st.stop() # Stop execution if JSON is invalid | |
| # If the DataFrame is successfully created, allow the user to select a column | |
| col_unparsed = st.selectbox("Select column corresponding to text", data.columns) | |
| if st.button("Parse Dataset") and st.session_state['json_format']: | |
| if API_KEY: | |
| # Only validate if the API key has changed | |
| if API_KEY != st.session_state['previous_api_key']: | |
| if is_api_key_valid(API_KEY): | |
| st.session_state['api_key_valid'] = True | |
| st.session_state['previous_api_key'] = API_KEY | |
| st.success("API key is valid!") | |
| else: | |
| st.session_state['api_key_valid'] = False | |
| st.error("Invalid API key. Please check and try again.") | |
| elif st.session_state['api_key_valid']: | |
| st.success("API key is valid!") | |
| if not API_KEY:# or not st.session_state['api_key_valid']: | |
| st.warning("Please enter your API key to proceed.") | |
| st.stop() | |
| selected_column_data = filtered_data[col_unparsed].tolist() | |
| st.session_state.result = selected_column_data | |
| with st.status("Parsing...", expanded=True) as stat: | |
| try: | |
| st.write("Parsing descriptions...") | |
| parser = CachedUAPParser(api_key=API_KEY, model='gpt-3.5-turbo-0125', col=st.session_state.result) | |
| descriptions = st.session_state.result | |
| format_long = FORMAT_LONG | |
| parser.process_descriptions(descriptions, format_long) | |
| st.session_state['parsed_responses'] = parser.parse_responses() | |
| try: | |
| responses_df = parser.responses_to_df('sightingDetails', st.session_state['parsed_responses']) | |
| except Exception as e: | |
| st.warning(f"Error parsing with 'sightingDetails': {e}") | |
| responses_df = parser.responses_to_df(None, st.session_state['parsed_responses']) | |
| if not responses_df.empty: | |
| st.dataframe(responses_df) | |
| st.session_state['parsed_responses_df'] = responses_df.copy() | |
| stat.update(label="Parsing complete", state="complete", expanded=False) | |
| else: | |
| st.error("Failed to create DataFrame from parsed responses.") | |
| except Exception as e: | |
| st.error(f"An error occurred during parsing: {str(e)}") | |
| # Add download button for parsed data | |
| if st.session_state['parsed_responses'] is not None: | |
| json_str = download_json(st.session_state['parsed_responses']) | |
| st.download_button( | |
| label="Download Parsed Data as JSON", | |
| data=json_str, | |
| file_name="parsed_responses.json", | |
| mime="application/json" | |
| ) | |
| # Add button to convert cached data to DataFrame | |
| if st.button("Convert Cached Data to DataFrame"): | |
| convert_cached_data_to_df(st.session_state['parsed_responses']) | |
| if st.session_state['parsed_responses_df'] is not None: | |
| st.download_button( | |
| label="Save CSV", | |
| data=convert_df(st.session_state['parsed_responses_df']), | |
| file_name="uap_data.csv", | |
| mime="text/csv", | |
| ) | |
| #except Exception as e: | |
| # stat.update(label=f"Parsing failed: {e}", state="error") | |
| # st.write("Parsing descriptions...") | |
| # st.update_status("Parsing descriptions...") | |