Spaces:
Sleeping
Sleeping
import os | |
import re | |
import time | |
from typing import Optional | |
import agentscope | |
from agentscope.agents import AgentBase, DialogAgent, UserAgent | |
from agentscope.message import Msg | |
from agentscope.prompt import PromptEngine, PromptType | |
from AIGN_Prompt import * | |
def Retryer(func, max_retries=10): | |
def wrapper(*args, **kwargs): | |
for _ in range(max_retries): | |
try: | |
return func(*args, **kwargs) | |
except Exception as e: | |
print("-" * 30 + f"\n失败:\n{e}\n" + "-" * 30) | |
time.sleep(2.333) | |
raise ValueError("失败") | |
return wrapper | |
class MarkdownAgent(AgentBase): | |
"""专门应对输入输出都是md格式的情况,例如小说生成""" | |
def __init__( | |
self, | |
chatLLM, | |
sys_prompt: str, | |
user_prompt: str, | |
name: str, | |
temperature=0.8, | |
top_p=0.8, | |
use_memory=False, | |
first_replay="明白了。", | |
# first_replay=None, | |
is_speak=True, | |
) -> None: | |
super().__init__(name=name, use_memory=False) | |
self.chatLLM = chatLLM | |
self.sys_prompt = sys_prompt | |
self.user_prompt = user_prompt | |
self.temperature = temperature | |
self.top_p = top_p | |
self.use_memory = use_memory | |
self.is_speak = is_speak | |
self.history = [{"role": "system", "content": self.sys_prompt}, {"role": "user", "content": self.user_prompt}] | |
if first_replay: | |
self.history.append({"role": "assistant", "content": first_replay}) | |
else: | |
resp = chatLLM(messages=self.history) | |
self.history.append({"role": "assistant", "content": resp["content"]}) | |
if self.is_speak: | |
self.speak(Msg(self.name, resp["content"])) | |
def query(self, user_input: str) -> str: | |
resp = self.chatLLM( | |
messages=self.history + [{"role": "user", "content": user_input}], | |
temperature=self.temperature, | |
top_p=self.top_p, | |
) | |
if self.use_memory: | |
self.history.append({"role": "user", "content": user_input}) | |
self.history.append({"role": "assistant", "content": resp["content"]}) | |
return resp | |
def getOutput(self, input_content: str, output_keys: list) -> dict: | |
"""解析类md格式中 # key 的内容""" | |
resp = self.query(input_content) | |
output = resp["content"] | |
lines = output.split("\n") | |
sections = self.parse_sections1(lines, output_keys) | |
# 检查是否所有需要的键都存在 | |
for k in output_keys: | |
if (k not in sections) or (len(sections[k]) == 0): | |
# 单独对k进行重新parse_sections2,此时查找##,并更新sections | |
section_content = self.parse_sections2(lines, k) | |
if section_content: | |
sections[k] = section_content | |
else: | |
raise ValueError(f"fail to parse {k} in output:\n") | |
if self.is_speak: | |
self.speak( | |
Msg( | |
self.name, | |
f"total_tokens: {resp['total_tokens']}\n{resp['content']}\n", | |
) | |
) | |
return sections | |
def parse_sections1(self, lines, output_keys): | |
sections = {key: "" for key in output_keys} | |
current_section = "" | |
for line in lines: | |
if line.startswith("# ") or line.startswith(" # "): | |
# new key | |
current_section = line[2:].strip() | |
sections[current_section] = [] | |
else: | |
# add content to current key | |
if current_section: | |
sections[current_section].append(line.strip()) | |
for key in sections.keys(): | |
sections[key] = "\n".join(sections[key]).strip() | |
return sections | |
def parse_sections2(self, lines, k): | |
content = [] | |
capturing = False | |
for line in lines: | |
stripped_line = line.strip() | |
if stripped_line.startswith(("##", " ##", "###", " ###")) and k.lower() in stripped_line.lower(): | |
capturing = True | |
continue | |
elif stripped_line.startswith(("##", " ##", "###", " ###")) and capturing: | |
break | |
if capturing: | |
content.append(stripped_line) | |
return "\n".join(content).strip() | |
def invoke(self, inputs: dict, output_keys: list) -> dict: | |
input_content = "" | |
for k, v in inputs.items(): | |
if isinstance(v, str) and len(v) > 0: | |
input_content += f"# {k}\n{v}\n\n" | |
#取消重试便于查找错误 | |
result = Retryer(self.getOutput)(input_content, output_keys) | |
#result = self.getOutput(input_content, output_keys) | |
return result | |
def clear_memory(self): | |
if self.use_memory: | |
self.history = self.history[:2] | |
class AIGN: | |
def __init__(self, chatLLM): | |
agentscope.init() | |
self.chatLLM = chatLLM | |
self.novel_outline = "" | |
self.paragraph_list = [] | |
self.novel_content = "" | |
self.writing_plan = "" | |
self.temp_setting = "" | |
self.writing_memory = "" | |
self.no_memory_paragraph = "" | |
self.user_idea = "" | |
self.user_requriments = "" | |
#self.embellishment_idea = "" | |
self.history_states = [] # 用于存储历史状态 | |
self.chapter_list = [] # 用于存储章节列表 | |
self.novel_outline_writer = MarkdownAgent( | |
chatLLM=self.chatLLM, | |
sys_prompt=system_prompt, | |
user_prompt=novel_outline_writer_prompt, | |
name="NovelOutlineWriter", | |
temperature=0.98, | |
) | |
#self.novel_beginning_writer = MarkdownAgent( | |
#chatLLM=self.chatLLM, | |
#sys_prompt=system_prompt + "不少于5000字", | |
#user_prompt=novel_beginning_writer_prompt, | |
#name="NovelBeginningWriter", | |
#temperature=0.80, | |
#) | |
self.novel_writer = MarkdownAgent( | |
chatLLM=self.chatLLM, | |
sys_prompt=system_prompt, | |
user_prompt=novel_writer_prompt, | |
name="NovelWriter", | |
temperature=0.81, | |
) | |
self.memory_maker = MarkdownAgent( | |
chatLLM=self.chatLLM, | |
sys_prompt=system_prompt, | |
user_prompt=memory_maker_prompt, | |
name="MemoryMaker", | |
temperature=0.66, | |
) | |
def split_chapters(self, novel_content): | |
# 使用正则表达式匹配章节标题 | |
chapter_pattern = re.compile(r'(?:##?|)?\s*第([一二三四五六七八九十百千万亿\d]+)章[::]?\s*(.+)') | |
# 将小说正文按章节标题分割 | |
chapters = chapter_pattern.split(novel_content) | |
# 移除第一个空字符串(如果存在) | |
if chapters[0] == '': | |
chapters = chapters[1:] | |
# 将章节标题和内容组合成元组 | |
chapter_tuples = [] | |
for i in range(0, len(chapters), 3): | |
if i + 2 < len(chapters): | |
chapter_num = chapters[i] | |
chapter_title = chapters[i + 1] | |
chapter_content = chapters[i + 2] | |
chapter_tuples.append((f"第{chapter_num}章 {chapter_title}", chapter_content)) | |
return chapter_tuples | |
def update_chapter_list(self): | |
self.chapter_list = self.split_chapters(self.novel_content) | |
def updateNovelContent(self): | |
self.novel_content = "" | |
for paragraph in self.paragraph_list: | |
self.novel_content += f"{paragraph}\n\n" | |
self.update_chapter_list() | |
return self.novel_content | |
def genNovelOutline(self, user_idea=None): | |
if user_idea: | |
self.user_idea = user_idea | |
resp = self.novel_outline_writer.invoke( | |
inputs={"用户想法": self.user_idea}, | |
output_keys=["大纲"], | |
) | |
self.novel_outline = resp["大纲"] | |
return self.novel_outline | |
# 添加分隔符以便于二次编辑 | |
def add_separator(self): | |
separator = "\n---------------------下一段---------------------\n" | |
self.novel_content += separator | |
if self.paragraph_list: | |
self.paragraph_list[-1] += separator | |
def genBeginning(self, user_requriments=None): | |
if user_requriments: | |
self.user_requriments = user_requriments | |
resp = self.novel_beginning_writer.invoke( | |
inputs={ | |
"用户想法": self.user_idea, | |
"小说大纲": self.novel_outline, | |
"用户要求": self.user_requriments, | |
}, | |
output_keys=["开头", "计划", "临时设定"], | |
) | |
beginning = resp["开头"] | |
self.writing_plan = resp["计划"] | |
self.temp_setting = resp["临时设定"] | |
self.paragraph_list.append(beginning) | |
self.updateNovelContent() | |
self.update_chapter_list() | |
#self.add_separator() # 添加分隔符便于二次编辑 | |
return beginning | |
def getLastParagraph(self, max_length=2000): | |
last_paragraph = "" | |
for i in range(0, len(self.paragraph_list)): | |
if (len(last_paragraph) + len(self.paragraph_list[-1 - i])) < max_length: | |
last_paragraph = self.paragraph_list[-1 - i] + "\n" + last_paragraph | |
else: | |
break | |
return last_paragraph | |
def recordNovel(self): | |
record_content = "" | |
record_content += f"# 大纲\n\n{self.novel_outline}\n\n" | |
record_content += f"# 正文\n\n" | |
record_content += self.novel_content | |
record_content += f"# 记忆\n\n{self.writing_memory}\n\n" | |
record_content += f"# 计划\n\n{self.writing_plan}\n\n" | |
record_content += f"# 临时设定\n\n{self.temp_setting}\n\n" | |
with open("novel_record.md", "w", encoding="utf-8") as f: | |
f.write(record_content) | |
def updateMemory(self): | |
if (len(self.no_memory_paragraph)) > 2000: | |
resp = self.memory_maker.invoke( | |
inputs={ | |
"前文记忆": self.writing_memory, | |
"正文内容": self.no_memory_paragraph, | |
}, | |
output_keys=["新的记忆"], | |
) | |
self.writing_memory = resp["新的记忆"] | |
self.no_memory_paragraph = "" | |
def save_state(self): | |
state = { | |
"novel_outline": self.novel_outline, | |
"paragraph_list": self.paragraph_list, | |
"novel_content": self.novel_content, | |
"writing_plan": self.writing_plan, | |
"temp_setting": self.temp_setting, | |
"writing_memory": self.writing_memory | |
} | |
self.history_states.append(state) | |
def undo(self): | |
if self.history_states: | |
previous_state = self.history_states.pop() | |
self.novel_outline = previous_state["novel_outline"] | |
self.paragraph_list = previous_state["paragraph_list"] | |
self.novel_content = previous_state["novel_content"] | |
self.writing_plan = previous_state["writing_plan"] | |
self.temp_setting = previous_state["temp_setting"] | |
self.writing_memory = previous_state["writing_memory"] | |
return True | |
return False | |
def genNextParagraph(self, user_requriments=None): | |
self.save_state() # 保存当前状态 | |
if user_requriments: | |
self.user_requriments = user_requriments | |
resp = self.novel_writer.invoke( | |
inputs={ | |
"用户想法": self.user_idea, | |
"大纲": self.novel_outline, | |
"前文记忆": self.writing_memory, | |
"临时设定": self.temp_setting, | |
"计划": self.writing_plan, | |
"用户要求": self.user_requriments, | |
"上文内容": self.getLastParagraph(), | |
}, | |
output_keys=["段落", "计划", "临时设定"], | |
) | |
next_paragraph = resp["段落"] | |
next_writing_plan = resp["计划"] | |
next_temp_setting = resp["临时设定"] | |
self.paragraph_list.append(next_paragraph) | |
self.writing_plan = next_writing_plan | |
self.temp_setting = next_temp_setting | |
self.no_memory_paragraph += f"\n{next_paragraph}" | |
self.updateMemory() | |
self.updateNovelContent() | |
self.recordNovel() | |
self.update_chapter_list() | |
#self.add_separator() # 添加分隔符便于二次编辑 | |
return next_paragraph | |