#credit to shadowcz007 for this module #from https://github.com/shadowcz007/comfyui-mixlab-nodes/blob/main/nodes/TextGenerateNode.py import re import os import folder_paths import comfy.utils import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from .utils import install_package try: from lark import Lark, Transformer, v_args except: print('install lark-parser...') install_package('lark-parser') from lark import Lark, Transformer, v_args model_path = os.path.join(folder_paths.models_dir, 'prompt_generator') zh_en_model_path = os.path.join(model_path, 'opus-mt-zh-en') zh_en_model, zh_en_tokenizer = None, None def correct_prompt_syntax(prompt=""): # print("input prompt",prompt) corrected_elements = [] # 处理成统一的英文标点 prompt = prompt.replace('(', '(').replace(')', ')').replace(',', ',').replace(';', ',').replace('。', '.').replace(':',':').replace('\\',',') # 删除多余的空格 prompt = re.sub(r'\s+', ' ', prompt).strip() prompt = prompt.replace("< ","<").replace(" >",">").replace("( ","(").replace(" )",")").replace("[ ","[").replace(' ]',']') # 分词 prompt_elements = prompt.split(',') def balance_brackets(element, open_bracket, close_bracket): open_brackets_count = element.count(open_bracket) close_brackets_count = element.count(close_bracket) return element + close_bracket * (open_brackets_count - close_brackets_count) for element in prompt_elements: element = element.strip() # 处理空元素 if not element: continue # 检查并处理圆括号、方括号、尖括号 if element[0] in '([': corrected_element = balance_brackets(element, '(', ')') if element[0] == '(' else balance_brackets(element, '[', ']') elif element[0] == '<': corrected_element = balance_brackets(element, '<', '>') else: # 删除开头的右括号或右方括号 corrected_element = element.lstrip(')]') corrected_elements.append(corrected_element) # 重组修正后的prompt return ','.join(corrected_elements) def detect_language(input_str): # 统计中文和英文字符的数量 count_cn = count_en = 0 for char in input_str: if '\u4e00' <= char <= '\u9fff': count_cn += 1 elif char.isalpha(): count_en += 1 # 根据统计的字符数量判断主要语言 if count_cn > count_en: return "cn" elif count_en > count_cn: return "en" else: return "unknow" def has_chinese(text): has_cn = False _text = text _text = re.sub(r'<.*?>', '', _text) _text = re.sub(r'__.*?__', '', _text) _text = re.sub(r'embedding:.*?$', '', _text) for char in _text: if '\u4e00' <= char <= '\u9fff': has_cn = True break elif char.isalpha(): continue return has_cn def translate(text): global zh_en_model_path, zh_en_model, zh_en_tokenizer if not os.path.exists(zh_en_model_path): zh_en_model_path = 'Helsinki-NLP/opus-mt-zh-en' if zh_en_model is None: zh_en_model = AutoModelForSeq2SeqLM.from_pretrained(zh_en_model_path).eval() zh_en_tokenizer = AutoTokenizer.from_pretrained(zh_en_model_path, padding=True, truncation=True) zh_en_model.to("cuda" if torch.cuda.is_available() else "cpu") with torch.no_grad(): encoded = zh_en_tokenizer([text], return_tensors="pt") encoded.to(zh_en_model.device) sequences = zh_en_model.generate(**encoded) return zh_en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] @v_args(inline=True) # Decorator to flatten the tree directly into the function arguments class ChinesePromptTranslate(Transformer): def sentence(self, *args): return ", ".join(args) def phrase(self, *args): return "".join(args) def emphasis(self, *args): # Reconstruct the emphasis with translated content return "(" + "".join(args) + ")" def weak_emphasis(self, *args): print('weak_emphasis:', args) return "[" + "".join(args) + "]" def embedding(self, *args): print('prompt embedding', args[0]) if len(args) == 1: embedding_name = str(args[0]) return f"embedding:{embedding_name}" elif len(args) > 1: embedding_name, *numbers = args if len(numbers) == 2: return f"embedding:{embedding_name}:{numbers[0]}:{numbers[1]}" elif len(numbers) == 1: return f"embedding:{embedding_name}:{numbers[0]}" else: return f"embedding:{embedding_name}" def lora(self, *args): if len(args) == 1: return f"" elif len(args) > 1: # print('lora', args) _, loar_name, *numbers = args loar_name = str(loar_name).strip() if len(numbers) == 2: return f"" elif len(numbers) == 1: return f"" else: return f"" def weight(self, word, number): translated_word = translate(str(word)).rstrip('.') return f"({translated_word}:{str(number).strip()})" def schedule(self, *args): print('prompt schedule', args) data = [str(arg).strip() for arg in args] return f"[{':'.join(data)}]" def word(self, word): # Translate each word using the dictionary word = str(word) match_cn = re.search(r'@.*?@', word) if re.search(r'__.*?__', word): return word.rstrip('.') elif match_cn: chinese = match_cn.group() before = word.split('@', 1) before = before[0] if len(before) > 0 else '' before = translate(str(before)).rstrip('.') if before else '' after = word.rsplit('@', 1) after = after[len(after)-1] if len(after) > 1 else '' after = translate(after).rstrip('.') if after else '' return before + chinese.replace('@', '').rstrip('.') + after elif detect_language(word) == "cn": return translate(word).rstrip('.') else: return word.rstrip('.') #定义Prompt文法 grammar = """ start: sentence sentence: phrase ("," phrase)* phrase: emphasis | weight | word | lora | embedding | schedule emphasis: "(" sentence ")" -> emphasis | "[" sentence "]" -> weak_emphasis weight: "(" word ":" NUMBER ")" schedule: "[" word ":" word ":" NUMBER "]" lora: "<" WORD ":" WORD (":" NUMBER)? (":" NUMBER)? ">" embedding: "embedding" ":" WORD (":" NUMBER)? (":" NUMBER)? word: WORD NUMBER: /\s*-?\d+(\.\d+)?\s*/ WORD: /[^,:\(\)\[\]<>]+/ """ def zh_to_en(text): global zh_en_model_path, zh_en_model, zh_en_tokenizer # 进度条 pbar = comfy.utils.ProgressBar(len(text) + 1) texts = [correct_prompt_syntax(t) for t in text] install_package('sentencepiece', '0.2.0') if not os.path.exists(zh_en_model_path): zh_en_model_path = 'Helsinki-NLP/opus-mt-zh-en' if zh_en_model is None: zh_en_model = AutoModelForSeq2SeqLM.from_pretrained(zh_en_model_path).eval() zh_en_tokenizer = AutoTokenizer.from_pretrained(zh_en_model_path, padding=True, truncation=True) zh_en_model.to("cuda" if torch.cuda.is_available() else "cpu") prompt_result = [] en_texts = [] for t in texts: if t: # translated_text = translated_word = translate(zh_en_tokenizer,zh_en_model,str(t)) parser = Lark(grammar, start="start", parser="lalr", transformer=ChinesePromptTranslate()) # print('t',t) result = parser.parse(t).children # print('en_result',result) # en_text=translate(zh_en_tokenizer,zh_en_model,text_without_syntax) en_texts.append(result[0]) zh_en_model.to('cpu') # print("test en_text", en_texts) # en_text.to("cuda" if torch.cuda.is_available() else "cpu") pbar.update(1) for t in en_texts: prompt_result.append(t) pbar.update(1) # print('prompt_result', prompt_result, ) if len(prompt_result) == 0: prompt_result = [""] return prompt_result