|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import importlib | 
					
						
						|  | import json | 
					
						
						|  | import traceback | 
					
						
						|  | from abc import ABC | 
					
						
						|  | from copy import deepcopy | 
					
						
						|  | from functools import partial | 
					
						
						|  |  | 
					
						
						|  | import pandas as pd | 
					
						
						|  |  | 
					
						
						|  | from agent.component import component_class | 
					
						
						|  | from agent.component.base import ComponentBase | 
					
						
						|  | from agent.settings import flow_logger, DEBUG | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Canvas(ABC): | 
					
						
						|  | """ | 
					
						
						|  | dsl = { | 
					
						
						|  | "components": { | 
					
						
						|  | "begin": { | 
					
						
						|  | "obj":{ | 
					
						
						|  | "component_name": "Begin", | 
					
						
						|  | "params": {}, | 
					
						
						|  | }, | 
					
						
						|  | "downstream": ["answer_0"], | 
					
						
						|  | "upstream": [], | 
					
						
						|  | }, | 
					
						
						|  | "answer_0": { | 
					
						
						|  | "obj": { | 
					
						
						|  | "component_name": "Answer", | 
					
						
						|  | "params": {} | 
					
						
						|  | }, | 
					
						
						|  | "downstream": ["retrieval_0"], | 
					
						
						|  | "upstream": ["begin", "generate_0"], | 
					
						
						|  | }, | 
					
						
						|  | "retrieval_0": { | 
					
						
						|  | "obj": { | 
					
						
						|  | "component_name": "Retrieval", | 
					
						
						|  | "params": {} | 
					
						
						|  | }, | 
					
						
						|  | "downstream": ["generate_0"], | 
					
						
						|  | "upstream": ["answer_0"], | 
					
						
						|  | }, | 
					
						
						|  | "generate_0": { | 
					
						
						|  | "obj": { | 
					
						
						|  | "component_name": "Generate", | 
					
						
						|  | "params": {} | 
					
						
						|  | }, | 
					
						
						|  | "downstream": ["answer_0"], | 
					
						
						|  | "upstream": ["retrieval_0"], | 
					
						
						|  | } | 
					
						
						|  | }, | 
					
						
						|  | "history": [], | 
					
						
						|  | "messages": [], | 
					
						
						|  | "reference": [], | 
					
						
						|  | "path": [["begin"]], | 
					
						
						|  | "answer": [] | 
					
						
						|  | } | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, dsl: str, tenant_id=None): | 
					
						
						|  | self.path = [] | 
					
						
						|  | self.history = [] | 
					
						
						|  | self.messages = [] | 
					
						
						|  | self.answer = [] | 
					
						
						|  | self.components = {} | 
					
						
						|  | self.dsl = json.loads(dsl) if dsl else { | 
					
						
						|  | "components": { | 
					
						
						|  | "begin": { | 
					
						
						|  | "obj": { | 
					
						
						|  | "component_name": "Begin", | 
					
						
						|  | "params": { | 
					
						
						|  | "prologue": "Hi there!" | 
					
						
						|  | } | 
					
						
						|  | }, | 
					
						
						|  | "downstream": [], | 
					
						
						|  | "upstream": [] | 
					
						
						|  | } | 
					
						
						|  | }, | 
					
						
						|  | "history": [], | 
					
						
						|  | "messages": [], | 
					
						
						|  | "reference": [], | 
					
						
						|  | "path": [], | 
					
						
						|  | "answer": [] | 
					
						
						|  | } | 
					
						
						|  | self._tenant_id = tenant_id | 
					
						
						|  | self._embed_id = "" | 
					
						
						|  | self.load() | 
					
						
						|  |  | 
					
						
						|  | def load(self): | 
					
						
						|  | self.components = self.dsl["components"] | 
					
						
						|  | cpn_nms = set([]) | 
					
						
						|  | for k, cpn in self.components.items(): | 
					
						
						|  | cpn_nms.add(cpn["obj"]["component_name"]) | 
					
						
						|  |  | 
					
						
						|  | assert "Begin" in cpn_nms, "There have to be an 'Begin' component." | 
					
						
						|  | assert "Answer" in cpn_nms, "There have to be an 'Answer' component." | 
					
						
						|  |  | 
					
						
						|  | for k, cpn in self.components.items(): | 
					
						
						|  | cpn_nms.add(cpn["obj"]["component_name"]) | 
					
						
						|  | param = component_class(cpn["obj"]["component_name"] + "Param")() | 
					
						
						|  | param.update(cpn["obj"]["params"]) | 
					
						
						|  | param.check() | 
					
						
						|  | cpn["obj"] = component_class(cpn["obj"]["component_name"])(self, k, param) | 
					
						
						|  | if cpn["obj"].component_name == "Categorize": | 
					
						
						|  | for _, desc in param.category_description.items(): | 
					
						
						|  | if desc["to"] not in cpn["downstream"]: | 
					
						
						|  | cpn["downstream"].append(desc["to"]) | 
					
						
						|  |  | 
					
						
						|  | self.path = self.dsl["path"] | 
					
						
						|  | self.history = self.dsl["history"] | 
					
						
						|  | self.messages = self.dsl["messages"] | 
					
						
						|  | self.answer = self.dsl["answer"] | 
					
						
						|  | self.reference = self.dsl["reference"] | 
					
						
						|  | self._embed_id = self.dsl.get("embed_id", "") | 
					
						
						|  |  | 
					
						
						|  | def __str__(self): | 
					
						
						|  | self.dsl["path"] = self.path | 
					
						
						|  | self.dsl["history"] = self.history | 
					
						
						|  | self.dsl["messages"] = self.messages | 
					
						
						|  | self.dsl["answer"] = self.answer | 
					
						
						|  | self.dsl["reference"] = self.reference | 
					
						
						|  | self.dsl["embed_id"] = self._embed_id | 
					
						
						|  | dsl = { | 
					
						
						|  | "components": {} | 
					
						
						|  | } | 
					
						
						|  | for k in self.dsl.keys(): | 
					
						
						|  | if k in ["components"]:continue | 
					
						
						|  | dsl[k] = deepcopy(self.dsl[k]) | 
					
						
						|  |  | 
					
						
						|  | for k, cpn in self.components.items(): | 
					
						
						|  | if k not in dsl["components"]: | 
					
						
						|  | dsl["components"][k] = {} | 
					
						
						|  | for c in cpn.keys(): | 
					
						
						|  | if c == "obj": | 
					
						
						|  | dsl["components"][k][c] = json.loads(str(cpn["obj"])) | 
					
						
						|  | continue | 
					
						
						|  | dsl["components"][k][c] = deepcopy(cpn[c]) | 
					
						
						|  | return json.dumps(dsl, ensure_ascii=False) | 
					
						
						|  |  | 
					
						
						|  | def reset(self): | 
					
						
						|  | self.path = [] | 
					
						
						|  | self.history = [] | 
					
						
						|  | self.messages = [] | 
					
						
						|  | self.answer = [] | 
					
						
						|  | self.reference = [] | 
					
						
						|  | for k, cpn in self.components.items(): | 
					
						
						|  | self.components[k]["obj"].reset() | 
					
						
						|  | self._embed_id = "" | 
					
						
						|  |  | 
					
						
						|  | def run(self, **kwargs): | 
					
						
						|  | ans = "" | 
					
						
						|  | if self.answer: | 
					
						
						|  | cpn_id = self.answer[0] | 
					
						
						|  | self.answer.pop(0) | 
					
						
						|  | try: | 
					
						
						|  | ans = self.components[cpn_id]["obj"].run(self.history, **kwargs) | 
					
						
						|  | except Exception as e: | 
					
						
						|  | ans = ComponentBase.be_output(str(e)) | 
					
						
						|  | self.path[-1].append(cpn_id) | 
					
						
						|  | if kwargs.get("stream"): | 
					
						
						|  | assert isinstance(ans, partial) | 
					
						
						|  | return ans | 
					
						
						|  | self.history.append(("assistant", ans.to_dict("records"))) | 
					
						
						|  | return ans | 
					
						
						|  |  | 
					
						
						|  | if not self.path: | 
					
						
						|  | self.components["begin"]["obj"].run(self.history, **kwargs) | 
					
						
						|  | self.path.append(["begin"]) | 
					
						
						|  |  | 
					
						
						|  | self.path.append([]) | 
					
						
						|  | ran = -1 | 
					
						
						|  |  | 
					
						
						|  | def prepare2run(cpns): | 
					
						
						|  | nonlocal ran, ans | 
					
						
						|  | for c in cpns: | 
					
						
						|  | if self.path[-1] and c == self.path[-1][-1]: continue | 
					
						
						|  | cpn = self.components[c]["obj"] | 
					
						
						|  | if cpn.component_name == "Answer": | 
					
						
						|  | self.answer.append(c) | 
					
						
						|  | else: | 
					
						
						|  | if DEBUG: print("RUN: ", c) | 
					
						
						|  | if cpn.component_name == "Generate": | 
					
						
						|  | cpids = cpn.get_dependent_components() | 
					
						
						|  | if any([c not in self.path[-1] for c in cpids]): | 
					
						
						|  | continue | 
					
						
						|  | ans = cpn.run(self.history, **kwargs) | 
					
						
						|  | self.path[-1].append(c) | 
					
						
						|  | ran += 1 | 
					
						
						|  |  | 
					
						
						|  | prepare2run(self.components[self.path[-2][-1]]["downstream"]) | 
					
						
						|  | while 0 <= ran < len(self.path[-1]): | 
					
						
						|  | if DEBUG: print(ran, self.path) | 
					
						
						|  | cpn_id = self.path[-1][ran] | 
					
						
						|  | cpn = self.get_component(cpn_id) | 
					
						
						|  | if not cpn["downstream"]: break | 
					
						
						|  |  | 
					
						
						|  | loop = self._find_loop() | 
					
						
						|  | if loop: raise OverflowError(f"Too much loops: {loop}") | 
					
						
						|  |  | 
					
						
						|  | if cpn["obj"].component_name.lower() in ["switch", "categorize", "relevant"]: | 
					
						
						|  | switch_out = cpn["obj"].output()[1].iloc[0, 0] | 
					
						
						|  | assert switch_out in self.components, \ | 
					
						
						|  | "{}'s output: {} not valid.".format(cpn_id, switch_out) | 
					
						
						|  | try: | 
					
						
						|  | prepare2run([switch_out]) | 
					
						
						|  | except Exception as e: | 
					
						
						|  | for p in [c for p in self.path for c in p][::-1]: | 
					
						
						|  | if p.lower().find("answer") >= 0: | 
					
						
						|  | self.get_component(p)["obj"].set_exception(e) | 
					
						
						|  | prepare2run([p]) | 
					
						
						|  | break | 
					
						
						|  | traceback.print_exc() | 
					
						
						|  | break | 
					
						
						|  | continue | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | prepare2run(cpn["downstream"]) | 
					
						
						|  | except Exception as e: | 
					
						
						|  | for p in [c for p in self.path for c in p][::-1]: | 
					
						
						|  | if p.lower().find("answer") >= 0: | 
					
						
						|  | self.get_component(p)["obj"].set_exception(e) | 
					
						
						|  | prepare2run([p]) | 
					
						
						|  | break | 
					
						
						|  | traceback.print_exc() | 
					
						
						|  | break | 
					
						
						|  |  | 
					
						
						|  | if self.answer: | 
					
						
						|  | cpn_id = self.answer[0] | 
					
						
						|  | self.answer.pop(0) | 
					
						
						|  | ans = self.components[cpn_id]["obj"].run(self.history, **kwargs) | 
					
						
						|  | self.path[-1].append(cpn_id) | 
					
						
						|  | if kwargs.get("stream"): | 
					
						
						|  | assert isinstance(ans, partial) | 
					
						
						|  | return ans | 
					
						
						|  |  | 
					
						
						|  | self.history.append(("assistant", ans.to_dict("records"))) | 
					
						
						|  |  | 
					
						
						|  | return ans | 
					
						
						|  |  | 
					
						
						|  | def get_component(self, cpn_id): | 
					
						
						|  | return self.components[cpn_id] | 
					
						
						|  |  | 
					
						
						|  | def get_tenant_id(self): | 
					
						
						|  | return self._tenant_id | 
					
						
						|  |  | 
					
						
						|  | def get_history(self, window_size): | 
					
						
						|  | convs = [] | 
					
						
						|  | for role, obj in self.history[window_size * -2:]: | 
					
						
						|  | convs.append({"role": role, "content": (obj if role == "user" else | 
					
						
						|  | '\n'.join(pd.DataFrame(obj)['content']))}) | 
					
						
						|  | return convs | 
					
						
						|  |  | 
					
						
						|  | def add_user_input(self, question): | 
					
						
						|  | self.history.append(("user", question)) | 
					
						
						|  |  | 
					
						
						|  | def set_embedding_model(self, embed_id): | 
					
						
						|  | self._embed_id = embed_id | 
					
						
						|  |  | 
					
						
						|  | def get_embedding_model(self): | 
					
						
						|  | return self._embed_id | 
					
						
						|  |  | 
					
						
						|  | def _find_loop(self, max_loops=2): | 
					
						
						|  | path = self.path[-1][::-1] | 
					
						
						|  | if len(path) < 2: return False | 
					
						
						|  |  | 
					
						
						|  | for i in range(len(path)): | 
					
						
						|  | if path[i].lower().find("answer") >= 0: | 
					
						
						|  | path = path[:i] | 
					
						
						|  | break | 
					
						
						|  |  | 
					
						
						|  | if len(path) < 2: return False | 
					
						
						|  |  | 
					
						
						|  | for l in range(2, len(path) // 2): | 
					
						
						|  | pat = ",".join(path[0:l]) | 
					
						
						|  | path_str = ",".join(path) | 
					
						
						|  | if len(pat) >= len(path_str): return False | 
					
						
						|  | loop = max_loops | 
					
						
						|  | while path_str.find(pat) == 0 and loop >= 0: | 
					
						
						|  | loop -= 1 | 
					
						
						|  | if len(pat)+1 >= len(path_str): | 
					
						
						|  | return False | 
					
						
						|  | path_str = path_str[len(pat)+1:] | 
					
						
						|  | if loop < 0: | 
					
						
						|  | pat = " => ".join([p.split(":")[0] for p in path[0:l]]) | 
					
						
						|  | return pat + " => " + pat | 
					
						
						|  |  | 
					
						
						|  | return False | 
					
						
						|  |  |