Spaces:
Build error
Build error
conditional_gen adjustments
Browse files- src/Surveyor.py +16 -3
src/Surveyor.py
CHANGED
|
@@ -146,6 +146,16 @@ class Surveyor:
|
|
| 146 |
else:
|
| 147 |
self.ledtokenizer = AutoTokenizer.from_pretrained(ledmodel_name)
|
| 148 |
self.ledmodel = T5ForConditionalGeneration.from_pretrained(ledmodel_name).to(self.torch_device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
self.ledmodel.eval()
|
| 150 |
if not no_save_models:
|
| 151 |
self.ledmodel.save_pretrained(models_dir + "/ledmodel")
|
|
@@ -170,12 +180,15 @@ class Surveyor:
|
|
| 170 |
self.summ_model.eval()
|
| 171 |
self.model = Summarizer(custom_model=self.summ_model, custom_tokenizer=self.summ_tokenizer)
|
| 172 |
|
| 173 |
-
if '
|
| 174 |
self.ledtokenizer = LEDTokenizer.from_pretrained(ledmodel_name)
|
| 175 |
self.ledmodel = LEDForConditionalGeneration.from_pretrained(models_dir + "/ledmodel").to(self.torch_device)
|
| 176 |
-
|
| 177 |
-
self.ledtokenizer =
|
| 178 |
self.ledmodel = T5ForConditionalGeneration.from_pretrained(models_dir + "/ledmodel").to(self.torch_device)
|
|
|
|
|
|
|
|
|
|
| 179 |
self.ledmodel.eval()
|
| 180 |
|
| 181 |
self.embedder = SentenceTransformer(models_dir + "/embedder")
|
|
|
|
| 146 |
else:
|
| 147 |
self.ledtokenizer = AutoTokenizer.from_pretrained(ledmodel_name)
|
| 148 |
self.ledmodel = T5ForConditionalGeneration.from_pretrained(ledmodel_name).to(self.torch_device)
|
| 149 |
+
|
| 150 |
+
if 'led' in ledmodel_name:
|
| 151 |
+
self.ledtokenizer = LEDTokenizer.from_pretrained(ledmodel_name)
|
| 152 |
+
self.ledmodel = LEDForConditionalGeneration.from_pretrained(ledmodel_name).to(self.torch_device)
|
| 153 |
+
elif 't5' in ledmodel_name:
|
| 154 |
+
self.ledtokenizer = AutoTokenizer.from_pretrained(ledmodel_name)
|
| 155 |
+
self.ledmodel = T5ForConditionalGeneration.from_pretrained(ledmodel_name).to(self.torch_device)
|
| 156 |
+
elif 'bart' in ledmodel_name:
|
| 157 |
+
self.ledtokenizer = AutoTokenizer.from_pretrained(ledmodel_name)
|
| 158 |
+
self.ledmodel = BartForConditionalGeneration.from_pretrained(ledmodel_name).to(self.torch_device)
|
| 159 |
self.ledmodel.eval()
|
| 160 |
if not no_save_models:
|
| 161 |
self.ledmodel.save_pretrained(models_dir + "/ledmodel")
|
|
|
|
| 180 |
self.summ_model.eval()
|
| 181 |
self.model = Summarizer(custom_model=self.summ_model, custom_tokenizer=self.summ_tokenizer)
|
| 182 |
|
| 183 |
+
if 'led' in ledmodel_name:
|
| 184 |
self.ledtokenizer = LEDTokenizer.from_pretrained(ledmodel_name)
|
| 185 |
self.ledmodel = LEDForConditionalGeneration.from_pretrained(models_dir + "/ledmodel").to(self.torch_device)
|
| 186 |
+
elif 't5' in ledmodel_name:
|
| 187 |
+
self.ledtokenizer = AutoTokenizer.from_pretrained(ledmodel_name)
|
| 188 |
self.ledmodel = T5ForConditionalGeneration.from_pretrained(models_dir + "/ledmodel").to(self.torch_device)
|
| 189 |
+
elif 'bart' in ledmodel_name:
|
| 190 |
+
self.ledtokenizer = AutoTokenizer.from_pretrained(ledmodel_name)
|
| 191 |
+
self.ledmodel = BartForConditionalGeneration.from_pretrained(models_dir + "/ledmodel").to(self.torch_device)
|
| 192 |
self.ledmodel.eval()
|
| 193 |
|
| 194 |
self.embedder = SentenceTransformer(models_dir + "/embedder")
|