|
|
|
import torch |
|
from transformers import AutoProcessor, AutoModel, TextIteratorStreamer |
|
|
|
class FullSequenceStreamer(TextIteratorStreamer): |
|
def __init__(self, tokenizer, **kwargs): |
|
super().__init__(tokenizer, **kwargs) |
|
|
|
def put(self, value, stream_end=False): |
|
|
|
decoded = self.tokenizer.batch_decode(value, **self.decode_kwargs) |
|
self.text_queue.put(decoded) |
|
if stream_end: |
|
self.text_queue.put(self.stop_signal, timeout=self.timeout) |
|
|
|
def end(self): |
|
self.text_queue.put(self.stop_signal, timeout=self.timeout) |
|
|
|
def get_model(device): |
|
|
|
model_name = "rp-yu/Dimple-7B" |
|
processor = AutoProcessor.from_pretrained( |
|
model_name, |
|
trust_remote_code=True |
|
) |
|
model = AutoModel.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.bfloat16, |
|
trust_remote_code=True, |
|
) |
|
model = model.eval() |
|
model = model.to(device) |
|
|
|
return model, processor |
|
|
|
|