import gradio as gr from transformers import AutoModelForSequenceClassification, AutoTokenizer import torch # Load model and tokenizer model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name) model.eval() # Set the model to evaluation mode # Function to compute relevance score and dynamically adjust threshold def get_relevance_score_and_excerpt(query, paragraph): if not query.strip() or not paragraph.strip(): return "Please provide both a query and a document paragraph.", "" # Tokenize the input inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): output = model(**inputs, output_attentions=True) # Get attention scores # Extract logits and calculate base relevance score logit = output.logits.squeeze().item() base_relevance_score = torch.sigmoid(torch.tensor(logit)).item() # Dynamically adjust the attention threshold based on relevance score dynamic_threshold = max(0.02, base_relevance_score * 0.1) # Example formula # Extract attention scores (last layer) attention = output.attentions[-1] # Shape: (batch_size, num_heads, seq_len, seq_len) attention_scores = attention.mean(dim=1).mean(dim=0) # Average over heads and batch # Tokenize query and paragraph separately query_tokens = tokenizer.tokenize(query) paragraph_tokens = tokenizer.tokenize(paragraph) query_len = len(query_tokens) + 2 # +2 for special tokens [CLS] and first [SEP] para_start_idx = query_len para_end_idx = len(inputs["input_ids"][0]) - 1 # Ignore final [SEP] token # Handle potential indexing issues if para_end_idx <= para_start_idx: return round(base_relevance_score, 4), "No relevant tokens extracted." # Extract paragraph attention scores and apply dynamic threshold para_attention_scores = attention_scores[para_start_idx:para_end_idx, para_start_idx:para_end_idx].mean(dim=0) if para_attention_scores.numel() == 0: return round(base_relevance_score, 4), "No relevant tokens extracted." # Get indices of relevant tokens above dynamic threshold relevant_indices = (para_attention_scores > dynamic_threshold).nonzero(as_tuple=True)[0].tolist() # Compute attention-weighted relevance score if relevant_indices: relevant_attention_values = para_attention_scores[relevant_indices] attention_weighted_score = relevant_attention_values.mean().item() * base_relevance_score else: attention_weighted_score = base_relevance_score # No relevant tokens found # Reconstruct paragraph with bolded relevant tokens highlighted_text = "" for idx, token in enumerate(paragraph_tokens): if idx in relevant_indices: highlighted_text += f"**{token}** " else: highlighted_text += f"{token} " # Convert tokens back to readable format highlighted_text = tokenizer.convert_tokens_to_string(highlighted_text.split()) return round(attention_weighted_score, 4), highlighted_text # Define Gradio interface interface = gr.Interface( fn=get_relevance_score_and_excerpt, inputs=[ gr.Textbox(label="Query", placeholder="Enter your search query..."), gr.Textbox(label="Document Paragraph", placeholder="Enter a paragraph to match...") ], outputs=[ gr.Textbox(label="Attention-Weighted Relevance Score"), gr.HTML(label="Highlighted Document Paragraph") ], title="Cross-Encoder with Dynamic Attention Threshold", description="Enter a query and document paragraph to get a relevance score with relevant tokens in bold.", allow_flagging="never", live=True ) if __name__ == "__main__": interface.launch()