File size: 5,174 Bytes
778c3d7 7877562 a92f249 f596e58 a92f249 f596e58 790f101 778c3d7 7877562 f596e58 a92f249 f596e58 a92f249 7877562 c2392fe 5de0b8a 7877562 f596e58 c2392fe f596e58 5de0b8a c2392fe f596e58 ea658a2 172af0f f596e58 172af0f f596e58 7877562 f596e58 c2392fe 7877562 c2392fe f596e58 7877562 f596e58 7877562 f596e58 790f101 f596e58 790f101 f596e58 7877562 f596e58 5de0b8a f596e58 5de0b8a f596e58 5de0b8a f596e58 5de0b8a f596e58 ea658a2 f596e58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import os
from typing import Type, Literal
from langchain_core.runnables import Runnable, RunnableParallel, RunnableLambda, chain
from langchain.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.messages import HumanMessage, AnyMessage
from langchain_core.exceptions import OutputParserException
from pydantic import BaseModel
from game_utils import log
from controllers import controller_from_name
Role = Literal["chameleon", "herd"]
class Player:
role: Role | None = None
"""The role of the player in the game. Can be "chameleon" or "herd"."""
rounds_played_as_chameleon: int = 0
"""The number of times the player has been the chameleon."""
rounds_played_as_herd: int = 0
"""The number of times the player has been in the herd."""
points: int = 0
"""The number of points the player has."""
messages: list[AnyMessage] = []
"""The messages the player has sent and received."""
prompt_queue: list[str] = []
"""A queue of prompts to be added to the next prompt."""
def __init__(
self,
name: str,
controller: str,
player_id: str = None,
log_filepath: str = None
):
self.name = name
self.id = player_id
if controller == "human":
self.controller_type = "human"
else:
self.controller_type = "ai"
self.controller = controller_from_name(controller)
self.log_filepath = log_filepath
if log_filepath:
player_info = {
"id": self.id,
"name": self.name,
"role": self.role,
"controller": {
"name": controller,
"type": self.controller_type
}
}
log(player_info, log_filepath)
# initialize the runnables
self.generate = RunnableLambda(self._generate)
self.format_output = RunnableLambda(self._output_formatter)
def assign_role(self, role: Role):
self.role = role
if role == "chameleon":
self.rounds_played_as_chameleon += 1
elif role == "herd":
self.rounds_played_as_herd += 1
async def respond_to(self, prompt: str, output_format: Type[BaseModel], max_retries=3):
"""Makes the player respond to a prompt. Returns the response in the specified format."""
if self.prompt_queue:
# If there are prompts in the queue, add them to the current prompt
prompt = "\n".join(self.prompt_queue + [prompt])
# Clear the prompt queue
self.prompt_queue = []
message = HumanMessage(content=prompt)
output = await self.generate.ainvoke(message)
if self.controller_type == "ai":
retries = 0
try:
output = await self.format_output.ainvoke({"output_format": output_format})
except Exception as e:
if retries > max_retries:
self.add_to_history(HumanMessage(content=f"Error formatting response: {e} \n\n Please try again."))
output = await self.format_output.ainvoke({"output_format": output_format})
retries += 1
else:
print(f"Max retries reached due to Error: {e}")
raise e
else:
# Convert the human message to the pydantic object format
field_name = output_format.model_fields.copy().popitem()[0] # only works because current outputs have only 1 field
output = output_format.model_validate({field_name: output.content})
return output
def add_to_history(self, message: AnyMessage):
self.messages.append(message)
log(message.dict(), self.log_filepath)
def is_human(self):
return self.controller_type == "human"
def is_ai(self):
return not self.is_human()
def _generate(self, message: HumanMessage):
"""Entry point for the Runnable generating responses, automatically logs the message."""
self.add_to_history(message)
# AI's need to be fed the whole message history, but humans can just go back and look at it
if self.controller_type == "human":
response = self.controller.invoke(message.content)
else:
response = self.controller.invoke(self.messages)
self.add_to_history(response)
return response
def _output_formatter(self, inputs: dict):
"""Formats the output of the response."""
output_format: BaseModel = inputs["output_format"]
prompt_template = PromptTemplate.from_template(
"Please rewrite your previous response using the following format: \n\n{format_instructions}"
)
parser = PydanticOutputParser(pydantic_object=output_format)
prompt = prompt_template.invoke({"format_instructions": parser.get_format_instructions()})
message = HumanMessage(content=prompt.text)
response = self.generate.invoke(message)
return parser.invoke(response)
|