Spaces:
Runtime error
Runtime error
import linecache | |
import re | |
from typing import Dict, List, Optional | |
import openai | |
class ChatCompletion: | |
def __init__(self, model: str = 'gpt-3.5-turbo', | |
api_key: Optional[str] = None, api_key_path: str = './openai_api_key'): | |
if api_key is None: | |
openai.api_key = api_key | |
api_key = linecache.getline(api_key_path, 2).strip('\n') | |
if len(api_key) == 0: | |
raise EnvironmentError | |
openai.api_key = api_key | |
self.model = model | |
self.system_messages = [] | |
self.user_messages = [] | |
def chat(self, msg: str, setting: Optional[str] = None, model: Optional[str] = None) -> str: | |
if self._context_length() > 2048: | |
self.reset() | |
if setting is not None: | |
if setting not in self.system_messages: | |
self.system_messages.append(setting) | |
if not self.user_messages or msg != self.user_messages[-1]: | |
self.user_messages.append(msg) | |
return self._run(model) | |
def retry(self, model: Optional[str] = None) -> str: | |
return self._run(model) | |
def reset(self): | |
self.system_messages.clear() | |
self.user_messages.clear() | |
def _make_message(self) -> List[Dict]: | |
sys_messages = [{'role': 'system', 'content': msg} for msg in self.system_messages] | |
user_messages = [{'role': 'user', 'content': msg} for msg in self.user_messages] | |
return sys_messages + user_messages | |
def _context_length(self) -> int: | |
return len(''.join(self.system_messages)) + len(''.join(self.user_messages)) | |
def _run(self, model: Optional[str] = None) -> str: | |
if model is None: | |
model = self.model | |
try: | |
response = openai.ChatCompletion.create(model=model, messages=self._make_message()) | |
ans = response['choices'][0]['message']['content'] | |
ans = re.sub(r'^\n+', '', ans) | |
except openai.error.OpenAIError as e: | |
ans = e | |
except Exception as e: | |
print(e) | |
return ans | |
def __call__(self, msg: str, setting: Optional[str] = None, model: Optional[str] = None) -> str: | |
return self.chat(msg, setting, model) | |