File size: 1,488 Bytes
778c3d7 5de0b8a 778c3d7 5de0b8a 778c3d7 5de0b8a 778c3d7 5de0b8a 778c3d7 5de0b8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import os
import openai
# Using TGI Inference Endpoints from Hugging Face
# api_type = "tgi"
api_type = "openai"
if api_type == "tgi":
model_name = "tgi"
client = openai.Client(
base_url=os.environ['HF_ENDPOINT_URL'] + "/v1/",
api_key=os.environ['HF_API_TOKEN']
)
else:
model_name = "gpt-3.5-turbo"
client = openai.Client()
class Player:
def __init__(self, name: str, controller: str, role: str):
self.name = name
self.controller = controller
self.role = role
self.messages = []
def collect_input(self, prompt: str) -> str:
"""Store the input and output in the messages list. Return the output."""
self.messages.append({"role": "user", "content": prompt})
output = self.respond(prompt)
self.messages.append({"role": "assistant", "content": output})
return output
def respond(self, prompt: str) -> str:
if self.controller == "human":
print(prompt)
return input()
elif self.controller == "ai":
chat_completion = client.chat.completions.create(
model=model_name,
messages=self.messages,
stream=False,
)
return chat_completion.choices[0].message.content
def add_message(self, message: str):
"""Add a message to the messages list. No response required."""
self.messages.append({"role": "user", "content": message})
|