transZ commited on
Commit
03539db
·
1 Parent(s): d691a4b

Working version

Browse files
Files changed (1) hide show
  1. sbert_cosine.py +6 -4
sbert_cosine.py CHANGED
@@ -104,7 +104,7 @@ class sbert_cosine(evaluate.Metric):
104
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
105
  return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
106
 
107
- def batch_to_device(batch, target_device: device):
108
  """
109
  send a pytorch batch to a device (CPU/GPU)
110
  """
@@ -118,18 +118,20 @@ class sbert_cosine(evaluate.Metric):
118
  tokenizer = AutoTokenizer.from_pretrained(model_type)
119
  model = BertModel.from_pretrained(model_type)
120
  model = model.to(device)
121
- cosine = nn.CosineSimilarity()
122
 
123
  def calculate(x: str, y: str):
124
  encoded_input = tokenizer([x, y], padding=True, truncation=True, return_tensors='pt')
125
- encoded_input = batch_to_device(encode_input, device)
126
  model_output = model(**encoded_input)
127
  embeds = mean_pooling(model_output, encoded_input['attention_mask'])
128
  res = cosine(embeds[0, :], embeds[1, :]).item()
129
  return res
130
 
 
 
131
  with torch.no_grad():
132
- score = torch.mean([calculate(pred, ref) for pred, ref in zip(predictions, references)]).item()
133
 
134
  return {
135
  "score": score,
 
104
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
105
  return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
106
 
107
+ def batch_to_device(batch, target_device):
108
  """
109
  send a pytorch batch to a device (CPU/GPU)
110
  """
 
118
  tokenizer = AutoTokenizer.from_pretrained(model_type)
119
  model = BertModel.from_pretrained(model_type)
120
  model = model.to(device)
121
+ cosine = nn.CosineSimilarity(dim=0)
122
 
123
  def calculate(x: str, y: str):
124
  encoded_input = tokenizer([x, y], padding=True, truncation=True, return_tensors='pt')
125
+ encoded_input = batch_to_device(encoded_input, device)
126
  model_output = model(**encoded_input)
127
  embeds = mean_pooling(model_output, encoded_input['attention_mask'])
128
  res = cosine(embeds[0, :], embeds[1, :]).item()
129
  return res
130
 
131
+ avg = lambda x: sum(x) / len(x)
132
+
133
  with torch.no_grad():
134
+ score = avg([calculate(pred, ref) for pred, ref in zip(predictions, references)])
135
 
136
  return {
137
  "score": score,