Spaces:
Runtime error
Runtime error
| import traceback | |
| from queue import Queue | |
| from threading import Thread | |
| import collections.abc | |
| import torch | |
| from transformers import StoppingCriteria | |
| class StoppingCriteriaSub(StoppingCriteria): | |
| def __init__(self, stops=[], encounters=[]): | |
| super().__init__() | |
| assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match" | |
| self.encounters = encounters | |
| self.stops = [stop.to("cuda") for stop in stops] | |
| self.num_stops = [0] * len(stops) | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
| for stopi, stop in enumerate(self.stops): | |
| if torch.all((stop == input_ids[0][-len(stop):])).item(): | |
| self.num_stops[stopi] += 1 | |
| if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]: | |
| return True | |
| # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True) | |
| # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True) | |
| return False | |