wilwork commited on
Commit
6a947e6
·
verified ·
1 Parent(s): 6e70d21

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -51
app.py CHANGED
@@ -14,54 +14,4 @@ def get_relevance_score_and_excerpt(query, paragraph):
14
  return "Please provide both a query and a document paragraph.", ""
15
 
16
  # Tokenize the input
17
- inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True, return_attention_mask=True)
18
-
19
- with torch.no_grad():
20
- output = model(**inputs, output_attentions=True) # Get attention scores
21
-
22
- # Extract logits and calculate relevance score
23
- logit = output.logits.squeeze().item()
24
- relevance_score = torch.sigmoid(torch.tensor(logit)).item()
25
-
26
- # Extract attention scores (use the last attention layer)
27
- attention = output.attentions[-1] # Shape: (batch_size, num_heads, seq_len, seq_len)
28
-
29
- # Average across attention heads to get token importance
30
- attention_scores = attention.mean(dim=1).squeeze(0) # Shape: (seq_len, seq_len)
31
-
32
- # Focus on the paragraph part only (ignore query tokens)
33
- input_tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze())
34
- query_length = len(tokenizer.tokenize(query))
35
-
36
- # Extract attention for the paragraph tokens only
37
- paragraph_tokens = input_tokens[query_length + 2 : -1] # Skip query and special tokens like [SEP]
38
- paragraph_attention = attention_scores[query_length + 2 : -1, query_length + 2 : -1].mean(dim=0)
39
-
40
- # Get the top tokens with highest attention scores
41
- top_token_indices = torch.argsort(paragraph_attention, descending=True)[:5] # Top 5 tokens
42
- highlighted_tokens = [paragraph_tokens[i] for i in top_token_indices]
43
-
44
- # Reconstruct the excerpt from top attention tokens
45
- excerpt = tokenizer.convert_tokens_to_string(highlighted_tokens)
46
-
47
- return round(relevance_score, 4), excerpt
48
-
49
- # Define Gradio interface
50
- interface = gr.Interface(
51
- fn=get_relevance_score_and_excerpt,
52
- inputs=[
53
- gr.Textbox(label="Query", placeholder="Enter your search query..."),
54
- gr.Textbox(label="Document Paragraph", placeholder="Enter a paragraph to match...")
55
- ],
56
- outputs=[
57
- gr.Textbox(label="Relevance Score"),
58
- gr.Textbox(label="Most Relevant Excerpt")
59
- ],
60
- title="Cross-Encoder Relevance Scoring with Attention-Based Excerpt Extraction",
61
- description="Enter a query and a document paragraph to get a relevance score and a relevant excerpt using attention scores.",
62
- allow_flagging="never",
63
- live=True
64
- )
65
-
66
- if __name__ == "__main__":
67
- interface.launch()
 
14
  return "Please provide both a query and a document paragraph.", ""
15
 
16
  # Tokenize the input
17
+ inputs = t