import torch from transformers import AutoProcessor, AutoModel, TextIteratorStreamer class FullSequenceStreamer(TextIteratorStreamer): def __init__(self, tokenizer, **kwargs): super().__init__(tokenizer, **kwargs) def put(self, value, stream_end=False): # Assume full token_ids are passed in every time decoded = self.tokenizer.batch_decode(value, **self.decode_kwargs) self.text_queue.put(decoded) if stream_end: self.text_queue.put(self.stop_signal, timeout=self.timeout) def end(self): self.text_queue.put(self.stop_signal, timeout=self.timeout) def get_model(device): model_name = "rp-yu/Dimple-7B" processor = AutoProcessor.from_pretrained( model_name, trust_remote_code=True ) model = AutoModel.from_pretrained( model_name, torch_dtype=torch.bfloat16, trust_remote_code=True, ) model = model.eval() model = model.to(device) return model, processor