import os
import json
import time
import requests
import numpy as np
import pandas as pd
from datetime import datetime
from typing import Dict, List, Any, Optional, Tuple
import gradio as gr
from dotenv import load_dotenv
# Vector DB and embedding imports
from langchain.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain_openai import ChatOpenAI
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
# Visualization imports
import plotly.graph_objects as go
from sklearn.manifold import TSNE
# Load environment variables
load_dotenv()
# Check if OPENAI_API_KEY is set
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
if not OPENAI_API_KEY:
print("โ ๏ธ Warning: OPENAI_API_KEY not found in environment variables.")
# Configuration
DEFAULT_DATASET_ID = "2457ea29-fc82-48b0-86ec-3b0755de7515"
DEFAULT_MODEL = "gpt-4o-mini"
API_BASE_URL = "https://data.cms.gov/data-api/v1"
INITIAL_SAMPLE_SIZE = 100 # Start with a small sample
# Dataset version mapping
DATASET_VERSIONS = {
# 2025 Data
"Q1 2025": "74edb053-bd01-40a0-91a0-4961c1fe6281",
# 2024 Data
"Q1 2024": "6d6e0e8d-64cf-43fb-9ba8-e2ad9b9bb21e",
"Q2 2024": "04405289-5635-4b2a-a64f-c4b6415ab6ff",
"Q3 2024": "e87f09c2-5ff7-4ddf-b60c-6130995b15cf",
"Q4 2024": "e9d278e4-90e8-47ab-9c5b-af2ca64bf352",
# 2023 Data
"Q1 2023": "0b6caf2f-8948-4603-922e-d7f0c52c0a45",
"Q2 2023": "46339a0c-0f07-40ed-8975-ddb387c367a4",
"Q3 2023": "70efac57-6093-4e1d-ad6a-36f8261f53eb",
"Q4 2023": "1df8331a-ed44-41ec-971f-158349658949",
# 2022 Data
"Q1 2022": "5b678653-aa36-455b-9144-1d073ef7991b",
# 2021 Data
"Q1 2021": "7b409bba-ca00-426e-9493-1dc10e5340cc",
# 2020 Data
"Q1 2020": "3870b29c-4312-4fb1-a956-71c148ae5b50",
# 2019 Data
"Q1 2019": "017e6ab7-7e19-4e98-b4fa-30578b47e578",
"Q4 2019": "2c209bdb-ed0c-42e0-b027-8a97024b8035"
}
# US States for reference
US_STATES = [
"", "AL", "AK", "AZ", "AR", "CA", "CO", "CT", "DE", "FL", "GA",
"HI", "ID", "IL", "IN", "IA", "KS", "KY", "LA", "ME", "MD",
"MA", "MI", "MN", "MS", "MO", "MT", "NE", "NV", "NH", "NJ",
"NM", "NY", "NC", "ND", "OH", "OK", "OR", "PA", "RI", "SC",
"SD", "TN", "TX", "UT", "VT", "VA", "WA", "WV", "WI", "WY",
"DC", "PR", "VI"
]
# State names mapping for better UI
STATE_NAMES = {
"": "All States",
"AL": "Alabama", "AK": "Alaska", "AZ": "Arizona", "AR": "Arkansas",
"CA": "California", "CO": "Colorado", "CT": "Connecticut", "DE": "Delaware",
"FL": "Florida", "GA": "Georgia", "HI": "Hawaii", "ID": "Idaho",
"IL": "Illinois", "IN": "Indiana", "IA": "Iowa", "KS": "Kansas",
"KY": "Kentucky", "LA": "Louisiana", "ME": "Maine", "MD": "Maryland",
"MA": "Massachusetts", "MI": "Michigan", "MN": "Minnesota", "MS": "Mississippi",
"MO": "Missouri", "MT": "Montana", "NE": "Nebraska", "NV": "Nevada",
"NH": "New Hampshire", "NJ": "New Jersey", "NM": "New Mexico", "NY": "New York",
"NC": "North Carolina", "ND": "North Dakota", "OH": "Ohio", "OK": "Oklahoma",
"OR": "Oregon", "PA": "Pennsylvania", "RI": "Rhode Island", "SC": "South Carolina",
"SD": "South Dakota", "TN": "Tennessee", "TX": "Texas", "UT": "Utah",
"VT": "Vermont", "VA": "Virginia", "WA": "Washington", "WV": "West Virginia",
"WI": "Wisconsin", "WY": "Wyoming", "DC": "District of Columbia",
"PR": "Puerto Rico", "VI": "Virgin Islands"
}
# Dictionary to store multiple datasets
rag_systems = {}
current_dataset_key = None
# Gradio theme configuration
theme = gr.themes.Soft(
primary_hue="blue",
secondary_hue="gray",
neutral_hue="slate",
font=gr.themes.GoogleFont("Inter")
)
def query_cms_api(version_id, state_filter="", max_records=100):
"""Query the CMS API with pagination."""
url = f"{API_BASE_URL}/dataset/{version_id}/data"
all_records = []
offset = 0
page_size = min(max_records, 100) # Page size, max 100
# Set up filter parameters
params = {
'size': page_size,
'offset': 0
}
# Add state filter if provided
if state_filter and state_filter != "":
params[f'filter[STATE_CD]'] = state_filter
progress_text = f"Querying CMS API...\n"
# Fetch data with pagination
while len(all_records) < max_records:
params['offset'] = offset
try:
response = requests.get(url, params=params)
if response.status_code != 200:
error_msg = f"Error: Status {response.status_code}"
return [], error_msg
# Parse the response - the API returns a list directly
records = response.json()
if not records or not isinstance(records, list):
if len(all_records) == 0:
return [], "No records found"
break
progress_text += f"Retrieved {len(records)} records (offset: {offset})\n"
all_records.extend(records)
# If we got fewer records than requested, we've reached the end
if len(records) < page_size:
break
# Move to next page
offset += len(records)
# Add delay to be nice to the API
time.sleep(0.5)
except Exception as e:
error_msg = f"Error querying API: {str(e)}"
return [], error_msg
final_records = all_records[:max_records]
success_msg = f"Successfully retrieved {len(final_records)} records"
return final_records, success_msg
def process_records(records, version):
"""Process CMS API records into documents for the RAG system."""
# Parse version into quarter and year
quarter = "Unknown"
year = "Unknown"
if ' ' in version:
parts = version.split(' ')
if len(parts) == 2:
quarter, year = parts
embeddings = OpenAIEmbeddings()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
# Convert records to documents
documents = []
for record in records:
# Format the record as text with explicit time information
content = [f"Medicare Provider Data from {quarter} {year}"]
content.append(f"Time Period: {quarter} of {year}")
# Add all fields from the record
for key, value in record.items():
if value is not None and value != "":
content.append(f"{key}: {value}")
text = "\n".join(content)
# Create metadata with explicit time fields
metadata = {
'dataset_version': version,
'quarter': quarter,
'year': year,
'record_id': record.get('ENRLMT_ID', 'unknown')
}
# Add all fields to metadata for better searchability
for key, value in record.items():
if value is not None and value != "":
try:
# Convert complex values to strings to avoid serialization issues
if not isinstance(value, (str, int, float, bool, type(None))):
metadata[key] = str(value)
else:
metadata[key] = value
except:
# If there's any issue, convert to string
metadata[key] = str(value)
documents.append(Document(page_content=text, metadata=metadata))
# Chunk documents
chunks = text_splitter.split_documents(documents)
# Create vector store
vector_store = FAISS.from_documents(chunks, embeddings)
return vector_store, len(documents), len(chunks)
def create_progress_callback():
"""Create a progress callback for long-running operations."""
def callback(message):
# In a real Gradio app, this would update a progress bar
print(f"Progress: {message}")
return callback
def validate_api_key():
"""Validate that the OpenAI API key is set."""
api_key = os.getenv('OPENAI_API_KEY')
if not api_key:
return False, "OpenAI API key not found. Please set it in your environment variables or .env file."
return True, "API key validated successfully."
def get_dataset_summary(rag_systems):
"""Generate a summary of all loaded datasets."""
if not rag_systems:
return "No datasets currently loaded."
summary_lines = ["### Currently Loaded Datasets:\n"]
for i, (key, system) in enumerate(rag_systems.items(), 1):
meta = system['metadata']
summary_lines.append(
f"{i}. **{meta['dataset_version']}** - "
f"State: {meta['state_filter']} - "
f"Records: {meta['record_count']} - "
f"Chunks: {meta['chunk_count']}"
)
if key == current_dataset_key:
summary_lines[-1] += " *(Current)*"
summary_lines.append(f"\n**Total datasets loaded:** {len(rag_systems)}")
return "\n".join(summary_lines)
def format_state_options():
"""Format state options for Gradio dropdown."""
options = []
for code in US_STATES:
if code == "":
options.append(("All States", ""))
else:
options.append((f"{STATE_NAMES[code]} ({code})", code))
return options
def load_dataset_gradio(version, state_filter, max_records, use_sample):
"""Load data from CMS API and set up the RAG system - Gradio version."""
global rag_systems, current_dataset_key
# Validate API key first
valid, message = validate_api_key()
if not valid:
return message, get_dataset_summary(rag_systems)
# Generate a unique key for this dataset
dataset_key = f"{version}_{state_filter}_{max_records}"
# Check if dataset already loaded
if dataset_key in rag_systems:
current_dataset_key = dataset_key
return f"โ
Dataset already loaded and set as current: {version} - {STATE_NAMES.get(state_filter, 'All States')}", get_dataset_summary(rag_systems)
# Get version ID
version_id = DATASET_VERSIONS.get(version)
if not version_id:
return f"โ Invalid version: {version}", get_dataset_summary(rag_systems)
# Adjust max records if sample
actual_max = INITIAL_SAMPLE_SIZE if use_sample else max_records
# Status message
status_msg = f"๐ Loading {version} data"
if state_filter:
status_msg += f" for {STATE_NAMES.get(state_filter, state_filter)}"
status_msg += f" (max {actual_max} records)..."
try:
# Fetch data from API
records, api_message = query_cms_api(version_id, state_filter, actual_max)
if not records:
return f"โ Failed to load data: {api_message}", get_dataset_summary(rag_systems)
status_msg += f"\nโ
{api_message}"
# Process records and create vector store
status_msg += "\n๐ Processing records and creating vector store..."
vector_store, doc_count, chunk_count = process_records(records, version)
# Set up RAG system
llm = ChatOpenAI(temperature=0.7, model_name=DEFAULT_MODEL)
memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)
retriever = vector_store.as_retriever()
conversation_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
memory=memory
)
# Store in the dictionary
rag_systems[dataset_key] = {
'vector_store': vector_store,
'conversation_chain': conversation_chain,
'metadata': {
'dataset_version': version,
'version_id': version_id,
'state_filter': STATE_NAMES.get(state_filter, "All States") if state_filter else "All States",
'record_count': len(records),
'document_count': doc_count,
'chunk_count': chunk_count,
'loaded_at': datetime.now().isoformat()
}
}
# Set as current dataset
current_dataset_key = dataset_key
success_msg = f"โ
Successfully loaded {version} - {STATE_NAMES.get(state_filter, 'All States')}\n"
success_msg += f"๐ Created {chunk_count} chunks from {len(records)} records"
return success_msg, get_dataset_summary(rag_systems)
except Exception as e:
error_msg = f"โ Error loading data: {str(e)}"
return error_msg, get_dataset_summary(rag_systems)
def switch_dataset_gradio(dataset_index):
"""Switch to a different dataset - Gradio version."""
global rag_systems, current_dataset_key
if not rag_systems:
return "โ No datasets loaded.", get_dataset_summary(rag_systems)
if not dataset_index:
return "โ Please select a dataset.", get_dataset_summary(rag_systems)
try:
# Parse the index from the selection (format: "1. Dataset Name")
index = int(dataset_index.split(".")[0])
if 1 <= index <= len(rag_systems):
key = list(rag_systems.keys())[index - 1]
current_dataset_key = key
meta = rag_systems[key]['metadata']
return f"โ
Switched to: {meta['dataset_version']} - {meta['state_filter']}", get_dataset_summary(rag_systems)
else:
return f"โ Invalid selection.", get_dataset_summary(rag_systems)
except:
return "โ Invalid selection format.", get_dataset_summary(rag_systems)
def remove_dataset_gradio(dataset_index):
"""Remove a dataset from memory - Gradio version."""
global rag_systems, current_dataset_key
if not rag_systems:
return "โ No datasets loaded.", get_dataset_summary(rag_systems)
if not dataset_index:
return "โ Please select a dataset to remove.", get_dataset_summary(rag_systems)
try:
# Parse the index from the selection
index = int(dataset_index.split(".")[0])
if 1 <= index <= len(rag_systems):
key = list(rag_systems.keys())[index - 1]
meta = rag_systems[key]['metadata']
# Remove the dataset
del rag_systems[key]
# If this was the current dataset, clear the current key
if key == current_dataset_key:
current_dataset_key = None
# Set another dataset as current if available
if rag_systems:
current_dataset_key = list(rag_systems.keys())[0]
return f"โ
Removed: {meta['dataset_version']} - {meta['state_filter']}", get_dataset_summary(rag_systems)
else:
return f"โ Invalid selection.", get_dataset_summary(rag_systems)
except Exception as e:
return f"โ Error removing dataset: {str(e)}", get_dataset_summary(rag_systems)
def get_dataset_choices():
"""Get formatted dataset choices for Gradio dropdown."""
if not rag_systems:
return []
choices = []
for i, (key, system) in enumerate(rag_systems.items(), 1):
meta = system['metadata']
choice_text = f"{i}. {meta['dataset_version']} - {meta['state_filter']} ({meta['record_count']} records)"
if key == current_dataset_key:
choice_text += " [CURRENT]"
choices.append(choice_text)
return choices
def clear_all_datasets_gradio():
"""Clear all loaded datasets - Gradio version."""
global rag_systems, current_dataset_key
if not rag_systems:
return "โน๏ธ No datasets to clear.", ""
count = len(rag_systems)
rag_systems.clear()
current_dataset_key = None
return f"โ
Cleared {count} dataset(s) from memory.", ""
def get_current_dataset_info():
"""Get information about the current dataset."""
global rag_systems, current_dataset_key
if not current_dataset_key or current_dataset_key not in rag_systems:
return "No dataset currently selected."
meta = rag_systems[current_dataset_key]['metadata']
info = f"**Current Dataset:** {meta['dataset_version']} - {meta['state_filter']}\n"
info += f"- Records: {meta['record_count']}\n"
info += f"- Chunks: {meta['chunk_count']}\n"
info += f"- Loaded: {meta['loaded_at'][:19]}"
return info
# def ask_question_gradio(question, chat_history):
# """Ask a question to the current dataset - Gradio version."""
# global rag_systems, current_dataset_key
# if not current_dataset_key or current_dataset_key not in rag_systems:
# response = "โ No dataset selected. Please load a dataset first."
# chat_history.append((question, response))
# return "", chat_history
# # Get the dataset
# system = rag_systems[current_dataset_key]
# meta = system['metadata']
# try:
# # Use the chain to get a response
# result = system['conversation_chain'].invoke({"question": question})
# answer = result["answer"]
# # Add dataset source information
# answer += f"\n\n*Source: {meta['dataset_version']} - {meta['state_filter']} ({meta['record_count']} records)*"
# # Update chat history
# chat_history.append((question, answer))
# return "", chat_history
# except Exception as e:
# error_response = f"โ Error processing query: {str(e)}"
# chat_history.append((question, error_response))
# return "", chat_history
# def ask_global_question_gradio(question, chat_history):
# """Ask a question that might require knowledge from all loaded datasets."""
# global rag_systems
# if not rag_systems:
# response = "โ No datasets loaded. Please load datasets first."
# chat_history.append((question, response))
# return "", chat_history
# # Check if this is a global question about the datasets themselves
# global_keywords = ['how many', 'which years', 'what years', 'what quarters', 'how many years',
# 'which quarters', 'time period', 'date range', 'all datasets', 'datasets',
# 'compare', 'comparison', 'difference', 'trend', 'over time']
# is_global_question = any(keyword in question.lower() for keyword in global_keywords)
# # Check if the question mentions a specific state
# mentioned_state = None
# question_lower = question.lower()
# # Check for state names
# for code, name in STATE_NAMES.items():
# if code and (code.lower() in question_lower or name.lower() in question_lower):
# mentioned_state = code
# break
# try:
# if mentioned_state and not is_global_question:
# # Find all datasets for that state
# suitable_datasets = []
# for key, system in rag_systems.items():
# meta = system['metadata']
# state_filter = meta['state_filter']
# # Check if this dataset matches the mentioned state
# if mentioned_state in state_filter or STATE_NAMES[mentioned_state] in state_filter:
# suitable_datasets.append(key)
# if suitable_datasets:
# response = f"๐ Found {len(suitable_datasets)} dataset(s) for {STATE_NAMES[mentioned_state]}:\n\n"
# # Query each suitable dataset
# all_results = []
# for dataset_key in suitable_datasets:
# system = rag_systems[dataset_key]
# meta = system['metadata']
# try:
# result = system['conversation_chain'].invoke({"question": question})
# answer = result["answer"]
# all_results.append({
# 'dataset': f"{meta['dataset_version']} - {meta['state_filter']}",
# 'answer': answer
# })
# except Exception as e:
# all_results.append({
# 'dataset': f"{meta['dataset_version']} - {meta['state_filter']}",
# 'answer': f"Error: {str(e)}"
# })
# # Format combined response
# for result in all_results:
# response += f"**{result['dataset']}**\n{result['answer']}\n\n---\n\n"
# chat_history.append((question, response))
# return "", chat_history
# else:
# response = f"โน๏ธ No datasets found for {STATE_NAMES[mentioned_state]}. Please load data for this state first."
# chat_history.append((question, response))
# return "", chat_history
# elif is_global_question:
# # Create a summary of all available datasets
# dataset_summary = generate_dataset_metadata_summary()
# # Create a system message that includes this metadata
# llm = ChatOpenAI(temperature=0.7, model_name=DEFAULT_MODEL)
# system_message = f"""You are an expert on Medicare Provider data. You have access to multiple datasets spanning different quarters and years.
# {dataset_summary}
# When answering questions, consider the metadata about all available datasets. For questions about time periods, years, quarters, or trends, use the information about which datasets are loaded."""
# messages = [
# {"role": "system", "content": system_message},
# {"role": "user", "content": question}
# ]
# response = llm.invoke(messages)
# answer = response.content
# chat_history.append((question, answer))
# return "", chat_history
# else:
# # For non-global questions without specific state mention, use the current dataset
# return ask_question_gradio(question, chat_history)
# except Exception as e:
# error_response = f"โ Error processing global query: {str(e)}"
# chat_history.append((question, error_response))
# return "", chat_history
def ask_question_gradio(question, chat_history):
"""Ask a question to the current dataset - Fixed version with proper memory handling."""
global rag_systems, current_dataset_key
if not current_dataset_key or current_dataset_key not in rag_systems:
response = "โ No dataset selected. Please load a dataset first."
chat_history.append((question, response))
return "", chat_history
# Get the dataset
system = rag_systems[current_dataset_key]
meta = system['metadata']
try:
# Create a more specific system prompt
system_prompt = f"""You are a helpful assistant analyzing Medicare Provider data.
Current Dataset Information:
- Dataset: {meta['dataset_version']} - {meta['state_filter']}
- Total Records: {meta['record_count']}
- Total Chunks: {meta['chunk_count']}
Important Instructions:
1. ALWAYS respond in English
2. Use the provided context to answer questions
3. If you can find relevant information in the context, provide a detailed answer
4. Only say "I don't know" if the information is genuinely not available in the context
5. Be specific and cite numbers when available
6. For questions about counts or statistics, check the context carefully
Remember: You have access to Medicare provider data including provider types, names, locations, and other details."""
# Create a new conversation chain with better configuration
llm = ChatOpenAI(
temperature=0.3, # Lower temperature for more consistent answers
model_name=DEFAULT_MODEL
)
# Create a new memory for this conversation - WITHOUT adding system message to memory
memory = ConversationBufferMemory(
memory_key='chat_history',
return_messages=True,
output_key='answer'
)
# Create retriever with better settings
retriever = system['vector_store'].as_retriever(
search_kwargs={"k": 10} # Retrieve more documents for better context
)
# Create conversation chain with explicit prompting
from langchain.chains import ConversationalRetrievalChain
from langchain.prompts import PromptTemplate
# Include system prompt in the qa_prompt template instead
qa_prompt = PromptTemplate(
template=f"""{system_prompt}
Context from the dataset:
{{context}}
Chat History:
{{chat_history}}
Human Question: {{question}}
Instructions:
- Answer based on the context provided
- Be specific and mention numbers/counts when available
- Respond ONLY in English
- If the context contains relevant information, use it to provide a detailed answer
- Only say you don't know if the information is truly not in the context
Assistant Answer:""",
input_variables=["context", "chat_history", "question"]
)
conversation_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
memory=memory,
combine_docs_chain_kwargs={"prompt": qa_prompt},
verbose=False
)
# Use the chain to get a response
result = conversation_chain.invoke({"question": question})
answer = result["answer"]
# Ensure the answer is in English and makes sense
if not answer or len(answer) < 10:
# Try a direct query if the answer seems too short
direct_query = f"Based on the {meta['dataset_version']} {meta['state_filter']} Medicare data with {meta['record_count']} records, {question}"
result = conversation_chain.invoke({"question": direct_query})
answer = result["answer"]
# Add dataset source information
answer += f"\n\n*Source: {meta['dataset_version']} - {meta['state_filter']} ({meta['record_count']} records)*"
# Update chat history
chat_history.append((question, answer))
return "", chat_history
except Exception as e:
error_response = f"โ Error processing query: {str(e)}\n\nPlease try rephrasing your question."
chat_history.append((question, error_response))
return "", chat_history
def ask_global_question_gradio(question, chat_history):
"""Ask a question that might require knowledge from all loaded datasets - Fixed version."""
global rag_systems
if not rag_systems:
response = "โ No datasets loaded. Please load datasets first."
chat_history.append((question, response))
return "", chat_history
# Check if this is a global question about the datasets themselves
global_keywords = ['how many', 'which years', 'what years', 'what quarters', 'how many years',
'which quarters', 'time period', 'date range', 'all datasets', 'datasets',
'compare', 'comparison', 'difference', 'trend', 'over time']
is_global_question = any(keyword in question.lower() for keyword in global_keywords)
# Check if the question mentions a specific state
mentioned_state = None
question_lower = question.lower()
# Check for state names
for code, name in STATE_NAMES.items():
if code and (code.lower() in question_lower or name.lower() in question_lower):
mentioned_state = code
break
try:
if mentioned_state and not is_global_question:
# Find all datasets for that state
suitable_datasets = []
for key, system in rag_systems.items():
meta = system['metadata']
state_filter = meta['state_filter']
# Check if this dataset matches the mentioned state
if mentioned_state in state_filter or STATE_NAMES[mentioned_state] in state_filter:
suitable_datasets.append(key)
if suitable_datasets:
response = f"๐ Found {len(suitable_datasets)} dataset(s) for {STATE_NAMES[mentioned_state]}:\n\n"
# Query each suitable dataset
all_results = []
for dataset_key in suitable_datasets:
system = rag_systems[dataset_key]
meta = system['metadata']
# Use the improved query function
original_key = current_dataset_key
current_dataset_key = dataset_key
# Create a temporary chat history for this query
temp_history = []
_, temp_history = ask_question_gradio(question, temp_history)
if temp_history:
answer = temp_history[0][1]
# Remove the source line as we'll add our own
if "*Source:" in answer:
answer = answer.split("*Source:")[0].strip()
all_results.append({
'dataset': f"{meta['dataset_version']} - {meta['state_filter']}",
'answer': answer
})
current_dataset_key = original_key
# Format combined response
for result in all_results:
response += f"**{result['dataset']}**\n{result['answer']}\n\n---\n\n"
chat_history.append((question, response))
return "", chat_history
else:
response = f"โน๏ธ No datasets found for {STATE_NAMES[mentioned_state]}. Please load data for this state first."
chat_history.append((question, response))
return "", chat_history
elif is_global_question:
# Create a summary of all available datasets
dataset_summary = generate_dataset_metadata_summary()
# Create a system message that includes this metadata
llm = ChatOpenAI(
temperature=0.3,
model_name=DEFAULT_MODEL,
model_kwargs={"response_format": {"type": "text"}}
)
system_message = f"""You are an expert on Medicare Provider data analysis.
Always respond in English.
{dataset_summary}
When answering questions:
1. Consider the metadata about all available datasets
2. For questions about time periods, years, quarters, or trends, use the dataset information
3. Be specific about which datasets contain what information
4. Always respond in clear, professional English"""
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": question}
]
response = llm.invoke(messages)
answer = response.content
chat_history.append((question, answer))
return "", chat_history
else:
# For non-global questions without specific state mention, use the current dataset
return ask_question_gradio(question, chat_history)
except Exception as e:
error_response = f"โ Error processing global query: {str(e)}\n\nPlease try rephrasing your question."
chat_history.append((question, error_response))
return "", chat_history
def generate_dataset_metadata_summary():
"""Generate a detailed summary of dataset metadata."""
if not rag_systems:
return "No datasets loaded."
summary = "# Available Datasets\n\n"
summary += "The following datasets are currently loaded:\n\n"
# Group by year
years = set()
quarters_by_year = {}
states = set()
for key, system in rag_systems.items():
meta = system['metadata']
version = meta['dataset_version']
state = meta['state_filter']
# Extract year from version (e.g., "Q1 2025" -> "2025")
if ' ' in version:
year = version.split(' ')[1]
quarter = version.split(' ')[0]
years.add(year)
states.add(state)
if year not in quarters_by_year:
quarters_by_year[year] = set()
quarters_by_year[year].add(quarter)
# Format the summary
summary += "## Years Available\n"
summary += ", ".join(sorted(list(years))) + "\n\n"
summary += "## Quarters Available by Year\n"
for year in sorted(quarters_by_year.keys()):
summary += f"- {year}: {', '.join(sorted(list(quarters_by_year[year])))}\n"
summary += "\n## States Available\n"
summary += ", ".join(sorted(list(states))) + "\n\n"
summary += "## Full Dataset List\n"
for key, system in rag_systems.items():
meta = system['metadata']
summary += f"- {meta['dataset_version']} - {meta['state_filter']} ({meta['record_count']} records)\n"
return summary
def compare_datasets_gradio(question, dataset_indices):
"""Compare multiple datasets by asking the same question - Gradio version."""
global rag_systems
if not rag_systems:
return "โ No datasets loaded. Please load datasets first."
if not dataset_indices or len(dataset_indices) < 2:
return "โ Please select at least 2 datasets to compare."
# Parse indices and get dataset keys
selected_keys = []
for selection in dataset_indices:
try:
index = int(selection.split(".")[0])
if 1 <= index <= len(rag_systems):
key = list(rag_systems.keys())[index - 1]
selected_keys.append(key)
except:
continue
if len(selected_keys) < 2:
return "โ Could not parse selected datasets."
comparison_result = f"# Comparison: {question}\n\n"
# Query each selected dataset
for key in selected_keys:
system = rag_systems[key]
meta = system['metadata']
dataset_name = f"{meta['dataset_version']} - {meta['state_filter']}"
comparison_result += f"## {dataset_name}\n\n"
try:
result = system['conversation_chain'].invoke({"question": question})
answer = result["answer"]
comparison_result += f"{answer}\n\n"
except Exception as e:
comparison_result += f"Error: {str(e)}\n\n"
comparison_result += "---\n\n"
return comparison_result
# def analyze_provider_types_gradio(dataset_key=None):
# """Analyze provider types in a dataset - Gradio version."""
# global rag_systems, current_dataset_key
# # Determine which dataset to use
# target_key = dataset_key if dataset_key and dataset_key in rag_systems else current_dataset_key
# if not target_key or target_key not in rag_systems:
# return "โ No dataset selected."
# system = rag_systems[target_key]
# meta = system['metadata']
# analysis_question = """
# Analyze the provider types in this dataset:
# 1. What are the most common provider types?
# 2. How many unique provider types are there?
# 3. What percentage of providers fall into each major category?
# Please provide a detailed breakdown.
# """
# try:
# result = system['conversation_chain'].invoke({"question": analysis_question})
# analysis = f"# Provider Type Analysis\n"
# analysis += f"**Dataset:** {meta['dataset_version']} - {meta['state_filter']}\n\n"
# analysis += result["answer"]
# return analysis
# except Exception as e:
# return f"โ Error analyzing provider types: {str(e)}"
def analyze_provider_types_gradio(dataset_key=None):
"""Analyze provider types in a dataset - Fixed version with better prompting."""
global rag_systems, current_dataset_key
# Determine which dataset to use
target_key = dataset_key if dataset_key and dataset_key in rag_systems else current_dataset_key
if not target_key or target_key not in rag_systems:
return "โ No dataset selected."
system = rag_systems[target_key]
meta = system['metadata']
# Create a specific analysis prompt
analysis_prompt = f"""Analyze the Medicare provider data from {meta['dataset_version']} - {meta['state_filter']}.
Please provide:
1. A list of the most common provider types (with counts if available)
2. The total number of unique provider types
3. A breakdown by major categories (practitioners, facilities, suppliers, etc.)
4. Any notable patterns or insights
Use the actual data from the context to provide specific numbers and percentages.
Respond only in English and be as detailed as possible based on the available data."""
try:
# Create a temporary chat history for this analysis
temp_history = []
original_key = current_dataset_key
current_dataset_key = target_key
_, temp_history = ask_question_gradio(analysis_prompt, temp_history)
current_dataset_key = original_key
if temp_history and len(temp_history) > 0:
analysis = temp_history[0][1]
# Clean up the source line
if "*Source:" in analysis:
analysis = analysis.split("*Source:")[0].strip()
formatted_analysis = f"# Provider Type Analysis\n"
formatted_analysis += f"**Dataset:** {meta['dataset_version']} - {meta['state_filter']}\n"
formatted_analysis += f"**Total Records:** {meta['record_count']}\n\n"
formatted_analysis += analysis
return formatted_analysis
else:
return "โ Could not analyze provider types. Please try again."
except Exception as e:
return f"โ Error analyzing provider types: {str(e)}"
def clear_chat_history():
"""Clear the chat history."""
return []
def visualize_datasets_gradio(dataset_indices, dimensions, sample_size=1000):
"""Create a visualization of one or more datasets - Gradio version."""
global rag_systems
if not rag_systems:
return None, "โ No datasets loaded. Please load datasets first."
if not dataset_indices:
return None, "โ Please select at least one dataset to visualize."
# Parse indices and get dataset keys
selected_keys = []
for selection in dataset_indices:
try:
index = int(selection.split(".")[0])
if 1 <= index <= len(rag_systems):
key = list(rag_systems.keys())[index - 1]
selected_keys.append(key)
except:
continue
if not selected_keys:
return None, "โ Could not parse selected datasets."
try:
# Create a combined visualization
all_vectors = []
all_metadata = []
all_contents = []
all_dataset_labels = []
status_msg = f"Processing {len(selected_keys)} dataset(s)...\n"
# Collect vectors from all requested datasets
for key in selected_keys:
vector_store = rag_systems[key]['vector_store']
meta = rag_systems[key]['metadata']
dataset_label = f"{meta['dataset_version']} - {meta['state_filter']}"
# Limit vectors for performance
num_vectors = min(sample_size, vector_store.index.ntotal)
status_msg += f"- {dataset_label}: {num_vectors} vectors\n"
for i in range(num_vectors):
all_vectors.append(vector_store.index.reconstruct(i))
doc_id = vector_store.index_to_docstore_id[i]
document = vector_store.docstore.search(doc_id)
all_metadata.append(document.metadata)
all_contents.append(document.page_content)
all_dataset_labels.append(dataset_label)
if not all_vectors:
return None, "โ No vectors to visualize."
vectors = np.array(all_vectors)
status_msg += f"\nTotal vectors: {len(all_vectors)}\n"
# Reduce dimensionality
status_msg += f"Reducing dimensionality to {dimensions}D using t-SNE..."
tsne = TSNE(n_components=dimensions, random_state=42, perplexity=min(30, len(all_vectors)-1))
reduced_vectors = tsne.fit_transform(vectors)
# Create color mapping based on dataset
unique_labels = list(set(all_dataset_labels))
colors = []
color_palette = [
'#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
'#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf'
]
color_map = {label: color_palette[i % len(color_palette)]
for i, label in enumerate(unique_labels)}
colors = [color_map[label] for label in all_dataset_labels]
# Create hover text
hover_texts = []
for meta, content, label in zip(all_metadata, all_contents, all_dataset_labels):
text = f"Dataset: {label}
"
# Add key metadata fields
key_fields = ['STATE_CD', 'PROVIDER_TYPE_DESC', 'FIRST_NAME', 'LAST_NAME', 'ORG_NAME']
for field in key_fields:
if field in meta and meta[field]:
text += f"{field}: {meta[field]}
"
# Add a preview of the content
content_preview = content[:200] + "..." if len(content) > 200 else content
text += f"
Preview: {content_preview}"
hover_texts.append(text)
# Create visualization
if dimensions == 2:
fig = go.Figure()
# Add a trace for each dataset
for label in unique_labels:
# Get indices for this dataset
indices = [i for i, l in enumerate(all_dataset_labels) if l == label]
# Add the scatter trace
fig.add_trace(go.Scatter(
x=reduced_vectors[indices, 0],
y=reduced_vectors[indices, 1],
mode='markers',
marker=dict(
size=6,
color=color_map[label],
opacity=0.7,
line=dict(width=1, color='white')
),
text=[hover_texts[i] for i in indices],
hoverinfo='text',
hoverlabel=dict(bgcolor="white", font_size=12),
name=label
))
fig.update_layout(
title={
'text': 'Medicare Provider Data - 2D Vector Space Visualization',
'font': {'size': 20}
},
xaxis_title='Dimension 1',
yaxis_title='Dimension 2',
width=900,
height=700,
hovermode='closest',
template='plotly_white',
legend=dict(
yanchor="top",
y=0.99,
xanchor="left",
x=0.01,
bgcolor="rgba(255,255,255,0.8)"
)
)
else: # 3D
fig = go.Figure()
# Add a trace for each dataset
for label in unique_labels:
# Get indices for this dataset
indices = [i for i, l in enumerate(all_dataset_labels) if l == label]
# Add the scatter trace
fig.add_trace(go.Scatter3d(
x=reduced_vectors[indices, 0],
y=reduced_vectors[indices, 1],
z=reduced_vectors[indices, 2],
mode='markers',
marker=dict(
size=5,
color=color_map[label],
opacity=0.7,
line=dict(width=1, color='white')
),
text=[hover_texts[i] for i in indices],
hoverinfo='text',
hoverlabel=dict(bgcolor="white", font_size=12),
name=label
))
fig.update_layout(
title={
'text': 'Medicare Provider Data - 3D Vector Space Visualization',
'font': {'size': 20}
},
scene=dict(
xaxis_title='Dimension 1',
yaxis_title='Dimension 2',
zaxis_title='Dimension 3',
camera=dict(
eye=dict(x=1.5, y=1.5, z=1.5)
)
),
width=900,
height=700,
template='plotly_white',
legend=dict(
yanchor="top",
y=0.99,
xanchor="left",
x=0.01,
bgcolor="rgba(255,255,255,0.8)"
)
)
success_msg = f"โ
Successfully created {dimensions}D visualization with {len(all_vectors)} vectors from {len(selected_keys)} dataset(s)"
return fig, success_msg
except Exception as e:
return None, f"โ Error creating visualization: {str(e)}"
def create_dataset_statistics_plot(dataset_indices):
"""Create statistical plots for selected datasets."""
global rag_systems
if not rag_systems:
return None, "โ No datasets loaded."
if not dataset_indices:
return None, "โ Please select at least one dataset."
# Parse indices and get dataset keys
selected_keys = []
for selection in dataset_indices:
try:
index = int(selection.split(".")[0])
if 1 <= index <= len(rag_systems):
key = list(rag_systems.keys())[index - 1]
selected_keys.append(key)
except:
continue
if not selected_keys:
return None, "โ Could not parse selected datasets."
try:
# Collect statistics
dataset_names = []
record_counts = []
chunk_counts = []
for key in selected_keys:
meta = rag_systems[key]['metadata']
dataset_names.append(f"{meta['dataset_version']}
{meta['state_filter']}")
record_counts.append(meta['record_count'])
chunk_counts.append(meta['chunk_count'])
# Create subplots
from plotly.subplots import make_subplots
fig = make_subplots(
rows=1, cols=2,
subplot_titles=('Records per Dataset', 'Chunks per Dataset'),
specs=[[{'type': 'bar'}, {'type': 'bar'}]]
)
# Add record count bars
fig.add_trace(
go.Bar(
x=dataset_names,
y=record_counts,
name='Records',
marker_color='lightblue',
text=record_counts,
textposition='auto',
),
row=1, col=1
)
# Add chunk count bars
fig.add_trace(
go.Bar(
x=dataset_names,
y=chunk_counts,
name='Chunks',
marker_color='lightgreen',
text=chunk_counts,
textposition='auto',
),
row=1, col=2
)
fig.update_layout(
title={
'text': 'Dataset Statistics Overview',
'font': {'size': 20}
},
showlegend=False,
height=500,
template='plotly_white'
)
fig.update_xaxes(tickangle=-45)
return fig, f"โ
Created statistics plot for {len(selected_keys)} dataset(s)"
except Exception as e:
return None, f"โ Error creating statistics plot: {str(e)}"
def inspect_dataset_gradio(num_samples):
"""Display sample documents from the current dataset - Gradio version."""
global rag_systems, current_dataset_key
if not current_dataset_key or current_dataset_key not in rag_systems:
return "โ No dataset selected. Please load a dataset first."
# Get the dataset
system = rag_systems[current_dataset_key]
vector_store = system['vector_store']
meta = system['metadata']
inspection_result = f"# Dataset Inspection\n\n"
inspection_result += f"**Dataset:** {meta['dataset_version']} - {meta['state_filter']}\n"
inspection_result += f"**Total documents:** {vector_store.index.ntotal}\n"
inspection_result += f"**Showing:** {min(num_samples, vector_store.index.ntotal)} sample documents\n\n"
inspection_result += "---\n\n"
for i in range(min(num_samples, vector_store.index.ntotal)):
try:
doc_id = vector_store.index_to_docstore_id[i]
document = vector_store.docstore.search(doc_id)
inspection_result += f"### Document {i+1}\n\n"
inspection_result += "**Metadata:**\n"
# Show key metadata fields
key_fields = ['PROVIDER_TYPE_DESC', 'STATE_CD', 'FIRST_NAME', 'LAST_NAME',
'ORG_NAME', 'NPI', 'ENRLMT_ID']
for field in key_fields:
if field in document.metadata and document.metadata[field]:
inspection_result += f"- **{field}:** {document.metadata[field]}\n"
# Show content preview
content_preview = document.page_content[:500] + "..." if len(document.page_content) > 500 else document.page_content
inspection_result += f"\n**Content Preview:**\n```\n{content_preview}\n```\n\n"
inspection_result += "---\n\n"
except Exception as e:
inspection_result += f"Error retrieving document {i}: {str(e)}\n\n"
return inspection_result
def create_gradio_interface():
"""Create the main Gradio interface."""
with gr.Blocks(theme=theme, title="Medicare Provider Data Analysis System") as app:
# Header
gr.Markdown(
"""
# ๐ฅ Medicare Provider Data Analysis System
This system allows you to load, query, and analyze Medicare provider data using advanced RAG (Retrieval-Augmented Generation) technology.
---
"""
)
# Main tabs
with gr.Tabs() as tabs:
# Tab 1: Dataset Management
with gr.Tab("๐ Dataset Management"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Load New Dataset")
version_dropdown = gr.Dropdown(
choices=list(DATASET_VERSIONS.keys()),
label="Select Quarter/Year",
value="Q1 2025"
)
state_dropdown = gr.Dropdown(
choices=format_state_options(),
label="Select State",
value=""
)
max_records_slider = gr.Slider(
minimum=100,
maximum=5000,
value=1000,
step=100,
label="Maximum Records"
)
use_sample_checkbox = gr.Checkbox(
label="Load sample only (100 records)",
value=True
)
load_button = gr.Button("๐ Load Dataset", variant="primary")
load_output = gr.Textbox(label="Loading Status", lines=3)
with gr.Column(scale=1):
gr.Markdown("### Manage Loaded Datasets")
dataset_summary = gr.Markdown(get_dataset_summary(rag_systems))
with gr.Row():
dataset_selector = gr.Dropdown(
choices=get_dataset_choices(),
label="Select Dataset",
interactive=True
)
with gr.Row():
switch_button = gr.Button("โ๏ธ Switch Dataset")
remove_button = gr.Button("๐๏ธ Remove Dataset")
clear_all_button = gr.Button("๐งน Clear All", variant="stop")
manage_output = gr.Textbox(label="Status", lines=2)
# Wire up dataset management events
def update_dataset_selector():
return gr.update(choices=get_dataset_choices())
load_button.click(
fn=load_dataset_gradio,
inputs=[version_dropdown, state_dropdown, max_records_slider, use_sample_checkbox],
outputs=[load_output, dataset_summary]
).then(
fn=update_dataset_selector,
outputs=dataset_selector
)
switch_button.click(
fn=switch_dataset_gradio,
inputs=dataset_selector,
outputs=[manage_output, dataset_summary]
)
remove_button.click(
fn=remove_dataset_gradio,
inputs=dataset_selector,
outputs=[manage_output, dataset_summary]
).then(
fn=update_dataset_selector,
outputs=dataset_selector
)
clear_all_button.click(
fn=clear_all_datasets_gradio,
outputs=[manage_output, dataset_summary]
).then(
fn=update_dataset_selector,
outputs=dataset_selector
)
# Tab 2: Query Interface
with gr.Tab("๐ฌ Query & Chat"):
gr.Markdown("### Ask Questions About Your Data")
current_dataset_info = gr.Markdown(get_current_dataset_info())
# Create a timer to update current dataset info
timer = gr.Timer(value=2)
timer.tick(fn=get_current_dataset_info, outputs=current_dataset_info)
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(
label="Conversation",
height=500,
show_copy_button=True
)
with gr.Row():
question_input = gr.Textbox(
label="Your Question",
placeholder="Ask about provider types, locations, statistics, etc.",
lines=2,
scale=4
)
with gr.Column(scale=1):
ask_button = gr.Button("๐ค Ask Current Dataset", variant="primary")
global_ask_button = gr.Button("๐ Ask All Datasets")
clear_chat_button = gr.Button("๐๏ธ Clear Chat")
with gr.Column(scale=1):
gr.Markdown("### Quick Actions")
analyze_providers_button = gr.Button("๐ Analyze Provider Types")
gr.Markdown("### Example Questions")
example_questions = [
"What are the most common provider types?",
"How many providers are in this dataset?",
"Show me all psychiatrists in the data",
"What types of medical facilities are included?",
"Compare provider counts across different quarters"
]
for eq in example_questions:
gr.Button(eq, size="sm").click(
lambda q=eq: (q, gr.update()),
outputs=[question_input, chatbot]
)
# Wire up query events
question_input.submit(
fn=ask_question_gradio,
inputs=[question_input, chatbot],
outputs=[question_input, chatbot]
)
ask_button.click(
fn=ask_question_gradio,
inputs=[question_input, chatbot],
outputs=[question_input, chatbot]
)
global_ask_button.click(
fn=ask_global_question_gradio,
inputs=[question_input, chatbot],
outputs=[question_input, chatbot]
)
clear_chat_button.click(
fn=clear_chat_history,
outputs=chatbot
)
analyze_providers_button.click(
fn=lambda: ("", [(
"Analyze provider types in the current dataset",
analyze_provider_types_gradio()
)]),
outputs=[question_input, chatbot]
)
# Tab 3: Comparison & Analysis
with gr.Tab("๐ Compare Datasets"):
gr.Markdown("### Compare Multiple Datasets")
with gr.Row():
compare_dataset_selector = gr.CheckboxGroup(
choices=get_dataset_choices(),
label="Select Datasets to Compare (choose 2 or more)",
value=[]
)
compare_question = gr.Textbox(
label="Comparison Question",
placeholder="Enter a question to ask all selected datasets",
lines=2
)
compare_button = gr.Button("๐ Compare Datasets", variant="primary")
comparison_output = gr.Markdown(label="Comparison Results")
# Update checkbox choices when datasets change
def update_compare_selector():
return gr.update(choices=get_dataset_choices())
timer.tick(fn=update_compare_selector, outputs=compare_dataset_selector)
compare_button.click(
fn=compare_datasets_gradio,
inputs=[compare_question, compare_dataset_selector],
outputs=comparison_output
)
# Tab 4: Visualization
with gr.Tab("๐ Visualizations"):
gr.Markdown("### Dataset Visualizations")
with gr.Row():
with gr.Column():
viz_dataset_selector = gr.CheckboxGroup(
choices=get_dataset_choices(),
label="Select Datasets to Visualize",
value=[]
)
viz_dimension = gr.Radio(
choices=[2, 3],
value=2,
label="Visualization Dimensions"
)
viz_sample_size = gr.Slider(
minimum=100,
maximum=2000,
value=500,
step=100,
label="Sample Size (per dataset)"
)
create_viz_button = gr.Button("๐จ Create Visualization", variant="primary")
stats_button = gr.Button("๐ Show Statistics")
viz_status = gr.Textbox(label="Status", lines=2)
with gr.Row():
viz_plot = gr.Plot(label="Vector Space Visualization")
stats_plot = gr.Plot(label="Dataset Statistics")
# Update visualization selector
def update_viz_selector():
return gr.update(choices=get_dataset_choices())
timer.tick(fn=update_viz_selector, outputs=viz_dataset_selector)
create_viz_button.click(
fn=visualize_datasets_gradio,
inputs=[viz_dataset_selector, viz_dimension, viz_sample_size],
outputs=[viz_plot, viz_status]
)
stats_button.click(
fn=create_dataset_statistics_plot,
inputs=[viz_dataset_selector],
outputs=[stats_plot, viz_status]
)
# Tab 5: Dataset Inspector
with gr.Tab("๐ Dataset Inspector"):
gr.Markdown("### Inspect Dataset Contents")
inspect_current_info = gr.Markdown(get_current_dataset_info())
timer.tick(fn=get_current_dataset_info, outputs=inspect_current_info)
num_samples_slider = gr.Slider(
minimum=1,
maximum=20,
value=5,
step=1,
label="Number of Sample Documents"
)
inspect_button = gr.Button("๐ Inspect Current Dataset", variant="primary")
inspection_output = gr.Markdown(label="Dataset Inspection Results")
inspect_button.click(
fn=inspect_dataset_gradio,
inputs=num_samples_slider,
outputs=inspection_output
)
# Tab 6: Settings & Help
with gr.Tab("โ๏ธ Settings & Help"):
gr.Markdown(
"""
### System Information
**Model:** GPT-4 Mini
**Embedding Model:** OpenAI Embeddings
**Vector Store:** FAISS
### API Configuration
This system uses the CMS.gov Data API to fetch Medicare provider information.
### Tips for Best Results
1. **Loading Data**: Start with sample data (100 records) to test queries quickly
2. **State Selection**: Load specific states for focused analysis
3. **Querying**: Be specific in your questions for better results
4. **Comparisons**: Load multiple quarters/states to analyze trends
### Common Use Cases
- **Provider Analysis**: Find specific types of healthcare providers
- **Geographic Distribution**: Analyze providers by state
- **Temporal Trends**: Compare data across different quarters
- **Provider Types**: Understand the distribution of specialties
### Troubleshooting
- **No API Key**: Ensure OPENAI_API_KEY is set in your environment
- **Loading Errors**: Check your internet connection and API limits
- **Query Errors**: Try rephrasing your question or check if data is loaded
"""
)
with gr.Row():
gr.Markdown("### Current Configuration")
config_info = gr.JSON(
value={
"api_key_set": bool(os.getenv('OPENAI_API_KEY')),
"default_model": DEFAULT_MODEL,
"api_base_url": API_BASE_URL,
"datasets_loaded": len(rag_systems)
},
label="System Configuration"
)
# Footer
gr.Markdown(
"""
---