Spaces:
Runtime error
Runtime error
| import base64 | |
| import hmac | |
| import json | |
| from datetime import datetime, timezone | |
| from urllib.parse import urlencode, urlparse | |
| from websocket import create_connection, WebSocketConnectionClosedException | |
| from utils.tools import get_prompt, process_response, init_script, create_script | |
| class SparkAPI: | |
| __api_url = 'ws(s)://spark-api.xf-yun.com/v2.1/chat' | |
| __max_token = 4096 | |
| def __init__(self, app_id, api_key, api_secret): | |
| self.__app_id = app_id | |
| self.__api_key = api_key | |
| self.__api_secret = api_secret | |
| def __set_max_tokens(self, token): | |
| if isinstance(token, int) is False or token < 0: | |
| print("set_max_tokens() error: tokens should be a positive integer!") | |
| return | |
| self.__max_token = token | |
| def __get_authorization_url(self): | |
| authorize_url = urlparse(self.__api_url) | |
| # 1. generate data | |
| date = datetime.now(timezone.utc).strftime('%a, %d %b %Y %H:%M:%S %Z') | |
| """ | |
| Generation rule of Authorization parameters | |
| 1) Obtain the APIKey and APISecret parameters from the console. | |
| 2) Use the aforementioned date to dynamically concatenate a string tmp. Here we take Huobi's URL as an example, | |
| the actual usage requires replacing the host and path with the specific request URL. | |
| """ | |
| signature_origin = "host: {}\ndate: {}\nGET {} HTTP/1.1".format( | |
| authorize_url.netloc, date, authorize_url.path | |
| ) | |
| signature = base64.b64encode( | |
| hmac.new( | |
| self.__api_secret.encode(), | |
| signature_origin.encode(), | |
| digestmod='sha256' | |
| ).digest() | |
| ).decode() | |
| authorization_origin = \ | |
| 'api_key="{}",algorithm="{}",headers="{}",signature="{}"'.format( | |
| self.__api_key, "hmac-sha256", "host date request-line", signature | |
| ) | |
| authorization = base64.b64encode( | |
| authorization_origin.encode()).decode() | |
| params = { | |
| "authorization": authorization, | |
| "date": date, | |
| "host": authorize_url.netloc | |
| } | |
| ws_url = self.__api_url + "?" + urlencode(params) | |
| return ws_url | |
| def __build_inputs( | |
| self, | |
| message: dict, | |
| user_id: str = "001", | |
| domain: str = "general", | |
| temperature: float = 0.5, | |
| max_tokens: int = 4096 | |
| ): | |
| input_dict = { | |
| "header": { | |
| "app_id": self.__app_id, | |
| "uid": user_id, | |
| }, | |
| "parameter": { | |
| "chat": { | |
| "domain": domain, | |
| "temperature": temperature, | |
| "max_tokens": max_tokens, | |
| } | |
| }, | |
| "payload": { | |
| "message": message | |
| } | |
| } | |
| return json.dumps(input_dict) | |
| def chat( | |
| self, | |
| query: str, | |
| history: list = None, # store the conversation history | |
| user_id: str = "001", | |
| domain: str = "general", | |
| max_tokens: int = 4096, | |
| temperature: float = 0.5, | |
| ): | |
| if history is None: | |
| history = [] | |
| # the max of max_length is 4096 | |
| max_tokens = min(max_tokens, 4096) | |
| url = self.__get_authorization_url() | |
| ws = create_connection(url) | |
| message = get_prompt(query, history) | |
| input_str = self.__build_inputs( | |
| message=message, | |
| user_id=user_id, | |
| domain=domain, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| ) | |
| ws.send(input_str) | |
| response_str = ws.recv() | |
| try: | |
| while True: | |
| response, history, status = process_response( | |
| response_str, history) | |
| """ | |
| The final return result, which means a complete conversation. | |
| doc url: https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E | |
| """ | |
| if len(response) == 0 or status == 2: | |
| break | |
| response_str = ws.recv() | |
| return response | |
| except WebSocketConnectionClosedException: | |
| print("Connection closed") | |
| finally: | |
| ws.close() | |
| # Stream output statement, used for terminal chat. | |
| def streaming_output( | |
| self, | |
| query: str, | |
| history: list = None, # store the conversation history | |
| user_id: str = "001", | |
| domain: str = "general", | |
| max_tokens: int = 4096, | |
| temperature: float = 0.5, | |
| ): | |
| if history is None: | |
| history = [] | |
| # the max of max_length is 4096 | |
| max_tokens = min(max_tokens, 4096) | |
| url = self.__get_authorization_url() | |
| ws = create_connection(url) | |
| message = get_prompt(query, history) | |
| input_str = self.__build_inputs( | |
| message=message, | |
| user_id=user_id, | |
| domain=domain, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| ) | |
| # print(input_str) | |
| # send question or prompt to url, and receive the answer | |
| ws.send(input_str) | |
| response_str = ws.recv() | |
| # Continuous conversation | |
| try: | |
| while True: | |
| response, history, status = process_response( | |
| response_str, history) | |
| yield response, history | |
| if len(response) == 0 or status == 2: | |
| break | |
| response_str = ws.recv() | |
| except WebSocketConnectionClosedException: | |
| print("Connection closed") | |
| finally: | |
| ws.close() | |
| def chat_stream(self): | |
| history = [] | |
| try: | |
| print("输入init来初始化剧本,输入create来创作剧本,输入exit或stop来终止对话\n") | |
| while True: | |
| query = input("Ask: ") | |
| if query == 'init': | |
| jsonfile = input("请输入剧本文件路径:") | |
| script_data = init_script(history, jsonfile) | |
| print( | |
| f"正在导入剧本{script_data['name']},角色信息:{script_data['characters']},剧情介绍:{script_data['summary']}") | |
| query = f"我希望你能够扮演这个剧本杀游戏的主持人,我希望你能够逐步引导玩家到达最终结局,同时希望你在游戏中设定一些随机事件,需要玩家依靠自身的能力解决,当玩家做出偏离主线的行为或者与剧本无关的行为时,你需要委婉地将玩家引导至正常游玩路线中,对于玩家需要决策的事件,你需要提供一些行动推荐,下面是剧本介绍:{script_data}" | |
| if query == 'create': | |
| name = input('请输入剧本名称:') | |
| characters = input('请输入角色信息:') | |
| summary = input('请输入剧情介绍:') | |
| details = input('请输入剧本细节') | |
| create_script(name, characters, summary, details) | |
| print('剧本创建成功!') | |
| continue | |
| if query == "exit" or query == "stop": | |
| break | |
| for response, _ in self.streaming_output(query, history): | |
| print("\r" + response, end="") | |
| print("\n") | |
| finally: | |
| print("\nThank you for using the SparkDesk AI. Welcome to use it again!") | |
| from langchain.llms.base import LLM | |
| from typing import Any, List, Mapping, Optional | |
| class Spark_forlangchain(LLM): | |
| # 类的成员变量,类型为整型 | |
| n: int | |
| app_id: str | |
| api_key: str | |
| api_secret: str | |
| # 用于指定该子类对象的类型 | |
| def _llm_type(self) -> str: | |
| return "Spark" | |
| # 重写基类方法,根据用户输入的prompt来响应用户,返回字符串 | |
| def _call( | |
| self, | |
| query: str, | |
| history: list = None, # store the conversation history | |
| user_id: str = "001", | |
| domain: str = "general", | |
| max_tokens: int = 4096, | |
| temperature: float = 0.7, | |
| stop: Optional[List[str]] = None, | |
| ) -> str: | |
| if stop is not None: | |
| raise ValueError("stop kwargs are not permitted.") | |
| bot = SparkAPI(app_id=self.app_id, api_key=self.api_key, | |
| api_secret=self.api_secret) | |
| response = bot.chat(query, history, user_id, | |
| domain, max_tokens, temperature) | |
| return response | |
| # 返回一个字典类型,包含LLM的唯一标识 | |
| def _identifying_params(self) -> Mapping[str, Any]: | |
| """Get the identifying parameters.""" | |
| return {"n": self.n} |