|
import os |
|
import openai |
|
|
|
class BaseAgent: |
|
"""Base class for all agents. Override __call__ to implement agent logic.""" |
|
def __call__(self, question: str) -> str: |
|
raise NotImplementedError("Agent must implement __call__ method.") |
|
|
|
class LLMOpenAIAgent(BaseAgent): |
|
"""Agent that uses OpenAI's GPT-3.5-turbo to answer questions.""" |
|
def __init__(self, api_key=None): |
|
self.api_key = api_key or os.getenv("OPENAI_API_KEY") |
|
openai.api_key = self.api_key |
|
|
|
def __call__(self, question: str) -> str: |
|
try: |
|
response = openai.ChatCompletion.create( |
|
model="gpt-3.5-turbo", |
|
messages=[{"role": "user", "content": question}], |
|
max_tokens=256, |
|
temperature=0.2, |
|
) |
|
return response.choices[0].message["content"].strip() |
|
except Exception as e: |
|
return f"Error: {e}" |