davebulaval commited on
Commit
1982c24
Β·
verified Β·
1 Parent(s): 46c63a3

Add device and remove dead verbose argument.

Browse files
Files changed (1) hide show
  1. meaningbert.py +4 -4
meaningbert.py CHANGED
@@ -66,7 +66,7 @@ MeaningBERT metric for assessing meaning preservation between sentences.
66
  Args:
67
  documents (list of str): Document sentences.
68
  simplifications (list of str): Simplification sentences (same number of element as documents).
69
- verbose (bool): Turn on intermediate status update.
70
 
71
  Returns:
72
  score: the meaning score between two sentences in alist format respecting the order of the documents and
@@ -77,7 +77,7 @@ Examples:
77
 
78
  >>> documents = ["hello there", "general kenobi"]
79
  >>> simplifications = ["hello there", "general kenobi"]
80
- >>> meaning_bert = evaluate.load("davebulaval/meaningbert")
81
  >>> results = meaning_bert.compute(documents=documents, simplifications=simplifications)
82
  """
83
 
@@ -112,7 +112,7 @@ class MeaningBERT(evaluate.Metric):
112
  self,
113
  documents: List,
114
  simplifications: List,
115
- verbose: bool = False,
116
  ) -> Dict:
117
  assert len(documents) == len(
118
  simplifications
@@ -126,7 +126,7 @@ class MeaningBERT(evaluate.Metric):
126
 
127
  # We load the MeaningBERT pretrained model
128
  scorer = AutoModelForSequenceClassification.from_pretrained(
129
- "davebulaval/MeaningBERT"
130
  )
131
  scorer.eval()
132
 
 
66
  Args:
67
  documents (list of str): Document sentences.
68
  simplifications (list of str): Simplification sentences (same number of element as documents).
69
+ device (str): Device to use for model inference. By default, set to "cuda".
70
 
71
  Returns:
72
  score: the meaning score between two sentences in alist format respecting the order of the documents and
 
77
 
78
  >>> documents = ["hello there", "general kenobi"]
79
  >>> simplifications = ["hello there", "general kenobi"]
80
+ >>> meaning_bert = evaluate.load("davebulaval/meaningbert", device="cuda:0")
81
  >>> results = meaning_bert.compute(documents=documents, simplifications=simplifications)
82
  """
83
 
 
112
  self,
113
  documents: List,
114
  simplifications: List,
115
+ device: str = "cuda",
116
  ) -> Dict:
117
  assert len(documents) == len(
118
  simplifications
 
126
 
127
  # We load the MeaningBERT pretrained model
128
  scorer = AutoModelForSequenceClassification.from_pretrained(
129
+ "davebulaval/MeaningBERT", device_map=device
130
  )
131
  scorer.eval()
132