obichimav commited on
Commit
3d7db32
Β·
verified Β·
1 Parent(s): 29db85a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +1424 -0
app.py ADDED
@@ -0,0 +1,1424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ import requests
5
+ import numpy as np
6
+ import pandas as pd
7
+ from datetime import datetime
8
+ from typing import Dict, List, Any, Optional, Tuple
9
+ import gradio as gr
10
+ from dotenv import load_dotenv
11
+
12
+ # Vector DB and embedding imports
13
+ from langchain.vectorstores import FAISS
14
+ from langchain_openai import OpenAIEmbeddings
15
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
16
+ from langchain.schema import Document
17
+ from langchain_openai import ChatOpenAI
18
+ from langchain.chains import ConversationalRetrievalChain
19
+ from langchain.memory import ConversationBufferMemory
20
+
21
+ # Visualization imports
22
+ import plotly.graph_objects as go
23
+ from sklearn.manifold import TSNE
24
+
25
+ # Load environment variables
26
+ load_dotenv()
27
+
28
+ # Check if OPENAI_API_KEY is set
29
+ OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
30
+ if not OPENAI_API_KEY:
31
+ print("⚠️ Warning: OPENAI_API_KEY not found in environment variables.")
32
+
33
+ # Configuration
34
+ DEFAULT_DATASET_ID = "2457ea29-fc82-48b0-86ec-3b0755de7515"
35
+ DEFAULT_MODEL = "gpt-4o-mini"
36
+ API_BASE_URL = "https://data.cms.gov/data-api/v1"
37
+ INITIAL_SAMPLE_SIZE = 100 # Start with a small sample
38
+
39
+ # Dataset version mapping
40
+ DATASET_VERSIONS = {
41
+ # 2025 Data
42
+ "Q1 2025": "74edb053-bd01-40a0-91a0-4961c1fe6281",
43
+
44
+ # 2024 Data
45
+ "Q1 2024": "6d6e0e8d-64cf-43fb-9ba8-e2ad9b9bb21e",
46
+ "Q2 2024": "04405289-5635-4b2a-a64f-c4b6415ab6ff",
47
+ "Q3 2024": "e87f09c2-5ff7-4ddf-b60c-6130995b15cf",
48
+ "Q4 2024": "e9d278e4-90e8-47ab-9c5b-af2ca64bf352",
49
+
50
+ # 2023 Data
51
+ "Q1 2023": "0b6caf2f-8948-4603-922e-d7f0c52c0a45",
52
+ "Q2 2023": "46339a0c-0f07-40ed-8975-ddb387c367a4",
53
+ "Q3 2023": "70efac57-6093-4e1d-ad6a-36f8261f53eb",
54
+ "Q4 2023": "1df8331a-ed44-41ec-971f-158349658949",
55
+
56
+ # 2022 Data
57
+ "Q1 2022": "5b678653-aa36-455b-9144-1d073ef7991b",
58
+
59
+ # 2021 Data
60
+ "Q1 2021": "7b409bba-ca00-426e-9493-1dc10e5340cc",
61
+
62
+ # 2020 Data
63
+ "Q1 2020": "3870b29c-4312-4fb1-a956-71c148ae5b50",
64
+
65
+ # 2019 Data
66
+ "Q1 2019": "017e6ab7-7e19-4e98-b4fa-30578b47e578",
67
+ "Q4 2019": "2c209bdb-ed0c-42e0-b027-8a97024b8035"
68
+ }
69
+
70
+ # US States for reference
71
+ US_STATES = [
72
+ "", "AL", "AK", "AZ", "AR", "CA", "CO", "CT", "DE", "FL", "GA",
73
+ "HI", "ID", "IL", "IN", "IA", "KS", "KY", "LA", "ME", "MD",
74
+ "MA", "MI", "MN", "MS", "MO", "MT", "NE", "NV", "NH", "NJ",
75
+ "NM", "NY", "NC", "ND", "OH", "OK", "OR", "PA", "RI", "SC",
76
+ "SD", "TN", "TX", "UT", "VT", "VA", "WA", "WV", "WI", "WY",
77
+ "DC", "PR", "VI"
78
+ ]
79
+
80
+ # State names mapping for better UI
81
+ STATE_NAMES = {
82
+ "": "All States",
83
+ "AL": "Alabama", "AK": "Alaska", "AZ": "Arizona", "AR": "Arkansas",
84
+ "CA": "California", "CO": "Colorado", "CT": "Connecticut", "DE": "Delaware",
85
+ "FL": "Florida", "GA": "Georgia", "HI": "Hawaii", "ID": "Idaho",
86
+ "IL": "Illinois", "IN": "Indiana", "IA": "Iowa", "KS": "Kansas",
87
+ "KY": "Kentucky", "LA": "Louisiana", "ME": "Maine", "MD": "Maryland",
88
+ "MA": "Massachusetts", "MI": "Michigan", "MN": "Minnesota", "MS": "Mississippi",
89
+ "MO": "Missouri", "MT": "Montana", "NE": "Nebraska", "NV": "Nevada",
90
+ "NH": "New Hampshire", "NJ": "New Jersey", "NM": "New Mexico", "NY": "New York",
91
+ "NC": "North Carolina", "ND": "North Dakota", "OH": "Ohio", "OK": "Oklahoma",
92
+ "OR": "Oregon", "PA": "Pennsylvania", "RI": "Rhode Island", "SC": "South Carolina",
93
+ "SD": "South Dakota", "TN": "Tennessee", "TX": "Texas", "UT": "Utah",
94
+ "VT": "Vermont", "VA": "Virginia", "WA": "Washington", "WV": "West Virginia",
95
+ "WI": "Wisconsin", "WY": "Wyoming", "DC": "District of Columbia",
96
+ "PR": "Puerto Rico", "VI": "Virgin Islands"
97
+ }
98
+
99
+ # Dictionary to store multiple datasets
100
+ rag_systems = {}
101
+ current_dataset_key = None
102
+
103
+ # Gradio theme configuration
104
+ theme = gr.themes.Soft(
105
+ primary_hue="blue",
106
+ secondary_hue="gray",
107
+ neutral_hue="slate",
108
+ font=gr.themes.GoogleFont("Inter")
109
+ )
110
+
111
+ def query_cms_api(version_id, state_filter="", max_records=100):
112
+ """Query the CMS API with pagination."""
113
+ url = f"{API_BASE_URL}/dataset/{version_id}/data"
114
+ all_records = []
115
+ offset = 0
116
+ page_size = min(max_records, 100) # Page size, max 100
117
+
118
+ # Set up filter parameters
119
+ params = {
120
+ 'size': page_size,
121
+ 'offset': 0
122
+ }
123
+
124
+ # Add state filter if provided
125
+ if state_filter and state_filter != "":
126
+ params[f'filter[STATE_CD]'] = state_filter
127
+
128
+ progress_text = f"Querying CMS API...\n"
129
+
130
+ # Fetch data with pagination
131
+ while len(all_records) < max_records:
132
+ params['offset'] = offset
133
+
134
+ try:
135
+ response = requests.get(url, params=params)
136
+
137
+ if response.status_code != 200:
138
+ error_msg = f"Error: Status {response.status_code}"
139
+ return [], error_msg
140
+
141
+ # Parse the response - the API returns a list directly
142
+ records = response.json()
143
+
144
+ if not records or not isinstance(records, list):
145
+ if len(all_records) == 0:
146
+ return [], "No records found"
147
+ break
148
+
149
+ progress_text += f"Retrieved {len(records)} records (offset: {offset})\n"
150
+ all_records.extend(records)
151
+
152
+ # If we got fewer records than requested, we've reached the end
153
+ if len(records) < page_size:
154
+ break
155
+
156
+ # Move to next page
157
+ offset += len(records)
158
+
159
+ # Add delay to be nice to the API
160
+ time.sleep(0.5)
161
+
162
+ except Exception as e:
163
+ error_msg = f"Error querying API: {str(e)}"
164
+ return [], error_msg
165
+
166
+ final_records = all_records[:max_records]
167
+ success_msg = f"Successfully retrieved {len(final_records)} records"
168
+
169
+ return final_records, success_msg
170
+
171
+ def process_records(records, version):
172
+ """Process CMS API records into documents for the RAG system."""
173
+ # Parse version into quarter and year
174
+ quarter = "Unknown"
175
+ year = "Unknown"
176
+ if ' ' in version:
177
+ parts = version.split(' ')
178
+ if len(parts) == 2:
179
+ quarter, year = parts
180
+
181
+ embeddings = OpenAIEmbeddings()
182
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
183
+
184
+ # Convert records to documents
185
+ documents = []
186
+
187
+ for record in records:
188
+ # Format the record as text with explicit time information
189
+ content = [f"Medicare Provider Data from {quarter} {year}"]
190
+ content.append(f"Time Period: {quarter} of {year}")
191
+
192
+ # Add all fields from the record
193
+ for key, value in record.items():
194
+ if value is not None and value != "":
195
+ content.append(f"{key}: {value}")
196
+
197
+ text = "\n".join(content)
198
+
199
+ # Create metadata with explicit time fields
200
+ metadata = {
201
+ 'dataset_version': version,
202
+ 'quarter': quarter,
203
+ 'year': year,
204
+ 'record_id': record.get('ENRLMT_ID', 'unknown')
205
+ }
206
+
207
+ # Add all fields to metadata for better searchability
208
+ for key, value in record.items():
209
+ if value is not None and value != "":
210
+ try:
211
+ # Convert complex values to strings to avoid serialization issues
212
+ if not isinstance(value, (str, int, float, bool, type(None))):
213
+ metadata[key] = str(value)
214
+ else:
215
+ metadata[key] = value
216
+ except:
217
+ # If there's any issue, convert to string
218
+ metadata[key] = str(value)
219
+
220
+ documents.append(Document(page_content=text, metadata=metadata))
221
+
222
+ # Chunk documents
223
+ chunks = text_splitter.split_documents(documents)
224
+
225
+ # Create vector store
226
+ vector_store = FAISS.from_documents(chunks, embeddings)
227
+
228
+ return vector_store, len(documents), len(chunks)
229
+
230
+ def create_progress_callback():
231
+ """Create a progress callback for long-running operations."""
232
+ def callback(message):
233
+ # In a real Gradio app, this would update a progress bar
234
+ print(f"Progress: {message}")
235
+ return callback
236
+
237
+ def validate_api_key():
238
+ """Validate that the OpenAI API key is set."""
239
+ api_key = os.getenv('OPENAI_API_KEY')
240
+ if not api_key:
241
+ return False, "OpenAI API key not found. Please set it in your environment variables or .env file."
242
+ return True, "API key validated successfully."
243
+
244
+ def get_dataset_summary(rag_systems):
245
+ """Generate a summary of all loaded datasets."""
246
+ if not rag_systems:
247
+ return "No datasets currently loaded."
248
+
249
+ summary_lines = ["### Currently Loaded Datasets:\n"]
250
+
251
+ for i, (key, system) in enumerate(rag_systems.items(), 1):
252
+ meta = system['metadata']
253
+ summary_lines.append(
254
+ f"{i}. **{meta['dataset_version']}** - "
255
+ f"State: {meta['state_filter']} - "
256
+ f"Records: {meta['record_count']} - "
257
+ f"Chunks: {meta['chunk_count']}"
258
+ )
259
+
260
+ if key == current_dataset_key:
261
+ summary_lines[-1] += " *(Current)*"
262
+
263
+ summary_lines.append(f"\n**Total datasets loaded:** {len(rag_systems)}")
264
+
265
+ return "\n".join(summary_lines)
266
+
267
+ def format_state_options():
268
+ """Format state options for Gradio dropdown."""
269
+ options = []
270
+ for code in US_STATES:
271
+ if code == "":
272
+ options.append(("All States", ""))
273
+ else:
274
+ options.append((f"{STATE_NAMES[code]} ({code})", code))
275
+ return options
276
+
277
+ def load_dataset_gradio(version, state_filter, max_records, use_sample):
278
+ """Load data from CMS API and set up the RAG system - Gradio version."""
279
+ global rag_systems, current_dataset_key
280
+
281
+ # Validate API key first
282
+ valid, message = validate_api_key()
283
+ if not valid:
284
+ return message, get_dataset_summary(rag_systems)
285
+
286
+ # Generate a unique key for this dataset
287
+ dataset_key = f"{version}_{state_filter}_{max_records}"
288
+
289
+ # Check if dataset already loaded
290
+ if dataset_key in rag_systems:
291
+ current_dataset_key = dataset_key
292
+ return f"βœ… Dataset already loaded and set as current: {version} - {STATE_NAMES.get(state_filter, 'All States')}", get_dataset_summary(rag_systems)
293
+
294
+ # Get version ID
295
+ version_id = DATASET_VERSIONS.get(version)
296
+ if not version_id:
297
+ return f"❌ Invalid version: {version}", get_dataset_summary(rag_systems)
298
+
299
+ # Adjust max records if sample
300
+ actual_max = INITIAL_SAMPLE_SIZE if use_sample else max_records
301
+
302
+ # Status message
303
+ status_msg = f"πŸ”„ Loading {version} data"
304
+ if state_filter:
305
+ status_msg += f" for {STATE_NAMES.get(state_filter, state_filter)}"
306
+ status_msg += f" (max {actual_max} records)..."
307
+
308
+ try:
309
+ # Fetch data from API
310
+ records, api_message = query_cms_api(version_id, state_filter, actual_max)
311
+
312
+ if not records:
313
+ return f"❌ Failed to load data: {api_message}", get_dataset_summary(rag_systems)
314
+
315
+ status_msg += f"\nβœ… {api_message}"
316
+
317
+ # Process records and create vector store
318
+ status_msg += "\nπŸ”„ Processing records and creating vector store..."
319
+ vector_store, doc_count, chunk_count = process_records(records, version)
320
+
321
+ # Set up RAG system
322
+ llm = ChatOpenAI(temperature=0.7, model_name=DEFAULT_MODEL)
323
+ memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)
324
+ retriever = vector_store.as_retriever()
325
+
326
+ conversation_chain = ConversationalRetrievalChain.from_llm(
327
+ llm=llm,
328
+ retriever=retriever,
329
+ memory=memory
330
+ )
331
+
332
+ # Store in the dictionary
333
+ rag_systems[dataset_key] = {
334
+ 'vector_store': vector_store,
335
+ 'conversation_chain': conversation_chain,
336
+ 'metadata': {
337
+ 'dataset_version': version,
338
+ 'version_id': version_id,
339
+ 'state_filter': STATE_NAMES.get(state_filter, "All States") if state_filter else "All States",
340
+ 'record_count': len(records),
341
+ 'document_count': doc_count,
342
+ 'chunk_count': chunk_count,
343
+ 'loaded_at': datetime.now().isoformat()
344
+ }
345
+ }
346
+
347
+ # Set as current dataset
348
+ current_dataset_key = dataset_key
349
+
350
+ success_msg = f"βœ… Successfully loaded {version} - {STATE_NAMES.get(state_filter, 'All States')}\n"
351
+ success_msg += f"πŸ“Š Created {chunk_count} chunks from {len(records)} records"
352
+
353
+ return success_msg, get_dataset_summary(rag_systems)
354
+
355
+ except Exception as e:
356
+ error_msg = f"❌ Error loading data: {str(e)}"
357
+ return error_msg, get_dataset_summary(rag_systems)
358
+
359
+ def switch_dataset_gradio(dataset_index):
360
+ """Switch to a different dataset - Gradio version."""
361
+ global rag_systems, current_dataset_key
362
+
363
+ if not rag_systems:
364
+ return "❌ No datasets loaded.", get_dataset_summary(rag_systems)
365
+
366
+ if not dataset_index:
367
+ return "❌ Please select a dataset.", get_dataset_summary(rag_systems)
368
+
369
+ try:
370
+ # Parse the index from the selection (format: "1. Dataset Name")
371
+ index = int(dataset_index.split(".")[0])
372
+
373
+ if 1 <= index <= len(rag_systems):
374
+ key = list(rag_systems.keys())[index - 1]
375
+ current_dataset_key = key
376
+ meta = rag_systems[key]['metadata']
377
+ return f"βœ… Switched to: {meta['dataset_version']} - {meta['state_filter']}", get_dataset_summary(rag_systems)
378
+ else:
379
+ return f"❌ Invalid selection.", get_dataset_summary(rag_systems)
380
+ except:
381
+ return "❌ Invalid selection format.", get_dataset_summary(rag_systems)
382
+
383
+ def remove_dataset_gradio(dataset_index):
384
+ """Remove a dataset from memory - Gradio version."""
385
+ global rag_systems, current_dataset_key
386
+
387
+ if not rag_systems:
388
+ return "❌ No datasets loaded.", get_dataset_summary(rag_systems)
389
+
390
+ if not dataset_index:
391
+ return "❌ Please select a dataset to remove.", get_dataset_summary(rag_systems)
392
+
393
+ try:
394
+ # Parse the index from the selection
395
+ index = int(dataset_index.split(".")[0])
396
+
397
+ if 1 <= index <= len(rag_systems):
398
+ key = list(rag_systems.keys())[index - 1]
399
+ meta = rag_systems[key]['metadata']
400
+
401
+ # Remove the dataset
402
+ del rag_systems[key]
403
+
404
+ # If this was the current dataset, clear the current key
405
+ if key == current_dataset_key:
406
+ current_dataset_key = None
407
+ # Set another dataset as current if available
408
+ if rag_systems:
409
+ current_dataset_key = list(rag_systems.keys())[0]
410
+
411
+ return f"βœ… Removed: {meta['dataset_version']} - {meta['state_filter']}", get_dataset_summary(rag_systems)
412
+ else:
413
+ return f"❌ Invalid selection.", get_dataset_summary(rag_systems)
414
+ except Exception as e:
415
+ return f"❌ Error removing dataset: {str(e)}", get_dataset_summary(rag_systems)
416
+
417
+ def get_dataset_choices():
418
+ """Get formatted dataset choices for Gradio dropdown."""
419
+ if not rag_systems:
420
+ return []
421
+
422
+ choices = []
423
+ for i, (key, system) in enumerate(rag_systems.items(), 1):
424
+ meta = system['metadata']
425
+ choice_text = f"{i}. {meta['dataset_version']} - {meta['state_filter']} ({meta['record_count']} records)"
426
+ if key == current_dataset_key:
427
+ choice_text += " [CURRENT]"
428
+ choices.append(choice_text)
429
+
430
+ return choices
431
+
432
+ def clear_all_datasets_gradio():
433
+ """Clear all loaded datasets - Gradio version."""
434
+ global rag_systems, current_dataset_key
435
+
436
+ if not rag_systems:
437
+ return "ℹ️ No datasets to clear.", ""
438
+
439
+ count = len(rag_systems)
440
+ rag_systems.clear()
441
+ current_dataset_key = None
442
+
443
+ return f"βœ… Cleared {count} dataset(s) from memory.", ""
444
+
445
+ def get_current_dataset_info():
446
+ """Get information about the current dataset."""
447
+ global rag_systems, current_dataset_key
448
+
449
+ if not current_dataset_key or current_dataset_key not in rag_systems:
450
+ return "No dataset currently selected."
451
+
452
+ meta = rag_systems[current_dataset_key]['metadata']
453
+ info = f"**Current Dataset:** {meta['dataset_version']} - {meta['state_filter']}\n"
454
+ info += f"- Records: {meta['record_count']}\n"
455
+ info += f"- Chunks: {meta['chunk_count']}\n"
456
+ info += f"- Loaded: {meta['loaded_at'][:19]}"
457
+
458
+ return info
459
+
460
+ def ask_question_gradio(question, chat_history):
461
+ """Ask a question to the current dataset - Gradio version."""
462
+ global rag_systems, current_dataset_key
463
+
464
+ if not current_dataset_key or current_dataset_key not in rag_systems:
465
+ response = "❌ No dataset selected. Please load a dataset first."
466
+ chat_history.append((question, response))
467
+ return "", chat_history
468
+
469
+ # Get the dataset
470
+ system = rag_systems[current_dataset_key]
471
+ meta = system['metadata']
472
+
473
+ try:
474
+ # Use the chain to get a response
475
+ result = system['conversation_chain'].invoke({"question": question})
476
+ answer = result["answer"]
477
+
478
+ # Add dataset source information
479
+ answer += f"\n\n*Source: {meta['dataset_version']} - {meta['state_filter']} ({meta['record_count']} records)*"
480
+
481
+ # Update chat history
482
+ chat_history.append((question, answer))
483
+
484
+ return "", chat_history
485
+
486
+ except Exception as e:
487
+ error_response = f"❌ Error processing query: {str(e)}"
488
+ chat_history.append((question, error_response))
489
+ return "", chat_history
490
+
491
+ def ask_global_question_gradio(question, chat_history):
492
+ """Ask a question that might require knowledge from all loaded datasets."""
493
+ global rag_systems
494
+
495
+ if not rag_systems:
496
+ response = "❌ No datasets loaded. Please load datasets first."
497
+ chat_history.append((question, response))
498
+ return "", chat_history
499
+
500
+ # Check if this is a global question about the datasets themselves
501
+ global_keywords = ['how many', 'which years', 'what years', 'what quarters', 'how many years',
502
+ 'which quarters', 'time period', 'date range', 'all datasets', 'datasets',
503
+ 'compare', 'comparison', 'difference', 'trend', 'over time']
504
+
505
+ is_global_question = any(keyword in question.lower() for keyword in global_keywords)
506
+
507
+ # Check if the question mentions a specific state
508
+ mentioned_state = None
509
+ question_lower = question.lower()
510
+
511
+ # Check for state names
512
+ for code, name in STATE_NAMES.items():
513
+ if code and (code.lower() in question_lower or name.lower() in question_lower):
514
+ mentioned_state = code
515
+ break
516
+
517
+ try:
518
+ if mentioned_state and not is_global_question:
519
+ # Find all datasets for that state
520
+ suitable_datasets = []
521
+
522
+ for key, system in rag_systems.items():
523
+ meta = system['metadata']
524
+ state_filter = meta['state_filter']
525
+
526
+ # Check if this dataset matches the mentioned state
527
+ if mentioned_state in state_filter or STATE_NAMES[mentioned_state] in state_filter:
528
+ suitable_datasets.append(key)
529
+
530
+ if suitable_datasets:
531
+ response = f"πŸ”„ Found {len(suitable_datasets)} dataset(s) for {STATE_NAMES[mentioned_state]}:\n\n"
532
+
533
+ # Query each suitable dataset
534
+ all_results = []
535
+ for dataset_key in suitable_datasets:
536
+ system = rag_systems[dataset_key]
537
+ meta = system['metadata']
538
+
539
+ try:
540
+ result = system['conversation_chain'].invoke({"question": question})
541
+ answer = result["answer"]
542
+ all_results.append({
543
+ 'dataset': f"{meta['dataset_version']} - {meta['state_filter']}",
544
+ 'answer': answer
545
+ })
546
+ except Exception as e:
547
+ all_results.append({
548
+ 'dataset': f"{meta['dataset_version']} - {meta['state_filter']}",
549
+ 'answer': f"Error: {str(e)}"
550
+ })
551
+
552
+ # Format combined response
553
+ for result in all_results:
554
+ response += f"**{result['dataset']}**\n{result['answer']}\n\n---\n\n"
555
+
556
+ chat_history.append((question, response))
557
+ return "", chat_history
558
+ else:
559
+ response = f"ℹ️ No datasets found for {STATE_NAMES[mentioned_state]}. Please load data for this state first."
560
+ chat_history.append((question, response))
561
+ return "", chat_history
562
+
563
+ elif is_global_question:
564
+ # Create a summary of all available datasets
565
+ dataset_summary = generate_dataset_metadata_summary()
566
+
567
+ # Create a system message that includes this metadata
568
+ llm = ChatOpenAI(temperature=0.7, model_name=DEFAULT_MODEL)
569
+
570
+ system_message = f"""You are an expert on Medicare Provider data. You have access to multiple datasets spanning different quarters and years.
571
+
572
+ {dataset_summary}
573
+
574
+ When answering questions, consider the metadata about all available datasets. For questions about time periods, years, quarters, or trends, use the information about which datasets are loaded."""
575
+
576
+ messages = [
577
+ {"role": "system", "content": system_message},
578
+ {"role": "user", "content": question}
579
+ ]
580
+
581
+ response = llm.invoke(messages)
582
+ answer = response.content
583
+
584
+ chat_history.append((question, answer))
585
+ return "", chat_history
586
+
587
+ else:
588
+ # For non-global questions without specific state mention, use the current dataset
589
+ return ask_question_gradio(question, chat_history)
590
+
591
+ except Exception as e:
592
+ error_response = f"❌ Error processing global query: {str(e)}"
593
+ chat_history.append((question, error_response))
594
+ return "", chat_history
595
+
596
+ def generate_dataset_metadata_summary():
597
+ """Generate a detailed summary of dataset metadata."""
598
+ if not rag_systems:
599
+ return "No datasets loaded."
600
+
601
+ summary = "# Available Datasets\n\n"
602
+ summary += "The following datasets are currently loaded:\n\n"
603
+
604
+ # Group by year
605
+ years = set()
606
+ quarters_by_year = {}
607
+ states = set()
608
+
609
+ for key, system in rag_systems.items():
610
+ meta = system['metadata']
611
+ version = meta['dataset_version']
612
+ state = meta['state_filter']
613
+
614
+ # Extract year from version (e.g., "Q1 2025" -> "2025")
615
+ if ' ' in version:
616
+ year = version.split(' ')[1]
617
+ quarter = version.split(' ')[0]
618
+
619
+ years.add(year)
620
+ states.add(state)
621
+
622
+ if year not in quarters_by_year:
623
+ quarters_by_year[year] = set()
624
+
625
+ quarters_by_year[year].add(quarter)
626
+
627
+ # Format the summary
628
+ summary += "## Years Available\n"
629
+ summary += ", ".join(sorted(list(years))) + "\n\n"
630
+
631
+ summary += "## Quarters Available by Year\n"
632
+ for year in sorted(quarters_by_year.keys()):
633
+ summary += f"- {year}: {', '.join(sorted(list(quarters_by_year[year])))}\n"
634
+
635
+ summary += "\n## States Available\n"
636
+ summary += ", ".join(sorted(list(states))) + "\n\n"
637
+
638
+ summary += "## Full Dataset List\n"
639
+ for key, system in rag_systems.items():
640
+ meta = system['metadata']
641
+ summary += f"- {meta['dataset_version']} - {meta['state_filter']} ({meta['record_count']} records)\n"
642
+
643
+ return summary
644
+
645
+ def compare_datasets_gradio(question, dataset_indices):
646
+ """Compare multiple datasets by asking the same question - Gradio version."""
647
+ global rag_systems
648
+
649
+ if not rag_systems:
650
+ return "❌ No datasets loaded. Please load datasets first."
651
+
652
+ if not dataset_indices or len(dataset_indices) < 2:
653
+ return "❌ Please select at least 2 datasets to compare."
654
+
655
+ # Parse indices and get dataset keys
656
+ selected_keys = []
657
+ for selection in dataset_indices:
658
+ try:
659
+ index = int(selection.split(".")[0])
660
+ if 1 <= index <= len(rag_systems):
661
+ key = list(rag_systems.keys())[index - 1]
662
+ selected_keys.append(key)
663
+ except:
664
+ continue
665
+
666
+ if len(selected_keys) < 2:
667
+ return "❌ Could not parse selected datasets."
668
+
669
+ comparison_result = f"# Comparison: {question}\n\n"
670
+
671
+ # Query each selected dataset
672
+ for key in selected_keys:
673
+ system = rag_systems[key]
674
+ meta = system['metadata']
675
+ dataset_name = f"{meta['dataset_version']} - {meta['state_filter']}"
676
+
677
+ comparison_result += f"## {dataset_name}\n\n"
678
+
679
+ try:
680
+ result = system['conversation_chain'].invoke({"question": question})
681
+ answer = result["answer"]
682
+ comparison_result += f"{answer}\n\n"
683
+ except Exception as e:
684
+ comparison_result += f"Error: {str(e)}\n\n"
685
+
686
+ comparison_result += "---\n\n"
687
+
688
+ return comparison_result
689
+
690
+ def analyze_provider_types_gradio(dataset_key=None):
691
+ """Analyze provider types in a dataset - Gradio version."""
692
+ global rag_systems, current_dataset_key
693
+
694
+ # Determine which dataset to use
695
+ target_key = dataset_key if dataset_key and dataset_key in rag_systems else current_dataset_key
696
+
697
+ if not target_key or target_key not in rag_systems:
698
+ return "❌ No dataset selected."
699
+
700
+ system = rag_systems[target_key]
701
+ meta = system['metadata']
702
+
703
+ analysis_question = """
704
+ Analyze the provider types in this dataset:
705
+ 1. What are the most common provider types?
706
+ 2. How many unique provider types are there?
707
+ 3. What percentage of providers fall into each major category?
708
+ Please provide a detailed breakdown.
709
+ """
710
+
711
+ try:
712
+ result = system['conversation_chain'].invoke({"question": analysis_question})
713
+
714
+ analysis = f"# Provider Type Analysis\n"
715
+ analysis += f"**Dataset:** {meta['dataset_version']} - {meta['state_filter']}\n\n"
716
+ analysis += result["answer"]
717
+
718
+ return analysis
719
+ except Exception as e:
720
+ return f"❌ Error analyzing provider types: {str(e)}"
721
+
722
+ def clear_chat_history():
723
+ """Clear the chat history."""
724
+ return []
725
+
726
+ def visualize_datasets_gradio(dataset_indices, dimensions, sample_size=1000):
727
+ """Create a visualization of one or more datasets - Gradio version."""
728
+ global rag_systems
729
+
730
+ if not rag_systems:
731
+ return None, "❌ No datasets loaded. Please load datasets first."
732
+
733
+ if not dataset_indices:
734
+ return None, "❌ Please select at least one dataset to visualize."
735
+
736
+ # Parse indices and get dataset keys
737
+ selected_keys = []
738
+ for selection in dataset_indices:
739
+ try:
740
+ index = int(selection.split(".")[0])
741
+ if 1 <= index <= len(rag_systems):
742
+ key = list(rag_systems.keys())[index - 1]
743
+ selected_keys.append(key)
744
+ except:
745
+ continue
746
+
747
+ if not selected_keys:
748
+ return None, "❌ Could not parse selected datasets."
749
+
750
+ try:
751
+ # Create a combined visualization
752
+ all_vectors = []
753
+ all_metadata = []
754
+ all_contents = []
755
+ all_dataset_labels = []
756
+
757
+ status_msg = f"Processing {len(selected_keys)} dataset(s)...\n"
758
+
759
+ # Collect vectors from all requested datasets
760
+ for key in selected_keys:
761
+ vector_store = rag_systems[key]['vector_store']
762
+ meta = rag_systems[key]['metadata']
763
+ dataset_label = f"{meta['dataset_version']} - {meta['state_filter']}"
764
+
765
+ # Limit vectors for performance
766
+ num_vectors = min(sample_size, vector_store.index.ntotal)
767
+ status_msg += f"- {dataset_label}: {num_vectors} vectors\n"
768
+
769
+ for i in range(num_vectors):
770
+ all_vectors.append(vector_store.index.reconstruct(i))
771
+
772
+ doc_id = vector_store.index_to_docstore_id[i]
773
+ document = vector_store.docstore.search(doc_id)
774
+
775
+ all_metadata.append(document.metadata)
776
+ all_contents.append(document.page_content)
777
+ all_dataset_labels.append(dataset_label)
778
+
779
+ if not all_vectors:
780
+ return None, "❌ No vectors to visualize."
781
+
782
+ vectors = np.array(all_vectors)
783
+ status_msg += f"\nTotal vectors: {len(all_vectors)}\n"
784
+
785
+ # Reduce dimensionality
786
+ status_msg += f"Reducing dimensionality to {dimensions}D using t-SNE..."
787
+ tsne = TSNE(n_components=dimensions, random_state=42, perplexity=min(30, len(all_vectors)-1))
788
+ reduced_vectors = tsne.fit_transform(vectors)
789
+
790
+ # Create color mapping based on dataset
791
+ unique_labels = list(set(all_dataset_labels))
792
+ colors = []
793
+ color_palette = [
794
+ '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
795
+ '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf'
796
+ ]
797
+ color_map = {label: color_palette[i % len(color_palette)]
798
+ for i, label in enumerate(unique_labels)}
799
+
800
+ colors = [color_map[label] for label in all_dataset_labels]
801
+
802
+ # Create hover text
803
+ hover_texts = []
804
+ for meta, content, label in zip(all_metadata, all_contents, all_dataset_labels):
805
+ text = f"<b>Dataset:</b> {label}<br>"
806
+
807
+ # Add key metadata fields
808
+ key_fields = ['STATE_CD', 'PROVIDER_TYPE_DESC', 'FIRST_NAME', 'LAST_NAME', 'ORG_NAME']
809
+ for field in key_fields:
810
+ if field in meta and meta[field]:
811
+ text += f"<b>{field}:</b> {meta[field]}<br>"
812
+
813
+ # Add a preview of the content
814
+ content_preview = content[:200] + "..." if len(content) > 200 else content
815
+ text += f"<br><b>Preview:</b> {content_preview}"
816
+
817
+ hover_texts.append(text)
818
+
819
+ # Create visualization
820
+ if dimensions == 2:
821
+ fig = go.Figure()
822
+
823
+ # Add a trace for each dataset
824
+ for label in unique_labels:
825
+ # Get indices for this dataset
826
+ indices = [i for i, l in enumerate(all_dataset_labels) if l == label]
827
+
828
+ # Add the scatter trace
829
+ fig.add_trace(go.Scatter(
830
+ x=reduced_vectors[indices, 0],
831
+ y=reduced_vectors[indices, 1],
832
+ mode='markers',
833
+ marker=dict(
834
+ size=6,
835
+ color=color_map[label],
836
+ opacity=0.7,
837
+ line=dict(width=1, color='white')
838
+ ),
839
+ text=[hover_texts[i] for i in indices],
840
+ hoverinfo='text',
841
+ hoverlabel=dict(bgcolor="white", font_size=12),
842
+ name=label
843
+ ))
844
+
845
+ fig.update_layout(
846
+ title={
847
+ 'text': 'Medicare Provider Data - 2D Vector Space Visualization',
848
+ 'font': {'size': 20}
849
+ },
850
+ xaxis_title='Dimension 1',
851
+ yaxis_title='Dimension 2',
852
+ width=900,
853
+ height=700,
854
+ hovermode='closest',
855
+ template='plotly_white',
856
+ legend=dict(
857
+ yanchor="top",
858
+ y=0.99,
859
+ xanchor="left",
860
+ x=0.01,
861
+ bgcolor="rgba(255,255,255,0.8)"
862
+ )
863
+ )
864
+ else: # 3D
865
+ fig = go.Figure()
866
+
867
+ # Add a trace for each dataset
868
+ for label in unique_labels:
869
+ # Get indices for this dataset
870
+ indices = [i for i, l in enumerate(all_dataset_labels) if l == label]
871
+
872
+ # Add the scatter trace
873
+ fig.add_trace(go.Scatter3d(
874
+ x=reduced_vectors[indices, 0],
875
+ y=reduced_vectors[indices, 1],
876
+ z=reduced_vectors[indices, 2],
877
+ mode='markers',
878
+ marker=dict(
879
+ size=5,
880
+ color=color_map[label],
881
+ opacity=0.7,
882
+ line=dict(width=1, color='white')
883
+ ),
884
+ text=[hover_texts[i] for i in indices],
885
+ hoverinfo='text',
886
+ hoverlabel=dict(bgcolor="white", font_size=12),
887
+ name=label
888
+ ))
889
+
890
+ fig.update_layout(
891
+ title={
892
+ 'text': 'Medicare Provider Data - 3D Vector Space Visualization',
893
+ 'font': {'size': 20}
894
+ },
895
+ scene=dict(
896
+ xaxis_title='Dimension 1',
897
+ yaxis_title='Dimension 2',
898
+ zaxis_title='Dimension 3',
899
+ camera=dict(
900
+ eye=dict(x=1.5, y=1.5, z=1.5)
901
+ )
902
+ ),
903
+ width=900,
904
+ height=700,
905
+ template='plotly_white',
906
+ legend=dict(
907
+ yanchor="top",
908
+ y=0.99,
909
+ xanchor="left",
910
+ x=0.01,
911
+ bgcolor="rgba(255,255,255,0.8)"
912
+ )
913
+ )
914
+
915
+ success_msg = f"βœ… Successfully created {dimensions}D visualization with {len(all_vectors)} vectors from {len(selected_keys)} dataset(s)"
916
+ return fig, success_msg
917
+
918
+ except Exception as e:
919
+ return None, f"❌ Error creating visualization: {str(e)}"
920
+
921
+ def create_dataset_statistics_plot(dataset_indices):
922
+ """Create statistical plots for selected datasets."""
923
+ global rag_systems
924
+
925
+ if not rag_systems:
926
+ return None, "❌ No datasets loaded."
927
+
928
+ if not dataset_indices:
929
+ return None, "❌ Please select at least one dataset."
930
+
931
+ # Parse indices and get dataset keys
932
+ selected_keys = []
933
+ for selection in dataset_indices:
934
+ try:
935
+ index = int(selection.split(".")[0])
936
+ if 1 <= index <= len(rag_systems):
937
+ key = list(rag_systems.keys())[index - 1]
938
+ selected_keys.append(key)
939
+ except:
940
+ continue
941
+
942
+ if not selected_keys:
943
+ return None, "❌ Could not parse selected datasets."
944
+
945
+ try:
946
+ # Collect statistics
947
+ dataset_names = []
948
+ record_counts = []
949
+ chunk_counts = []
950
+
951
+ for key in selected_keys:
952
+ meta = rag_systems[key]['metadata']
953
+ dataset_names.append(f"{meta['dataset_version']}<br>{meta['state_filter']}")
954
+ record_counts.append(meta['record_count'])
955
+ chunk_counts.append(meta['chunk_count'])
956
+
957
+ # Create subplots
958
+ from plotly.subplots import make_subplots
959
+
960
+ fig = make_subplots(
961
+ rows=1, cols=2,
962
+ subplot_titles=('Records per Dataset', 'Chunks per Dataset'),
963
+ specs=[[{'type': 'bar'}, {'type': 'bar'}]]
964
+ )
965
+
966
+ # Add record count bars
967
+ fig.add_trace(
968
+ go.Bar(
969
+ x=dataset_names,
970
+ y=record_counts,
971
+ name='Records',
972
+ marker_color='lightblue',
973
+ text=record_counts,
974
+ textposition='auto',
975
+ ),
976
+ row=1, col=1
977
+ )
978
+
979
+ # Add chunk count bars
980
+ fig.add_trace(
981
+ go.Bar(
982
+ x=dataset_names,
983
+ y=chunk_counts,
984
+ name='Chunks',
985
+ marker_color='lightgreen',
986
+ text=chunk_counts,
987
+ textposition='auto',
988
+ ),
989
+ row=1, col=2
990
+ )
991
+
992
+ fig.update_layout(
993
+ title={
994
+ 'text': 'Dataset Statistics Overview',
995
+ 'font': {'size': 20}
996
+ },
997
+ showlegend=False,
998
+ height=500,
999
+ template='plotly_white'
1000
+ )
1001
+
1002
+ fig.update_xaxes(tickangle=-45)
1003
+
1004
+ return fig, f"βœ… Created statistics plot for {len(selected_keys)} dataset(s)"
1005
+
1006
+ except Exception as e:
1007
+ return None, f"❌ Error creating statistics plot: {str(e)}"
1008
+
1009
+ def inspect_dataset_gradio(num_samples):
1010
+ """Display sample documents from the current dataset - Gradio version."""
1011
+ global rag_systems, current_dataset_key
1012
+
1013
+ if not current_dataset_key or current_dataset_key not in rag_systems:
1014
+ return "❌ No dataset selected. Please load a dataset first."
1015
+
1016
+ # Get the dataset
1017
+ system = rag_systems[current_dataset_key]
1018
+ vector_store = system['vector_store']
1019
+ meta = system['metadata']
1020
+
1021
+ inspection_result = f"# Dataset Inspection\n\n"
1022
+ inspection_result += f"**Dataset:** {meta['dataset_version']} - {meta['state_filter']}\n"
1023
+ inspection_result += f"**Total documents:** {vector_store.index.ntotal}\n"
1024
+ inspection_result += f"**Showing:** {min(num_samples, vector_store.index.ntotal)} sample documents\n\n"
1025
+ inspection_result += "---\n\n"
1026
+
1027
+ for i in range(min(num_samples, vector_store.index.ntotal)):
1028
+ try:
1029
+ doc_id = vector_store.index_to_docstore_id[i]
1030
+ document = vector_store.docstore.search(doc_id)
1031
+
1032
+ inspection_result += f"### Document {i+1}\n\n"
1033
+ inspection_result += "**Metadata:**\n"
1034
+
1035
+ # Show key metadata fields
1036
+ key_fields = ['PROVIDER_TYPE_DESC', 'STATE_CD', 'FIRST_NAME', 'LAST_NAME',
1037
+ 'ORG_NAME', 'NPI', 'ENRLMT_ID']
1038
+
1039
+ for field in key_fields:
1040
+ if field in document.metadata and document.metadata[field]:
1041
+ inspection_result += f"- **{field}:** {document.metadata[field]}\n"
1042
+
1043
+ # Show content preview
1044
+ content_preview = document.page_content[:500] + "..." if len(document.page_content) > 500 else document.page_content
1045
+ inspection_result += f"\n**Content Preview:**\n```\n{content_preview}\n```\n\n"
1046
+ inspection_result += "---\n\n"
1047
+
1048
+ except Exception as e:
1049
+ inspection_result += f"Error retrieving document {i}: {str(e)}\n\n"
1050
+
1051
+ return inspection_result
1052
+
1053
+ def create_gradio_interface():
1054
+ """Create the main Gradio interface."""
1055
+
1056
+ with gr.Blocks(theme=theme, title="Medicare Provider Data Analysis System") as app:
1057
+ # Header
1058
+ gr.Markdown(
1059
+ """
1060
+ # πŸ₯ Medicare Provider Data Analysis System
1061
+
1062
+ This system allows you to load, query, and analyze Medicare provider data using advanced RAG (Retrieval-Augmented Generation) technology.
1063
+
1064
+ ---
1065
+ """
1066
+ )
1067
+
1068
+ # Main tabs
1069
+ with gr.Tabs() as tabs:
1070
+ # Tab 1: Dataset Management
1071
+ with gr.Tab("πŸ“Š Dataset Management"):
1072
+ with gr.Row():
1073
+ with gr.Column(scale=1):
1074
+ gr.Markdown("### Load New Dataset")
1075
+
1076
+ version_dropdown = gr.Dropdown(
1077
+ choices=list(DATASET_VERSIONS.keys()),
1078
+ label="Select Quarter/Year",
1079
+ value="Q1 2025"
1080
+ )
1081
+
1082
+ state_dropdown = gr.Dropdown(
1083
+ choices=format_state_options(),
1084
+ label="Select State",
1085
+ value=""
1086
+ )
1087
+
1088
+ max_records_slider = gr.Slider(
1089
+ minimum=100,
1090
+ maximum=5000,
1091
+ value=1000,
1092
+ step=100,
1093
+ label="Maximum Records"
1094
+ )
1095
+
1096
+ use_sample_checkbox = gr.Checkbox(
1097
+ label="Load sample only (100 records)",
1098
+ value=True
1099
+ )
1100
+
1101
+ load_button = gr.Button("πŸ”„ Load Dataset", variant="primary")
1102
+ load_output = gr.Textbox(label="Loading Status", lines=3)
1103
+
1104
+ with gr.Column(scale=1):
1105
+ gr.Markdown("### Manage Loaded Datasets")
1106
+
1107
+ dataset_summary = gr.Markdown(get_dataset_summary(rag_systems))
1108
+
1109
+ with gr.Row():
1110
+ dataset_selector = gr.Dropdown(
1111
+ choices=get_dataset_choices(),
1112
+ label="Select Dataset",
1113
+ interactive=True
1114
+ )
1115
+
1116
+ with gr.Row():
1117
+ switch_button = gr.Button("↔️ Switch Dataset")
1118
+ remove_button = gr.Button("πŸ—‘οΈ Remove Dataset")
1119
+ clear_all_button = gr.Button("🧹 Clear All", variant="stop")
1120
+
1121
+ manage_output = gr.Textbox(label="Status", lines=2)
1122
+
1123
+ # Wire up dataset management events
1124
+ def update_dataset_selector():
1125
+ return gr.update(choices=get_dataset_choices())
1126
+
1127
+ load_button.click(
1128
+ fn=load_dataset_gradio,
1129
+ inputs=[version_dropdown, state_dropdown, max_records_slider, use_sample_checkbox],
1130
+ outputs=[load_output, dataset_summary]
1131
+ ).then(
1132
+ fn=update_dataset_selector,
1133
+ outputs=dataset_selector
1134
+ )
1135
+
1136
+ switch_button.click(
1137
+ fn=switch_dataset_gradio,
1138
+ inputs=dataset_selector,
1139
+ outputs=[manage_output, dataset_summary]
1140
+ )
1141
+
1142
+ remove_button.click(
1143
+ fn=remove_dataset_gradio,
1144
+ inputs=dataset_selector,
1145
+ outputs=[manage_output, dataset_summary]
1146
+ ).then(
1147
+ fn=update_dataset_selector,
1148
+ outputs=dataset_selector
1149
+ )
1150
+
1151
+ clear_all_button.click(
1152
+ fn=clear_all_datasets_gradio,
1153
+ outputs=[manage_output, dataset_summary]
1154
+ ).then(
1155
+ fn=update_dataset_selector,
1156
+ outputs=dataset_selector
1157
+ )
1158
+
1159
+ # Tab 2: Query Interface
1160
+ with gr.Tab("πŸ’¬ Query & Chat"):
1161
+ gr.Markdown("### Ask Questions About Your Data")
1162
+
1163
+ current_dataset_info = gr.Markdown(get_current_dataset_info())
1164
+
1165
+ # Create a timer to update current dataset info
1166
+ timer = gr.Timer(value=2)
1167
+ timer.tick(fn=get_current_dataset_info, outputs=current_dataset_info)
1168
+
1169
+ with gr.Row():
1170
+ with gr.Column(scale=3):
1171
+ chatbot = gr.Chatbot(
1172
+ label="Conversation",
1173
+ height=500,
1174
+ show_copy_button=True
1175
+ )
1176
+
1177
+ with gr.Row():
1178
+ question_input = gr.Textbox(
1179
+ label="Your Question",
1180
+ placeholder="Ask about provider types, locations, statistics, etc.",
1181
+ lines=2,
1182
+ scale=4
1183
+ )
1184
+
1185
+ with gr.Column(scale=1):
1186
+ ask_button = gr.Button("πŸ“€ Ask Current Dataset", variant="primary")
1187
+ global_ask_button = gr.Button("🌐 Ask All Datasets")
1188
+ clear_chat_button = gr.Button("πŸ—‘οΈ Clear Chat")
1189
+
1190
+ with gr.Column(scale=1):
1191
+ gr.Markdown("### Quick Actions")
1192
+
1193
+ analyze_providers_button = gr.Button("πŸ“Š Analyze Provider Types")
1194
+
1195
+ gr.Markdown("### Example Questions")
1196
+ example_questions = [
1197
+ "What are the most common provider types?",
1198
+ "How many providers are in this dataset?",
1199
+ "Show me all psychiatrists in the data",
1200
+ "What types of medical facilities are included?",
1201
+ "Compare provider counts across different quarters"
1202
+ ]
1203
+
1204
+ for eq in example_questions:
1205
+ gr.Button(eq, size="sm").click(
1206
+ lambda q=eq: (q, gr.update()),
1207
+ outputs=[question_input, chatbot]
1208
+ )
1209
+
1210
+ # Wire up query events
1211
+ question_input.submit(
1212
+ fn=ask_question_gradio,
1213
+ inputs=[question_input, chatbot],
1214
+ outputs=[question_input, chatbot]
1215
+ )
1216
+
1217
+ ask_button.click(
1218
+ fn=ask_question_gradio,
1219
+ inputs=[question_input, chatbot],
1220
+ outputs=[question_input, chatbot]
1221
+ )
1222
+
1223
+ global_ask_button.click(
1224
+ fn=ask_global_question_gradio,
1225
+ inputs=[question_input, chatbot],
1226
+ outputs=[question_input, chatbot]
1227
+ )
1228
+
1229
+ clear_chat_button.click(
1230
+ fn=clear_chat_history,
1231
+ outputs=chatbot
1232
+ )
1233
+
1234
+ analyze_providers_button.click(
1235
+ fn=lambda: ("", [(
1236
+ "Analyze provider types in the current dataset",
1237
+ analyze_provider_types_gradio()
1238
+ )]),
1239
+ outputs=[question_input, chatbot]
1240
+ )
1241
+
1242
+ # Tab 3: Comparison & Analysis
1243
+ with gr.Tab("πŸ” Compare Datasets"):
1244
+ gr.Markdown("### Compare Multiple Datasets")
1245
+
1246
+ with gr.Row():
1247
+ compare_dataset_selector = gr.CheckboxGroup(
1248
+ choices=get_dataset_choices(),
1249
+ label="Select Datasets to Compare (choose 2 or more)",
1250
+ value=[]
1251
+ )
1252
+
1253
+ compare_question = gr.Textbox(
1254
+ label="Comparison Question",
1255
+ placeholder="Enter a question to ask all selected datasets",
1256
+ lines=2
1257
+ )
1258
+
1259
+ compare_button = gr.Button("πŸ”„ Compare Datasets", variant="primary")
1260
+
1261
+ comparison_output = gr.Markdown(label="Comparison Results")
1262
+
1263
+ # Update checkbox choices when datasets change
1264
+ def update_compare_selector():
1265
+ return gr.update(choices=get_dataset_choices())
1266
+
1267
+ timer.tick(fn=update_compare_selector, outputs=compare_dataset_selector)
1268
+
1269
+ compare_button.click(
1270
+ fn=compare_datasets_gradio,
1271
+ inputs=[compare_question, compare_dataset_selector],
1272
+ outputs=comparison_output
1273
+ )
1274
+
1275
+ # Tab 4: Visualization
1276
+ with gr.Tab("πŸ“ˆ Visualizations"):
1277
+ gr.Markdown("### Dataset Visualizations")
1278
+
1279
+ with gr.Row():
1280
+ with gr.Column():
1281
+ viz_dataset_selector = gr.CheckboxGroup(
1282
+ choices=get_dataset_choices(),
1283
+ label="Select Datasets to Visualize",
1284
+ value=[]
1285
+ )
1286
+
1287
+ viz_dimension = gr.Radio(
1288
+ choices=[2, 3],
1289
+ value=2,
1290
+ label="Visualization Dimensions"
1291
+ )
1292
+
1293
+ viz_sample_size = gr.Slider(
1294
+ minimum=100,
1295
+ maximum=2000,
1296
+ value=500,
1297
+ step=100,
1298
+ label="Sample Size (per dataset)"
1299
+ )
1300
+
1301
+ create_viz_button = gr.Button("🎨 Create Visualization", variant="primary")
1302
+ stats_button = gr.Button("πŸ“Š Show Statistics")
1303
+
1304
+ viz_status = gr.Textbox(label="Status", lines=2)
1305
+
1306
+ with gr.Row():
1307
+ viz_plot = gr.Plot(label="Vector Space Visualization")
1308
+ stats_plot = gr.Plot(label="Dataset Statistics")
1309
+
1310
+ # Update visualization selector
1311
+ def update_viz_selector():
1312
+ return gr.update(choices=get_dataset_choices())
1313
+
1314
+ timer.tick(fn=update_viz_selector, outputs=viz_dataset_selector)
1315
+
1316
+ create_viz_button.click(
1317
+ fn=visualize_datasets_gradio,
1318
+ inputs=[viz_dataset_selector, viz_dimension, viz_sample_size],
1319
+ outputs=[viz_plot, viz_status]
1320
+ )
1321
+
1322
+ stats_button.click(
1323
+ fn=create_dataset_statistics_plot,
1324
+ inputs=[viz_dataset_selector],
1325
+ outputs=[stats_plot, viz_status]
1326
+ )
1327
+
1328
+ # Tab 5: Dataset Inspector
1329
+ with gr.Tab("πŸ”Ž Dataset Inspector"):
1330
+ gr.Markdown("### Inspect Dataset Contents")
1331
+
1332
+ inspect_current_info = gr.Markdown(get_current_dataset_info())
1333
+ timer.tick(fn=get_current_dataset_info, outputs=inspect_current_info)
1334
+
1335
+ num_samples_slider = gr.Slider(
1336
+ minimum=1,
1337
+ maximum=20,
1338
+ value=5,
1339
+ step=1,
1340
+ label="Number of Sample Documents"
1341
+ )
1342
+
1343
+ inspect_button = gr.Button("πŸ” Inspect Current Dataset", variant="primary")
1344
+
1345
+ inspection_output = gr.Markdown(label="Dataset Inspection Results")
1346
+
1347
+ inspect_button.click(
1348
+ fn=inspect_dataset_gradio,
1349
+ inputs=num_samples_slider,
1350
+ outputs=inspection_output
1351
+ )
1352
+
1353
+ # Tab 6: Settings & Help
1354
+ with gr.Tab("βš™οΈ Settings & Help"):
1355
+ gr.Markdown(
1356
+ """
1357
+ ### System Information
1358
+
1359
+ **Model:** GPT-4 Mini
1360
+ **Embedding Model:** OpenAI Embeddings
1361
+ **Vector Store:** FAISS
1362
+
1363
+ ### API Configuration
1364
+
1365
+ This system uses the CMS.gov Data API to fetch Medicare provider information.
1366
+
1367
+ ### Tips for Best Results
1368
+
1369
+ 1. **Loading Data**: Start with sample data (100 records) to test queries quickly
1370
+ 2. **State Selection**: Load specific states for focused analysis
1371
+ 3. **Querying**: Be specific in your questions for better results
1372
+ 4. **Comparisons**: Load multiple quarters/states to analyze trends
1373
+
1374
+ ### Common Use Cases
1375
+
1376
+ - **Provider Analysis**: Find specific types of healthcare providers
1377
+ - **Geographic Distribution**: Analyze providers by state
1378
+ - **Temporal Trends**: Compare data across different quarters
1379
+ - **Provider Types**: Understand the distribution of specialties
1380
+
1381
+ ### Troubleshooting
1382
+
1383
+ - **No API Key**: Ensure OPENAI_API_KEY is set in your environment
1384
+ - **Loading Errors**: Check your internet connection and API limits
1385
+ - **Query Errors**: Try rephrasing your question or check if data is loaded
1386
+ """
1387
+ )
1388
+
1389
+ with gr.Row():
1390
+ gr.Markdown("### Current Configuration")
1391
+ config_info = gr.JSON(
1392
+ value={
1393
+ "api_key_set": bool(os.getenv('OPENAI_API_KEY')),
1394
+ "default_model": DEFAULT_MODEL,
1395
+ "api_base_url": API_BASE_URL,
1396
+ "datasets_loaded": len(rag_systems)
1397
+ },
1398
+ label="System Configuration"
1399
+ )
1400
+
1401
+ # Footer
1402
+ gr.Markdown(
1403
+ """
1404
+ ---
1405
+
1406
+ <center>
1407
+ Medicare Provider Data Analysis System | Powered by LangChain & OpenAI
1408
+ </center>
1409
+ """
1410
+ )
1411
+
1412
+ return app
1413
+
1414
+ # Main execution
1415
+ if __name__ == "__main__":
1416
+ # Create and launch the app
1417
+ app = create_gradio_interface()
1418
+
1419
+ # Launch with appropriate settings
1420
+ app.launch(
1421
+ server_name="0.0.0.0", # Allow external connections
1422
+ server_port=7860, # Default Gradio port
1423
+
1424
+ )