Spaces:
Sleeping
Sleeping
File size: 8,990 Bytes
2dc48af 5d0cac4 463419e 84dac74 c9d5325 2dc48af 7f60ac5 b41be8b c9d5325 2dc48af 6823e55 2dc48af c9d5325 6094eb6 c9d5325 2dc48af c9d5325 2dc48af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 |
import streamlit as st
from random import choices, randint
import os
os.system("pip install transformers")
os.system("pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu")
os.system("pip install einops")
os.system("pip install sentencepiece")
os.system("pip install openai")
import json
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from openai import OpenAI
# """
# 谁是卧底游戏
# """
state = st.session_state
# title
st.title("谁是卧底😎")
# read the word list from the json file
with open("word_list.json", "r") as f:
word_list = json.load(f)
pass
# define the avatar dict for the players
avatar_dict = {
"host": "🐼",
"P1": "🚀",
"P2": "🚄",
"P3": "🚁",
"P4": "🚂",
"P5": "🚢",
"P6": "🚤",
"P7": "🚙",
"P8": "🚠",
"P9": "🚲",
"P10": "🚜",
"H": "🤹♂️",
}
# define the state of the game and save some data
## 全局消息栈
if "messages" not in state:
state.messages = []
pass
## 玩家列表
if "players" not in state:
state.players = []
pass
## 玩家系统提示
if "prompt" not in state:
with open("prompt.txt", "r") as f:
state.prompt = f.read()
pass
pass
# create a new OpenAI client and define the generation function
if "client" not in state:
state.client = OpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
base_url=os.getenv("BASE_URL")
)
state.model_name = "internlm/internlm2_5-20b-chat"
# model_path = "internlm/internlm2_5-7b-chat"
# state.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True)
# state.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# state.model.eval()
pass
def response(messages):
return state.client.chat.completions.create(
model=state.model_name,
messages=messages,
temperature=0.5,
).choices[0].message.content
# prompt = state.tokenizer.apply_chat_template(
# messages,
# tokenize=False,
# add_generation_prompt=True
# )
# response,history = state.model.chat(state.tokenizer,prompt,history=[])
# return response
# settings
## 记录轮次
if "max_round" not in state:
state.max_round = 100
pass
if "round" not in state:
state.round = 0
pass
## 侧边栏设置游戏,包括关键词,人数,回合数等
with st.sidebar:
st.write("游戏设置")
with st.form(key="game_setting"):
if "words" not in state:
words_num = randint(0, len(word_list)-1)
state.words = word_list[words_num]
total_num = st.number_input("总人数", 5, 10, 5)
spy_num = st.number_input("卧底人数", 1, total_num//2, 1)
max_round = st.number_input("最大回合数", 5, 10, 10)
submitted = st.form_submit_button("保存设置")
## 提交后保存设置,初始化玩家、消息栈等
if submitted:
# print("游戏设置已保存")
state.spy_word = state.words["spy_word"] # 卧底关键词
state.civilian_word = state.words["civilian_word"] # 平民关键词
state.total_num = total_num # 总人数
state.spy_num = spy_num # 卧底人数
state.max_round = max_round # 最大回合数
## 初始化玩家列表,人类玩家和AI玩家分开存
human_dignity = randint(0,1) # 人类玩家的身份,0: 平民 1: 卧底
if human_dignity == 0:
state.players = [{"id": "H", "dignity": "civilian"}]
st.write("你的关键词是{}".format(state.civilian_word))
pass
else:
state.players = [{"id": "H", "dignity": "spy"}]
st.write("你的关键词是{}".format(state.spy_word))
pass
state.players += [{"id":"P"+str(i+1)} for i in range(total_num-1)]
if human_dignity == 1 and state.spy_num-1 == 0:
spy_id = []
pass
else:
spy_id = choices([f"P{a}" for a in list(range(1,total_num))], k=state.spy_num)
pass
for each in state.players:
if each["id"] in spy_id:
each["dignity"] = "spy"
pass
else:
each["dignity"] = "civilian"
pass
pass
pass
pass
pass
# 消息显示窗口
with st.container(height=300):
for message in state.messages:
with st.chat_message(message["id"], avatar=avatar_dict[message["id"]]):
st.text(message["id"])
st.write(message["message"])
pass
# 游戏主体环节
if state.round < state.max_round:
if "description" not in state:
state.description = []
pass
## 控制游戏轮次及开始
start = st.button("开始第{}轮游戏".format(state.round+1))
if start:
if "round" not in state:
state.round = 0
pass
state.messages.append({"id":"host", "message":f"第{state.round+1}轮游戏开始"})
## 生成描述环节
for player in state.players:
## 如果是人类玩家,跳过
if player["id"] == "H":
continue
if player["dignity"] == "spy":
text = response([
{"role":"system", "content":state.prompt.format(state.spy_word)},
{"role":"system", "content":"/describe "+"请根据描述历史记录描述你的关键词,需要注意不能与已有的描述重复,下面是描述历史记录:\n"+"\n".join(state.description)}
])
else:
text = response([
{"role":"system", "content":state.prompt.format(state.civilian_word)},
{"role":"system", "content":"/describe "+"请根据描述历史记录描述你的关键词,下面是描述历史记录:\n"+"\n".join(state.description)}
])
pass
state.messages.append({"id":player["id"], "message": text})
state.description.append(player["id"] + ":" + text)
pass
st.rerun()
pass
## 投票环节, AI玩家生成回复,最后由人类玩家统计并选择投票对象
col1, col2 = st.columns([8,2])
with col1:
vote_id = st.selectbox("投票对象", [a["id"] for a in state.players])
pass
with col2:
if st.button("开始投票"):
for player in state.players:
if player["id"] == "H":
continue
text = response([
{"role":"system", "content":state.prompt.format(state.spy_word)},
{"role":"system", "content":"/vote "+"请根据描述历史记录选择要投出的玩家,下面是描述历史记录:\n"+"\n".join(state.description)+"\n"+"当前场上存活玩家id为:\n"+",".join([a["id"] for a in state.players])}
])
state.messages.append({"id":player["id"], "message": text})
pass
st.rerun()
pass
pass
if st.button("投出玩家"):
state.messages.append({"id":"host", "message":f"玩家{vote_id}被投票出局"})
state.round += 1
for player in state.players:
if player["id"] == vote_id:
state.players.remove(player)
break
pass
## 验证是否还有卧底存活
spy_live = False
for player in state.players:
if player["dignity"] == "spy":
spy_live = True
break
pass
if not spy_live and state.players:
state.messages.append({"id":"host", "message":"平民胜利!"})
pass
elif spy_live:
## 统计当前卧底人数,如果占据一半以上则卧底胜利
spy_num = 0
for player in state.players:
if player["dignity"] == "spy":
spy_num += 1
pass
pass
if spy_num >= len(state.players)//2:
state.messages.append({"id":"host", "message":"卧底胜利!"})
pass
st.rerun()
pass
human_live = False
for player in state.players:
if player["id"] == "H":
human_live = True
break
pass
if human_live:
if des := st.chat_input():
state.messages.append({"id":"H", "message":des})
state.description.append("H:"+des)
st.rerun()
|