alexkueck commited on
Commit
90366fa
·
1 Parent(s): 5973677

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +27 -0
utils.py CHANGED
@@ -53,6 +53,33 @@ def generate_prompt_with_history(text, history, tokenizer, max_length=2048):
53
  else:
54
  return None
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
 
58
  def load_tokenizer_and_model(base_model, load_8bit=False):
 
53
  else:
54
  return None
55
 
56
+ # Few-Shot Training wird durch ein spezielles Modell-Laden gemacht (setFit)
57
+ def load_tokenizer_and_model_setFit(base_model, load_8bit=False):
58
+ if torch.cuda.is_available():
59
+ device = "cuda"
60
+ else:
61
+ device = "cpu"
62
+
63
+
64
+ tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast = True, use_auth_token=True, bos_token='<|startoftext|>', eos_token='<|endoftext|>', pad_token='<|pad|>')
65
+ if device == "cuda":
66
+ model = SetFitModel.from_pretrained(
67
+ base_model,
68
+ load_in_8bit=load_8bit,
69
+ torch_dtype=torch.float16,
70
+ device_map="auto",
71
+ use_auth_token=True,
72
+ )
73
+ else:
74
+ model = SetFitModel.from_pretrained(
75
+ base_model, device_map={"": device}, low_cpu_mem_usage=True
76
+ )
77
+
78
+ #if not load_8bit:
79
+ #model.half() # seems to fix bugs for some users.
80
+
81
+ model.eval()
82
+ return tokenizer,model, device
83
 
84
 
85
  def load_tokenizer_and_model(base_model, load_8bit=False):