jerpint commited on
Commit
9143594
·
1 Parent(s): e02ac39

update data structures

Browse files
Files changed (1) hide show
  1. buster/completers/base.py +21 -22
buster/completers/base.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import logging
2
  import os
3
  from abc import ABC, abstractmethod
@@ -5,8 +6,6 @@ from abc import ABC, abstractmethod
5
  import openai
6
  import promptlayer
7
 
8
- from buster.formatter.base import Response
9
-
10
  logger = logging.getLogger(__name__)
11
  logging.basicConfig(level=logging.INFO)
12
 
@@ -20,44 +19,46 @@ if promptlayer_api_key:
20
  openai = promptlayer.openai
21
  openai.api_key = os.environ.get("OPENAI_API_KEY")
22
 
 
 
 
 
 
23
 
24
  class Completer(ABC):
25
- def __init__(self, cfg):
26
- self.cfg = cfg
27
 
28
  @abstractmethod
29
  def complete(self, prompt) -> str:
30
  ...
31
 
32
- def generate_response(self, user_input, documents) -> Response:
33
  # Call the API to generate a response
34
- prompt = self.prepare_prompt(user_input, documents)
35
- name = self.cfg["name"]
36
- logger.info(f"querying model {name}...")
37
- logger.info(f"{prompt=}")
38
  try:
39
- completion_kwargs = self.cfg["completion_kwargs"]
40
- completion = self.complete(prompt=prompt, **completion_kwargs)
41
  except Exception as e:
42
  # log the error and return a generic response instead.
43
  logger.exception("Error connecting to OpenAI API. See traceback:")
44
- return Response("", True, "We're having trouble connecting to OpenAI right now... Try again soon!")
45
 
46
- return Response(completion)
47
 
48
 
49
  class GPT3Completer(Completer):
50
  def prepare_prompt(
51
  self,
 
52
  user_input: str,
53
- documents: str,
54
  ) -> str:
55
  """
56
  Prepare the prompt with prompt engineering.
57
  """
58
- text_before_docs = self.cfg["text_before_documents"]
59
- text_before_prompt = self.cfg["text_before_prompt"]
60
- return text_before_docs + documents + text_before_prompt + user_input
61
 
62
  def complete(self, prompt, **completion_kwargs):
63
  response = openai.Completion.create(prompt=prompt, **completion_kwargs)
@@ -67,16 +68,14 @@ class GPT3Completer(Completer):
67
  class ChatGPTCompleter(Completer):
68
  def prepare_prompt(
69
  self,
 
70
  user_input: str,
71
- documents: str,
72
  ) -> list:
73
  """
74
  Prepare the prompt with prompt engineering.
75
  """
76
- text_before_docs = self.cfg["text_before_documents"]
77
- text_before_prompt = self.cfg["text_before_prompt"]
78
  prompt = [
79
- {"role": "system", "content": text_before_docs + documents + text_before_prompt},
80
  {"role": "user", "content": user_input},
81
  ]
82
  return prompt
@@ -96,4 +95,4 @@ def get_completer(completer_cfg):
96
  "GPT3": GPT3Completer,
97
  "ChatGPT": ChatGPTCompleter,
98
  }
99
- return completers[name](completer_cfg)
 
1
+ from dataclasses import dataclass
2
  import logging
3
  import os
4
  from abc import ABC, abstractmethod
 
6
  import openai
7
  import promptlayer
8
 
 
 
9
  logger = logging.getLogger(__name__)
10
  logging.basicConfig(level=logging.INFO)
11
 
 
19
  openai = promptlayer.openai
20
  openai.api_key = os.environ.get("OPENAI_API_KEY")
21
 
22
+ @dataclass(slots=True)
23
+ class Completion:
24
+ text: str
25
+ error: bool = False
26
+ error_msg: str | None = None
27
 
28
  class Completer(ABC):
29
+ def __init__(self, completion_kwargs: dict):
30
+ self.completion_kwargs = completion_kwargs
31
 
32
  @abstractmethod
33
  def complete(self, prompt) -> str:
34
  ...
35
 
36
+ def generate_response(self, system_prompt, user_input) -> Completion:
37
  # Call the API to generate a response
38
+ prompt = self.prepare_prompt(system_prompt, user_input)
39
+ logger.info(f"querying model with parameters: {self.completion_kwargs}...")
40
+ logger.info(f"{system_prompt=}")
41
+ logger.info(f"{user_input=}")
42
  try:
43
+ completion = self.complete(prompt=prompt, **self.completion_kwargs)
 
44
  except Exception as e:
45
  # log the error and return a generic response instead.
46
  logger.exception("Error connecting to OpenAI API. See traceback:")
47
+ return Completion("", True, "We're having trouble connecting to OpenAI right now... Try again soon!")
48
 
49
+ return Completion(completion)
50
 
51
 
52
  class GPT3Completer(Completer):
53
  def prepare_prompt(
54
  self,
55
+ system_prompt: str,
56
  user_input: str,
 
57
  ) -> str:
58
  """
59
  Prepare the prompt with prompt engineering.
60
  """
61
+ return system_prompt + user_input
 
 
62
 
63
  def complete(self, prompt, **completion_kwargs):
64
  response = openai.Completion.create(prompt=prompt, **completion_kwargs)
 
68
  class ChatGPTCompleter(Completer):
69
  def prepare_prompt(
70
  self,
71
+ system_prompt: str,
72
  user_input: str,
 
73
  ) -> list:
74
  """
75
  Prepare the prompt with prompt engineering.
76
  """
 
 
77
  prompt = [
78
+ {"role": "system", "content": system_prompt},
79
  {"role": "user", "content": user_input},
80
  ]
81
  return prompt
 
95
  "GPT3": GPT3Completer,
96
  "ChatGPT": ChatGPTCompleter,
97
  }
98
+ return completers[name](completer_cfg["completion_kwargs"])