zhzluke96
update
01e655b
raw
history blame
3.87 kB
import re
import torch
import torch.nn.functional as F
class CustomRepetitionPenaltyLogitsProcessorRepeat:
def __init__(self, penalty: float, max_input_ids, past_window):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(
f"`penalty` has to be a strictly positive float, but is {penalty}"
)
self.penalty = penalty
self.max_input_ids = max_input_ids
self.past_window = past_window
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
input_ids = input_ids[:, -self.past_window :]
freq = F.one_hot(input_ids, scores.size(1)).sum(1)
freq[self.max_input_ids :] = 0
alpha = self.penalty**freq
scores = torch.where(scores < 0, scores * alpha, scores / alpha)
return scores
class CustomRepetitionPenaltyLogitsProcessor:
def __init__(self, penalty: float, max_input_ids, past_window):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(
f"`penalty` has to be a strictly positive float, but is {penalty}"
)
self.penalty = penalty
self.max_input_ids = max_input_ids
self.past_window = past_window
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
input_ids = input_ids[:, -self.past_window :]
score = torch.gather(scores, 1, input_ids)
_score = score.detach().clone()
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
score[input_ids >= self.max_input_ids] = _score[input_ids >= self.max_input_ids]
scores.scatter_(1, input_ids, score)
return scores
def count_invalid_characters(s, reserved_tokens: list = []):
escaped_tokens = [re.escape(token) for token in reserved_tokens]
reserved_pattern = "|".join(escaped_tokens)
s = re.sub(rf"{reserved_pattern}", "", s)
pattern = re.compile(r"[^\u4e00-\u9fffA-Za-z,。、,\. ]")
non_alphabetic_chinese_chars = pattern.findall(s)
return set(non_alphabetic_chinese_chars)
def detect_language(sentence):
chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]")
english_word_pattern = re.compile(r"\b[A-Za-z]+\b")
chinese_chars = chinese_char_pattern.findall(sentence)
english_words = english_word_pattern.findall(sentence)
if len(chinese_chars) > len(english_words):
return "zh"
else:
return "en"
character_map = {
":": ",",
";": ",",
"!": "。",
"(": ",",
")": ",",
"【": ",",
"】": ",",
"『": ",",
"』": ",",
"「": ",",
"」": ",",
"《": ",",
"》": ",",
"-": ",",
"‘": "",
"“": "",
"’": "",
"”": "",
":": ",",
";": ",",
"!": ".",
"(": ",",
")": ",",
# '[': ',',
# ']': ',',
">": ",",
"<": ",",
"-": ",",
}
halfwidth_2_fullwidth_map = {
"!": "!",
'"': "“",
"'": "‘",
"#": "#",
"$": "$",
"%": "%",
"&": "&",
"(": "(",
")": ")",
",": ",",
"-": "-",
"*": "*",
"+": "+",
".": "。",
"/": "/",
":": ":",
";": ";",
"<": "<",
"=": "=",
">": ">",
"?": "?",
"@": "@",
# '[': '[',
"\\": "\",
# ']': ']',
"^": "^",
# '_': '_',
"`": "`",
"{": "{",
"|": "|",
"}": "}",
"~": "~",
}
def apply_half2full_map(text):
translation_table = str.maketrans(halfwidth_2_fullwidth_map)
return text.translate(translation_table)
def apply_character_map(text):
translation_table = str.maketrans(character_map)
return text.translate(translation_table)