Ashoka74 commited on
Commit
a0b5bde
·
verified ·
1 Parent(s): 6d2b558

Update rag_search.py

Browse files
Files changed (1) hide show
  1. rag_search.py +43 -1
rag_search.py CHANGED
@@ -157,7 +157,8 @@ def plot_line(df, x_column, y_columns, figsize=(12, 10), color='orange', title=N
157
 
158
  return fig
159
 
160
- def plot_bar(df, x_column, y_column, figsize=(12, 10), color='orange', title=None):
 
161
  fig, ax = plt.subplots(figsize=figsize)
162
 
163
  sns.barplot(data=df, x=x_column, y=y_column, color=color, ax=ax)
@@ -169,6 +170,8 @@ def plot_bar(df, x_column, y_column, figsize=(12, 10), color='orange', title=Non
169
  ax.tick_params(axis='x', colors=color)
170
  ax.tick_params(axis='y', colors=color)
171
 
 
 
172
  # Remove background
173
  fig.patch.set_alpha(0)
174
  ax.patch.set_alpha(0)
@@ -311,6 +314,45 @@ def filter_dataframe(df: pd.DataFrame) -> pd.DataFrame:
311
  if len(user_date_input) == 2:
312
  user_date_input = tuple(map(pd.to_datetime, user_date_input))
313
  start_date, end_date = user_date_input
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  df_ = df_.loc[df_[column].between(start_date, end_date)]
315
 
316
  date_column = column
 
157
 
158
  return fig
159
 
160
+
161
+ def plot_bar(df, x_column, y_column, figsize=(12, 10), color='orange', title=None, rotation=45):
162
  fig, ax = plt.subplots(figsize=figsize)
163
 
164
  sns.barplot(data=df, x=x_column, y=y_column, color=color, ax=ax)
 
170
  ax.tick_params(axis='x', colors=color)
171
  ax.tick_params(axis='y', colors=color)
172
 
173
+ plt.xticks(rotation=rotation)
174
+
175
  # Remove background
176
  fig.patch.set_alpha(0)
177
  ax.patch.set_alpha(0)
 
314
  if len(user_date_input) == 2:
315
  user_date_input = tuple(map(pd.to_datetime, user_date_input))
316
  start_date, end_date = user_date_input
317
+
318
+ # Determine the most appropriate time unit for plot
319
+ time_units = {
320
+ 'year': df_[column].dt.year,
321
+ 'month': df_[column].dt.to_period('M'),
322
+ 'day': df_[column].dt.date
323
+ }
324
+ unique_counts = {unit: col.nunique() for unit, col in time_units.items()}
325
+ closest_to_36 = min(unique_counts, key=lambda k: abs(unique_counts[k] - 36))
326
+
327
+ # Group by the most appropriate time unit and count occurrences
328
+ grouped = df_.groupby(time_units[closest_to_36]).size().reset_index(name='count')
329
+ grouped.columns = [column, 'count']
330
+
331
+ # Create a complete date range
332
+ if closest_to_36 == 'year':
333
+ date_range = pd.date_range(start=f"{start_date.year}-01-01", end=f"{end_date.year}-12-31", freq='YS')
334
+ elif closest_to_36 == 'month':
335
+ date_range = pd.date_range(start=start_date.replace(day=1), end=end_date + pd.offsets.MonthEnd(0), freq='MS')
336
+ else: # day
337
+ date_range = pd.date_range(start=start_date, end=end_date, freq='D')
338
+
339
+ # Create a DataFrame with the complete date range
340
+ complete_range = pd.DataFrame({column: date_range})
341
+
342
+ # Convert the date column to the appropriate format based on closest_to_36
343
+ if closest_to_36 == 'year':
344
+ complete_range[column] = complete_range[column].dt.year
345
+ elif closest_to_36 == 'month':
346
+ complete_range[column] = complete_range[column].dt.to_period('M')
347
+
348
+ # Merge the complete range with the grouped data
349
+ final_data = pd.merge(complete_range, grouped, on=column, how='left').fillna(0)
350
+
351
+ with st.status(f"Date Distributions: {column}", expanded=False) as stat:
352
+ try:
353
+ st.pyplot(plot_bar(final_data, column, 'count'))
354
+ except Exception as e:
355
+ st.error(f"Error plotting bar chart: {e}")
356
  df_ = df_.loc[df_[column].between(start_date, end_date)]
357
 
358
  date_column = column