import gradio as gr from transformers import AutoModelForSequenceClassification, AutoTokenizer import torch # Load model and tokenizer model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name) model.eval() # Set model to evaluation mode # Function to get relevance score and relevant excerpt based on attention scores def get_relevance_score_and_excerpt(query, paragraph): if not query.strip() or not paragraph.strip(): return "Please provide both a query and a document paragraph.", "" # Tokenize the input inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True, return_attention_mask=True) with torch.no_grad(): output = model(**inputs, output_attentions=True) # Get attention scores # Extract logits and calculate relevance score logit = output.logits.squeeze().item() relevance_score = torch.sigmoid(torch.tensor(logit)).item() # Extract attention scores (use the last attention layer) attention = output.attentions[-1] # Shape: (batch_size, num_heads, seq_len, seq_len) # Average across attention heads to get token importance attention_scores = attention.mean(dim=1).squeeze(0) # Shape: (seq_len, seq_len) # Focus on the paragraph part only (ignore query tokens) input_tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze()) query_length = len(tokenizer.tokenize(query)) # Extract attention for the paragraph tokens only paragraph_tokens = input_tokens[query_length + 2 : -1] # Skip query and special tokens like [SEP] paragraph_attention = attention_scores[query_length + 2 : -1, query_length + 2 : -1].mean(dim=0) # Get the top tokens with highest attention scores top_token_indices = torch.argsort(paragraph_attention, descending=True)[:5] # Top 5 tokens highlighted_tokens = [paragraph_tokens[i] for i in top_token_indices] # Reconstruct the excerpt from top attention tokens excerpt = tokenizer.convert_tokens_to_string(highlighted_tokens) return round(relevance_score, 4), excerpt # Define Gradio interface interface = gr.Interface( fn=get_relevance_score_and_excerpt, inputs=[ gr.Textbox(label="Query", placeholder="Enter your search query..."), gr.Textbox(label="Document Paragraph", placeholder="Enter a paragraph to match...") ], outputs=[ gr.Textbox(label="Relevance Score"), gr.Textbox(label="Most Relevant Excerpt") ], title="Cross-Encoder Relevance Scoring with Attention-Based Excerpt Extraction", description="Enter a query and a document paragraph to get a relevance score and a relevant excerpt using attention scores.", allow_flagging="never", live=True ) if __name__ == "__main__": interface.launch()