import os import json import time import requests import numpy as np import pandas as pd from datetime import datetime from typing import Dict, List, Any, Optional, Tuple import gradio as gr from dotenv import load_dotenv # Vector DB and embedding imports from langchain.vectorstores import FAISS from langchain_openai import OpenAIEmbeddings from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.schema import Document from langchain_openai import ChatOpenAI from langchain.chains import ConversationalRetrievalChain from langchain.memory import ConversationBufferMemory # Visualization imports import plotly.graph_objects as go from sklearn.manifold import TSNE # Load environment variables load_dotenv() # Check if OPENAI_API_KEY is set OPENAI_API_KEY = os.getenv('OPENAI_API_KEY') if not OPENAI_API_KEY: print("โš ๏ธ Warning: OPENAI_API_KEY not found in environment variables.") # Configuration DEFAULT_DATASET_ID = "2457ea29-fc82-48b0-86ec-3b0755de7515" DEFAULT_MODEL = "gpt-4o-mini" API_BASE_URL = "https://data.cms.gov/data-api/v1" INITIAL_SAMPLE_SIZE = 100 # Start with a small sample # Dataset version mapping DATASET_VERSIONS = { # 2025 Data "Q1 2025": "74edb053-bd01-40a0-91a0-4961c1fe6281", # 2024 Data "Q1 2024": "6d6e0e8d-64cf-43fb-9ba8-e2ad9b9bb21e", "Q2 2024": "04405289-5635-4b2a-a64f-c4b6415ab6ff", "Q3 2024": "e87f09c2-5ff7-4ddf-b60c-6130995b15cf", "Q4 2024": "e9d278e4-90e8-47ab-9c5b-af2ca64bf352", # 2023 Data "Q1 2023": "0b6caf2f-8948-4603-922e-d7f0c52c0a45", "Q2 2023": "46339a0c-0f07-40ed-8975-ddb387c367a4", "Q3 2023": "70efac57-6093-4e1d-ad6a-36f8261f53eb", "Q4 2023": "1df8331a-ed44-41ec-971f-158349658949", # 2022 Data "Q1 2022": "5b678653-aa36-455b-9144-1d073ef7991b", # 2021 Data "Q1 2021": "7b409bba-ca00-426e-9493-1dc10e5340cc", # 2020 Data "Q1 2020": "3870b29c-4312-4fb1-a956-71c148ae5b50", # 2019 Data "Q1 2019": "017e6ab7-7e19-4e98-b4fa-30578b47e578", "Q4 2019": "2c209bdb-ed0c-42e0-b027-8a97024b8035" } # US States for reference US_STATES = [ "", "AL", "AK", "AZ", "AR", "CA", "CO", "CT", "DE", "FL", "GA", "HI", "ID", "IL", "IN", "IA", "KS", "KY", "LA", "ME", "MD", "MA", "MI", "MN", "MS", "MO", "MT", "NE", "NV", "NH", "NJ", "NM", "NY", "NC", "ND", "OH", "OK", "OR", "PA", "RI", "SC", "SD", "TN", "TX", "UT", "VT", "VA", "WA", "WV", "WI", "WY", "DC", "PR", "VI" ] # State names mapping for better UI STATE_NAMES = { "": "All States", "AL": "Alabama", "AK": "Alaska", "AZ": "Arizona", "AR": "Arkansas", "CA": "California", "CO": "Colorado", "CT": "Connecticut", "DE": "Delaware", "FL": "Florida", "GA": "Georgia", "HI": "Hawaii", "ID": "Idaho", "IL": "Illinois", "IN": "Indiana", "IA": "Iowa", "KS": "Kansas", "KY": "Kentucky", "LA": "Louisiana", "ME": "Maine", "MD": "Maryland", "MA": "Massachusetts", "MI": "Michigan", "MN": "Minnesota", "MS": "Mississippi", "MO": "Missouri", "MT": "Montana", "NE": "Nebraska", "NV": "Nevada", "NH": "New Hampshire", "NJ": "New Jersey", "NM": "New Mexico", "NY": "New York", "NC": "North Carolina", "ND": "North Dakota", "OH": "Ohio", "OK": "Oklahoma", "OR": "Oregon", "PA": "Pennsylvania", "RI": "Rhode Island", "SC": "South Carolina", "SD": "South Dakota", "TN": "Tennessee", "TX": "Texas", "UT": "Utah", "VT": "Vermont", "VA": "Virginia", "WA": "Washington", "WV": "West Virginia", "WI": "Wisconsin", "WY": "Wyoming", "DC": "District of Columbia", "PR": "Puerto Rico", "VI": "Virgin Islands" } # Dictionary to store multiple datasets rag_systems = {} current_dataset_key = None # Gradio theme configuration theme = gr.themes.Soft( primary_hue="blue", secondary_hue="gray", neutral_hue="slate", font=gr.themes.GoogleFont("Inter") ) def query_cms_api(version_id, state_filter="", max_records=100): """Query the CMS API with pagination.""" url = f"{API_BASE_URL}/dataset/{version_id}/data" all_records = [] offset = 0 page_size = min(max_records, 100) # Page size, max 100 # Set up filter parameters params = { 'size': page_size, 'offset': 0 } # Add state filter if provided if state_filter and state_filter != "": params[f'filter[STATE_CD]'] = state_filter progress_text = f"Querying CMS API...\n" # Fetch data with pagination while len(all_records) < max_records: params['offset'] = offset try: response = requests.get(url, params=params) if response.status_code != 200: error_msg = f"Error: Status {response.status_code}" return [], error_msg # Parse the response - the API returns a list directly records = response.json() if not records or not isinstance(records, list): if len(all_records) == 0: return [], "No records found" break progress_text += f"Retrieved {len(records)} records (offset: {offset})\n" all_records.extend(records) # If we got fewer records than requested, we've reached the end if len(records) < page_size: break # Move to next page offset += len(records) # Add delay to be nice to the API time.sleep(0.5) except Exception as e: error_msg = f"Error querying API: {str(e)}" return [], error_msg final_records = all_records[:max_records] success_msg = f"Successfully retrieved {len(final_records)} records" return final_records, success_msg def process_records(records, version): """Process CMS API records into documents for the RAG system.""" # Parse version into quarter and year quarter = "Unknown" year = "Unknown" if ' ' in version: parts = version.split(' ') if len(parts) == 2: quarter, year = parts embeddings = OpenAIEmbeddings() text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) # Convert records to documents documents = [] for record in records: # Format the record as text with explicit time information content = [f"Medicare Provider Data from {quarter} {year}"] content.append(f"Time Period: {quarter} of {year}") # Add all fields from the record for key, value in record.items(): if value is not None and value != "": content.append(f"{key}: {value}") text = "\n".join(content) # Create metadata with explicit time fields metadata = { 'dataset_version': version, 'quarter': quarter, 'year': year, 'record_id': record.get('ENRLMT_ID', 'unknown') } # Add all fields to metadata for better searchability for key, value in record.items(): if value is not None and value != "": try: # Convert complex values to strings to avoid serialization issues if not isinstance(value, (str, int, float, bool, type(None))): metadata[key] = str(value) else: metadata[key] = value except: # If there's any issue, convert to string metadata[key] = str(value) documents.append(Document(page_content=text, metadata=metadata)) # Chunk documents chunks = text_splitter.split_documents(documents) # Create vector store vector_store = FAISS.from_documents(chunks, embeddings) return vector_store, len(documents), len(chunks) def create_progress_callback(): """Create a progress callback for long-running operations.""" def callback(message): # In a real Gradio app, this would update a progress bar print(f"Progress: {message}") return callback def validate_api_key(): """Validate that the OpenAI API key is set.""" api_key = os.getenv('OPENAI_API_KEY') if not api_key: return False, "OpenAI API key not found. Please set it in your environment variables or .env file." return True, "API key validated successfully." def get_dataset_summary(rag_systems): """Generate a summary of all loaded datasets.""" if not rag_systems: return "No datasets currently loaded." summary_lines = ["### Currently Loaded Datasets:\n"] for i, (key, system) in enumerate(rag_systems.items(), 1): meta = system['metadata'] summary_lines.append( f"{i}. **{meta['dataset_version']}** - " f"State: {meta['state_filter']} - " f"Records: {meta['record_count']} - " f"Chunks: {meta['chunk_count']}" ) if key == current_dataset_key: summary_lines[-1] += " *(Current)*" summary_lines.append(f"\n**Total datasets loaded:** {len(rag_systems)}") return "\n".join(summary_lines) def format_state_options(): """Format state options for Gradio dropdown.""" options = [] for code in US_STATES: if code == "": options.append(("All States", "")) else: options.append((f"{STATE_NAMES[code]} ({code})", code)) return options def load_dataset_gradio(version, state_filter, max_records, use_sample): """Load data from CMS API and set up the RAG system - Gradio version.""" global rag_systems, current_dataset_key # Validate API key first valid, message = validate_api_key() if not valid: return message, get_dataset_summary(rag_systems) # Generate a unique key for this dataset dataset_key = f"{version}_{state_filter}_{max_records}" # Check if dataset already loaded if dataset_key in rag_systems: current_dataset_key = dataset_key return f"โœ… Dataset already loaded and set as current: {version} - {STATE_NAMES.get(state_filter, 'All States')}", get_dataset_summary(rag_systems) # Get version ID version_id = DATASET_VERSIONS.get(version) if not version_id: return f"โŒ Invalid version: {version}", get_dataset_summary(rag_systems) # Adjust max records if sample actual_max = INITIAL_SAMPLE_SIZE if use_sample else max_records # Status message status_msg = f"๐Ÿ”„ Loading {version} data" if state_filter: status_msg += f" for {STATE_NAMES.get(state_filter, state_filter)}" status_msg += f" (max {actual_max} records)..." try: # Fetch data from API records, api_message = query_cms_api(version_id, state_filter, actual_max) if not records: return f"โŒ Failed to load data: {api_message}", get_dataset_summary(rag_systems) status_msg += f"\nโœ… {api_message}" # Process records and create vector store status_msg += "\n๐Ÿ”„ Processing records and creating vector store..." vector_store, doc_count, chunk_count = process_records(records, version) # Set up RAG system llm = ChatOpenAI(temperature=0.7, model_name=DEFAULT_MODEL) memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True) retriever = vector_store.as_retriever() conversation_chain = ConversationalRetrievalChain.from_llm( llm=llm, retriever=retriever, memory=memory ) # Store in the dictionary rag_systems[dataset_key] = { 'vector_store': vector_store, 'conversation_chain': conversation_chain, 'metadata': { 'dataset_version': version, 'version_id': version_id, 'state_filter': STATE_NAMES.get(state_filter, "All States") if state_filter else "All States", 'record_count': len(records), 'document_count': doc_count, 'chunk_count': chunk_count, 'loaded_at': datetime.now().isoformat() } } # Set as current dataset current_dataset_key = dataset_key success_msg = f"โœ… Successfully loaded {version} - {STATE_NAMES.get(state_filter, 'All States')}\n" success_msg += f"๐Ÿ“Š Created {chunk_count} chunks from {len(records)} records" return success_msg, get_dataset_summary(rag_systems) except Exception as e: error_msg = f"โŒ Error loading data: {str(e)}" return error_msg, get_dataset_summary(rag_systems) def switch_dataset_gradio(dataset_index): """Switch to a different dataset - Gradio version.""" global rag_systems, current_dataset_key if not rag_systems: return "โŒ No datasets loaded.", get_dataset_summary(rag_systems) if not dataset_index: return "โŒ Please select a dataset.", get_dataset_summary(rag_systems) try: # Parse the index from the selection (format: "1. Dataset Name") index = int(dataset_index.split(".")[0]) if 1 <= index <= len(rag_systems): key = list(rag_systems.keys())[index - 1] current_dataset_key = key meta = rag_systems[key]['metadata'] return f"โœ… Switched to: {meta['dataset_version']} - {meta['state_filter']}", get_dataset_summary(rag_systems) else: return f"โŒ Invalid selection.", get_dataset_summary(rag_systems) except: return "โŒ Invalid selection format.", get_dataset_summary(rag_systems) def remove_dataset_gradio(dataset_index): """Remove a dataset from memory - Gradio version.""" global rag_systems, current_dataset_key if not rag_systems: return "โŒ No datasets loaded.", get_dataset_summary(rag_systems) if not dataset_index: return "โŒ Please select a dataset to remove.", get_dataset_summary(rag_systems) try: # Parse the index from the selection index = int(dataset_index.split(".")[0]) if 1 <= index <= len(rag_systems): key = list(rag_systems.keys())[index - 1] meta = rag_systems[key]['metadata'] # Remove the dataset del rag_systems[key] # If this was the current dataset, clear the current key if key == current_dataset_key: current_dataset_key = None # Set another dataset as current if available if rag_systems: current_dataset_key = list(rag_systems.keys())[0] return f"โœ… Removed: {meta['dataset_version']} - {meta['state_filter']}", get_dataset_summary(rag_systems) else: return f"โŒ Invalid selection.", get_dataset_summary(rag_systems) except Exception as e: return f"โŒ Error removing dataset: {str(e)}", get_dataset_summary(rag_systems) def get_dataset_choices(): """Get formatted dataset choices for Gradio dropdown.""" if not rag_systems: return [] choices = [] for i, (key, system) in enumerate(rag_systems.items(), 1): meta = system['metadata'] choice_text = f"{i}. {meta['dataset_version']} - {meta['state_filter']} ({meta['record_count']} records)" if key == current_dataset_key: choice_text += " [CURRENT]" choices.append(choice_text) return choices def clear_all_datasets_gradio(): """Clear all loaded datasets - Gradio version.""" global rag_systems, current_dataset_key if not rag_systems: return "โ„น๏ธ No datasets to clear.", "" count = len(rag_systems) rag_systems.clear() current_dataset_key = None return f"โœ… Cleared {count} dataset(s) from memory.", "" def get_current_dataset_info(): """Get information about the current dataset.""" global rag_systems, current_dataset_key if not current_dataset_key or current_dataset_key not in rag_systems: return "No dataset currently selected." meta = rag_systems[current_dataset_key]['metadata'] info = f"**Current Dataset:** {meta['dataset_version']} - {meta['state_filter']}\n" info += f"- Records: {meta['record_count']}\n" info += f"- Chunks: {meta['chunk_count']}\n" info += f"- Loaded: {meta['loaded_at'][:19]}" return info # def ask_question_gradio(question, chat_history): # """Ask a question to the current dataset - Gradio version.""" # global rag_systems, current_dataset_key # if not current_dataset_key or current_dataset_key not in rag_systems: # response = "โŒ No dataset selected. Please load a dataset first." # chat_history.append((question, response)) # return "", chat_history # # Get the dataset # system = rag_systems[current_dataset_key] # meta = system['metadata'] # try: # # Use the chain to get a response # result = system['conversation_chain'].invoke({"question": question}) # answer = result["answer"] # # Add dataset source information # answer += f"\n\n*Source: {meta['dataset_version']} - {meta['state_filter']} ({meta['record_count']} records)*" # # Update chat history # chat_history.append((question, answer)) # return "", chat_history # except Exception as e: # error_response = f"โŒ Error processing query: {str(e)}" # chat_history.append((question, error_response)) # return "", chat_history # def ask_global_question_gradio(question, chat_history): # """Ask a question that might require knowledge from all loaded datasets.""" # global rag_systems # if not rag_systems: # response = "โŒ No datasets loaded. Please load datasets first." # chat_history.append((question, response)) # return "", chat_history # # Check if this is a global question about the datasets themselves # global_keywords = ['how many', 'which years', 'what years', 'what quarters', 'how many years', # 'which quarters', 'time period', 'date range', 'all datasets', 'datasets', # 'compare', 'comparison', 'difference', 'trend', 'over time'] # is_global_question = any(keyword in question.lower() for keyword in global_keywords) # # Check if the question mentions a specific state # mentioned_state = None # question_lower = question.lower() # # Check for state names # for code, name in STATE_NAMES.items(): # if code and (code.lower() in question_lower or name.lower() in question_lower): # mentioned_state = code # break # try: # if mentioned_state and not is_global_question: # # Find all datasets for that state # suitable_datasets = [] # for key, system in rag_systems.items(): # meta = system['metadata'] # state_filter = meta['state_filter'] # # Check if this dataset matches the mentioned state # if mentioned_state in state_filter or STATE_NAMES[mentioned_state] in state_filter: # suitable_datasets.append(key) # if suitable_datasets: # response = f"๐Ÿ”„ Found {len(suitable_datasets)} dataset(s) for {STATE_NAMES[mentioned_state]}:\n\n" # # Query each suitable dataset # all_results = [] # for dataset_key in suitable_datasets: # system = rag_systems[dataset_key] # meta = system['metadata'] # try: # result = system['conversation_chain'].invoke({"question": question}) # answer = result["answer"] # all_results.append({ # 'dataset': f"{meta['dataset_version']} - {meta['state_filter']}", # 'answer': answer # }) # except Exception as e: # all_results.append({ # 'dataset': f"{meta['dataset_version']} - {meta['state_filter']}", # 'answer': f"Error: {str(e)}" # }) # # Format combined response # for result in all_results: # response += f"**{result['dataset']}**\n{result['answer']}\n\n---\n\n" # chat_history.append((question, response)) # return "", chat_history # else: # response = f"โ„น๏ธ No datasets found for {STATE_NAMES[mentioned_state]}. Please load data for this state first." # chat_history.append((question, response)) # return "", chat_history # elif is_global_question: # # Create a summary of all available datasets # dataset_summary = generate_dataset_metadata_summary() # # Create a system message that includes this metadata # llm = ChatOpenAI(temperature=0.7, model_name=DEFAULT_MODEL) # system_message = f"""You are an expert on Medicare Provider data. You have access to multiple datasets spanning different quarters and years. # {dataset_summary} # When answering questions, consider the metadata about all available datasets. For questions about time periods, years, quarters, or trends, use the information about which datasets are loaded.""" # messages = [ # {"role": "system", "content": system_message}, # {"role": "user", "content": question} # ] # response = llm.invoke(messages) # answer = response.content # chat_history.append((question, answer)) # return "", chat_history # else: # # For non-global questions without specific state mention, use the current dataset # return ask_question_gradio(question, chat_history) # except Exception as e: # error_response = f"โŒ Error processing global query: {str(e)}" # chat_history.append((question, error_response)) # return "", chat_history def ask_question_gradio(question, chat_history): """Ask a question to the current dataset - Fixed version with proper memory handling.""" global rag_systems, current_dataset_key if not current_dataset_key or current_dataset_key not in rag_systems: response = "โŒ No dataset selected. Please load a dataset first." chat_history.append((question, response)) return "", chat_history # Get the dataset system = rag_systems[current_dataset_key] meta = system['metadata'] try: # Create a more specific system prompt system_prompt = f"""You are a helpful assistant analyzing Medicare Provider data. Current Dataset Information: - Dataset: {meta['dataset_version']} - {meta['state_filter']} - Total Records: {meta['record_count']} - Total Chunks: {meta['chunk_count']} Important Instructions: 1. ALWAYS respond in English 2. Use the provided context to answer questions 3. If you can find relevant information in the context, provide a detailed answer 4. Only say "I don't know" if the information is genuinely not available in the context 5. Be specific and cite numbers when available 6. For questions about counts or statistics, check the context carefully Remember: You have access to Medicare provider data including provider types, names, locations, and other details.""" # Create a new conversation chain with better configuration llm = ChatOpenAI( temperature=0.3, # Lower temperature for more consistent answers model_name=DEFAULT_MODEL ) # Create a new memory for this conversation - WITHOUT adding system message to memory memory = ConversationBufferMemory( memory_key='chat_history', return_messages=True, output_key='answer' ) # Create retriever with better settings retriever = system['vector_store'].as_retriever( search_kwargs={"k": 10} # Retrieve more documents for better context ) # Create conversation chain with explicit prompting from langchain.chains import ConversationalRetrievalChain from langchain.prompts import PromptTemplate # Include system prompt in the qa_prompt template instead qa_prompt = PromptTemplate( template=f"""{system_prompt} Context from the dataset: {{context}} Chat History: {{chat_history}} Human Question: {{question}} Instructions: - Answer based on the context provided - Be specific and mention numbers/counts when available - Respond ONLY in English - If the context contains relevant information, use it to provide a detailed answer - Only say you don't know if the information is truly not in the context Assistant Answer:""", input_variables=["context", "chat_history", "question"] ) conversation_chain = ConversationalRetrievalChain.from_llm( llm=llm, retriever=retriever, memory=memory, combine_docs_chain_kwargs={"prompt": qa_prompt}, verbose=False ) # Use the chain to get a response result = conversation_chain.invoke({"question": question}) answer = result["answer"] # Ensure the answer is in English and makes sense if not answer or len(answer) < 10: # Try a direct query if the answer seems too short direct_query = f"Based on the {meta['dataset_version']} {meta['state_filter']} Medicare data with {meta['record_count']} records, {question}" result = conversation_chain.invoke({"question": direct_query}) answer = result["answer"] # Add dataset source information answer += f"\n\n*Source: {meta['dataset_version']} - {meta['state_filter']} ({meta['record_count']} records)*" # Update chat history chat_history.append((question, answer)) return "", chat_history except Exception as e: error_response = f"โŒ Error processing query: {str(e)}\n\nPlease try rephrasing your question." chat_history.append((question, error_response)) return "", chat_history def ask_global_question_gradio(question, chat_history): """Ask a question that might require knowledge from all loaded datasets - Fixed version.""" global rag_systems if not rag_systems: response = "โŒ No datasets loaded. Please load datasets first." chat_history.append((question, response)) return "", chat_history # Check if this is a global question about the datasets themselves global_keywords = ['how many', 'which years', 'what years', 'what quarters', 'how many years', 'which quarters', 'time period', 'date range', 'all datasets', 'datasets', 'compare', 'comparison', 'difference', 'trend', 'over time'] is_global_question = any(keyword in question.lower() for keyword in global_keywords) # Check if the question mentions a specific state mentioned_state = None question_lower = question.lower() # Check for state names for code, name in STATE_NAMES.items(): if code and (code.lower() in question_lower or name.lower() in question_lower): mentioned_state = code break try: if mentioned_state and not is_global_question: # Find all datasets for that state suitable_datasets = [] for key, system in rag_systems.items(): meta = system['metadata'] state_filter = meta['state_filter'] # Check if this dataset matches the mentioned state if mentioned_state in state_filter or STATE_NAMES[mentioned_state] in state_filter: suitable_datasets.append(key) if suitable_datasets: response = f"๐Ÿ”„ Found {len(suitable_datasets)} dataset(s) for {STATE_NAMES[mentioned_state]}:\n\n" # Query each suitable dataset all_results = [] for dataset_key in suitable_datasets: system = rag_systems[dataset_key] meta = system['metadata'] # Use the improved query function original_key = current_dataset_key current_dataset_key = dataset_key # Create a temporary chat history for this query temp_history = [] _, temp_history = ask_question_gradio(question, temp_history) if temp_history: answer = temp_history[0][1] # Remove the source line as we'll add our own if "*Source:" in answer: answer = answer.split("*Source:")[0].strip() all_results.append({ 'dataset': f"{meta['dataset_version']} - {meta['state_filter']}", 'answer': answer }) current_dataset_key = original_key # Format combined response for result in all_results: response += f"**{result['dataset']}**\n{result['answer']}\n\n---\n\n" chat_history.append((question, response)) return "", chat_history else: response = f"โ„น๏ธ No datasets found for {STATE_NAMES[mentioned_state]}. Please load data for this state first." chat_history.append((question, response)) return "", chat_history elif is_global_question: # Create a summary of all available datasets dataset_summary = generate_dataset_metadata_summary() # Create a system message that includes this metadata llm = ChatOpenAI( temperature=0.3, model_name=DEFAULT_MODEL, model_kwargs={"response_format": {"type": "text"}} ) system_message = f"""You are an expert on Medicare Provider data analysis. Always respond in English. {dataset_summary} When answering questions: 1. Consider the metadata about all available datasets 2. For questions about time periods, years, quarters, or trends, use the dataset information 3. Be specific about which datasets contain what information 4. Always respond in clear, professional English""" messages = [ {"role": "system", "content": system_message}, {"role": "user", "content": question} ] response = llm.invoke(messages) answer = response.content chat_history.append((question, answer)) return "", chat_history else: # For non-global questions without specific state mention, use the current dataset return ask_question_gradio(question, chat_history) except Exception as e: error_response = f"โŒ Error processing global query: {str(e)}\n\nPlease try rephrasing your question." chat_history.append((question, error_response)) return "", chat_history def generate_dataset_metadata_summary(): """Generate a detailed summary of dataset metadata.""" if not rag_systems: return "No datasets loaded." summary = "# Available Datasets\n\n" summary += "The following datasets are currently loaded:\n\n" # Group by year years = set() quarters_by_year = {} states = set() for key, system in rag_systems.items(): meta = system['metadata'] version = meta['dataset_version'] state = meta['state_filter'] # Extract year from version (e.g., "Q1 2025" -> "2025") if ' ' in version: year = version.split(' ')[1] quarter = version.split(' ')[0] years.add(year) states.add(state) if year not in quarters_by_year: quarters_by_year[year] = set() quarters_by_year[year].add(quarter) # Format the summary summary += "## Years Available\n" summary += ", ".join(sorted(list(years))) + "\n\n" summary += "## Quarters Available by Year\n" for year in sorted(quarters_by_year.keys()): summary += f"- {year}: {', '.join(sorted(list(quarters_by_year[year])))}\n" summary += "\n## States Available\n" summary += ", ".join(sorted(list(states))) + "\n\n" summary += "## Full Dataset List\n" for key, system in rag_systems.items(): meta = system['metadata'] summary += f"- {meta['dataset_version']} - {meta['state_filter']} ({meta['record_count']} records)\n" return summary def compare_datasets_gradio(question, dataset_indices): """Compare multiple datasets by asking the same question - Gradio version.""" global rag_systems if not rag_systems: return "โŒ No datasets loaded. Please load datasets first." if not dataset_indices or len(dataset_indices) < 2: return "โŒ Please select at least 2 datasets to compare." # Parse indices and get dataset keys selected_keys = [] for selection in dataset_indices: try: index = int(selection.split(".")[0]) if 1 <= index <= len(rag_systems): key = list(rag_systems.keys())[index - 1] selected_keys.append(key) except: continue if len(selected_keys) < 2: return "โŒ Could not parse selected datasets." comparison_result = f"# Comparison: {question}\n\n" # Query each selected dataset for key in selected_keys: system = rag_systems[key] meta = system['metadata'] dataset_name = f"{meta['dataset_version']} - {meta['state_filter']}" comparison_result += f"## {dataset_name}\n\n" try: result = system['conversation_chain'].invoke({"question": question}) answer = result["answer"] comparison_result += f"{answer}\n\n" except Exception as e: comparison_result += f"Error: {str(e)}\n\n" comparison_result += "---\n\n" return comparison_result # def analyze_provider_types_gradio(dataset_key=None): # """Analyze provider types in a dataset - Gradio version.""" # global rag_systems, current_dataset_key # # Determine which dataset to use # target_key = dataset_key if dataset_key and dataset_key in rag_systems else current_dataset_key # if not target_key or target_key not in rag_systems: # return "โŒ No dataset selected." # system = rag_systems[target_key] # meta = system['metadata'] # analysis_question = """ # Analyze the provider types in this dataset: # 1. What are the most common provider types? # 2. How many unique provider types are there? # 3. What percentage of providers fall into each major category? # Please provide a detailed breakdown. # """ # try: # result = system['conversation_chain'].invoke({"question": analysis_question}) # analysis = f"# Provider Type Analysis\n" # analysis += f"**Dataset:** {meta['dataset_version']} - {meta['state_filter']}\n\n" # analysis += result["answer"] # return analysis # except Exception as e: # return f"โŒ Error analyzing provider types: {str(e)}" def analyze_provider_types_gradio(dataset_key=None): """Analyze provider types in a dataset - Fixed version with better prompting.""" global rag_systems, current_dataset_key # Determine which dataset to use target_key = dataset_key if dataset_key and dataset_key in rag_systems else current_dataset_key if not target_key or target_key not in rag_systems: return "โŒ No dataset selected." system = rag_systems[target_key] meta = system['metadata'] # Create a specific analysis prompt analysis_prompt = f"""Analyze the Medicare provider data from {meta['dataset_version']} - {meta['state_filter']}. Please provide: 1. A list of the most common provider types (with counts if available) 2. The total number of unique provider types 3. A breakdown by major categories (practitioners, facilities, suppliers, etc.) 4. Any notable patterns or insights Use the actual data from the context to provide specific numbers and percentages. Respond only in English and be as detailed as possible based on the available data.""" try: # Create a temporary chat history for this analysis temp_history = [] original_key = current_dataset_key current_dataset_key = target_key _, temp_history = ask_question_gradio(analysis_prompt, temp_history) current_dataset_key = original_key if temp_history and len(temp_history) > 0: analysis = temp_history[0][1] # Clean up the source line if "*Source:" in analysis: analysis = analysis.split("*Source:")[0].strip() formatted_analysis = f"# Provider Type Analysis\n" formatted_analysis += f"**Dataset:** {meta['dataset_version']} - {meta['state_filter']}\n" formatted_analysis += f"**Total Records:** {meta['record_count']}\n\n" formatted_analysis += analysis return formatted_analysis else: return "โŒ Could not analyze provider types. Please try again." except Exception as e: return f"โŒ Error analyzing provider types: {str(e)}" def clear_chat_history(): """Clear the chat history.""" return [] def visualize_datasets_gradio(dataset_indices, dimensions, sample_size=1000): """Create a visualization of one or more datasets - Gradio version.""" global rag_systems if not rag_systems: return None, "โŒ No datasets loaded. Please load datasets first." if not dataset_indices: return None, "โŒ Please select at least one dataset to visualize." # Parse indices and get dataset keys selected_keys = [] for selection in dataset_indices: try: index = int(selection.split(".")[0]) if 1 <= index <= len(rag_systems): key = list(rag_systems.keys())[index - 1] selected_keys.append(key) except: continue if not selected_keys: return None, "โŒ Could not parse selected datasets." try: # Create a combined visualization all_vectors = [] all_metadata = [] all_contents = [] all_dataset_labels = [] status_msg = f"Processing {len(selected_keys)} dataset(s)...\n" # Collect vectors from all requested datasets for key in selected_keys: vector_store = rag_systems[key]['vector_store'] meta = rag_systems[key]['metadata'] dataset_label = f"{meta['dataset_version']} - {meta['state_filter']}" # Limit vectors for performance num_vectors = min(sample_size, vector_store.index.ntotal) status_msg += f"- {dataset_label}: {num_vectors} vectors\n" for i in range(num_vectors): all_vectors.append(vector_store.index.reconstruct(i)) doc_id = vector_store.index_to_docstore_id[i] document = vector_store.docstore.search(doc_id) all_metadata.append(document.metadata) all_contents.append(document.page_content) all_dataset_labels.append(dataset_label) if not all_vectors: return None, "โŒ No vectors to visualize." vectors = np.array(all_vectors) status_msg += f"\nTotal vectors: {len(all_vectors)}\n" # Reduce dimensionality status_msg += f"Reducing dimensionality to {dimensions}D using t-SNE..." tsne = TSNE(n_components=dimensions, random_state=42, perplexity=min(30, len(all_vectors)-1)) reduced_vectors = tsne.fit_transform(vectors) # Create color mapping based on dataset unique_labels = list(set(all_dataset_labels)) colors = [] color_palette = [ '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf' ] color_map = {label: color_palette[i % len(color_palette)] for i, label in enumerate(unique_labels)} colors = [color_map[label] for label in all_dataset_labels] # Create hover text hover_texts = [] for meta, content, label in zip(all_metadata, all_contents, all_dataset_labels): text = f"Dataset: {label}
" # Add key metadata fields key_fields = ['STATE_CD', 'PROVIDER_TYPE_DESC', 'FIRST_NAME', 'LAST_NAME', 'ORG_NAME'] for field in key_fields: if field in meta and meta[field]: text += f"{field}: {meta[field]}
" # Add a preview of the content content_preview = content[:200] + "..." if len(content) > 200 else content text += f"
Preview: {content_preview}" hover_texts.append(text) # Create visualization if dimensions == 2: fig = go.Figure() # Add a trace for each dataset for label in unique_labels: # Get indices for this dataset indices = [i for i, l in enumerate(all_dataset_labels) if l == label] # Add the scatter trace fig.add_trace(go.Scatter( x=reduced_vectors[indices, 0], y=reduced_vectors[indices, 1], mode='markers', marker=dict( size=6, color=color_map[label], opacity=0.7, line=dict(width=1, color='white') ), text=[hover_texts[i] for i in indices], hoverinfo='text', hoverlabel=dict(bgcolor="white", font_size=12), name=label )) fig.update_layout( title={ 'text': 'Medicare Provider Data - 2D Vector Space Visualization', 'font': {'size': 20} }, xaxis_title='Dimension 1', yaxis_title='Dimension 2', width=900, height=700, hovermode='closest', template='plotly_white', legend=dict( yanchor="top", y=0.99, xanchor="left", x=0.01, bgcolor="rgba(255,255,255,0.8)" ) ) else: # 3D fig = go.Figure() # Add a trace for each dataset for label in unique_labels: # Get indices for this dataset indices = [i for i, l in enumerate(all_dataset_labels) if l == label] # Add the scatter trace fig.add_trace(go.Scatter3d( x=reduced_vectors[indices, 0], y=reduced_vectors[indices, 1], z=reduced_vectors[indices, 2], mode='markers', marker=dict( size=5, color=color_map[label], opacity=0.7, line=dict(width=1, color='white') ), text=[hover_texts[i] for i in indices], hoverinfo='text', hoverlabel=dict(bgcolor="white", font_size=12), name=label )) fig.update_layout( title={ 'text': 'Medicare Provider Data - 3D Vector Space Visualization', 'font': {'size': 20} }, scene=dict( xaxis_title='Dimension 1', yaxis_title='Dimension 2', zaxis_title='Dimension 3', camera=dict( eye=dict(x=1.5, y=1.5, z=1.5) ) ), width=900, height=700, template='plotly_white', legend=dict( yanchor="top", y=0.99, xanchor="left", x=0.01, bgcolor="rgba(255,255,255,0.8)" ) ) success_msg = f"โœ… Successfully created {dimensions}D visualization with {len(all_vectors)} vectors from {len(selected_keys)} dataset(s)" return fig, success_msg except Exception as e: return None, f"โŒ Error creating visualization: {str(e)}" def create_dataset_statistics_plot(dataset_indices): """Create statistical plots for selected datasets.""" global rag_systems if not rag_systems: return None, "โŒ No datasets loaded." if not dataset_indices: return None, "โŒ Please select at least one dataset." # Parse indices and get dataset keys selected_keys = [] for selection in dataset_indices: try: index = int(selection.split(".")[0]) if 1 <= index <= len(rag_systems): key = list(rag_systems.keys())[index - 1] selected_keys.append(key) except: continue if not selected_keys: return None, "โŒ Could not parse selected datasets." try: # Collect statistics dataset_names = [] record_counts = [] chunk_counts = [] for key in selected_keys: meta = rag_systems[key]['metadata'] dataset_names.append(f"{meta['dataset_version']}
{meta['state_filter']}") record_counts.append(meta['record_count']) chunk_counts.append(meta['chunk_count']) # Create subplots from plotly.subplots import make_subplots fig = make_subplots( rows=1, cols=2, subplot_titles=('Records per Dataset', 'Chunks per Dataset'), specs=[[{'type': 'bar'}, {'type': 'bar'}]] ) # Add record count bars fig.add_trace( go.Bar( x=dataset_names, y=record_counts, name='Records', marker_color='lightblue', text=record_counts, textposition='auto', ), row=1, col=1 ) # Add chunk count bars fig.add_trace( go.Bar( x=dataset_names, y=chunk_counts, name='Chunks', marker_color='lightgreen', text=chunk_counts, textposition='auto', ), row=1, col=2 ) fig.update_layout( title={ 'text': 'Dataset Statistics Overview', 'font': {'size': 20} }, showlegend=False, height=500, template='plotly_white' ) fig.update_xaxes(tickangle=-45) return fig, f"โœ… Created statistics plot for {len(selected_keys)} dataset(s)" except Exception as e: return None, f"โŒ Error creating statistics plot: {str(e)}" def inspect_dataset_gradio(num_samples): """Display sample documents from the current dataset - Gradio version.""" global rag_systems, current_dataset_key if not current_dataset_key or current_dataset_key not in rag_systems: return "โŒ No dataset selected. Please load a dataset first." # Get the dataset system = rag_systems[current_dataset_key] vector_store = system['vector_store'] meta = system['metadata'] inspection_result = f"# Dataset Inspection\n\n" inspection_result += f"**Dataset:** {meta['dataset_version']} - {meta['state_filter']}\n" inspection_result += f"**Total documents:** {vector_store.index.ntotal}\n" inspection_result += f"**Showing:** {min(num_samples, vector_store.index.ntotal)} sample documents\n\n" inspection_result += "---\n\n" for i in range(min(num_samples, vector_store.index.ntotal)): try: doc_id = vector_store.index_to_docstore_id[i] document = vector_store.docstore.search(doc_id) inspection_result += f"### Document {i+1}\n\n" inspection_result += "**Metadata:**\n" # Show key metadata fields key_fields = ['PROVIDER_TYPE_DESC', 'STATE_CD', 'FIRST_NAME', 'LAST_NAME', 'ORG_NAME', 'NPI', 'ENRLMT_ID'] for field in key_fields: if field in document.metadata and document.metadata[field]: inspection_result += f"- **{field}:** {document.metadata[field]}\n" # Show content preview content_preview = document.page_content[:500] + "..." if len(document.page_content) > 500 else document.page_content inspection_result += f"\n**Content Preview:**\n```\n{content_preview}\n```\n\n" inspection_result += "---\n\n" except Exception as e: inspection_result += f"Error retrieving document {i}: {str(e)}\n\n" return inspection_result def create_gradio_interface(): """Create the main Gradio interface.""" with gr.Blocks(theme=theme, title="Medicare Provider Data Analysis System") as app: # Header gr.Markdown( """ # ๐Ÿฅ Medicare Provider Data Analysis System This system allows you to load, query, and analyze Medicare provider data using advanced RAG (Retrieval-Augmented Generation) technology. --- """ ) # Main tabs with gr.Tabs() as tabs: # Tab 1: Dataset Management with gr.Tab("๐Ÿ“Š Dataset Management"): with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Load New Dataset") version_dropdown = gr.Dropdown( choices=list(DATASET_VERSIONS.keys()), label="Select Quarter/Year", value="Q1 2025" ) state_dropdown = gr.Dropdown( choices=format_state_options(), label="Select State", value="" ) max_records_slider = gr.Slider( minimum=100, maximum=5000, value=1000, step=100, label="Maximum Records" ) use_sample_checkbox = gr.Checkbox( label="Load sample only (100 records)", value=True ) load_button = gr.Button("๐Ÿ”„ Load Dataset", variant="primary") load_output = gr.Textbox(label="Loading Status", lines=3) with gr.Column(scale=1): gr.Markdown("### Manage Loaded Datasets") dataset_summary = gr.Markdown(get_dataset_summary(rag_systems)) with gr.Row(): dataset_selector = gr.Dropdown( choices=get_dataset_choices(), label="Select Dataset", interactive=True ) with gr.Row(): switch_button = gr.Button("โ†”๏ธ Switch Dataset") remove_button = gr.Button("๐Ÿ—‘๏ธ Remove Dataset") clear_all_button = gr.Button("๐Ÿงน Clear All", variant="stop") manage_output = gr.Textbox(label="Status", lines=2) # Wire up dataset management events def update_dataset_selector(): return gr.update(choices=get_dataset_choices()) load_button.click( fn=load_dataset_gradio, inputs=[version_dropdown, state_dropdown, max_records_slider, use_sample_checkbox], outputs=[load_output, dataset_summary] ).then( fn=update_dataset_selector, outputs=dataset_selector ) switch_button.click( fn=switch_dataset_gradio, inputs=dataset_selector, outputs=[manage_output, dataset_summary] ) remove_button.click( fn=remove_dataset_gradio, inputs=dataset_selector, outputs=[manage_output, dataset_summary] ).then( fn=update_dataset_selector, outputs=dataset_selector ) clear_all_button.click( fn=clear_all_datasets_gradio, outputs=[manage_output, dataset_summary] ).then( fn=update_dataset_selector, outputs=dataset_selector ) # Tab 2: Query Interface with gr.Tab("๐Ÿ’ฌ Query & Chat"): gr.Markdown("### Ask Questions About Your Data") current_dataset_info = gr.Markdown(get_current_dataset_info()) # Create a timer to update current dataset info timer = gr.Timer(value=2) timer.tick(fn=get_current_dataset_info, outputs=current_dataset_info) with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot( label="Conversation", height=500, show_copy_button=True ) with gr.Row(): question_input = gr.Textbox( label="Your Question", placeholder="Ask about provider types, locations, statistics, etc.", lines=2, scale=4 ) with gr.Column(scale=1): ask_button = gr.Button("๐Ÿ“ค Ask Current Dataset", variant="primary") global_ask_button = gr.Button("๐ŸŒ Ask All Datasets") clear_chat_button = gr.Button("๐Ÿ—‘๏ธ Clear Chat") with gr.Column(scale=1): gr.Markdown("### Quick Actions") analyze_providers_button = gr.Button("๐Ÿ“Š Analyze Provider Types") gr.Markdown("### Example Questions") example_questions = [ "What are the most common provider types?", "How many providers are in this dataset?", "Show me all psychiatrists in the data", "What types of medical facilities are included?", "Compare provider counts across different quarters" ] for eq in example_questions: gr.Button(eq, size="sm").click( lambda q=eq: (q, gr.update()), outputs=[question_input, chatbot] ) # Wire up query events question_input.submit( fn=ask_question_gradio, inputs=[question_input, chatbot], outputs=[question_input, chatbot] ) ask_button.click( fn=ask_question_gradio, inputs=[question_input, chatbot], outputs=[question_input, chatbot] ) global_ask_button.click( fn=ask_global_question_gradio, inputs=[question_input, chatbot], outputs=[question_input, chatbot] ) clear_chat_button.click( fn=clear_chat_history, outputs=chatbot ) analyze_providers_button.click( fn=lambda: ("", [( "Analyze provider types in the current dataset", analyze_provider_types_gradio() )]), outputs=[question_input, chatbot] ) # Tab 3: Comparison & Analysis with gr.Tab("๐Ÿ” Compare Datasets"): gr.Markdown("### Compare Multiple Datasets") with gr.Row(): compare_dataset_selector = gr.CheckboxGroup( choices=get_dataset_choices(), label="Select Datasets to Compare (choose 2 or more)", value=[] ) compare_question = gr.Textbox( label="Comparison Question", placeholder="Enter a question to ask all selected datasets", lines=2 ) compare_button = gr.Button("๐Ÿ”„ Compare Datasets", variant="primary") comparison_output = gr.Markdown(label="Comparison Results") # Update checkbox choices when datasets change def update_compare_selector(): return gr.update(choices=get_dataset_choices()) timer.tick(fn=update_compare_selector, outputs=compare_dataset_selector) compare_button.click( fn=compare_datasets_gradio, inputs=[compare_question, compare_dataset_selector], outputs=comparison_output ) # Tab 4: Visualization with gr.Tab("๐Ÿ“ˆ Visualizations"): gr.Markdown("### Dataset Visualizations") with gr.Row(): with gr.Column(): viz_dataset_selector = gr.CheckboxGroup( choices=get_dataset_choices(), label="Select Datasets to Visualize", value=[] ) viz_dimension = gr.Radio( choices=[2, 3], value=2, label="Visualization Dimensions" ) viz_sample_size = gr.Slider( minimum=100, maximum=2000, value=500, step=100, label="Sample Size (per dataset)" ) create_viz_button = gr.Button("๐ŸŽจ Create Visualization", variant="primary") stats_button = gr.Button("๐Ÿ“Š Show Statistics") viz_status = gr.Textbox(label="Status", lines=2) with gr.Row(): viz_plot = gr.Plot(label="Vector Space Visualization") stats_plot = gr.Plot(label="Dataset Statistics") # Update visualization selector def update_viz_selector(): return gr.update(choices=get_dataset_choices()) timer.tick(fn=update_viz_selector, outputs=viz_dataset_selector) create_viz_button.click( fn=visualize_datasets_gradio, inputs=[viz_dataset_selector, viz_dimension, viz_sample_size], outputs=[viz_plot, viz_status] ) stats_button.click( fn=create_dataset_statistics_plot, inputs=[viz_dataset_selector], outputs=[stats_plot, viz_status] ) # Tab 5: Dataset Inspector with gr.Tab("๐Ÿ”Ž Dataset Inspector"): gr.Markdown("### Inspect Dataset Contents") inspect_current_info = gr.Markdown(get_current_dataset_info()) timer.tick(fn=get_current_dataset_info, outputs=inspect_current_info) num_samples_slider = gr.Slider( minimum=1, maximum=20, value=5, step=1, label="Number of Sample Documents" ) inspect_button = gr.Button("๐Ÿ” Inspect Current Dataset", variant="primary") inspection_output = gr.Markdown(label="Dataset Inspection Results") inspect_button.click( fn=inspect_dataset_gradio, inputs=num_samples_slider, outputs=inspection_output ) # Tab 6: Settings & Help with gr.Tab("โš™๏ธ Settings & Help"): gr.Markdown( """ ### System Information **Model:** GPT-4 Mini **Embedding Model:** OpenAI Embeddings **Vector Store:** FAISS ### API Configuration This system uses the CMS.gov Data API to fetch Medicare provider information. ### Tips for Best Results 1. **Loading Data**: Start with sample data (100 records) to test queries quickly 2. **State Selection**: Load specific states for focused analysis 3. **Querying**: Be specific in your questions for better results 4. **Comparisons**: Load multiple quarters/states to analyze trends ### Common Use Cases - **Provider Analysis**: Find specific types of healthcare providers - **Geographic Distribution**: Analyze providers by state - **Temporal Trends**: Compare data across different quarters - **Provider Types**: Understand the distribution of specialties ### Troubleshooting - **No API Key**: Ensure OPENAI_API_KEY is set in your environment - **Loading Errors**: Check your internet connection and API limits - **Query Errors**: Try rephrasing your question or check if data is loaded """ ) with gr.Row(): gr.Markdown("### Current Configuration") config_info = gr.JSON( value={ "api_key_set": bool(os.getenv('OPENAI_API_KEY')), "default_model": DEFAULT_MODEL, "api_base_url": API_BASE_URL, "datasets_loaded": len(rag_systems) }, label="System Configuration" ) # Footer gr.Markdown( """ ---
Medicare Provider Data Analysis System | Powered by LangChain & OpenAI
""" ) return app # Main execution if __name__ == "__main__": # Create and launch the app app = create_gradio_interface() # Launch with appropriate settings app.launch( server_name="0.0.0.0", # Allow external connections server_port=7860, # Default Gradio port )