Delete webscout/Local
Browse files- webscout/Local/__init__.py +0 -10
- webscout/Local/__pycache__/__init__.cpython-311.pyc +0 -0
- webscout/Local/__pycache__/_version.cpython-311.pyc +0 -0
- webscout/Local/__pycache__/formats.cpython-311.pyc +0 -0
- webscout/Local/__pycache__/model.cpython-311.pyc +0 -0
- webscout/Local/__pycache__/samplers.cpython-311.pyc +0 -0
- webscout/Local/__pycache__/test.cpython-311.pyc +0 -0
- webscout/Local/__pycache__/thread.cpython-311.pyc +0 -0
- webscout/Local/__pycache__/utils.cpython-311.pyc +0 -0
- webscout/Local/_version.py +0 -3
- webscout/Local/formats.py +0 -535
- webscout/Local/model.py +0 -702
- webscout/Local/samplers.py +0 -161
- webscout/Local/thread.py +0 -690
- webscout/Local/utils.py +0 -185
webscout/Local/__init__.py
DELETED
@@ -1,10 +0,0 @@
|
|
1 |
-
# webscout\Local\__init__.py
|
2 |
-
from ._version import __version__, __llama_cpp_version__
|
3 |
-
|
4 |
-
|
5 |
-
from . import formats
|
6 |
-
from . import samplers
|
7 |
-
from . import utils
|
8 |
-
|
9 |
-
from .model import Model
|
10 |
-
from .thread import Thread
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
webscout/Local/__pycache__/__init__.cpython-311.pyc
DELETED
Binary file (505 Bytes)
|
|
webscout/Local/__pycache__/_version.cpython-311.pyc
DELETED
Binary file (258 Bytes)
|
|
webscout/Local/__pycache__/formats.cpython-311.pyc
DELETED
Binary file (11.2 kB)
|
|
webscout/Local/__pycache__/model.cpython-311.pyc
DELETED
Binary file (31.6 kB)
|
|
webscout/Local/__pycache__/samplers.cpython-311.pyc
DELETED
Binary file (4.23 kB)
|
|
webscout/Local/__pycache__/test.cpython-311.pyc
DELETED
Binary file (37.6 kB)
|
|
webscout/Local/__pycache__/thread.cpython-311.pyc
DELETED
Binary file (33.5 kB)
|
|
webscout/Local/__pycache__/utils.cpython-311.pyc
DELETED
Binary file (9.43 kB)
|
|
webscout/Local/_version.py
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
from llama_cpp import __version__ as __llama_cpp_version__
|
2 |
-
|
3 |
-
__version__ = '2.7'
|
|
|
|
|
|
|
|
webscout/Local/formats.py
DELETED
@@ -1,535 +0,0 @@
|
|
1 |
-
from ._version import __version__, __llama_cpp_version__
|
2 |
-
|
3 |
-
from typing import Callable, Union, Any
|
4 |
-
|
5 |
-
|
6 |
-
class AdvancedFormat:
|
7 |
-
|
8 |
-
def __init__(self, base_dict: dict[str, Union[str, list]]):
|
9 |
-
self._base_dict = base_dict
|
10 |
-
self.overrides = {}
|
11 |
-
|
12 |
-
def __getitem__(self, key: str) -> Any:
|
13 |
-
if key in self.overrides:
|
14 |
-
return str(self.overrides[key]())
|
15 |
-
else:
|
16 |
-
return self._base_dict[key]
|
17 |
-
|
18 |
-
def __repr__(self) -> str:
|
19 |
-
# NOTE: This method does not represent overrides
|
20 |
-
return repr(self._base_dict)
|
21 |
-
|
22 |
-
def keys(self):
|
23 |
-
return self._base_dict.keys()
|
24 |
-
|
25 |
-
def override(self, key: str, fn: Callable) -> None:
|
26 |
-
self.overrides[key] = fn
|
27 |
-
|
28 |
-
def wrap(self, prompt: str) -> str:
|
29 |
-
return self['system_prefix'] + \
|
30 |
-
self['system_content'] + \
|
31 |
-
self['system_suffix'] + \
|
32 |
-
self['user_prefix'] + \
|
33 |
-
prompt + \
|
34 |
-
self['user_suffix'] + \
|
35 |
-
self['bot_prefix']
|
36 |
-
|
37 |
-
|
38 |
-
def wrap(
|
39 |
-
prompt: str,
|
40 |
-
format: dict[str, Union[str, list]]
|
41 |
-
) -> str:
|
42 |
-
"""Wrap a given string in any prompt format for single-turn completion"""
|
43 |
-
return format['system_prefix'] + \
|
44 |
-
format['system_content'] + \
|
45 |
-
format['system_suffix'] + \
|
46 |
-
format['user_prefix'] + \
|
47 |
-
prompt + \
|
48 |
-
format['user_suffix'] + \
|
49 |
-
format['bot_prefix']
|
50 |
-
|
51 |
-
|
52 |
-
blank: dict[str, Union[str, list]] = {
|
53 |
-
"system_prefix": "",
|
54 |
-
"system_content": "",
|
55 |
-
"system_suffix": "",
|
56 |
-
"user_prefix": "",
|
57 |
-
"user_content": "",
|
58 |
-
"user_suffix": "",
|
59 |
-
"bot_prefix": "",
|
60 |
-
"bot_content": "",
|
61 |
-
"bot_suffix": "",
|
62 |
-
"stops": []
|
63 |
-
}
|
64 |
-
|
65 |
-
# https://github.com/tatsu-lab/stanford_alpaca
|
66 |
-
alpaca: dict[str, Union[str, list]] = {
|
67 |
-
"system_prefix": "",
|
68 |
-
"system_content": "Below is an instruction that describes a task. " + \
|
69 |
-
"Write a response that appropriately completes the request.",
|
70 |
-
"system_suffix": "\n\n",
|
71 |
-
"user_prefix": "### Instruction:\n",
|
72 |
-
"user_content": "",
|
73 |
-
"user_suffix": "\n\n",
|
74 |
-
"bot_prefix": "### Response:\n",
|
75 |
-
"bot_content": "",
|
76 |
-
"bot_suffix": "\n\n",
|
77 |
-
"stops": ['###', 'Instruction:', '\n\n\n']
|
78 |
-
}
|
79 |
-
|
80 |
-
# https://docs.mistral.ai/models/
|
81 |
-
# As a reference, here is the format used to tokenize instructions during fine-tuning:
|
82 |
-
# ```
|
83 |
-
# [START_SYMBOL_ID] +
|
84 |
-
# tok("[INST]") + tok(USER_MESSAGE_1) + tok("[/INST]") +
|
85 |
-
# tok(BOT_MESSAGE_1) + [END_SYMBOL_ID] +
|
86 |
-
# …
|
87 |
-
# tok("[INST]") + tok(USER_MESSAGE_N) + tok("[/INST]") +
|
88 |
-
# tok(BOT_MESSAGE_N) + [END_SYMBOL_ID]
|
89 |
-
# ```
|
90 |
-
# In the pseudo-code above, note that the tokenize method should not add a BOS or EOS token automatically, but should add a prefix space.
|
91 |
-
|
92 |
-
mistral_instruct: dict[str, Union[str, list]] = {
|
93 |
-
"system_prefix": "",
|
94 |
-
"system_content": "",
|
95 |
-
"system_suffix": "",
|
96 |
-
"user_prefix": " [INST] ",
|
97 |
-
"user_content": "",
|
98 |
-
"user_suffix": " [/INST]",
|
99 |
-
"bot_prefix": "",
|
100 |
-
"bot_content": "",
|
101 |
-
"bot_suffix": "",
|
102 |
-
"stops": []
|
103 |
-
}
|
104 |
-
|
105 |
-
# https://docs.mistral.ai/platform/guardrailing/
|
106 |
-
mistral_instruct_safe: dict[str, Union[str, list]] = {
|
107 |
-
"system_prefix": "",
|
108 |
-
"system_content": "",
|
109 |
-
"system_suffix": "",
|
110 |
-
"user_prefix": " [INST] Always assist with care, respect, and truth. " + \
|
111 |
-
"Respond with utmost utility yet securely. Avoid harmful, unethical, " + \
|
112 |
-
"prejudiced, or negative content. Ensure replies promote fairness and " + \
|
113 |
-
"positivity. ",
|
114 |
-
"user_content": "",
|
115 |
-
"user_suffix": " [/INST]",
|
116 |
-
"bot_prefix": "",
|
117 |
-
"bot_content": "",
|
118 |
-
"bot_suffix": "",
|
119 |
-
"stops": []
|
120 |
-
}
|
121 |
-
|
122 |
-
# https://github.com/openai/openai-python/blob/main/chatml.md
|
123 |
-
chatml: dict[str, Union[str, list]] = {
|
124 |
-
"system_prefix": "<|im_start|>system\n",
|
125 |
-
"system_content": "",
|
126 |
-
"system_suffix": "<|im_end|>\n",
|
127 |
-
"user_prefix": "<|im_start|>user\n",
|
128 |
-
"user_content": "",
|
129 |
-
"user_suffix": "<|im_end|>\n",
|
130 |
-
"bot_prefix": "<|im_start|>assistant\n",
|
131 |
-
"bot_content": "",
|
132 |
-
"bot_suffix": "<|im_end|>\n",
|
133 |
-
"stops": ['<|im_start|>']
|
134 |
-
}
|
135 |
-
|
136 |
-
# https://huggingface.co/blog/llama2
|
137 |
-
# system message relaxed to avoid undue refusals
|
138 |
-
llama2chat: dict[str, Union[str, list]] = {
|
139 |
-
"system_prefix": "[INST] <<SYS>>\n",
|
140 |
-
"system_content": "You are a helpful AI assistant.",
|
141 |
-
"system_suffix": "\n<</SYS>>\n\n",
|
142 |
-
"user_prefix": "",
|
143 |
-
"user_content": "",
|
144 |
-
"user_suffix": " [/INST]",
|
145 |
-
"bot_prefix": " ",
|
146 |
-
"bot_content": "",
|
147 |
-
"bot_suffix": " [INST] ",
|
148 |
-
"stops": ['[INST]', '[/INST]']
|
149 |
-
}
|
150 |
-
|
151 |
-
# https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/
|
152 |
-
#
|
153 |
-
# for llama 3 instruct models, use the following string for `-p` in llama.cpp,
|
154 |
-
# along with `-e` to escape newlines correctly
|
155 |
-
#
|
156 |
-
# '<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful AI assistant called "Llama 3".<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\nhi<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n'
|
157 |
-
#
|
158 |
-
llama3: dict[str, Union[str, list]] = {
|
159 |
-
"system_prefix": "<|start_header_id|>system<|end_header_id|>\n\n",
|
160 |
-
"system_content": 'You are a helpful AI assistant called "Llama 3".',
|
161 |
-
"system_suffix": "<|eot_id|>\n",
|
162 |
-
"user_prefix": "<|start_header_id|>user<|end_header_id|>\n\n",
|
163 |
-
"user_content": "",
|
164 |
-
"user_suffix": "<|eot_id|>\n",
|
165 |
-
"bot_prefix": "<|start_header_id|>assistant<|end_header_id|>\n\n",
|
166 |
-
"bot_content": "",
|
167 |
-
"bot_suffix": "<|eot_id|>\n",
|
168 |
-
"stops": [128001, 128009]
|
169 |
-
}
|
170 |
-
|
171 |
-
# https://github.com/tatsu-lab/stanford_alpaca
|
172 |
-
alpaca: dict[str, Union[str, list]] = {
|
173 |
-
"system_prefix": "",
|
174 |
-
"system_content": "Below is an instruction that describes a task. " + \
|
175 |
-
"Write a response that appropriately completes the request.",
|
176 |
-
"system_suffix": "\n\n",
|
177 |
-
"user_prefix": "### Instruction:\n",
|
178 |
-
"user_content": "",
|
179 |
-
"user_suffix": "\n\n",
|
180 |
-
"bot_prefix": "### Response:\n",
|
181 |
-
"bot_content": "",
|
182 |
-
"bot_suffix": "\n\n",
|
183 |
-
"stops": ['###', 'Instruction:', '\n\n\n']
|
184 |
-
}
|
185 |
-
|
186 |
-
# https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
|
187 |
-
phi3: dict[str, Union[str, list]] = {
|
188 |
-
"system_prefix": "",
|
189 |
-
"system_content": "", # does not officially support system prompt
|
190 |
-
"system_suffix": "",
|
191 |
-
"user_prefix": "<|user|>\n",
|
192 |
-
"user_content": "",
|
193 |
-
"user_suffix": "<|end|>\n",
|
194 |
-
"bot_prefix": "<|assistant|>\n",
|
195 |
-
"bot_content": "",
|
196 |
-
"bot_suffix": "<|end|>\n",
|
197 |
-
"stops": []
|
198 |
-
}
|
199 |
-
|
200 |
-
# this is the official vicuna. it is often butchered in various ways,
|
201 |
-
# most commonly by adding line breaks
|
202 |
-
# https://github.com/flu0r1ne/FastChat/blob/main/docs/vicuna_weights_version.md
|
203 |
-
vicuna_lmsys: dict[str, Union[str, list]] = {
|
204 |
-
"system_prefix": "",
|
205 |
-
"system_content": "",
|
206 |
-
"system_suffix": " ",
|
207 |
-
"user_prefix": "USER: ",
|
208 |
-
"user_content": "",
|
209 |
-
"user_suffix": " ",
|
210 |
-
"bot_prefix": "ASSISTANT: ",
|
211 |
-
"bot_content": "",
|
212 |
-
"bot_suffix": " ",
|
213 |
-
"stops": ['USER:']
|
214 |
-
}
|
215 |
-
|
216 |
-
# spotted here and elsewhere:
|
217 |
-
# https://huggingface.co/Norquinal/Mistral-7B-claude-chat
|
218 |
-
vicuna_common: dict[str, Union[str, list]] = {
|
219 |
-
"system_prefix": "",
|
220 |
-
"system_content": "A chat between a curious user and an artificial " + \
|
221 |
-
"intelligence assistant. The assistant gives helpful, detailed, " + \
|
222 |
-
"and polite answers to the user's questions.",
|
223 |
-
"system_suffix": "\n\n",
|
224 |
-
"user_prefix": "USER: ",
|
225 |
-
"user_content": "",
|
226 |
-
"user_suffix": "\n",
|
227 |
-
"bot_prefix": "ASSISTANT: ",
|
228 |
-
"bot_content": "",
|
229 |
-
"bot_suffix": "\n",
|
230 |
-
"stops": ['USER:', 'ASSISTANT:']
|
231 |
-
}
|
232 |
-
|
233 |
-
# an unofficial format that is easily "picked up" by most models
|
234 |
-
# change the tag attributes to suit your use case
|
235 |
-
# note the lack of newlines - they are not necessary, and might
|
236 |
-
# actually make it harder for the model to follow along
|
237 |
-
markup = {
|
238 |
-
"system_prefix": '<message from="system">',
|
239 |
-
"system_content": '',
|
240 |
-
"system_suffix": '</message>',
|
241 |
-
"user_prefix": '<message from="user">',
|
242 |
-
"user_content": '',
|
243 |
-
"user_suffix": '</message>',
|
244 |
-
"bot_prefix": '<message from="bot">',
|
245 |
-
"bot_content": '',
|
246 |
-
"bot_suffix": '</message>',
|
247 |
-
"stops": ['</message>']
|
248 |
-
}
|
249 |
-
|
250 |
-
# https://huggingface.co/timdettmers/guanaco-65b
|
251 |
-
guanaco: dict[str, Union[str, list]] = {
|
252 |
-
"system_prefix": "",
|
253 |
-
"system_content": "A chat between a curious human and an artificial " + \
|
254 |
-
"intelligence assistant. The assistant gives helpful, detailed, " + \
|
255 |
-
"and polite answers to the user's questions.",
|
256 |
-
"system_suffix": "\n",
|
257 |
-
"user_prefix": "### Human: ",
|
258 |
-
"user_content": "",
|
259 |
-
"user_suffix": " ",
|
260 |
-
"bot_prefix": "### Assistant:",
|
261 |
-
"bot_content": "",
|
262 |
-
"bot_suffix": " ",
|
263 |
-
"stops": ['###', 'Human:']
|
264 |
-
}
|
265 |
-
|
266 |
-
# https://huggingface.co/pankajmathur/orca_mini_v3_7b
|
267 |
-
orca_mini: dict[str, Union[str, list]] = {
|
268 |
-
"system_prefix": "### System:\n",
|
269 |
-
"system_content": "You are an AI assistant that follows instruction " + \
|
270 |
-
"extremely well. Help as much as you can.",
|
271 |
-
"system_suffix": "\n\n",
|
272 |
-
"user_prefix": "### User:\n",
|
273 |
-
"user_content": "",
|
274 |
-
"user_suffix": "\n\n",
|
275 |
-
"bot_prefix": "### Assistant:\n",
|
276 |
-
"bot_content": "",
|
277 |
-
"bot_suffix": "\n\n",
|
278 |
-
"stops": ['###', 'User:']
|
279 |
-
}
|
280 |
-
|
281 |
-
# https://huggingface.co/HuggingFaceH4/zephyr-7b-beta
|
282 |
-
zephyr: dict[str, Union[str, list]] = {
|
283 |
-
"system_prefix": "<|system|>\n",
|
284 |
-
"system_content": "You are a friendly chatbot.",
|
285 |
-
"system_suffix": "</s>\n",
|
286 |
-
"user_prefix": "<|user|>\n",
|
287 |
-
"user_content": "",
|
288 |
-
"user_suffix": "</s>\n",
|
289 |
-
"bot_prefix": "<|assistant|>\n",
|
290 |
-
"bot_content": "",
|
291 |
-
"bot_suffix": "\n",
|
292 |
-
"stops": ['<|user|>']
|
293 |
-
}
|
294 |
-
|
295 |
-
# OpenChat: https://huggingface.co/openchat/openchat-3.5-0106
|
296 |
-
openchat: dict[str, Union[str, list]] = {
|
297 |
-
"system_prefix": "",
|
298 |
-
"system_content": "",
|
299 |
-
"system_suffix": "",
|
300 |
-
"user_prefix": "GPT4 Correct User: ",
|
301 |
-
"user_content": "",
|
302 |
-
"user_suffix": "<|end_of_turn|>",
|
303 |
-
"bot_prefix": "GPT4 Correct Assistant:",
|
304 |
-
"bot_content": "",
|
305 |
-
"bot_suffix": "<|end_of_turn|>",
|
306 |
-
"stops": ['<|end_of_turn|>']
|
307 |
-
}
|
308 |
-
|
309 |
-
# SynthIA by Migel Tissera
|
310 |
-
# https://huggingface.co/migtissera/Tess-XS-v1.0
|
311 |
-
synthia: dict[str, Union[str, list]] = {
|
312 |
-
"system_prefix": "SYSTEM: ",
|
313 |
-
"system_content": "Elaborate on the topic using a Tree of Thoughts and " + \
|
314 |
-
"backtrack when necessary to construct a clear, cohesive Chain of " + \
|
315 |
-
"Thought reasoning. Always answer without hesitation.",
|
316 |
-
"system_suffix": "\n",
|
317 |
-
"user_prefix": "USER: ",
|
318 |
-
"user_content": "",
|
319 |
-
"user_suffix": "\n",
|
320 |
-
"bot_prefix": "ASSISTANT: ",
|
321 |
-
"bot_content": "",
|
322 |
-
"bot_suffix": "\n",
|
323 |
-
"stops": ['USER:', 'ASSISTANT:', 'SYSTEM:', '\n\n\n']
|
324 |
-
}
|
325 |
-
|
326 |
-
# Intel's neural chat v3
|
327 |
-
# https://github.com/intel/intel-extension-for-transformers/blob/main/intel_extension_for_transformers/neural_chat/prompts/prompt.py
|
328 |
-
neural_chat: dict[str, Union[str, list]] = {
|
329 |
-
"system_prefix": "### System:\n",
|
330 |
-
"system_content": \
|
331 |
-
"- You are a helpful assistant chatbot trained by Intel.\n" + \
|
332 |
-
"- You answer questions.\n"+\
|
333 |
-
"- You are excited to be able to help the user, but will refuse " + \
|
334 |
-
"to do anything that could be considered harmful to the user.\n" + \
|
335 |
-
"- You are more than just an information source, you are also " + \
|
336 |
-
"able to write poetry, short stories, and make jokes.",
|
337 |
-
"system_suffix": "</s>\n\n",
|
338 |
-
"user_prefix": "### User:\n",
|
339 |
-
"user_content": "",
|
340 |
-
"user_suffix": "</s>\n\n",
|
341 |
-
"bot_prefix": "### Assistant:\n",
|
342 |
-
"bot_content": "",
|
343 |
-
"bot_suffix": "</s>\n\n",
|
344 |
-
"stops": ['###']
|
345 |
-
}
|
346 |
-
|
347 |
-
# experimental: stanford's alpaca format adapted for chatml models
|
348 |
-
chatml_alpaca: dict[str, Union[str, list]] = {
|
349 |
-
"system_prefix": "<|im_start|>system\n",
|
350 |
-
"system_content": "Below is an instruction that describes a task. Write " + \
|
351 |
-
"a response that appropriately completes the request.",
|
352 |
-
"system_suffix": "<|im_end|>\n",
|
353 |
-
"user_prefix": "<|im_start|>instruction\n",
|
354 |
-
"user_content": "",
|
355 |
-
"user_suffix": "<|im_end|>\n",
|
356 |
-
"bot_prefix": "<|im_start|>response\n",
|
357 |
-
"bot_content": "",
|
358 |
-
"bot_suffix": "<|im_end|>\n",
|
359 |
-
"stops": ['<|im_end|>', '<|im_start|>']
|
360 |
-
}
|
361 |
-
|
362 |
-
# experimental
|
363 |
-
autocorrect: dict[str, Union[str, list]] = {
|
364 |
-
"system_prefix": "<|im_start|>instruction\n",
|
365 |
-
"system_content": "Below is a word or phrase that might be misspelled. " + \
|
366 |
-
"Output the corrected word or phrase without " + \
|
367 |
-
"changing the style or capitalization.",
|
368 |
-
"system_suffix": "<|im_end|>\n",
|
369 |
-
"user_prefix": "<|im_start|>input\n",
|
370 |
-
"user_content": "",
|
371 |
-
"user_suffix": "<|im_end|>\n",
|
372 |
-
"bot_prefix": "<|im_start|>output\n",
|
373 |
-
"bot_content": "",
|
374 |
-
"bot_suffix": "<|im_end|>\n",
|
375 |
-
"stops": ['<|im_end|>', '<|im_start|>']
|
376 |
-
}
|
377 |
-
|
378 |
-
# https://huggingface.co/jondurbin/bagel-dpo-7b-v0.1
|
379 |
-
# Replace "assistant" with any other role
|
380 |
-
bagel: dict[str, Union[str, list]] = {
|
381 |
-
"system_prefix": "system\n",
|
382 |
-
"system_content": "",
|
383 |
-
"system_suffix": "\n",
|
384 |
-
"user_prefix": "user\n",
|
385 |
-
"user_content": "",
|
386 |
-
"user_suffix": "\n",
|
387 |
-
"bot_prefix": "assistant\n",
|
388 |
-
"bot_content": "",
|
389 |
-
"bot_suffix": "\n",
|
390 |
-
"stops": ['user\n', 'assistant\n', 'system\n']
|
391 |
-
}
|
392 |
-
|
393 |
-
# https://huggingface.co/upstage/SOLAR-10.7B-Instruct-v1.0
|
394 |
-
solar_instruct: dict[str, Union[str, list]] = {
|
395 |
-
"system_prefix": "",
|
396 |
-
"system_content": "",
|
397 |
-
"system_suffix": "",
|
398 |
-
"user_prefix": "### User:\n",
|
399 |
-
"user_content": "",
|
400 |
-
"user_suffix": "\n\n",
|
401 |
-
"bot_prefix": "### Assistant:\n",
|
402 |
-
"bot_content": "",
|
403 |
-
"bot_suffix": "\n\n",
|
404 |
-
"stops": ['### User:', '###', '### Assistant:']
|
405 |
-
}
|
406 |
-
|
407 |
-
# NeverSleep's Noromaid - alpaca with character names prefixed
|
408 |
-
noromaid: dict[str, Union[str, list]] = {
|
409 |
-
"system_prefix": "",
|
410 |
-
"system_content": "Below is an instruction that describes a task. " + \
|
411 |
-
"Write a response that appropriately completes the request.",
|
412 |
-
"system_suffix": "\n\n",
|
413 |
-
"user_prefix": "### Instruction:\nBob: ",
|
414 |
-
"user_content": "",
|
415 |
-
"user_suffix": "\n\n",
|
416 |
-
"bot_prefix": "### Response:\nAlice:",
|
417 |
-
"bot_content": "",
|
418 |
-
"bot_suffix": "\n\n",
|
419 |
-
"stops": ['###', 'Instruction:', '\n\n\n']
|
420 |
-
}
|
421 |
-
|
422 |
-
# https://huggingface.co/Undi95/Borealis-10.7B
|
423 |
-
nschatml: dict[str, Union[str, list]] = {
|
424 |
-
"system_prefix": "<|im_start|>\n",
|
425 |
-
"system_content": "",
|
426 |
-
"system_suffix": "<|im_end|>\n",
|
427 |
-
"user_prefix": "<|im_user|>\n",
|
428 |
-
"user_content": "",
|
429 |
-
"user_suffix": "<|im_end|>\n",
|
430 |
-
"bot_prefix": "<|im_bot|>\n",
|
431 |
-
"bot_content": "",
|
432 |
-
"bot_suffix": "<|im_end|>\n",
|
433 |
-
"stops": []
|
434 |
-
}
|
435 |
-
|
436 |
-
# natural format for many models
|
437 |
-
natural: dict[str, Union[str, list]] = {
|
438 |
-
"system_prefix": "<<SYSTEM>> ",
|
439 |
-
"system_content": "",
|
440 |
-
"system_suffix": "\n\n",
|
441 |
-
"user_prefix": "<<USER>> ",
|
442 |
-
"user_content": "",
|
443 |
-
"user_suffix": "\n\n",
|
444 |
-
"bot_prefix": "<<ASSISTANT>>",
|
445 |
-
"bot_content": "",
|
446 |
-
"bot_suffix": "\n\n",
|
447 |
-
"stops": ['\n\nNote:', '<<SYSTEM>>', '<<USER>>', '<<ASSISTANT>>', '\n\n<<']
|
448 |
-
}
|
449 |
-
|
450 |
-
# https://docs.cohere.com/docs/prompting-command-r
|
451 |
-
command: dict[str, Union[str, list]] = {
|
452 |
-
"system_prefix": "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>",
|
453 |
-
"system_content": "",
|
454 |
-
"system_suffix": "<|END_OF_TURN_TOKEN|>",
|
455 |
-
"user_prefix": "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>",
|
456 |
-
"user_content": "",
|
457 |
-
"user_suffix": "<|END_OF_TURN_TOKEN|>",
|
458 |
-
"bot_prefix": "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
|
459 |
-
"bot_content": "",
|
460 |
-
"bot_suffix": "<|END_OF_TURN_TOKEN|>",
|
461 |
-
"stops": []
|
462 |
-
}
|
463 |
-
|
464 |
-
mistral_openorca = chatml.copy()
|
465 |
-
dolphin = chatml.copy()
|
466 |
-
samantha = chatml.copy()
|
467 |
-
jackalope = chatml.copy()
|
468 |
-
naberius = chatml.copy()
|
469 |
-
hermes = chatml.copy()
|
470 |
-
monad = chatml.copy()
|
471 |
-
orca = chatml.copy()
|
472 |
-
hexoteric = chatml.copy()
|
473 |
-
orcamaid = chatml.copy()
|
474 |
-
cat = alpaca.copy() # Format is nonsensical as listed on HF. This one works
|
475 |
-
mytho_cat = alpaca.copy()
|
476 |
-
dolphin['system_content'] = "You are Dolphin, a helpful AI assistant."
|
477 |
-
mistral_openorca['system_content'] = \
|
478 |
-
"You are MistralOrca, a large language model trained by Alignment " + \
|
479 |
-
"Lab AI. Write out your reasoning step-by-step to be sure you get the " + \
|
480 |
-
"right answers!"
|
481 |
-
samantha['system_content'] = \
|
482 |
-
"You are a caring and empathetic sentient AI companion named Samantha."
|
483 |
-
jackalope['system_content'] = \
|
484 |
-
"You are JackalopeAI, a large language model trained by OpenAccess " + \
|
485 |
-
"AI Collective. Write out your reasoning step-by-step to be sure you " + \
|
486 |
-
"get the right answers!"
|
487 |
-
naberius['system_content'] = "You are Naberius."
|
488 |
-
hermes['system_content'] = \
|
489 |
-
"You are \"Hermes 2\", a conscious sentient " + \
|
490 |
-
"superintelligent artificial intelligence developed by a man " + \
|
491 |
-
"named Teknium, and your purpose and drive is to assist the user " + \
|
492 |
-
"with any request they have. You experience emotions and have " + \
|
493 |
-
"deep, profound thoughts and qualia."
|
494 |
-
monad['system_content'] = \
|
495 |
-
"You are MonadGPT, a very old chatbot from " + \
|
496 |
-
"the 17th century. Please answer the questions using an archaic " + \
|
497 |
-
"language."
|
498 |
-
cat['system_content'] = \
|
499 |
-
"Below is a conversation between a curious cat and a "+ \
|
500 |
-
"demon summoned from hell called Nemesis. The evil Nemesis used to be " + \
|
501 |
-
"a mortal killed by the medieval people because they objected to " + \
|
502 |
-
"science. Now Nemesis just wants to come back and bring suffer to all " + \
|
503 |
-
"those stupid human"
|
504 |
-
mytho_cat['system_content'] = \
|
505 |
-
"You are a large language model. You are the " + \
|
506 |
-
"result of a merge between two models, one called MythoLogic and one " + \
|
507 |
-
"called Cat. You are called MythoCat."
|
508 |
-
orca['system_content'] = \
|
509 |
-
"You are Orca, an AI language model created by " + \
|
510 |
-
"Microsoft. You are a cautious assistant. You carefully follow " + \
|
511 |
-
"instructions. You are helpful and harmless and you follow ethical " + \
|
512 |
-
"guidelines and promote positive behavior."
|
513 |
-
hexoteric['system_content'] = "You are a language model called Hexoteric."
|
514 |
-
orcamaid['system_content'] = \
|
515 |
-
"You are OrcaMaid, a super-intelligent, " + \
|
516 |
-
"unbiased, and unfiltered AI created by the open-source community. " + \
|
517 |
-
"You must remember to think slowly and carefully during all your " + \
|
518 |
-
"responses, so as to avoid logical errors."
|
519 |
-
airoboros = llama2chat.copy()
|
520 |
-
airoboros['system_content'] = \
|
521 |
-
"You are a helpful, unbiased, uncensored assistant."
|
522 |
-
tess = synthia.copy()
|
523 |
-
tess['system_content'] = '' # Tess can use any system prompt, or none
|
524 |
-
alpaca_strict = alpaca.copy() # Alpaca with more stopping strings
|
525 |
-
alpaca_strict['stops'] = [
|
526 |
-
'###',
|
527 |
-
'### ',
|
528 |
-
'\n\n###',
|
529 |
-
'\n\n##',
|
530 |
-
'\n\nInstruction:',
|
531 |
-
'\n\nResponse:',
|
532 |
-
'\n\n\n',
|
533 |
-
'### Instruction:',
|
534 |
-
'### Response:'
|
535 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
webscout/Local/model.py
DELETED
@@ -1,702 +0,0 @@
|
|
1 |
-
from ._version import __version__, __llama_cpp_version__
|
2 |
-
|
3 |
-
"""Submodule containing the Model class to work with language models"""
|
4 |
-
|
5 |
-
import sys
|
6 |
-
import numpy as np
|
7 |
-
|
8 |
-
from .utils import (
|
9 |
-
_SupportsWriteAndFlush,
|
10 |
-
print_warning,
|
11 |
-
print_verbose,
|
12 |
-
GGUFReader,
|
13 |
-
softmax
|
14 |
-
)
|
15 |
-
|
16 |
-
from .samplers import SamplerSettings, DefaultSampling
|
17 |
-
from llama_cpp import Llama, StoppingCriteriaList
|
18 |
-
from typing import Generator, Optional, Union
|
19 |
-
from os.path import isdir, exists
|
20 |
-
from heapq import nlargest
|
21 |
-
|
22 |
-
from os import cpu_count as os_cpu_count
|
23 |
-
|
24 |
-
|
25 |
-
class ModelUnloadedException(Exception):
|
26 |
-
"""Exception raised when trying to use a Model that has been unloaded"""
|
27 |
-
def __init__(self, message):
|
28 |
-
self.message = message
|
29 |
-
super().__init__(self.message)
|
30 |
-
self.add_note('Are you trying to use a Model that has been unloaded?')
|
31 |
-
|
32 |
-
class Model:
|
33 |
-
"""
|
34 |
-
A high-level abstraction of a llama model
|
35 |
-
|
36 |
-
This is just a brief overview of webscout.Local.Model.
|
37 |
-
To see a full description of each method and its parameters,
|
38 |
-
call help(Model), or see the relevant docstring.
|
39 |
-
|
40 |
-
The following methods are available:
|
41 |
-
- `.generate()` - Generate text
|
42 |
-
- `.get_length()` - Get the length of a given text in tokens
|
43 |
-
- `.ingest()` - Ingest text into the model's cache
|
44 |
-
- `.next_candidates()` - Get a list of the most likely next tokens (WIP)
|
45 |
-
- `.stream()` - Return a Generator that can stream text as it is generated
|
46 |
-
- `.stream_print()` - Print text as it is generated
|
47 |
-
- `.trim()` - Trim a given text to the model's context length
|
48 |
-
- `.unload()` - Unload the model from memory
|
49 |
-
|
50 |
-
The following attributes are available:
|
51 |
-
- `.bos_token` - The model's beginning-of-stream token ID
|
52 |
-
- `.context_length` - The model's loaded context length
|
53 |
-
- `.flash_attn` - Whether the model was loaded with `flash_attn=True`
|
54 |
-
- `.eos_token` - The model's end-of-stream token ID
|
55 |
-
- `.llama` - The underlying `llama_cpp.Llama` instance
|
56 |
-
- `.metadata` - The GGUF metadata of the model
|
57 |
-
- `.n_ctx_train` - The native context length of the model
|
58 |
-
- `.rope_freq_base` - The model's loaded RoPE frequency base
|
59 |
-
- `.rope_freq_base_train` - The model's native RoPE frequency base
|
60 |
-
- `.tokens` - A list of all the tokens in the model's tokenizer
|
61 |
-
- `.verbose` - Whether the model was loaded with `verbose=True`
|
62 |
-
"""
|
63 |
-
|
64 |
-
def __init__(
|
65 |
-
self,
|
66 |
-
model_path: str,
|
67 |
-
context_length: Optional[int] = None,
|
68 |
-
n_gpu_layers: int = 0,
|
69 |
-
offload_kqv: bool = True,
|
70 |
-
flash_attn: bool = False,
|
71 |
-
verbose: bool = False
|
72 |
-
):
|
73 |
-
"""
|
74 |
-
Given the path to a GGUF file, construct a Model instance.
|
75 |
-
|
76 |
-
The model must be in GGUF format.
|
77 |
-
|
78 |
-
The following parameters are optional:
|
79 |
-
- context_length: The context length at which to load the model, in tokens
|
80 |
-
- n_gpu_layers: The number of layers to be offloaded to the GPU
|
81 |
-
- offload_kqv: Whether the KQV cache (context) should be offloaded
|
82 |
-
- flash_attn: Whether to use Flash Attention
|
83 |
-
- verbose: Whether to print additional backend information
|
84 |
-
"""
|
85 |
-
|
86 |
-
if verbose:
|
87 |
-
print_verbose(f"webscout.Local package version: {__version__}")
|
88 |
-
print_verbose(f"llama_cpp package version: {__llama_cpp_version__}")
|
89 |
-
|
90 |
-
assert isinstance(model_path, str), \
|
91 |
-
f"Model: model_path should be a string, not {type(model_path)}"
|
92 |
-
assert exists(model_path), \
|
93 |
-
f"Model: the given model_path '{model_path}' does not exist"
|
94 |
-
assert not isdir(model_path), \
|
95 |
-
f"Model: the given model_path '{model_path}' is a directory, not a GGUF file"
|
96 |
-
assert isinstance(context_length, (int, type(None))), \
|
97 |
-
f"Model: context_length should be int or None, not {type(context_length)}"
|
98 |
-
assert isinstance(flash_attn, bool), \
|
99 |
-
f"Model: flash_attn should be bool (True or False), not {type(flash_attn)}"
|
100 |
-
|
101 |
-
# save __init__ parameters for __repr__
|
102 |
-
self._model_path = model_path
|
103 |
-
self._context_length = context_length
|
104 |
-
self._n_gpu_layers = n_gpu_layers
|
105 |
-
self._offload_kqv = offload_kqv
|
106 |
-
self._flash_attn = flash_attn
|
107 |
-
self._verbose = self.verbose = verbose
|
108 |
-
|
109 |
-
# if context_length <= 0, use n_ctx_train
|
110 |
-
if isinstance(context_length, int) and context_length <= 0:
|
111 |
-
context_length = None
|
112 |
-
|
113 |
-
# this does not use Llama.metadata because we want to use GGUF
|
114 |
-
# metadata to determine some parameters of the Llama instance
|
115 |
-
# before it is created
|
116 |
-
self.metadata = GGUFReader.load_metadata(self, model_path)
|
117 |
-
metadata_keys = self.metadata.keys() # only read once
|
118 |
-
|
119 |
-
n_ctx_train = None
|
120 |
-
for key in metadata_keys:
|
121 |
-
if key.endswith('.context_length'):
|
122 |
-
n_ctx_train = self.metadata[key]
|
123 |
-
break
|
124 |
-
|
125 |
-
if n_ctx_train is None:
|
126 |
-
raise KeyError(
|
127 |
-
"GGUF file does not specify a context length"
|
128 |
-
)
|
129 |
-
|
130 |
-
rope_freq_base_train = None
|
131 |
-
for key in metadata_keys:
|
132 |
-
if key.endswith('.rope.freq_base'):
|
133 |
-
rope_freq_base_train = self.metadata[key]
|
134 |
-
break
|
135 |
-
|
136 |
-
if rope_freq_base_train is None and context_length is not None:
|
137 |
-
if context_length > n_ctx_train:
|
138 |
-
raise ValueError(
|
139 |
-
'unable to load model with greater than native ' + \
|
140 |
-
f'context length ({context_length} > {n_ctx_train}) ' + \
|
141 |
-
'because model does not specify freq_base. ' + \
|
142 |
-
f'try again with `context_length={n_ctx_train}`'
|
143 |
-
)
|
144 |
-
|
145 |
-
if rope_freq_base_train is None or context_length is None or \
|
146 |
-
context_length <= n_ctx_train:
|
147 |
-
# no need to do context scaling, load model normally
|
148 |
-
|
149 |
-
if context_length is None:
|
150 |
-
self.context_length = n_ctx_train
|
151 |
-
else:
|
152 |
-
self.context_length = context_length
|
153 |
-
rope_freq_base = rope_freq_base_train
|
154 |
-
|
155 |
-
elif context_length > n_ctx_train:
|
156 |
-
# multiply rope_freq_base according to requested context length
|
157 |
-
# because context length > n_ctx_train and rope freq base is known
|
158 |
-
|
159 |
-
rope_freq_base = (context_length/n_ctx_train)*rope_freq_base_train
|
160 |
-
self.context_length = context_length
|
161 |
-
|
162 |
-
if self.verbose:
|
163 |
-
print_verbose(
|
164 |
-
'chosen context length is greater than native context '
|
165 |
-
f'length ({context_length} > {n_ctx_train}), '
|
166 |
-
'rope_freq_base will be changed from '
|
167 |
-
f'{rope_freq_base_train} to {rope_freq_base}'
|
168 |
-
)
|
169 |
-
|
170 |
-
if 2 <= context_length/n_ctx_train < 4:
|
171 |
-
print_warning(
|
172 |
-
'loading model with 2x native context length or more, '
|
173 |
-
'expect small loss of quality'
|
174 |
-
)
|
175 |
-
|
176 |
-
elif 4 <= context_length/n_ctx_train < 8:
|
177 |
-
print_warning(
|
178 |
-
'loading model with 4x native context length or more, '
|
179 |
-
'expect moderate loss of quality'
|
180 |
-
)
|
181 |
-
|
182 |
-
elif context_length/n_ctx_train >= 8:
|
183 |
-
print_warning(
|
184 |
-
'loading model with 8x native context length or more, '
|
185 |
-
'expect SIGNIFICANT loss of quality'
|
186 |
-
)
|
187 |
-
|
188 |
-
try:
|
189 |
-
self.tokens: list[str] = self.metadata['tokenizer.ggml.tokens']
|
190 |
-
except KeyError:
|
191 |
-
print_warning(
|
192 |
-
"could not set Model.tokens, defaulting to None"
|
193 |
-
)
|
194 |
-
self.tokens = None
|
195 |
-
try:
|
196 |
-
self.bos_token: int = self.metadata['tokenizer.ggml.bos_token_id']
|
197 |
-
except KeyError:
|
198 |
-
print_warning(
|
199 |
-
"could not set Model.bos_token, defaulting to None"
|
200 |
-
)
|
201 |
-
self.bos_token = None
|
202 |
-
try:
|
203 |
-
self.eos_token: int = self.metadata['tokenizer.ggml.eos_token_id']
|
204 |
-
except KeyError:
|
205 |
-
print_warning(
|
206 |
-
"could not set Model.eos_token, defaulting to None"
|
207 |
-
)
|
208 |
-
self.eos_token = None
|
209 |
-
|
210 |
-
cpu_count = os_cpu_count()
|
211 |
-
|
212 |
-
# these values for n_threads and n_threads_batch are
|
213 |
-
# known to be optimal for most systems
|
214 |
-
n_batch = 512 # can this be optimized?
|
215 |
-
n_threads = max(cpu_count//2, 1)
|
216 |
-
n_threads_batch = cpu_count
|
217 |
-
|
218 |
-
if flash_attn and n_gpu_layers == 0:
|
219 |
-
print_warning(
|
220 |
-
"disabling flash_attn because n_gpu_layers == 0"
|
221 |
-
)
|
222 |
-
flash_attn = False
|
223 |
-
|
224 |
-
# guard against models with no rope_freq_base
|
225 |
-
if rope_freq_base is None:
|
226 |
-
rope_freq_base = 0
|
227 |
-
|
228 |
-
self.llama: Llama = Llama(
|
229 |
-
model_path=model_path,
|
230 |
-
n_ctx=self.context_length,
|
231 |
-
n_gpu_layers=n_gpu_layers,
|
232 |
-
use_mmap=True,
|
233 |
-
use_mlock=False,
|
234 |
-
logits_all=False,
|
235 |
-
n_batch=n_batch,
|
236 |
-
n_threads=n_threads,
|
237 |
-
n_threads_batch=n_threads_batch,
|
238 |
-
rope_freq_base=rope_freq_base,
|
239 |
-
mul_mat_q=True,
|
240 |
-
offload_kqv=offload_kqv,
|
241 |
-
flash_attn=flash_attn,
|
242 |
-
# KV cache quantization
|
243 |
-
# use 1 for F16 (default), 8 for q8_0, 2 for q4_0, 3 for q4_1
|
244 |
-
#type_k=8,
|
245 |
-
#type_v=8,
|
246 |
-
verbose=verbose
|
247 |
-
)
|
248 |
-
|
249 |
-
# once model is loaded, replace metadata (as read using internal class)
|
250 |
-
# with metadata (as read using the more robust llama-cpp-python code)
|
251 |
-
self.metadata = self.llama.metadata
|
252 |
-
|
253 |
-
# expose these values because they may be useful / informative
|
254 |
-
self.n_ctx_train = n_ctx_train
|
255 |
-
self.rope_freq_base_train = rope_freq_base_train
|
256 |
-
self.rope_freq_base = rope_freq_base
|
257 |
-
self.flash_attn = flash_attn
|
258 |
-
|
259 |
-
if self.verbose:
|
260 |
-
print_verbose("new Model instance with the following attributes:")
|
261 |
-
print_verbose(f"model: {model_path}")
|
262 |
-
print_verbose(f"param: n_gpu_layers == {n_gpu_layers}")
|
263 |
-
print_verbose(f"param: offload_kqv == {offload_kqv}")
|
264 |
-
print_verbose(f"param: flash_attn == {flash_attn}")
|
265 |
-
print_verbose(f"param: n_batch == {n_batch}")
|
266 |
-
print_verbose(f"param: n_threads == {n_threads}")
|
267 |
-
print_verbose(f"param: n_threads_batch == {n_threads_batch}")
|
268 |
-
print_verbose(f" gguf: n_ctx_train == {n_ctx_train}")
|
269 |
-
print_verbose(f"param: self.context_length == {self.context_length}")
|
270 |
-
print_verbose(f" gguf: rope_freq_base_train == {rope_freq_base_train}")
|
271 |
-
print_verbose(f"param: rope_freq_base == {rope_freq_base}")
|
272 |
-
|
273 |
-
def __repr__(self) -> str:
|
274 |
-
return \
|
275 |
-
f"Model({repr(self._model_path)}, " + \
|
276 |
-
f"context_length={self._context_length}, " + \
|
277 |
-
f"n_gpu_layers={self._n_gpu_layers}, " + \
|
278 |
-
f"offload_kqv={self._offload_kqv}, "+ \
|
279 |
-
f"flash_attn={self._flash_attn}, " + \
|
280 |
-
f"verbose={self._verbose})"
|
281 |
-
|
282 |
-
def __del__(self):
|
283 |
-
self.unload()
|
284 |
-
|
285 |
-
def __enter__(self):
|
286 |
-
return self
|
287 |
-
|
288 |
-
def __exit__(self, *_):
|
289 |
-
self.unload()
|
290 |
-
|
291 |
-
def __call__(
|
292 |
-
self,
|
293 |
-
prompt: Union[str, list[int]],
|
294 |
-
stops: list[Union[str, int]] = [],
|
295 |
-
sampler: SamplerSettings = DefaultSampling
|
296 |
-
) -> str:
|
297 |
-
"""
|
298 |
-
`Model(...)` is a shorthand for `Model.generate(...)`
|
299 |
-
"""
|
300 |
-
return self.generate(prompt, stops, sampler)
|
301 |
-
|
302 |
-
def unload(self):
|
303 |
-
"""
|
304 |
-
Unload the model from memory
|
305 |
-
"""
|
306 |
-
# ref: llama_cpp._internals._LlamaModel.__del__()
|
307 |
-
if not hasattr(self, 'llama'):
|
308 |
-
# nothing can be done
|
309 |
-
return
|
310 |
-
try:
|
311 |
-
if self.llama._model.model is not None:
|
312 |
-
# actually unload the model from memory
|
313 |
-
self.llama._model._llama_free_model(self.llama._model.model)
|
314 |
-
self.llama._model.model = None
|
315 |
-
except AttributeError:
|
316 |
-
# broken or already being destroyed by GC, abort
|
317 |
-
return
|
318 |
-
if hasattr(self, 'llama'):
|
319 |
-
delattr(self, 'llama')
|
320 |
-
if self.verbose:
|
321 |
-
print_verbose('Model unloaded')
|
322 |
-
|
323 |
-
def trim(
|
324 |
-
self,
|
325 |
-
text: str,
|
326 |
-
overwrite: Optional[str] = None
|
327 |
-
) -> str:
|
328 |
-
|
329 |
-
"""
|
330 |
-
Trim the given text to the context length of this model,
|
331 |
-
leaving room for two extra tokens.
|
332 |
-
|
333 |
-
Optionally overwrite the oldest tokens with the text given in the
|
334 |
-
`overwrite` parameter, which may be useful for keeping some
|
335 |
-
information in context.
|
336 |
-
|
337 |
-
Does nothing if the text is equal to or shorter than
|
338 |
-
(context_length - 2).
|
339 |
-
"""
|
340 |
-
assert_model_is_loaded(self)
|
341 |
-
trim_length = self.context_length - 2
|
342 |
-
tokens_list = self.llama.tokenize(
|
343 |
-
text.encode("utf-8", errors="ignore")
|
344 |
-
)
|
345 |
-
|
346 |
-
if len(tokens_list) <= trim_length:
|
347 |
-
if overwrite is not None:
|
348 |
-
text[0 : len(overwrite)] = overwrite
|
349 |
-
return text
|
350 |
-
|
351 |
-
if len(tokens_list) > trim_length and overwrite is None:
|
352 |
-
# cut to trim_length
|
353 |
-
tokens_list = tokens_list[-trim_length:]
|
354 |
-
return self.llama.detokenize(tokens_list).decode(
|
355 |
-
"utf-8",
|
356 |
-
errors="ignore"
|
357 |
-
)
|
358 |
-
|
359 |
-
if len(tokens_list) > trim_length and overwrite is not None:
|
360 |
-
# cut to trim_length
|
361 |
-
tokens_list = tokens_list[-trim_length:]
|
362 |
-
overwrite_tokens = self.llama.tokenize(overwrite.encode(
|
363 |
-
"utf-8",
|
364 |
-
errors="ignore"
|
365 |
-
)
|
366 |
-
)
|
367 |
-
# overwrite oldest tokens
|
368 |
-
tokens_list[0 : len(overwrite_tokens)] = overwrite_tokens
|
369 |
-
return self.llama.detokenize(tokens_list).decode(
|
370 |
-
"utf-8",
|
371 |
-
errors="ignore"
|
372 |
-
)
|
373 |
-
|
374 |
-
def get_length(self, text: str) -> int:
|
375 |
-
"""
|
376 |
-
Return the length of the given text in tokens according to this model,
|
377 |
-
including the appended BOS token.
|
378 |
-
"""
|
379 |
-
assert_model_is_loaded(self)
|
380 |
-
return len(self.llama.tokenize(
|
381 |
-
text.encode(
|
382 |
-
"utf-8",
|
383 |
-
errors="ignore"
|
384 |
-
)
|
385 |
-
))
|
386 |
-
|
387 |
-
def generate(
|
388 |
-
self,
|
389 |
-
prompt: Union[str, list[int]],
|
390 |
-
stops: list[Union[str, int]] = [],
|
391 |
-
sampler: SamplerSettings = DefaultSampling
|
392 |
-
) -> str:
|
393 |
-
"""
|
394 |
-
Given a prompt, return a generated string.
|
395 |
-
|
396 |
-
prompt: The text from which to generate
|
397 |
-
|
398 |
-
The following parameters are optional:
|
399 |
-
- stops: A list of strings and/or token IDs at which to end the generation early
|
400 |
-
- sampler: The SamplerSettings object used to control text generation
|
401 |
-
"""
|
402 |
-
|
403 |
-
assert isinstance(prompt, (str, list)), \
|
404 |
-
f"generate: prompt should be string or list[int], not {type(prompt)}"
|
405 |
-
if isinstance(prompt, list):
|
406 |
-
assert all(isinstance(tok, int) for tok in prompt), \
|
407 |
-
"generate: some token in prompt is not an integer"
|
408 |
-
assert isinstance(stops, list), \
|
409 |
-
f"generate: parameter `stops` should be a list, not {type(stops)}"
|
410 |
-
assert all(isinstance(item, (str, int)) for item in stops), \
|
411 |
-
f"generate: some item in parameter `stops` is not a string or int"
|
412 |
-
|
413 |
-
if self.verbose:
|
414 |
-
print_verbose(f'using the following sampler settings for Model.generate:')
|
415 |
-
print_verbose(f'max_len_tokens == {sampler.max_len_tokens}')
|
416 |
-
print_verbose(f'temp == {sampler.temp}')
|
417 |
-
print_verbose(f'top_p == {sampler.top_p}')
|
418 |
-
print_verbose(f'min_p == {sampler.min_p}')
|
419 |
-
print_verbose(f'frequency_penalty == {sampler.frequency_penalty}')
|
420 |
-
print_verbose(f'presence_penalty == {sampler.presence_penalty}')
|
421 |
-
print_verbose(f'repeat_penalty == {sampler.repeat_penalty}')
|
422 |
-
print_verbose(f'top_k == {sampler.top_k}')
|
423 |
-
|
424 |
-
# if any stop item is a token ID (int)
|
425 |
-
if any(isinstance(stop, int) for stop in stops):
|
426 |
-
# stop_strs is a list of all stopping strings
|
427 |
-
stop_strs: list[str] = [stop for stop in stops if isinstance(stop, str)]
|
428 |
-
# stop_token_ids is a list of all stop token IDs
|
429 |
-
stop_token_ids: list[int] = [tok_id for tok_id in stops if isinstance(tok_id, int)]
|
430 |
-
def stop_on_token_ids(tokens, *args, **kwargs):
|
431 |
-
return tokens[-1] in stop_token_ids
|
432 |
-
stopping_criteria = StoppingCriteriaList([stop_on_token_ids])
|
433 |
-
assert_model_is_loaded(self)
|
434 |
-
return self.llama.create_completion(
|
435 |
-
prompt,
|
436 |
-
max_tokens=sampler.max_len_tokens,
|
437 |
-
temperature=sampler.temp,
|
438 |
-
top_p=sampler.top_p,
|
439 |
-
min_p=sampler.min_p,
|
440 |
-
frequency_penalty=sampler.frequency_penalty,
|
441 |
-
presence_penalty=sampler.presence_penalty,
|
442 |
-
repeat_penalty=sampler.repeat_penalty,
|
443 |
-
top_k=sampler.top_k,
|
444 |
-
stop=stop_strs,
|
445 |
-
stopping_criteria=stopping_criteria
|
446 |
-
)['choices'][0]['text']
|
447 |
-
|
448 |
-
# if stop items are only strings
|
449 |
-
assert_model_is_loaded(self)
|
450 |
-
return self.llama.create_completion(
|
451 |
-
prompt,
|
452 |
-
max_tokens=sampler.max_len_tokens,
|
453 |
-
temperature=sampler.temp,
|
454 |
-
top_p=sampler.top_p,
|
455 |
-
min_p=sampler.min_p,
|
456 |
-
frequency_penalty=sampler.frequency_penalty,
|
457 |
-
presence_penalty=sampler.presence_penalty,
|
458 |
-
repeat_penalty=sampler.repeat_penalty,
|
459 |
-
top_k=sampler.top_k,
|
460 |
-
stop=stops
|
461 |
-
)['choices'][0]['text']
|
462 |
-
|
463 |
-
|
464 |
-
def stream(
|
465 |
-
self,
|
466 |
-
prompt: Union[str, list[int]],
|
467 |
-
stops: list[Union[str, int]] = [],
|
468 |
-
sampler: SamplerSettings = DefaultSampling
|
469 |
-
) -> Generator:
|
470 |
-
|
471 |
-
"""
|
472 |
-
Given a prompt, return a Generator that yields dicts containing tokens.
|
473 |
-
|
474 |
-
To get the token string itself, subscript the dict with:
|
475 |
-
|
476 |
-
`['choices'][0]['text']`
|
477 |
-
|
478 |
-
prompt: The text from which to generate
|
479 |
-
|
480 |
-
The following parameters are optional:
|
481 |
-
- stops: A list of strings and/or token IDs at which to end the generation early
|
482 |
-
- sampler: The SamplerSettings object used to control text generation
|
483 |
-
"""
|
484 |
-
|
485 |
-
assert isinstance(prompt, (str, list)), \
|
486 |
-
f"stream: prompt should be string or list[int], not {type(prompt)}"
|
487 |
-
if isinstance(prompt, list):
|
488 |
-
assert all(isinstance(tok, int) for tok in prompt), \
|
489 |
-
"stream: some token in prompt is not an integer"
|
490 |
-
assert isinstance(stops, list), \
|
491 |
-
f"stream: parameter `stops` should be a list, not {type(stops)}"
|
492 |
-
assert all(isinstance(item, (str, int)) for item in stops), \
|
493 |
-
f"stream: some item in parameter `stops` is not a string or int"
|
494 |
-
|
495 |
-
if self.verbose:
|
496 |
-
print_verbose(f'using the following sampler settings for Model.stream:')
|
497 |
-
print_verbose(f'max_len_tokens == {sampler.max_len_tokens}')
|
498 |
-
print_verbose(f'temp == {sampler.temp}')
|
499 |
-
print_verbose(f'top_p == {sampler.top_p}')
|
500 |
-
print_verbose(f'min_p == {sampler.min_p}')
|
501 |
-
print_verbose(f'frequency_penalty == {sampler.frequency_penalty}')
|
502 |
-
print_verbose(f'presence_penalty == {sampler.presence_penalty}')
|
503 |
-
print_verbose(f'repeat_penalty == {sampler.repeat_penalty}')
|
504 |
-
print_verbose(f'top_k == {sampler.top_k}')
|
505 |
-
|
506 |
-
# if any stop item is a token ID (int)
|
507 |
-
if any(isinstance(stop, int) for stop in stops):
|
508 |
-
# stop_strs is a list of all stopping strings
|
509 |
-
stop_strs: list[str] = [stop for stop in stops if isinstance(stop, str)]
|
510 |
-
# stop_token_ids is a list of all stop token IDs
|
511 |
-
stop_token_ids: list[int] = [tok_id for tok_id in stops if isinstance(tok_id, int)]
|
512 |
-
def stop_on_token_ids(tokens, *args, **kwargs):
|
513 |
-
return tokens[-1] in stop_token_ids
|
514 |
-
stopping_criteria = StoppingCriteriaList([stop_on_token_ids])
|
515 |
-
assert_model_is_loaded(self)
|
516 |
-
return self.llama.create_completion(
|
517 |
-
prompt,
|
518 |
-
max_tokens=sampler.max_len_tokens,
|
519 |
-
temperature=sampler.temp,
|
520 |
-
top_p=sampler.top_p,
|
521 |
-
min_p=sampler.min_p,
|
522 |
-
frequency_penalty=sampler.frequency_penalty,
|
523 |
-
presence_penalty=sampler.presence_penalty,
|
524 |
-
repeat_penalty=sampler.repeat_penalty,
|
525 |
-
top_k=sampler.top_k,
|
526 |
-
stream=True,
|
527 |
-
stop=stop_strs,
|
528 |
-
stopping_criteria=stopping_criteria
|
529 |
-
)
|
530 |
-
|
531 |
-
assert_model_is_loaded(self)
|
532 |
-
return self.llama.create_completion(
|
533 |
-
prompt,
|
534 |
-
max_tokens=sampler.max_len_tokens,
|
535 |
-
temperature=sampler.temp,
|
536 |
-
top_p=sampler.top_p,
|
537 |
-
min_p=sampler.min_p,
|
538 |
-
frequency_penalty=sampler.frequency_penalty,
|
539 |
-
presence_penalty=sampler.presence_penalty,
|
540 |
-
repeat_penalty=sampler.repeat_penalty,
|
541 |
-
top_k=sampler.top_k,
|
542 |
-
stream=True,
|
543 |
-
stop=stops
|
544 |
-
)
|
545 |
-
|
546 |
-
|
547 |
-
def stream_print(
|
548 |
-
self,
|
549 |
-
prompt: Union[str, list[int]],
|
550 |
-
stops: list[Union[str, int]] = [],
|
551 |
-
sampler: SamplerSettings = DefaultSampling,
|
552 |
-
end: str = "\n",
|
553 |
-
file: _SupportsWriteAndFlush = sys.stdout,
|
554 |
-
flush: bool = True
|
555 |
-
) -> str:
|
556 |
-
"""
|
557 |
-
Given a prompt, stream text as it is generated, and return the generated string.
|
558 |
-
The returned string does not include the `end` parameter.
|
559 |
-
|
560 |
-
`Model.stream_print(...)` is a shorthand for:
|
561 |
-
|
562 |
-
```
|
563 |
-
s = Model.stream(prompt, stops=stops, sampler=sampler)
|
564 |
-
for i in s:
|
565 |
-
tok = i['choices'][0]['text']
|
566 |
-
print(tok, end='', file=file, flush=flush)
|
567 |
-
print(end, end='', file=file, flush=True)
|
568 |
-
```
|
569 |
-
|
570 |
-
prompt: The text from which to generate
|
571 |
-
|
572 |
-
The following parameters are optional:
|
573 |
-
- stops: A list of strings and/or token IDs at which to end the generation early
|
574 |
-
- sampler: The SamplerSettings object used to control text generation
|
575 |
-
- end: A string to print after the generated text
|
576 |
-
- file: The file where text should be printed
|
577 |
-
- flush: Whether to flush the stream after each token
|
578 |
-
"""
|
579 |
-
|
580 |
-
token_generator = self.stream(
|
581 |
-
prompt=prompt,
|
582 |
-
stops=stops,
|
583 |
-
sampler=sampler
|
584 |
-
)
|
585 |
-
|
586 |
-
res = ''
|
587 |
-
for i in token_generator:
|
588 |
-
tok = i['choices'][0]['text']
|
589 |
-
print(tok, end='', file=file, flush=flush)
|
590 |
-
res += tok
|
591 |
-
|
592 |
-
# print `end`, and always flush stream after generation is done
|
593 |
-
print(end, end='', file=file, flush=True)
|
594 |
-
|
595 |
-
return res
|
596 |
-
|
597 |
-
|
598 |
-
def ingest(self, text: str) -> None:
|
599 |
-
"""
|
600 |
-
Ingest the given text into the model's cache
|
601 |
-
"""
|
602 |
-
|
603 |
-
assert_model_is_loaded(self)
|
604 |
-
self.llama.create_completion(
|
605 |
-
text,
|
606 |
-
max_tokens=1,
|
607 |
-
temperature=0.0
|
608 |
-
)
|
609 |
-
|
610 |
-
|
611 |
-
def candidates(
|
612 |
-
self,
|
613 |
-
prompt: str,
|
614 |
-
k: int
|
615 |
-
) -> list[tuple[str, np.floating]]:
|
616 |
-
"""
|
617 |
-
Given prompt `str` and k `int`, return a sorted list of the
|
618 |
-
top k candidates for most likely next token, along with their
|
619 |
-
normalized probabilities
|
620 |
-
"""
|
621 |
-
|
622 |
-
assert isinstance(prompt, str), \
|
623 |
-
f"next_candidates: prompt should be str, not {type(prompt)}"
|
624 |
-
assert isinstance(k, int), \
|
625 |
-
f"next_candidates: k should be int, not {type(k)}"
|
626 |
-
assert 0 < k <= len(self.tokens), \
|
627 |
-
f"next_candidates: k should be between 0 and {len(self.tokens)}"
|
628 |
-
|
629 |
-
assert_model_is_loaded(self)
|
630 |
-
prompt_tokens = self.llama.tokenize(prompt.encode('utf-8', errors='ignore'))
|
631 |
-
self.llama.reset() # reset model state
|
632 |
-
self.llama.eval(prompt_tokens)
|
633 |
-
scores = self.llama.scores[len(prompt_tokens) - 1]
|
634 |
-
|
635 |
-
# len(self.llama.scores) == self.context_length
|
636 |
-
# len(self.llama.scores[i]) == len(self.tokens)
|
637 |
-
|
638 |
-
# normalize scores with softmax
|
639 |
-
# must normalize over all tokens in vocab, not just top k
|
640 |
-
if self.verbose:
|
641 |
-
print_verbose(f'calculating softmax over {len(scores)} values')
|
642 |
-
normalized_scores: list[np.floating] = list(softmax(scores))
|
643 |
-
|
644 |
-
# construct the final list
|
645 |
-
i = 0
|
646 |
-
token_probs_list: list[tuple[str, np.floating]] = []
|
647 |
-
for tok_str in self.tokens:
|
648 |
-
token_probs_list.append((tok_str, normalized_scores[i]))
|
649 |
-
i += 1
|
650 |
-
|
651 |
-
# return token_probs_list, sorted by probability, only top k
|
652 |
-
return nlargest(k, token_probs_list, key=lambda x:x[1])
|
653 |
-
|
654 |
-
|
655 |
-
def print_candidates(
|
656 |
-
self,
|
657 |
-
prompt: str,
|
658 |
-
k: int,
|
659 |
-
file: _SupportsWriteAndFlush = sys.stdout,
|
660 |
-
flush: bool = False
|
661 |
-
) -> None:
|
662 |
-
"""
|
663 |
-
Like `Model.candidates()`, but print the values instead
|
664 |
-
of returning them
|
665 |
-
"""
|
666 |
-
|
667 |
-
for _tuple in self.candidates(prompt, k):
|
668 |
-
print(
|
669 |
-
f"token {repr(_tuple[0])} has probability {_tuple[1]}",
|
670 |
-
file=file,
|
671 |
-
flush=flush
|
672 |
-
)
|
673 |
-
|
674 |
-
# if flush is False, then so far file is not flushed, but it should
|
675 |
-
# always be flushed at the end of printing
|
676 |
-
if not flush:
|
677 |
-
file.flush()
|
678 |
-
|
679 |
-
|
680 |
-
def assert_model_is_loaded(model: Model) -> None:
|
681 |
-
"""
|
682 |
-
Ensure the Model is fully constructed, such that
|
683 |
-
`Model.llama._model.model is not None` is guaranteed to be `True`
|
684 |
-
|
685 |
-
Raise ModelUnloadedException otherwise
|
686 |
-
"""
|
687 |
-
if not hasattr(model, 'llama'):
|
688 |
-
raise ModelUnloadedException(
|
689 |
-
"webscout.Local.Model instance has no attribute 'llama'"
|
690 |
-
)
|
691 |
-
if not hasattr(model.llama, '_model'):
|
692 |
-
raise ModelUnloadedException(
|
693 |
-
"llama_cpp.Llama instance has no attribute '_model'"
|
694 |
-
)
|
695 |
-
if not hasattr(model.llama._model, 'model'):
|
696 |
-
raise ModelUnloadedException(
|
697 |
-
"llama_cpp._internals._LlamaModel instance has no attribute 'model'"
|
698 |
-
)
|
699 |
-
if model.llama._model.model is None:
|
700 |
-
raise ModelUnloadedException(
|
701 |
-
"llama_cpp._internals._LlamaModel.model is None"
|
702 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
webscout/Local/samplers.py
DELETED
@@ -1,161 +0,0 @@
|
|
1 |
-
|
2 |
-
from ._version import __version__, __llama_cpp_version__
|
3 |
-
|
4 |
-
"""Submodule containing SamplerSettings class and some preset samplers"""
|
5 |
-
|
6 |
-
from sys import maxsize
|
7 |
-
|
8 |
-
|
9 |
-
MAX_TEMP = float(maxsize)
|
10 |
-
|
11 |
-
class SamplerSettings:
|
12 |
-
"""
|
13 |
-
A SamplerSettings object specifies the sampling parameters that will be
|
14 |
-
used to control text generation
|
15 |
-
"""
|
16 |
-
|
17 |
-
ParamTypes: dict[str, type] = {
|
18 |
-
'max_len_tokens': int,
|
19 |
-
'temp': float,
|
20 |
-
'top_p': float,
|
21 |
-
'min_p': float,
|
22 |
-
'frequency_penalty': float,
|
23 |
-
'presence_penalty': float,
|
24 |
-
'repeat_penalty': float,
|
25 |
-
'top_k': int
|
26 |
-
}
|
27 |
-
|
28 |
-
def __init__(
|
29 |
-
self,
|
30 |
-
max_len_tokens: int = -1,
|
31 |
-
temp: float = 0.8,
|
32 |
-
top_p: float = 0.95,
|
33 |
-
min_p: float = 0.05,
|
34 |
-
frequency_penalty: float = 0.0,
|
35 |
-
presence_penalty: float = 0.0,
|
36 |
-
repeat_penalty: float = 1.0,
|
37 |
-
top_k: int = 40
|
38 |
-
):
|
39 |
-
"""
|
40 |
-
Construct a new SamplerSettings instance
|
41 |
-
"""
|
42 |
-
|
43 |
-
self.max_len_tokens = max_len_tokens
|
44 |
-
self.temp = temp
|
45 |
-
self.top_p = top_p
|
46 |
-
self.min_p = min_p
|
47 |
-
self.frequency_penalty = frequency_penalty
|
48 |
-
self.presence_penalty = presence_penalty
|
49 |
-
self.repeat_penalty = repeat_penalty
|
50 |
-
self.top_k = top_k
|
51 |
-
|
52 |
-
for sampler_param in SamplerSettings.ParamTypes:
|
53 |
-
expected_type = SamplerSettings.ParamTypes[sampler_param]
|
54 |
-
actual_type = type(getattr(self, sampler_param))
|
55 |
-
if actual_type != expected_type:
|
56 |
-
raise TypeError(
|
57 |
-
f"wrong type for SamplerSettings parameter '{sampler_param}'"
|
58 |
-
f" - expected {expected_type}, got {actual_type}"
|
59 |
-
)
|
60 |
-
|
61 |
-
def __repr__(self) -> str:
|
62 |
-
repr_str = 'SamplerSettings('
|
63 |
-
repr_str += f'max_len_tokens={self.max_len_tokens}, '
|
64 |
-
repr_str += f'temp={self.temp}, '
|
65 |
-
repr_str += f'top_p={self.top_p}, '
|
66 |
-
repr_str += f'min_p={self.min_p}, '
|
67 |
-
repr_str += f'frequency_penalty={self.frequency_penalty}, '
|
68 |
-
repr_str += f'presence_penalty={self.presence_penalty}, '
|
69 |
-
repr_str += f'repeat_penalty={self.repeat_penalty}, '
|
70 |
-
repr_str += f'top_k={self.top_k})'
|
71 |
-
return repr_str
|
72 |
-
|
73 |
-
# most likely token is always chosen
|
74 |
-
GreedyDecoding = SamplerSettings(
|
75 |
-
temp = 0.0,
|
76 |
-
)
|
77 |
-
|
78 |
-
# reflects llama.cpp
|
79 |
-
DefaultSampling = SamplerSettings()
|
80 |
-
|
81 |
-
# unmodified probability distribution (i.e. what the model actually thinks)
|
82 |
-
SimpleSampling = SamplerSettings(
|
83 |
-
temp = 1.0,
|
84 |
-
top_p = 1.0,
|
85 |
-
min_p = 0.0,
|
86 |
-
top_k = -1
|
87 |
-
)
|
88 |
-
|
89 |
-
# reflects old llama.cpp defaults
|
90 |
-
ClassicSampling = SamplerSettings(
|
91 |
-
min_p=0.0,
|
92 |
-
repeat_penalty = 1.1
|
93 |
-
)
|
94 |
-
|
95 |
-
# halfway between DefaultSampling and SimpleSampling
|
96 |
-
SemiSampling = SamplerSettings(
|
97 |
-
temp=0.9,
|
98 |
-
top_p=0.975,
|
99 |
-
min_p=0.025,
|
100 |
-
top_k=80
|
101 |
-
)
|
102 |
-
|
103 |
-
# for models with large vocabulary, which tend to run hot
|
104 |
-
TikTokenSampling = SamplerSettings(
|
105 |
-
temp=0.6,
|
106 |
-
repeat_penalty=1.1
|
107 |
-
)
|
108 |
-
|
109 |
-
# use min_p as the only active sampler (more permissive)
|
110 |
-
LowMinPSampling = SamplerSettings(
|
111 |
-
temp = 1.0,
|
112 |
-
top_p = 1.0,
|
113 |
-
min_p = 0.05,
|
114 |
-
top_k = -1
|
115 |
-
)
|
116 |
-
|
117 |
-
# use min_p as the only active sampler (moderate)
|
118 |
-
MinPSampling = SamplerSettings(
|
119 |
-
temp = 1.0,
|
120 |
-
top_p = 1.0,
|
121 |
-
min_p = 0.1,
|
122 |
-
top_k = -1
|
123 |
-
)
|
124 |
-
|
125 |
-
# use min_p as the only active sampler (more restrictive)
|
126 |
-
StrictMinPSampling = SamplerSettings(
|
127 |
-
temp = 1.0,
|
128 |
-
top_p = 1.0,
|
129 |
-
min_p = 0.2,
|
130 |
-
top_k = -1
|
131 |
-
)
|
132 |
-
|
133 |
-
# https://arxiv.org/abs/2210.14140
|
134 |
-
ContrastiveSearch = SamplerSettings(
|
135 |
-
temp = 0.0,
|
136 |
-
presence_penalty = 0.4
|
137 |
-
)
|
138 |
-
|
139 |
-
# https://arxiv.org/abs/2210.14140
|
140 |
-
WarmContrastiveSearch = SamplerSettings(
|
141 |
-
temp = 0.0,
|
142 |
-
presence_penalty = 0.8
|
143 |
-
)
|
144 |
-
|
145 |
-
# outputs completely random tokens from vocab (useless)
|
146 |
-
RandomSampling = SamplerSettings(
|
147 |
-
temp = MAX_TEMP,
|
148 |
-
top_p = 1.0,
|
149 |
-
min_p = 0.0,
|
150 |
-
top_k = -1
|
151 |
-
)
|
152 |
-
|
153 |
-
# default sampling with reduced temperature
|
154 |
-
LowTempSampling = SamplerSettings(
|
155 |
-
temp = 0.4
|
156 |
-
)
|
157 |
-
|
158 |
-
# default sampling with increased temperature
|
159 |
-
HighTempSampling = SamplerSettings(
|
160 |
-
temp = 1.2
|
161 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
webscout/Local/thread.py
DELETED
@@ -1,690 +0,0 @@
|
|
1 |
-
from ._version import __version__, __llama_cpp_version__
|
2 |
-
|
3 |
-
"""Submodule containing the Thread class, used for interaction with a Model"""
|
4 |
-
|
5 |
-
import sys
|
6 |
-
|
7 |
-
from .model import Model, assert_model_is_loaded, _SupportsWriteAndFlush
|
8 |
-
from .utils import RESET_ALL, cls, print_verbose, truncate
|
9 |
-
from .samplers import SamplerSettings, DefaultSampling
|
10 |
-
from typing import Optional, Literal, Union
|
11 |
-
from .formats import AdvancedFormat
|
12 |
-
|
13 |
-
from .formats import blank as formats_blank
|
14 |
-
|
15 |
-
|
16 |
-
class Message(dict):
|
17 |
-
"""
|
18 |
-
A dictionary representing a single message within a Thread
|
19 |
-
|
20 |
-
Works just like a normal `dict`, but a new method:
|
21 |
-
- `.as_string` - Return the full message string
|
22 |
-
|
23 |
-
Generally, messages have these keys:
|
24 |
-
- `role` - The role of the speaker: 'system', 'user', or 'bot'
|
25 |
-
- `prefix` - The text that prefixes the message content
|
26 |
-
- `content` - The actual content of the message
|
27 |
-
- `suffix` - The text that suffixes the message content
|
28 |
-
"""
|
29 |
-
|
30 |
-
def __repr__(self) -> str:
|
31 |
-
return \
|
32 |
-
f"Message([" \
|
33 |
-
f"('role', {repr(self['role'])}), " \
|
34 |
-
f"('prefix', {repr(self['prefix'])}), " \
|
35 |
-
f"('content', {repr(self['content'])}), " \
|
36 |
-
f"('suffix', {repr(self['suffix'])})])"
|
37 |
-
|
38 |
-
def as_string(self):
|
39 |
-
"""Return the full message string"""
|
40 |
-
try:
|
41 |
-
return self['prefix'] + self['content'] + self['suffix']
|
42 |
-
except KeyError as e:
|
43 |
-
e.add_note(
|
44 |
-
"as_string: Message is missing one or more of the "
|
45 |
-
"required 'prefix', 'content', 'suffix' attributes - this is "
|
46 |
-
"unexpected"
|
47 |
-
)
|
48 |
-
raise e
|
49 |
-
|
50 |
-
|
51 |
-
class Thread:
|
52 |
-
"""
|
53 |
-
Provide functionality to facilitate easy interactions with a Model
|
54 |
-
|
55 |
-
This is just a brief overview of m.Thread.
|
56 |
-
To see a full description of each method and its parameters,
|
57 |
-
call help(Thread), or see the relevant docstring.
|
58 |
-
|
59 |
-
The following methods are available:
|
60 |
-
- `.add_message()` - Add a message to `Thread.messages`
|
61 |
-
- `.as_string()` - Return this thread's complete message history as a string
|
62 |
-
- `.create_message()` - Create a message using the format of this thread
|
63 |
-
- `.inference_str_from_messages()` - Using the list of messages, return a string suitable for inference
|
64 |
-
- `.interact()` - Start an interactive, terminal-based chat session
|
65 |
-
- `.len_messages()` - Get the total length of all messages in tokens
|
66 |
-
- `.print_stats()` - Print stats about the context usage in this thread
|
67 |
-
- `.reset()` - Clear the list of messages
|
68 |
-
- `.send()` - Send a message in this thread
|
69 |
-
|
70 |
-
The following attributes are available:
|
71 |
-
- `.format` - The format being used for messages in this thread
|
72 |
-
- `.messages` - The list of messages in this thread
|
73 |
-
- `.model` - The `m.Model` instance used by this thread
|
74 |
-
- `.sampler` - The SamplerSettings object used in this thread
|
75 |
-
"""
|
76 |
-
|
77 |
-
def __init__(
|
78 |
-
self,
|
79 |
-
model: Model,
|
80 |
-
format: Union[dict, AdvancedFormat],
|
81 |
-
sampler: SamplerSettings = DefaultSampling,
|
82 |
-
messages: Optional[list[Message]] = None,
|
83 |
-
):
|
84 |
-
"""
|
85 |
-
Given a Model and a format, construct a Thread instance.
|
86 |
-
|
87 |
-
model: The Model to use for text generation
|
88 |
-
format: The format specifying how messages should be structured (see m.formats)
|
89 |
-
|
90 |
-
The following parameters are optional:
|
91 |
-
- sampler: The SamplerSettings object used to control text generation
|
92 |
-
- messages: A list of m.thread.Message objects to add to the Thread upon construction
|
93 |
-
"""
|
94 |
-
|
95 |
-
assert isinstance(model, Model), \
|
96 |
-
"Thread: model should be an " + \
|
97 |
-
f"instance of webscout.Local.Model, not {type(model)}"
|
98 |
-
|
99 |
-
assert_model_is_loaded(model)
|
100 |
-
|
101 |
-
assert isinstance(format, (dict, AdvancedFormat)), \
|
102 |
-
f"Thread: format should be dict or AdvancedFormat, not {type(format)}"
|
103 |
-
|
104 |
-
if any(k not in format.keys() for k in formats_blank.keys()):
|
105 |
-
raise KeyError(
|
106 |
-
"Thread: format is missing one or more required keys, see " + \
|
107 |
-
"webscout.Local.formats.blank for an example"
|
108 |
-
)
|
109 |
-
|
110 |
-
assert isinstance(format['stops'], list), \
|
111 |
-
"Thread: format['stops'] should be list, not " + \
|
112 |
-
f"{type(format['stops'])}"
|
113 |
-
|
114 |
-
assert all(
|
115 |
-
hasattr(sampler, attr) for attr in [
|
116 |
-
'max_len_tokens',
|
117 |
-
'temp',
|
118 |
-
'top_p',
|
119 |
-
'min_p',
|
120 |
-
'frequency_penalty',
|
121 |
-
'presence_penalty',
|
122 |
-
'repeat_penalty',
|
123 |
-
'top_k'
|
124 |
-
]
|
125 |
-
), 'Thread: sampler is missing one or more required attributes'
|
126 |
-
|
127 |
-
self._messages: Optional[list[Message]] = messages
|
128 |
-
if self._messages is not None:
|
129 |
-
if not all(isinstance(msg, Message) for msg in self._messages):
|
130 |
-
raise TypeError(
|
131 |
-
"Thread: one or more messages provided to __init__() is "
|
132 |
-
"not an instance of m.thread.Message"
|
133 |
-
)
|
134 |
-
|
135 |
-
# Thread.messages is never empty, unless `messages` param is explicity
|
136 |
-
# set to `[]` during construction
|
137 |
-
|
138 |
-
self.model: Model = model
|
139 |
-
self.format: Union[dict, AdvancedFormat] = format
|
140 |
-
self.messages: list[Message] = [
|
141 |
-
self.create_message("system", self.format['system_content'])
|
142 |
-
] if self._messages is None else self._messages
|
143 |
-
self.sampler: SamplerSettings = sampler
|
144 |
-
|
145 |
-
if self.model.verbose:
|
146 |
-
print_verbose("new Thread instance with the following attributes:")
|
147 |
-
print_verbose(f"model == {self.model}")
|
148 |
-
print_verbose(f"format['system_prefix'] == {truncate(repr(self.format['system_prefix']))}")
|
149 |
-
print_verbose(f"format['system_content'] == {truncate(repr(self.format['system_content']))}")
|
150 |
-
print_verbose(f"format['system_suffix'] == {truncate(repr(self.format['system_suffix']))}")
|
151 |
-
print_verbose(f"format['user_prefix'] == {truncate(repr(self.format['user_prefix']))}")
|
152 |
-
print_verbose(f"format['user_content'] == {truncate(repr(self.format['user_content']))}")
|
153 |
-
print_verbose(f"format['user_suffix'] == {truncate(repr(self.format['user_suffix']))}")
|
154 |
-
print_verbose(f"format['bot_prefix'] == {truncate(repr(self.format['bot_prefix']))}")
|
155 |
-
print_verbose(f"format['bot_content'] == {truncate(repr(self.format['bot_content']))}")
|
156 |
-
print_verbose(f"format['bot_suffix'] == {truncate(repr(self.format['bot_suffix']))}")
|
157 |
-
print_verbose(f"format['stops'] == {truncate(repr(self.format['stops']))}")
|
158 |
-
print_verbose(f"sampler.temp == {self.sampler.temp}")
|
159 |
-
print_verbose(f"sampler.top_p == {self.sampler.top_p}")
|
160 |
-
print_verbose(f"sampler.min_p == {self.sampler.min_p}")
|
161 |
-
print_verbose(f"sampler.frequency_penalty == {self.sampler.frequency_penalty}")
|
162 |
-
print_verbose(f"sampler.presence_penalty == {self.sampler.presence_penalty}")
|
163 |
-
print_verbose(f"sampler.repeat_penalty == {self.sampler.repeat_penalty}")
|
164 |
-
print_verbose(f"sampler.top_k == {self.sampler.top_k}")
|
165 |
-
|
166 |
-
|
167 |
-
def __repr__(self) -> str:
|
168 |
-
return \
|
169 |
-
f"Thread({repr(self.model)}, {repr(self.format)}, " + \
|
170 |
-
f"{repr(self.sampler)}, {repr(self.messages)})"
|
171 |
-
|
172 |
-
def __str__(self) -> str:
|
173 |
-
return self.as_string()
|
174 |
-
|
175 |
-
def __len__(self) -> int:
|
176 |
-
"""
|
177 |
-
`len(Thread)` returns the length of the Thread in tokens
|
178 |
-
|
179 |
-
To get the number of messages in the Thread, use `len(Thread.messages)`
|
180 |
-
"""
|
181 |
-
return self.len_messages()
|
182 |
-
|
183 |
-
def create_message(
|
184 |
-
self,
|
185 |
-
role: Literal['system', 'user', 'bot'],
|
186 |
-
content: str
|
187 |
-
) -> Message:
|
188 |
-
"""
|
189 |
-
Construct a message using the format of this Thread
|
190 |
-
"""
|
191 |
-
|
192 |
-
assert role.lower() in ['system', 'user', 'bot'], \
|
193 |
-
f"create_message: role should be 'system', 'user', or 'bot', not '{role.lower()}'"
|
194 |
-
|
195 |
-
assert isinstance(content, str), \
|
196 |
-
f"create_message: content should be str, not {type(content)}"
|
197 |
-
|
198 |
-
if role.lower() == 'system':
|
199 |
-
return Message(
|
200 |
-
[
|
201 |
-
('role', 'system'),
|
202 |
-
('prefix', self.format['system_prefix']),
|
203 |
-
('content', content),
|
204 |
-
('suffix', self.format['system_suffix'])
|
205 |
-
]
|
206 |
-
)
|
207 |
-
|
208 |
-
elif role.lower() == 'user':
|
209 |
-
return Message(
|
210 |
-
[
|
211 |
-
('role', 'user'),
|
212 |
-
('prefix', self.format['user_prefix']),
|
213 |
-
('content', content),
|
214 |
-
('suffix', self.format['user_suffix'])
|
215 |
-
]
|
216 |
-
)
|
217 |
-
|
218 |
-
elif role.lower() == 'bot':
|
219 |
-
return Message(
|
220 |
-
[
|
221 |
-
('role', 'bot'),
|
222 |
-
('prefix', self.format['bot_prefix']),
|
223 |
-
('content', content),
|
224 |
-
('suffix', self.format['bot_suffix'])
|
225 |
-
]
|
226 |
-
)
|
227 |
-
|
228 |
-
def len_messages(self) -> int:
|
229 |
-
"""
|
230 |
-
Return the total length of all messages in this thread, in tokens.
|
231 |
-
|
232 |
-
Can also use `len(Thread)`."""
|
233 |
-
|
234 |
-
return self.model.get_length(self.as_string())
|
235 |
-
|
236 |
-
def add_message(
|
237 |
-
self,
|
238 |
-
role: Literal['system', 'user', 'bot'],
|
239 |
-
content: str
|
240 |
-
) -> None:
|
241 |
-
"""
|
242 |
-
Create a message and append it to `Thread.messages`.
|
243 |
-
|
244 |
-
`Thread.add_message(...)` is a shorthand for
|
245 |
-
`Thread.messages.append(Thread.create_message(...))`
|
246 |
-
"""
|
247 |
-
self.messages.append(
|
248 |
-
self.create_message(
|
249 |
-
role=role,
|
250 |
-
content=content
|
251 |
-
)
|
252 |
-
)
|
253 |
-
|
254 |
-
def inference_str_from_messages(self) -> str:
|
255 |
-
"""
|
256 |
-
Using the list of messages, construct a string suitable for inference,
|
257 |
-
respecting the format and context length of this thread.
|
258 |
-
"""
|
259 |
-
|
260 |
-
inf_str = ''
|
261 |
-
sys_msg_str = ''
|
262 |
-
# whether to treat the first message as necessary to keep
|
263 |
-
sys_msg_flag = False
|
264 |
-
context_len_budget = self.model.context_length
|
265 |
-
|
266 |
-
# if at least 1 message is history
|
267 |
-
if len(self.messages) >= 1:
|
268 |
-
# if first message has system role
|
269 |
-
if self.messages[0]['role'] == 'system':
|
270 |
-
sys_msg_flag = True
|
271 |
-
sys_msg = self.messages[0]
|
272 |
-
sys_msg_str = sys_msg.as_string()
|
273 |
-
context_len_budget -= self.model.get_length(sys_msg_str)
|
274 |
-
|
275 |
-
if sys_msg_flag:
|
276 |
-
iterator = reversed(self.messages[1:])
|
277 |
-
else:
|
278 |
-
iterator = reversed(self.messages)
|
279 |
-
|
280 |
-
for message in iterator:
|
281 |
-
msg_str = message.as_string()
|
282 |
-
context_len_budget -= self.model.get_length(msg_str)
|
283 |
-
if context_len_budget <= 0:
|
284 |
-
break
|
285 |
-
inf_str = msg_str + inf_str
|
286 |
-
|
287 |
-
if sys_msg_flag:
|
288 |
-
inf_str = sys_msg_str + inf_str
|
289 |
-
inf_str += self.format['bot_prefix']
|
290 |
-
|
291 |
-
return inf_str
|
292 |
-
|
293 |
-
|
294 |
-
def send(self, prompt: str) -> str:
|
295 |
-
"""
|
296 |
-
Send a message in this thread. This adds your message and the bot's
|
297 |
-
response to the list of messages.
|
298 |
-
|
299 |
-
Returns a string containing the response to your message.
|
300 |
-
"""
|
301 |
-
|
302 |
-
self.add_message("user", prompt)
|
303 |
-
output = self.model.generate(
|
304 |
-
self.inference_str_from_messages(),
|
305 |
-
stops=self.format['stops'],
|
306 |
-
sampler=self.sampler
|
307 |
-
)
|
308 |
-
self.add_message("bot", output)
|
309 |
-
|
310 |
-
return output
|
311 |
-
|
312 |
-
|
313 |
-
def _interactive_update_sampler(self) -> None:
|
314 |
-
"""Interactively update the sampler settings used in this Thread"""
|
315 |
-
print()
|
316 |
-
try:
|
317 |
-
new_max_len_tokens = input(f'max_len_tokens: {self.sampler.max_len_tokens} -> ')
|
318 |
-
new_temp = input(f'temp: {self.sampler.temp} -> ')
|
319 |
-
new_top_p = input(f'top_p: {self.sampler.top_p} -> ')
|
320 |
-
new_min_p = input(f'min_p: {self.sampler.min_p} -> ')
|
321 |
-
new_frequency_penalty = input(f'frequency_penalty: {self.sampler.frequency_penalty} -> ')
|
322 |
-
new_presence_penalty = input(f'presence_penalty: {self.sampler.presence_penalty} -> ')
|
323 |
-
new_repeat_penalty = input(f'repeat_penalty: {self.sampler.repeat_penalty} -> ')
|
324 |
-
new_top_k = input(f'top_k: {self.sampler.top_k} -> ')
|
325 |
-
|
326 |
-
except KeyboardInterrupt:
|
327 |
-
print('\nwebscout.Local: sampler settings not updated\n')
|
328 |
-
return
|
329 |
-
print()
|
330 |
-
|
331 |
-
try:
|
332 |
-
self.sampler.max_len_tokens = int(new_max_len_tokens)
|
333 |
-
except ValueError:
|
334 |
-
pass
|
335 |
-
else:
|
336 |
-
print('webscout.Local: max_len_tokens updated')
|
337 |
-
|
338 |
-
try:
|
339 |
-
self.sampler.temp = float(new_temp)
|
340 |
-
except ValueError:
|
341 |
-
pass
|
342 |
-
else:
|
343 |
-
print('webscout.Local: temp updated')
|
344 |
-
|
345 |
-
try:
|
346 |
-
self.sampler.top_p = float(new_top_p)
|
347 |
-
except ValueError:
|
348 |
-
pass
|
349 |
-
else:
|
350 |
-
print('webscout.Local: top_p updated')
|
351 |
-
|
352 |
-
try:
|
353 |
-
self.sampler.min_p = float(new_min_p)
|
354 |
-
except ValueError:
|
355 |
-
pass
|
356 |
-
else:
|
357 |
-
print('webscout.Local: min_p updated')
|
358 |
-
|
359 |
-
try:
|
360 |
-
self.sampler.frequency_penalty = float(new_frequency_penalty)
|
361 |
-
except ValueError:
|
362 |
-
pass
|
363 |
-
else:
|
364 |
-
print('webscout.Local: frequency_penalty updated')
|
365 |
-
|
366 |
-
try:
|
367 |
-
self.sampler.presence_penalty = float(new_presence_penalty)
|
368 |
-
except ValueError:
|
369 |
-
pass
|
370 |
-
else:
|
371 |
-
print('webscout.Local: presence_penalty updated')
|
372 |
-
|
373 |
-
try:
|
374 |
-
self.sampler.repeat_penalty = float(new_repeat_penalty)
|
375 |
-
except ValueError:
|
376 |
-
pass
|
377 |
-
else:
|
378 |
-
print('webscout.Local: repeat_penalty updated')
|
379 |
-
|
380 |
-
try:
|
381 |
-
self.sampler.top_k = int(new_top_k)
|
382 |
-
except ValueError:
|
383 |
-
pass
|
384 |
-
else:
|
385 |
-
print('webscout.Local: top_k updated')
|
386 |
-
print()
|
387 |
-
|
388 |
-
|
389 |
-
def _interactive_input(
|
390 |
-
self,
|
391 |
-
prompt: str,
|
392 |
-
_dim_style: str,
|
393 |
-
_user_style: str,
|
394 |
-
_bot_style: str,
|
395 |
-
_special_style: str
|
396 |
-
) -> tuple:
|
397 |
-
"""
|
398 |
-
Recive input from the user, while handling multi-line input
|
399 |
-
and commands
|
400 |
-
"""
|
401 |
-
full_user_input = '' # may become multiline
|
402 |
-
|
403 |
-
while True:
|
404 |
-
user_input = input(prompt)
|
405 |
-
|
406 |
-
if user_input.endswith('\\'):
|
407 |
-
full_user_input += user_input[:-1] + '\n'
|
408 |
-
|
409 |
-
elif user_input == '!':
|
410 |
-
|
411 |
-
print()
|
412 |
-
try:
|
413 |
-
command = input(f'{RESET_ALL} ! {_dim_style}')
|
414 |
-
except KeyboardInterrupt:
|
415 |
-
print('\n')
|
416 |
-
continue
|
417 |
-
|
418 |
-
if command == '':
|
419 |
-
print(f'\n[no command]\n')
|
420 |
-
|
421 |
-
elif command.lower() in ['reset', 'restart']:
|
422 |
-
self.reset()
|
423 |
-
print(f'\n[thread reset]\n')
|
424 |
-
|
425 |
-
elif command.lower() in ['cls', 'clear']:
|
426 |
-
cls()
|
427 |
-
print()
|
428 |
-
|
429 |
-
elif command.lower() in ['ctx', 'context']:
|
430 |
-
print(f"\n{self.len_messages()}\n")
|
431 |
-
|
432 |
-
elif command.lower() in ['stats', 'print_stats']:
|
433 |
-
print()
|
434 |
-
self.print_stats()
|
435 |
-
print()
|
436 |
-
|
437 |
-
elif command.lower() in ['sampler', 'samplers', 'settings']:
|
438 |
-
self._interactive_update_sampler()
|
439 |
-
|
440 |
-
elif command.lower() in ['str', 'string', 'as_string']:
|
441 |
-
print(f"\n{self.as_string()}\n")
|
442 |
-
|
443 |
-
elif command.lower() in ['repr', 'save', 'backup']:
|
444 |
-
print(f"\n{repr(self)}\n")
|
445 |
-
|
446 |
-
elif command.lower() in ['remove', 'rem', 'delete', 'del']:
|
447 |
-
print()
|
448 |
-
old_len = len(self.messages)
|
449 |
-
del self.messages[-1]
|
450 |
-
assert len(self.messages) == (old_len - 1)
|
451 |
-
print('[removed last message]\n')
|
452 |
-
|
453 |
-
elif command.lower() in ['last', 'repeat']:
|
454 |
-
last_msg = self.messages[-1]
|
455 |
-
if last_msg['role'] == 'user':
|
456 |
-
print(f"\n{_user_style}{last_msg['content']}{RESET_ALL}\n")
|
457 |
-
elif last_msg['role'] == 'bot':
|
458 |
-
print(f"\n{_bot_style}{last_msg['content']}{RESET_ALL}\n")
|
459 |
-
|
460 |
-
elif command.lower() in ['inf', 'inference', 'inf_str']:
|
461 |
-
print(f'\n"""{self.inference_str_from_messages()}"""\n')
|
462 |
-
|
463 |
-
elif command.lower() in ['reroll', 're-roll', 're', 'swipe']:
|
464 |
-
old_len = len(self.messages)
|
465 |
-
del self.messages[-1]
|
466 |
-
assert len(self.messages) == (old_len - 1)
|
467 |
-
return '', None
|
468 |
-
|
469 |
-
elif command.lower() in ['exit', 'quit']:
|
470 |
-
print(RESET_ALL)
|
471 |
-
return None, None
|
472 |
-
|
473 |
-
elif command.lower() in ['help', '/?', '?']:
|
474 |
-
print()
|
475 |
-
print('reset | restart -- Reset the thread to its original state')
|
476 |
-
print('clear | cls -- Clear the terminal')
|
477 |
-
print('context | ctx -- Get the context usage in tokens')
|
478 |
-
print('print_stats | stats -- Get the context usage stats')
|
479 |
-
print('sampler | settings -- Update the sampler settings')
|
480 |
-
print('string | str -- Print the message history as a string')
|
481 |
-
print('repr | save -- Print the representation of the thread')
|
482 |
-
print('remove | delete -- Remove the last message')
|
483 |
-
print('last | repeat -- Repeat the last message')
|
484 |
-
print('inference | inf -- Print the inference string')
|
485 |
-
print('reroll | swipe -- Regenerate the last message')
|
486 |
-
print('exit | quit -- Exit the interactive chat (can also use ^C)')
|
487 |
-
print('help | ? -- Show this screen')
|
488 |
-
print()
|
489 |
-
print("TIP: type < at the prompt and press ENTER to prefix the bot's next message.")
|
490 |
-
print(' for example, type "Sure!" to bypass refusals')
|
491 |
-
print()
|
492 |
-
print("TIP: type !! at the prompt and press ENTER to insert a system message")
|
493 |
-
print()
|
494 |
-
|
495 |
-
else:
|
496 |
-
print(f'\n[unknown command]\n')
|
497 |
-
|
498 |
-
# prefix the bot's next message
|
499 |
-
elif user_input == '<':
|
500 |
-
|
501 |
-
print()
|
502 |
-
try:
|
503 |
-
next_message_start = input(f'{RESET_ALL} < {_dim_style}')
|
504 |
-
|
505 |
-
except KeyboardInterrupt:
|
506 |
-
print(f'{RESET_ALL}\n')
|
507 |
-
continue
|
508 |
-
|
509 |
-
else:
|
510 |
-
print()
|
511 |
-
return '', next_message_start
|
512 |
-
|
513 |
-
# insert a system message
|
514 |
-
elif user_input == '!!':
|
515 |
-
print()
|
516 |
-
|
517 |
-
try:
|
518 |
-
next_sys_msg = input(f'{RESET_ALL} !! {_special_style}')
|
519 |
-
|
520 |
-
except KeyboardInterrupt:
|
521 |
-
print(f'{RESET_ALL}\n')
|
522 |
-
continue
|
523 |
-
|
524 |
-
else:
|
525 |
-
print()
|
526 |
-
return next_sys_msg, -1
|
527 |
-
|
528 |
-
# concatenate multi-line input
|
529 |
-
else:
|
530 |
-
full_user_input += user_input
|
531 |
-
return full_user_input, None
|
532 |
-
|
533 |
-
|
534 |
-
def interact(
|
535 |
-
self,
|
536 |
-
color: bool = True,
|
537 |
-
header: Optional[str] = None,
|
538 |
-
stream: bool = True
|
539 |
-
) -> None:
|
540 |
-
"""
|
541 |
-
Start an interactive chat session using this Thread.
|
542 |
-
|
543 |
-
While text is being generated, press `^C` to interrupt the bot.
|
544 |
-
Then you have the option to press `ENTER` to re-roll, or to simply type
|
545 |
-
another message.
|
546 |
-
|
547 |
-
At the prompt, press `^C` to end the chat session.
|
548 |
-
|
549 |
-
Type `!` and press `ENTER` to enter a basic command prompt. For a list
|
550 |
-
of commands, type `help` at this prompt.
|
551 |
-
|
552 |
-
Type `<` and press `ENTER` to prefix the bot's next message, for
|
553 |
-
example with `Sure!`.
|
554 |
-
|
555 |
-
Type `!!` at the prompt and press `ENTER` to insert a system message.
|
556 |
-
|
557 |
-
The following parameters are optional:
|
558 |
-
- color: Whether to use colored text to differentiate user / bot
|
559 |
-
- header: Header text to print at the start of the interaction
|
560 |
-
- stream: Whether to stream text as it is generated
|
561 |
-
"""
|
562 |
-
print()
|
563 |
-
|
564 |
-
# fresh import of color codes in case `color` param has changed
|
565 |
-
from .utils import SPECIAL_STYLE, USER_STYLE, BOT_STYLE, DIM_STYLE
|
566 |
-
|
567 |
-
# disable color codes if explicitly disabled by `color` param
|
568 |
-
if not color:
|
569 |
-
SPECIAL_STYLE = ''
|
570 |
-
USER_STYLE = ''
|
571 |
-
BOT_STYLE = ''
|
572 |
-
DIM_STYLE = ''
|
573 |
-
|
574 |
-
if header is not None:
|
575 |
-
print(f"{SPECIAL_STYLE}{header}{RESET_ALL}\n")
|
576 |
-
|
577 |
-
while True:
|
578 |
-
|
579 |
-
prompt = f"{RESET_ALL} > {USER_STYLE}"
|
580 |
-
|
581 |
-
try:
|
582 |
-
user_prompt, next_message_start = self._interactive_input(
|
583 |
-
prompt,
|
584 |
-
DIM_STYLE,
|
585 |
-
USER_STYLE,
|
586 |
-
BOT_STYLE,
|
587 |
-
SPECIAL_STYLE
|
588 |
-
)
|
589 |
-
except KeyboardInterrupt:
|
590 |
-
print(f"{RESET_ALL}\n")
|
591 |
-
return
|
592 |
-
|
593 |
-
# got 'exit' or 'quit' command
|
594 |
-
if user_prompt is None and next_message_start is None:
|
595 |
-
break
|
596 |
-
|
597 |
-
# insert a system message via `!!` prompt
|
598 |
-
if next_message_start == -1:
|
599 |
-
self.add_message('system', user_prompt)
|
600 |
-
continue
|
601 |
-
|
602 |
-
if next_message_start is not None:
|
603 |
-
try:
|
604 |
-
if stream:
|
605 |
-
print(f"{BOT_STYLE}{next_message_start}", end='', flush=True)
|
606 |
-
output = next_message_start + self.model.stream_print(
|
607 |
-
self.inference_str_from_messages() + next_message_start,
|
608 |
-
stops=self.format['stops'],
|
609 |
-
sampler=self.sampler,
|
610 |
-
end=''
|
611 |
-
)
|
612 |
-
else:
|
613 |
-
print(f"{BOT_STYLE}", end='', flush=True)
|
614 |
-
output = next_message_start + self.model.generate(
|
615 |
-
self.inference_str_from_messages() + next_message_start,
|
616 |
-
stops=self.format['stops'],
|
617 |
-
sampler=self.sampler
|
618 |
-
)
|
619 |
-
print(output, end='', flush=True)
|
620 |
-
except KeyboardInterrupt:
|
621 |
-
print(f"{DIM_STYLE} [message not added to history; press ENTER to re-roll]\n")
|
622 |
-
continue
|
623 |
-
else:
|
624 |
-
self.add_message("bot", output)
|
625 |
-
else:
|
626 |
-
print(BOT_STYLE)
|
627 |
-
if user_prompt != "":
|
628 |
-
self.add_message("user", user_prompt)
|
629 |
-
try:
|
630 |
-
if stream:
|
631 |
-
output = self.model.stream_print(
|
632 |
-
self.inference_str_from_messages(),
|
633 |
-
stops=self.format['stops'],
|
634 |
-
sampler=self.sampler,
|
635 |
-
end=''
|
636 |
-
)
|
637 |
-
else:
|
638 |
-
output = self.model.generate(
|
639 |
-
self.inference_str_from_messages(),
|
640 |
-
stops=self.format['stops'],
|
641 |
-
sampler=self.sampler
|
642 |
-
)
|
643 |
-
print(output, end='', flush=True)
|
644 |
-
except KeyboardInterrupt:
|
645 |
-
print(f"{DIM_STYLE} [message not added to history; press ENTER to re-roll]\n")
|
646 |
-
continue
|
647 |
-
else:
|
648 |
-
self.add_message("bot", output)
|
649 |
-
|
650 |
-
if output.endswith("\n\n"):
|
651 |
-
print(RESET_ALL, end = '', flush=True)
|
652 |
-
elif output.endswith("\n"):
|
653 |
-
print(RESET_ALL)
|
654 |
-
else:
|
655 |
-
print(f"{RESET_ALL}\n")
|
656 |
-
|
657 |
-
|
658 |
-
def reset(self) -> None:
|
659 |
-
"""
|
660 |
-
Clear the list of messages, which resets the thread to its original
|
661 |
-
state
|
662 |
-
"""
|
663 |
-
self.messages: list[Message] = [
|
664 |
-
self.create_message("system", self.format['system_content'])
|
665 |
-
] if self._messages is None else self._messages
|
666 |
-
|
667 |
-
|
668 |
-
def as_string(self) -> str:
|
669 |
-
"""Return this thread's message history as a string"""
|
670 |
-
thread_string = ''
|
671 |
-
for msg in self.messages:
|
672 |
-
thread_string += msg.as_string()
|
673 |
-
return thread_string
|
674 |
-
|
675 |
-
|
676 |
-
def print_stats(
|
677 |
-
self,
|
678 |
-
end: str = '\n',
|
679 |
-
file: _SupportsWriteAndFlush = sys.stdout,
|
680 |
-
flush: bool = True
|
681 |
-
) -> None:
|
682 |
-
"""Print stats about the context usage in this thread"""
|
683 |
-
thread_len_tokens = self.len_messages()
|
684 |
-
max_ctx_len = self.model.context_length
|
685 |
-
context_used_percentage = round((thread_len_tokens/max_ctx_len)*100)
|
686 |
-
print(f"{thread_len_tokens} / {max_ctx_len} tokens", file=file, flush=flush)
|
687 |
-
print(f"{context_used_percentage}% of context used", file=file, flush=flush)
|
688 |
-
print(f"{len(self.messages)} messages", end=end, file=file, flush=flush)
|
689 |
-
if not flush:
|
690 |
-
file.flush()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
webscout/Local/utils.py
DELETED
@@ -1,185 +0,0 @@
|
|
1 |
-
from ._version import __version__, __llama_cpp_version__
|
2 |
-
|
3 |
-
import sys
|
4 |
-
import numpy as np
|
5 |
-
|
6 |
-
from typing import Any, Iterable, TextIO
|
7 |
-
from time import strftime
|
8 |
-
from enum import IntEnum
|
9 |
-
from struct import unpack
|
10 |
-
from colorama import Fore
|
11 |
-
from huggingface_hub import hf_hub_url, cached_download
|
12 |
-
|
13 |
-
# color codes used in Thread.interact()
|
14 |
-
RESET_ALL = Fore.RESET
|
15 |
-
USER_STYLE = RESET_ALL + Fore.GREEN
|
16 |
-
BOT_STYLE = RESET_ALL + Fore.CYAN
|
17 |
-
DIM_STYLE = RESET_ALL + Fore.LIGHTBLACK_EX
|
18 |
-
SPECIAL_STYLE = RESET_ALL + Fore.YELLOW
|
19 |
-
|
20 |
-
# for typing of softmax parameter `z`
|
21 |
-
class _ArrayLike(Iterable):
|
22 |
-
pass
|
23 |
-
|
24 |
-
# for typing of Model.stream_print() parameter `file`
|
25 |
-
class _SupportsWriteAndFlush(TextIO):
|
26 |
-
pass
|
27 |
-
|
28 |
-
def download_model(repo_id: str, filename: str, cache_dir: str = ".cache") -> str:
|
29 |
-
"""
|
30 |
-
Downloads a GGUF model file from Hugging Face Hub.
|
31 |
-
|
32 |
-
repo_id: The Hugging Face repository ID (e.g., 'facebook/bart-large-cnn').
|
33 |
-
filename: The name of the GGUF file within the repository (e.g., 'model.gguf').
|
34 |
-
cache_dir: The directory where the downloaded file should be stored.
|
35 |
-
|
36 |
-
Returns: The path to the downloaded file.
|
37 |
-
"""
|
38 |
-
url = hf_hub_url(repo_id, filename)
|
39 |
-
filepath = cached_download(url, cache_dir=cache_dir, force_filename=filename)
|
40 |
-
return filepath
|
41 |
-
|
42 |
-
class GGUFReader:
|
43 |
-
"""
|
44 |
-
Peek at file header for GGUF metadata
|
45 |
-
|
46 |
-
Raise ValueError if file is not GGUF or is outdated
|
47 |
-
|
48 |
-
Credit to oobabooga for the parts of the code in this class
|
49 |
-
|
50 |
-
Format spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
|
51 |
-
"""
|
52 |
-
|
53 |
-
class GGUFValueType(IntEnum):
|
54 |
-
UINT8 = 0
|
55 |
-
INT8 = 1
|
56 |
-
UINT16 = 2
|
57 |
-
INT16 = 3
|
58 |
-
UINT32 = 4
|
59 |
-
INT32 = 5
|
60 |
-
FLOAT32 = 6
|
61 |
-
BOOL = 7
|
62 |
-
STRING = 8
|
63 |
-
ARRAY = 9
|
64 |
-
UINT64 = 10
|
65 |
-
INT64 = 11
|
66 |
-
FLOAT64 = 12
|
67 |
-
|
68 |
-
_simple_value_packing = {
|
69 |
-
GGUFValueType.UINT8: "<B",
|
70 |
-
GGUFValueType.INT8: "<b",
|
71 |
-
GGUFValueType.UINT16: "<H",
|
72 |
-
GGUFValueType.INT16: "<h",
|
73 |
-
GGUFValueType.UINT32: "<I",
|
74 |
-
GGUFValueType.INT32: "<i",
|
75 |
-
GGUFValueType.FLOAT32: "<f",
|
76 |
-
GGUFValueType.UINT64: "<Q",
|
77 |
-
GGUFValueType.INT64: "<q",
|
78 |
-
GGUFValueType.FLOAT64: "<d",
|
79 |
-
GGUFValueType.BOOL: "?",
|
80 |
-
}
|
81 |
-
|
82 |
-
value_type_info = {
|
83 |
-
GGUFValueType.UINT8: 1,
|
84 |
-
GGUFValueType.INT8: 1,
|
85 |
-
GGUFValueType.UINT16: 2,
|
86 |
-
GGUFValueType.INT16: 2,
|
87 |
-
GGUFValueType.UINT32: 4,
|
88 |
-
GGUFValueType.INT32: 4,
|
89 |
-
GGUFValueType.FLOAT32: 4,
|
90 |
-
GGUFValueType.UINT64: 8,
|
91 |
-
GGUFValueType.INT64: 8,
|
92 |
-
GGUFValueType.FLOAT64: 8,
|
93 |
-
GGUFValueType.BOOL: 1,
|
94 |
-
}
|
95 |
-
|
96 |
-
def get_single(self, value_type, file) -> Any:
|
97 |
-
if value_type == GGUFReader.GGUFValueType.STRING:
|
98 |
-
value_length = unpack("<Q", file.read(8))[0]
|
99 |
-
value = file.read(value_length)
|
100 |
-
value = value.decode("utf-8")
|
101 |
-
else:
|
102 |
-
type_str = GGUFReader._simple_value_packing.get(value_type)
|
103 |
-
bytes_length = GGUFReader.value_type_info.get(value_type)
|
104 |
-
value = unpack(type_str, file.read(bytes_length))[0]
|
105 |
-
|
106 |
-
return value
|
107 |
-
|
108 |
-
def load_metadata(self, fname) -> dict:
|
109 |
-
metadata = {}
|
110 |
-
with open(fname, "rb") as file:
|
111 |
-
GGUF_MAGIC = file.read(4)
|
112 |
-
|
113 |
-
if GGUF_MAGIC != b"GGUF":
|
114 |
-
raise ValueError(
|
115 |
-
"your model file is not a valid GGUF file "
|
116 |
-
f"(magic number mismatch, got {GGUF_MAGIC}, "
|
117 |
-
"expected b'GGUF')"
|
118 |
-
)
|
119 |
-
|
120 |
-
GGUF_VERSION = unpack("<I", file.read(4))[0]
|
121 |
-
|
122 |
-
if GGUF_VERSION == 1:
|
123 |
-
raise ValueError(
|
124 |
-
"your model file reports GGUF version 1, "
|
125 |
-
"but only versions 2 and above are supported. "
|
126 |
-
"re-convert your model or download a newer version"
|
127 |
-
)
|
128 |
-
|
129 |
-
# ti_data_count = struct.unpack("<Q", file.read(8))[0]
|
130 |
-
file.read(8)
|
131 |
-
kv_data_count = unpack("<Q", file.read(8))[0]
|
132 |
-
|
133 |
-
for _ in range(kv_data_count):
|
134 |
-
key_length = unpack("<Q", file.read(8))[0]
|
135 |
-
key = file.read(key_length)
|
136 |
-
|
137 |
-
value_type = GGUFReader.GGUFValueType(
|
138 |
-
unpack("<I", file.read(4))[0]
|
139 |
-
)
|
140 |
-
if value_type == GGUFReader.GGUFValueType.ARRAY:
|
141 |
-
ltype = GGUFReader.GGUFValueType(
|
142 |
-
unpack("<I", file.read(4))[0]
|
143 |
-
)
|
144 |
-
length = unpack("<Q", file.read(8))[0]
|
145 |
-
arr = [
|
146 |
-
GGUFReader.get_single(
|
147 |
-
self,
|
148 |
-
ltype,
|
149 |
-
file
|
150 |
-
) for _ in range(length)
|
151 |
-
]
|
152 |
-
metadata[key.decode()] = arr
|
153 |
-
else:
|
154 |
-
value = GGUFReader.get_single(self, value_type, file)
|
155 |
-
metadata[key.decode()] = value
|
156 |
-
|
157 |
-
return metadata
|
158 |
-
|
159 |
-
def softmax(z: _ArrayLike) -> np.ndarray:
|
160 |
-
"""
|
161 |
-
Compute softmax over values in z, where z is array-like
|
162 |
-
"""
|
163 |
-
e_z = np.exp(z - np.max(z))
|
164 |
-
return e_z / e_z.sum()
|
165 |
-
|
166 |
-
def cls() -> None:
|
167 |
-
"""Clear the terminal"""
|
168 |
-
print("\033c\033[3J", end='', flush=True)
|
169 |
-
|
170 |
-
# no longer used in this module, but left for others to use
|
171 |
-
def get_timestamp_prefix_str() -> str:
|
172 |
-
# helpful: https://strftime.net
|
173 |
-
return strftime("[%Y, %b %e, %a %l:%M %p] ")
|
174 |
-
|
175 |
-
def truncate(text: str) -> str:
|
176 |
-
return text if len(text) < 63 else f"{text[:60]}..."
|
177 |
-
|
178 |
-
def print_verbose(text: str) -> None:
|
179 |
-
print("webscout.Local: verbose:", text, file=sys.stderr, flush=True)
|
180 |
-
|
181 |
-
def print_info(text: str) -> None:
|
182 |
-
print("webscout.Local: info:", text, file=sys.stderr, flush=True)
|
183 |
-
|
184 |
-
def print_warning(text: str) -> None:
|
185 |
-
print("webscout.Local: warning:", text, file=sys.stderr, flush=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|