import streamlit as st import pandas as pd import torch import re from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from peft import PeftModel from text_processing import TextProcessor import gc from pathlib import Path # Configure page st.set_page_config( page_title="Biomedical Papers Analysis", page_icon="🔬", layout="wide" ) # Initialize session state if 'processed_data' not in st.session_state: st.session_state.processed_data = None if 'summaries' not in st.session_state: st.session_state.summaries = None if 'text_processor' not in st.session_state: st.session_state.text_processor = None if 'processing_started' not in st.session_state: st.session_state.processing_started = False if 'focused_summary_generated' not in st.session_state: st.session_state.focused_summary_generated = False def load_model(model_type): """Load appropriate model based on type with proper memory management""" try: # Clear any existing cached data gc.collect() torch.cuda.empty_cache() device = "cpu" # Force CPU usage if model_type == "summarize": # Load the new fine-tuned model directly model = AutoModelForSeq2SeqLM.from_pretrained( "pendar02/bart-large-pubmedd", cache_dir="./models", torch_dtype=torch.float32 ).to(device) tokenizer = AutoTokenizer.from_pretrained( "pendar02/bart-large-pubmedd", cache_dir="./models" ) else: # question_focused base_model = AutoModelForSeq2SeqLM.from_pretrained( "GanjinZero/biobart-base", cache_dir="./models", torch_dtype=torch.float32 ).to(device) model = PeftModel.from_pretrained( base_model, "pendar02/biobart-finetune", is_trainable=False ).to(device) tokenizer = AutoTokenizer.from_pretrained( "GanjinZero/biobart-base", cache_dir="./models" ) model.eval() return model, tokenizer except Exception as e: st.error(f"Error loading model: {str(e)}") raise def cleanup_model(model, tokenizer): """Properly cleanup model resources""" try: del model del tokenizer torch.cuda.empty_cache() gc.collect() except Exception: pass @st.cache_data def process_excel(uploaded_file): """Process uploaded Excel file""" try: df = pd.read_excel(uploaded_file) required_columns = ['Abstract', 'Article Title', 'Authors', 'Source Title', 'Publication Year', 'DOI', 'Times Cited, All Databases'] # Check required columns missing_columns = [col for col in required_columns if col not in df.columns] if missing_columns: st.error(f"Missing required columns: {', '.join(missing_columns)}") return None return df[required_columns] except Exception as e: st.error(f"Error processing file: {str(e)}") return None def preprocess_text(text): """Preprocess text to add appropriate formatting before summarization""" if not isinstance(text, str) or not text.strip(): return text # Split text into sentences (basic implementation) sentences = [s.strip() for s in text.replace('. ', '.\n').split('\n')] # Remove empty sentences sentences = [s for s in sentences if s] # Join with proper line breaks formatted_text = '\n'.join(sentences) return formatted_text def post_process_summary(summary): """Clean up and improve summary coherence.""" if not summary: return summary # Split into sentences sentences = [s.strip() for s in summary.split('.')] sentences = [s for s in sentences if s] # Remove empty sentences # Correct common issues processed_sentences = [] for sentence in sentences: # Remove redundant phrases sentence = re.sub(r"\b(and and|appointment and appointment)\b", "and", sentence) # Ensure first letter capitalization sentence = sentence.capitalize() # Avoid duplicates if sentence not in processed_sentences: processed_sentences.append(sentence) # Join sentences with proper punctuation cleaned_summary = '. '.join(processed_sentences) return cleaned_summary if cleaned_summary.endswith('.') else cleaned_summary + '.' def improve_summary_generation(text, model, tokenizer): """Generate improved summary with better prompt and validation.""" if not isinstance(text, str) or not text.strip(): return "No abstract available to summarize." # Add a structured prompt for summarization formatted_text = ( "Summarize this biomedical research abstract into the following structure:\n" "1. Background and Objectives\n" "2. Methods\n" "3. Key Findings (include any percentages or numbers)\n" "4. Conclusions\n" f"Abstract:\n{text.strip()}" ) # Prepare input tokens inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True) inputs = {k: v.to(model.device) for k, v in inputs.items()} # Generate summary with adjusted parameters try: with torch.no_grad(): summary_ids = model.generate( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=300, # Increased for more detailed summaries min_length=100, # Ensure summaries are not too short num_beams=5, length_penalty=1.5, no_repeat_ngram_size=3, temperature=0.7, repetition_penalty=1.3, ) summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) except Exception as e: return f"Error in generation: {str(e)}" # Post-process the summary return post_process_summary(summary) # Validate the summary if not validate_summary(processed_summary, text): # Retry with alternate generation parameters with torch.no_grad(): summary_ids = model.generate( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=250, min_length=50, num_beams=4, length_penalty=2.0, no_repeat_ngram_size=4, temperature=0.8, repetition_penalty=1.5, ) summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) processed_summary = post_process_summary(summary) return processed_summary def validate_summary(summary, original_text): """Validate summary content against original text.""" # Check for common validation points if not summary or len(summary.split()) < 20: return False # Too short if len(summary.split()) > len(original_text.split()) * 0.8: return False # Too long # Ensure structure is maintained (e.g., headings are present) required_sections = ["background and objectives", "methods", "key findings", "conclusions"] if not all(section.lower() in summary.lower() for section in required_sections): return False # Ensure no repetitive sentences sentences = summary.split('.') if len(sentences) != len(set(sentences)): return False return True def generate_focused_summary(question, abstracts, model, tokenizer): """Generate focused summary based on question""" # Preprocess each abstract formatted_abstracts = [preprocess_text(abstract) for abstract in abstracts] combined_input = f"Question: {question} Abstracts: " + " [SEP] ".join(formatted_abstracts) inputs = tokenizer(combined_input, return_tensors="pt", max_length=1024, truncation=True) inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): summary_ids = model.generate( **{ "input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"], "max_length": 200, "min_length": 50, "num_beams": 4, "length_penalty": 2.0, "early_stopping": True } ) return tokenizer.decode(summary_ids[0], skip_special_tokens=True) def create_filter_controls(df, sort_column): """Create appropriate filter controls based on the selected column""" filtered_df = df.copy() if sort_column == 'Publication Year': # Year range slider year_min = int(df['Publication Year'].min()) year_max = int(df['Publication Year'].max()) col1, col2 = st.columns(2) with col1: start_year = st.number_input('From Year', min_value=year_min, max_value=year_max, value=year_min) with col2: end_year = st.number_input('To Year', min_value=year_min, max_value=year_max, value=year_max) filtered_df = filtered_df[ (filtered_df['Publication Year'] >= start_year) & (filtered_df['Publication Year'] <= end_year) ] elif sort_column == 'Authors': # Multi-select for authors unique_authors = sorted(set( author.strip() for authors in df['Authors'].dropna() for author in authors.split(';') )) selected_authors = st.multiselect( 'Select Authors', unique_authors ) if selected_authors: filtered_df = filtered_df[ filtered_df['Authors'].apply( lambda x: any(author in str(x) for author in selected_authors) ) ] elif sort_column == 'Source Title': # Multi-select for source titles unique_sources = sorted(df['Source Title'].unique()) selected_sources = st.multiselect( 'Select Sources', unique_sources ) if selected_sources: filtered_df = filtered_df[filtered_df['Source Title'].isin(selected_sources)] elif sort_column == 'Article Title': # Only alphabetical sorting, no filtering pass elif sort_column == 'Times Cited': # Cited count range slider cited_min = int(df['Times Cited'].min()) cited_max = int(df['Times Cited'].max()) col1, col2 = st.columns(2) with col1: start_cited = st.number_input('From Cited Count', min_value=cited_min, max_value=cited_max, value=cited_min) with col2: end_cited = st.number_input('To Cited Count', min_value=cited_min, max_value=cited_max, value=cited_max) filtered_df = filtered_df[ (filtered_df['Times Cited'] >= start_cited) & (filtered_df['Times Cited'] <= end_cited) ] return filtered_df def main(): st.title("🔬 Biomedical Papers Analysis") # File upload section uploaded_file = st.file_uploader( "Upload Excel file containing papers", type=['xlsx', 'xls'], help="File must contain: Abstract, Article Title, Authors, Source Title, Publication Year, DOI" ) # Question input - moved up but hidden initially question_container = st.empty() question = "" if uploaded_file is not None: # Process Excel file if st.session_state.processed_data is None: with st.spinner("Processing file..."): df = process_excel(uploaded_file) if df is not None: st.session_state.processed_data = df.dropna(subset=["Abstract"]) if st.session_state.processed_data is not None: df = st.session_state.processed_data st.write(f"📊 Loaded {len(df)} papers with abstracts") # Get question before processing with question_container: question = st.text_input( "Enter your research question (optional):", help="If provided, a question-focused summary will be generated after individual summaries" ) # Single button for both processes if not st.session_state.get('processing_started', False): if st.button("Start Analysis"): st.session_state.processing_started = True # Show processing status and results if st.session_state.get('processing_started', False): # Individual Summaries Section st.header("📝 Individual Paper Summaries") # Generate summaries if not already done if st.session_state.summaries is None: try: with st.spinner("Generating individual paper summaries..."): model, tokenizer = load_model("summarize") summaries = [] progress_bar = st.progress(0) for idx, abstract in enumerate(df['Abstract']): summary = improve_summary_generation(abstract, model, tokenizer) summaries.append(summary) progress_bar.progress((idx + 1) / len(df)) st.session_state.summaries = summaries cleanup_model(model, tokenizer) progress_bar.empty() except Exception as e: st.error(f"Error generating summaries: {str(e)}") st.session_state.processing_started = False # Display summaries with improved sorting and filtering if st.session_state.summaries is not None: col1, col2 = st.columns(2) with col1: sort_options = ['Article Title', 'Authors', 'Publication Year', 'Source Title', 'Times Cited'] sort_column = st.selectbox("Sort/Filter by:", sort_options) with col2: # Only show A-Z/Z-A option for Article Title if sort_column == 'Article Title': ascending = st.radio( "Sort order", ["A to Z", "Z to A"], horizontal=True ) == "A to Z" elif sort_column == 'Times Cited': ascending = st.radio( "Sort order", ["Most cited", "Least cited"], horizontal=True ) == "Least cited" else: ascending = True # Default for other columns # Create display dataframe display_df = df.copy() display_df['Summary'] = st.session_state.summaries display_df['Publication Year'] = display_df['Publication Year'].astype(int) display_df.rename(columns={'Times Cited, All Databases': 'Times Cited'}, inplace=True) display_df['Times Cited'] = display_df['Times Cited'].fillna(0).astype(int) # Apply filters filtered_df = create_filter_controls(display_df, sort_column) if sort_column == 'Article Title': # Sort alphabetically sorted_df = filtered_df.sort_values(by=sort_column, ascending=ascending) else: # Keep original order for other columns after filtering # Keep original order for other columns after filtering sorted_df = filtered_df # Show number of filtered results if len(sorted_df) != len(display_df): st.write(f"Showing {len(sorted_df)} of {len(display_df)} papers") # Apply custom styling st.markdown(""" """, unsafe_allow_html=True) # Display papers using the filtered and sorted dataframe for _, row in sorted_df.iterrows(): paper_info_cols = st.columns([1, 1]) with paper_info_cols[0]: # PAPER column st.markdown('
PAPER
', unsafe_allow_html=True) st.markdown(f"""
{row['Article Title']}
Authors: {row['Authors']}
Source: {row['Source Title']}
Publication Year: {row['Publication Year']}
Times Cited: {row['Times Cited']}
DOI: {row['DOI'] if pd.notna(row['DOI']) else 'None'}
""", unsafe_allow_html=True) with paper_info_cols[1]: # SUMMARY column st.markdown('
SUMMARY
', unsafe_allow_html=True) st.markdown(f"""
{row['Summary']}
""", unsafe_allow_html=True) # Add spacing between papers st.markdown("
", unsafe_allow_html=True) # Question-focused Summary Section (only if question provided) if question.strip(): st.header("❓ Question-focused Summary") if not st.session_state.get('focused_summary_generated', False): try: with st.spinner("Analyzing relevant papers..."): # Initialize text processor if needed if st.session_state.text_processor is None: st.session_state.text_processor = TextProcessor() # Find relevant abstracts results = st.session_state.text_processor.find_most_relevant_abstracts( question, df['Abstract'].tolist(), top_k=5 ) # Load question-focused model model, tokenizer = load_model("question_focused") # Generate focused summary relevant_abstracts = df['Abstract'].iloc[results['top_indices']].tolist() focused_summary = generate_focused_summary( question, relevant_abstracts, model, tokenizer ) # Store results st.session_state.focused_summary = focused_summary st.session_state.relevant_papers = df.iloc[results['top_indices']] st.session_state.relevance_scores = results['scores'] st.session_state.focused_summary_generated = True # Cleanup second model cleanup_model(model, tokenizer) except Exception as e: st.error(f"Error generating focused summary: {str(e)}") # Display focused summary results if st.session_state.get('focused_summary_generated', False): st.subheader("Summary") st.write(st.session_state.focused_summary) st.subheader("Most Relevant Papers") relevant_papers = st.session_state.relevant_papers[ ['Article Title', 'Authors', 'Publication Year', 'DOI'] ].copy() relevant_papers['Relevance Score'] = st.session_state.relevance_scores relevant_papers['Publication Year'] = relevant_papers['Publication Year'].astype(int) st.dataframe(relevant_papers, hide_index=True) if __name__ == "__main__": main()