Tesvia commited on
Commit
eea77dd
·
verified ·
1 Parent(s): e4ed116

Upload 6 files

Browse files
Files changed (1) hide show
  1. agent.py +8 -6
agent.py CHANGED
@@ -56,8 +56,7 @@ def _select_model():
56
  hf_token = os.getenv("HF_API_KEY")
57
  return InferenceClientModel(
58
  model_id=hf_model_id,
59
- token=hf_token,
60
- system_prompt=SYSTEM_PROMPT
61
  )
62
 
63
  if provider == "openai":
@@ -66,8 +65,7 @@ def _select_model():
66
  openai_token = os.getenv("OPENAI_API_KEY")
67
  return OpenAIServerModel(
68
  model_id=openai_model_id,
69
- api_key=openai_token,
70
- system_prompt=SYSTEM_PROMPT
71
  )
72
 
73
  raise ValueError(
@@ -89,8 +87,12 @@ DEFAULT_TOOLS = [
89
  ]
90
 
91
  class GAIAAgent(CodeAgent):
92
- def __init__(self, tools=None):
93
- super().__init__(tools=tools or DEFAULT_TOOLS, model=_select_model())
 
 
 
 
94
 
95
  # Convenience so the object itself can be *called* directly
96
  def __call__(self, question: str, **kwargs: Any) -> str:
 
56
  hf_token = os.getenv("HF_API_KEY")
57
  return InferenceClientModel(
58
  model_id=hf_model_id,
59
+ token=hf_token
 
60
  )
61
 
62
  if provider == "openai":
 
65
  openai_token = os.getenv("OPENAI_API_KEY")
66
  return OpenAIServerModel(
67
  model_id=openai_model_id,
68
+ api_key=openai_token
 
69
  )
70
 
71
  raise ValueError(
 
87
  ]
88
 
89
  class GAIAAgent(CodeAgent):
90
+ def __init__(self, tools=None, *, system_prompt: str = SYSTEM_PROMPT):
91
+ super().__init__(
92
+ tools=tools or DEFAULT_TOOLS,
93
+ model=_select_model(),
94
+ system_prompt=system_prompt
95
+ )
96
 
97
  # Convenience so the object itself can be *called* directly
98
  def __call__(self, question: str, **kwargs: Any) -> str: