File size: 4,125 Bytes
778c3d7 f596e58 a92f249 f596e58 a92f249 f596e58 790f101 778c3d7 f596e58 a92f249 f596e58 a92f249 c2392fe 5de0b8a f596e58 c2392fe f596e58 5de0b8a c2392fe f596e58 a92f249 5de0b8a c2392fe ea658a2 172af0f f596e58 172af0f f596e58 c2392fe f596e58 790f101 f596e58 790f101 f596e58 5de0b8a f596e58 5de0b8a f596e58 5de0b8a f596e58 5de0b8a f596e58 ea658a2 f596e58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import os
from typing import Type
from langchain_core.runnables import Runnable, RunnableParallel, RunnableLambda, chain
from langchain.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.messages import HumanMessage, AnyMessage
from pydantic import BaseModel
from game_utils import log
from controllers import controller_from_name
class Player:
def __init__(
self,
name: str,
controller: str,
role: str,
player_id: str = None,
log_filepath: str = None
):
self.name = name
self.id = player_id
if controller == "human":
self.controller_type = "human"
else:
self.controller_type = "ai"
self.controller = controller_from_name(controller)
self.role = role
self.messages = []
self.prompt_queue = []
self.log_filepath = log_filepath
if log_filepath:
player_info = {
"id": self.id,
"name": self.name,
"role": self.role,
"controller": {
"name": controller,
"type": self.controller_type
}
}
log(player_info, log_filepath)
# initialize the runnables
self.generate = RunnableLambda(self._generate)
self.format_output = RunnableLambda(self._output_formatter)
async def respond_to(self, prompt: str, output_format: Type[BaseModel], max_retries=3):
"""Makes the player respond to a prompt. Returns the response in the specified format."""
if self.prompt_queue:
# If there are prompts in the queue, add them to the current prompt
prompt = "\n".join(self.prompt_queue + [prompt])
message = HumanMessage(content=prompt)
output = await self.generate.ainvoke(message)
if self.controller_type == "ai":
retries = 0
try:
output = await self.format_output.ainvoke({"output_format": output_format})
except ValueError as e:
if retries < max_retries:
self.add_to_history(HumanMessage(content=f"Error formatting response: {e} \n\n Please try again."))
output = await self.format_output.ainvoke({"output_format": output_format})
retries += 1
else:
raise e
else:
# Convert the human message to the pydantic object format
field_name = output_format.model_fields.copy().popitem()[0] # only works because current outputs have only 1 field
output = output_format.model_validate({field_name: output.content})
return output
def add_to_history(self, message: AnyMessage):
self.messages.append(message)
log(message.dict(), self.log_filepath)
def _generate(self, message: HumanMessage):
"""Entry point for the Runnable generating responses, automatically logs the message."""
self.add_to_history(message)
# AI's need to be fed the whole message history, but humans can just go back and look at it
if self.controller_type == "human":
response = self.controller.invoke(message.content)
else:
response = self.controller.invoke(self.messages)
self.add_to_history(response)
return response
def _output_formatter(self, inputs: dict):
"""Formats the output of the response."""
output_format: BaseModel = inputs["output_format"]
prompt_template = PromptTemplate.from_template(
"Please rewrite your previous response using the following format: \n\n{format_instructions}"
)
parser = PydanticOutputParser(pydantic_object=output_format)
prompt = prompt_template.invoke({"format_instructions": parser.get_format_instructions()})
message = HumanMessage(content=prompt.text)
response = self.generate.invoke(message)
return parser.invoke(response)
|