Spaces:
Configuration error
Configuration error
#credit to shadowcz007 for this module | |
#from https://github.com/shadowcz007/comfyui-mixlab-nodes/blob/main/nodes/TextGenerateNode.py | |
import re | |
import os | |
import folder_paths | |
import comfy.utils | |
import torch | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
from .utils import install_package | |
try: | |
from lark import Lark, Transformer, v_args | |
except: | |
print('install lark-parser...') | |
install_package('lark-parser') | |
from lark import Lark, Transformer, v_args | |
model_path = os.path.join(folder_paths.models_dir, 'prompt_generator') | |
zh_en_model_path = os.path.join(model_path, 'opus-mt-zh-en') | |
zh_en_model, zh_en_tokenizer = None, None | |
def correct_prompt_syntax(prompt=""): | |
# print("input prompt",prompt) | |
corrected_elements = [] | |
# 处理成统一的英文标点 | |
prompt = prompt.replace('(', '(').replace(')', ')').replace(',', ',').replace(';', ',').replace('。', '.').replace(':',':').replace('\\',',') | |
# 删除多余的空格 | |
prompt = re.sub(r'\s+', ' ', prompt).strip() | |
prompt = prompt.replace("< ","<").replace(" >",">").replace("( ","(").replace(" )",")").replace("[ ","[").replace(' ]',']') | |
# 分词 | |
prompt_elements = prompt.split(',') | |
def balance_brackets(element, open_bracket, close_bracket): | |
open_brackets_count = element.count(open_bracket) | |
close_brackets_count = element.count(close_bracket) | |
return element + close_bracket * (open_brackets_count - close_brackets_count) | |
for element in prompt_elements: | |
element = element.strip() | |
# 处理空元素 | |
if not element: | |
continue | |
# 检查并处理圆括号、方括号、尖括号 | |
if element[0] in '([': | |
corrected_element = balance_brackets(element, '(', ')') if element[0] == '(' else balance_brackets(element, '[', ']') | |
elif element[0] == '<': | |
corrected_element = balance_brackets(element, '<', '>') | |
else: | |
# 删除开头的右括号或右方括号 | |
corrected_element = element.lstrip(')]') | |
corrected_elements.append(corrected_element) | |
# 重组修正后的prompt | |
return ','.join(corrected_elements) | |
def detect_language(input_str): | |
# 统计中文和英文字符的数量 | |
count_cn = count_en = 0 | |
for char in input_str: | |
if '\u4e00' <= char <= '\u9fff': | |
count_cn += 1 | |
elif char.isalpha(): | |
count_en += 1 | |
# 根据统计的字符数量判断主要语言 | |
if count_cn > count_en: | |
return "cn" | |
elif count_en > count_cn: | |
return "en" | |
else: | |
return "unknow" | |
def has_chinese(text): | |
has_cn = False | |
_text = text | |
_text = re.sub(r'<.*?>', '', _text) | |
_text = re.sub(r'__.*?__', '', _text) | |
_text = re.sub(r'embedding:.*?$', '', _text) | |
for char in _text: | |
if '\u4e00' <= char <= '\u9fff': | |
has_cn = True | |
break | |
elif char.isalpha(): | |
continue | |
return has_cn | |
def translate(text): | |
global zh_en_model_path, zh_en_model, zh_en_tokenizer | |
if not os.path.exists(zh_en_model_path): | |
zh_en_model_path = 'Helsinki-NLP/opus-mt-zh-en' | |
if zh_en_model is None: | |
zh_en_model = AutoModelForSeq2SeqLM.from_pretrained(zh_en_model_path).eval() | |
zh_en_tokenizer = AutoTokenizer.from_pretrained(zh_en_model_path, padding=True, truncation=True) | |
zh_en_model.to("cuda" if torch.cuda.is_available() else "cpu") | |
with torch.no_grad(): | |
encoded = zh_en_tokenizer([text], return_tensors="pt") | |
encoded.to(zh_en_model.device) | |
sequences = zh_en_model.generate(**encoded) | |
return zh_en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] | |
# Decorator to flatten the tree directly into the function arguments | |
class ChinesePromptTranslate(Transformer): | |
def sentence(self, *args): | |
return ", ".join(args) | |
def phrase(self, *args): | |
return "".join(args) | |
def emphasis(self, *args): | |
# Reconstruct the emphasis with translated content | |
return "(" + "".join(args) + ")" | |
def weak_emphasis(self, *args): | |
print('weak_emphasis:', args) | |
return "[" + "".join(args) + "]" | |
def embedding(self, *args): | |
print('prompt embedding', args[0]) | |
if len(args) == 1: | |
embedding_name = str(args[0]) | |
return f"embedding:{embedding_name}" | |
elif len(args) > 1: | |
embedding_name, *numbers = args | |
if len(numbers) == 2: | |
return f"embedding:{embedding_name}:{numbers[0]}:{numbers[1]}" | |
elif len(numbers) == 1: | |
return f"embedding:{embedding_name}:{numbers[0]}" | |
else: | |
return f"embedding:{embedding_name}" | |
def lora(self, *args): | |
if len(args) == 1: | |
return f"<lora:{args[0]}>" | |
elif len(args) > 1: | |
# print('lora', args) | |
_, loar_name, *numbers = args | |
loar_name = str(loar_name).strip() | |
if len(numbers) == 2: | |
return f"<lora:{loar_name}:{numbers[0]}:{numbers[1]}>" | |
elif len(numbers) == 1: | |
return f"<lora:{loar_name}:{numbers[0]}>" | |
else: | |
return f"<lora:{loar_name}>" | |
def weight(self, word, number): | |
translated_word = translate(str(word)).rstrip('.') | |
return f"({translated_word}:{str(number).strip()})" | |
def schedule(self, *args): | |
print('prompt schedule', args) | |
data = [str(arg).strip() for arg in args] | |
return f"[{':'.join(data)}]" | |
def word(self, word): | |
# Translate each word using the dictionary | |
word = str(word) | |
match_cn = re.search(r'@.*?@', word) | |
if re.search(r'__.*?__', word): | |
return word.rstrip('.') | |
elif match_cn: | |
chinese = match_cn.group() | |
before = word.split('@', 1) | |
before = before[0] if len(before) > 0 else '' | |
before = translate(str(before)).rstrip('.') if before else '' | |
after = word.rsplit('@', 1) | |
after = after[len(after)-1] if len(after) > 1 else '' | |
after = translate(after).rstrip('.') if after else '' | |
return before + chinese.replace('@', '').rstrip('.') + after | |
elif detect_language(word) == "cn": | |
return translate(word).rstrip('.') | |
else: | |
return word.rstrip('.') | |
#定义Prompt文法 | |
grammar = """ | |
start: sentence | |
sentence: phrase ("," phrase)* | |
phrase: emphasis | weight | word | lora | embedding | schedule | |
emphasis: "(" sentence ")" -> emphasis | |
| "[" sentence "]" -> weak_emphasis | |
weight: "(" word ":" NUMBER ")" | |
schedule: "[" word ":" word ":" NUMBER "]" | |
lora: "<" WORD ":" WORD (":" NUMBER)? (":" NUMBER)? ">" | |
embedding: "embedding" ":" WORD (":" NUMBER)? (":" NUMBER)? | |
word: WORD | |
NUMBER: /\s*-?\d+(\.\d+)?\s*/ | |
WORD: /[^,:\(\)\[\]<>]+/ | |
""" | |
def zh_to_en(text): | |
global zh_en_model_path, zh_en_model, zh_en_tokenizer | |
# 进度条 | |
pbar = comfy.utils.ProgressBar(len(text) + 1) | |
texts = [correct_prompt_syntax(t) for t in text] | |
install_package('sentencepiece', '0.2.0') | |
if not os.path.exists(zh_en_model_path): | |
zh_en_model_path = 'Helsinki-NLP/opus-mt-zh-en' | |
if zh_en_model is None: | |
zh_en_model = AutoModelForSeq2SeqLM.from_pretrained(zh_en_model_path).eval() | |
zh_en_tokenizer = AutoTokenizer.from_pretrained(zh_en_model_path, padding=True, truncation=True) | |
zh_en_model.to("cuda" if torch.cuda.is_available() else "cpu") | |
prompt_result = [] | |
en_texts = [] | |
for t in texts: | |
if t: | |
# translated_text = translated_word = translate(zh_en_tokenizer,zh_en_model,str(t)) | |
parser = Lark(grammar, start="start", parser="lalr", transformer=ChinesePromptTranslate()) | |
# print('t',t) | |
result = parser.parse(t).children | |
# print('en_result',result) | |
# en_text=translate(zh_en_tokenizer,zh_en_model,text_without_syntax) | |
en_texts.append(result[0]) | |
zh_en_model.to('cpu') | |
# print("test en_text", en_texts) | |
# en_text.to("cuda" if torch.cuda.is_available() else "cpu") | |
pbar.update(1) | |
for t in en_texts: | |
prompt_result.append(t) | |
pbar.update(1) | |
# print('prompt_result', prompt_result, ) | |
if len(prompt_result) == 0: | |
prompt_result = [""] | |
return prompt_result |