File size: 8,990 Bytes
2dc48af
 
5d0cac4
463419e
 
 
84dac74
c9d5325
2dc48af
7f60ac5
b41be8b
c9d5325
2dc48af
6823e55
 
 
2dc48af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9d5325
 
 
 
 
6094eb6
c9d5325
 
 
 
2dc48af
 
c9d5325
 
 
 
 
 
 
 
 
 
 
 
2dc48af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
import streamlit as st
from random import choices, randint
import os
os.system("pip install transformers")
os.system("pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu")
os.system("pip install einops")
os.system("pip install sentencepiece")
os.system("pip install openai")
import json
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from openai import OpenAI

# """
# 谁是卧底游戏
# """

state = st.session_state

# title
st.title("谁是卧底😎")

# read the word list from the json file
with open("word_list.json", "r") as f:
    word_list = json.load(f)
    pass

# define the avatar dict for the players
avatar_dict = {
    "host": "🐼",
    "P1": "🚀",
    "P2": "🚄",
    "P3": "🚁",
    "P4": "🚂",
    "P5": "🚢",
    "P6": "🚤",
    "P7": "🚙",
    "P8": "🚠",
    "P9": "🚲",
    "P10": "🚜",
    "H": "🤹‍♂️",
}

# define the state of the game and save some data
## 全局消息栈
if "messages" not in state:
    state.messages = []
    pass
## 玩家列表
if "players" not in state:
    state.players = []
    pass
## 玩家系统提示
if "prompt" not in state:
    with open("prompt.txt", "r") as f:
        state.prompt = f.read()
        pass
    pass

# create a new OpenAI client and define the generation function
if "client" not in state:
    state.client = OpenAI(
        api_key=os.getenv("OPENAI_API_KEY"),
        base_url=os.getenv("BASE_URL")
    )
    state.model_name = "internlm/internlm2_5-20b-chat"
    # model_path = "internlm/internlm2_5-7b-chat"
    # state.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True)
    # state.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    # state.model.eval()
    pass
def response(messages):
    return state.client.chat.completions.create(
        model=state.model_name,
        messages=messages,
        temperature=0.5,
    ).choices[0].message.content
    # prompt = state.tokenizer.apply_chat_template(
    #             messages,
    #             tokenize=False,
    #             add_generation_prompt=True
    #         )
    # response,history = state.model.chat(state.tokenizer,prompt,history=[])
    # return response

# settings
## 记录轮次
if "max_round" not in state:
    state.max_round = 100
    pass
if "round" not in state:
    state.round = 0
    pass
## 侧边栏设置游戏,包括关键词,人数,回合数等
with st.sidebar:
    st.write("游戏设置")
    with st.form(key="game_setting"):
        if "words" not in state:
            words_num = randint(0, len(word_list)-1)
            state.words = word_list[words_num]
        total_num = st.number_input("总人数", 5, 10, 5)
        spy_num = st.number_input("卧底人数", 1, total_num//2, 1)
        max_round = st.number_input("最大回合数", 5, 10, 10)

        submitted = st.form_submit_button("保存设置")

        ## 提交后保存设置,初始化玩家、消息栈等
        if submitted:
            # print("游戏设置已保存")
            state.spy_word = state.words["spy_word"]           # 卧底关键词
            state.civilian_word = state.words["civilian_word"] # 平民关键词
            state.total_num = total_num        # 总人数
            state.spy_num = spy_num           # 卧底人数
            state.max_round = max_round     # 最大回合数
            ## 初始化玩家列表,人类玩家和AI玩家分开存
            human_dignity = randint(0,1)    # 人类玩家的身份,0: 平民 1: 卧底
            if human_dignity == 0:
                state.players = [{"id": "H", "dignity": "civilian"}]
                st.write("你的关键词是{}".format(state.civilian_word))
                pass
            else:
                state.players = [{"id": "H", "dignity": "spy"}]
                st.write("你的关键词是{}".format(state.spy_word))
                pass
            state.players += [{"id":"P"+str(i+1)} for i in range(total_num-1)]
            if human_dignity == 1 and state.spy_num-1 == 0:
                spy_id = []
                pass
            else:
                spy_id = choices([f"P{a}" for a in list(range(1,total_num))], k=state.spy_num)
                pass
            for each in state.players:
                if each["id"] in spy_id:
                    each["dignity"] = "spy"
                    pass
                else:
                    each["dignity"] = "civilian"
                    pass
                pass
            pass
        pass
    pass

# 消息显示窗口
with st.container(height=300):
    for message in state.messages:
        with st.chat_message(message["id"], avatar=avatar_dict[message["id"]]):
            st.text(message["id"])
            st.write(message["message"])
        pass
# 游戏主体环节
if state.round < state.max_round:
    if "description" not in state:
        state.description = []
        pass
    ## 控制游戏轮次及开始
    start = st.button("开始第{}轮游戏".format(state.round+1))
    if start:
        if "round" not in state:
            state.round = 0
            pass
        state.messages.append({"id":"host", "message":f"第{state.round+1}轮游戏开始"})
        ## 生成描述环节
        for player in state.players:
            ## 如果是人类玩家,跳过
            if player["id"] == "H":
                continue
            if player["dignity"] == "spy":
                text = response([
                    {"role":"system", "content":state.prompt.format(state.spy_word)},
                    {"role":"system", "content":"/describe "+"请根据描述历史记录描述你的关键词,需要注意不能与已有的描述重复,下面是描述历史记录:\n"+"\n".join(state.description)}
                    ])
            else:
                text = response([
                    {"role":"system", "content":state.prompt.format(state.civilian_word)},
                    {"role":"system", "content":"/describe "+"请根据描述历史记录描述你的关键词,下面是描述历史记录:\n"+"\n".join(state.description)}
                    ])
                pass
            state.messages.append({"id":player["id"], "message": text})
            state.description.append(player["id"] + ":" + text)
            pass
        st.rerun()
        pass
    ## 投票环节, AI玩家生成回复,最后由人类玩家统计并选择投票对象
    col1, col2 = st.columns([8,2])
    with col1:
        vote_id = st.selectbox("投票对象", [a["id"] for a in state.players])
        pass
    with col2:
        if st.button("开始投票"):
            for player in state.players:
                if player["id"] == "H":
                    continue
                text = response([
                    {"role":"system", "content":state.prompt.format(state.spy_word)},
                    {"role":"system", "content":"/vote "+"请根据描述历史记录选择要投出的玩家,下面是描述历史记录:\n"+"\n".join(state.description)+"\n"+"当前场上存活玩家id为:\n"+",".join([a["id"] for a in state.players])}
                    ])
                state.messages.append({"id":player["id"], "message": text})
                pass
            st.rerun()
            pass
        pass
        if st.button("投出玩家"):
            state.messages.append({"id":"host", "message":f"玩家{vote_id}被投票出局"})
            state.round += 1
            for player in state.players:
                if player["id"] == vote_id:
                    state.players.remove(player)
                    break
                pass
        ## 验证是否还有卧底存活
            spy_live = False
            for player in state.players:
                if player["dignity"] == "spy":
                    spy_live = True
                    break
                pass
            if not spy_live and state.players:
                state.messages.append({"id":"host", "message":"平民胜利!"})
                pass
            elif spy_live:
                ## 统计当前卧底人数,如果占据一半以上则卧底胜利
                spy_num = 0
                for player in state.players:
                    if player["dignity"] == "spy":
                        spy_num += 1
                        pass
                    pass
                if spy_num >= len(state.players)//2:
                    state.messages.append({"id":"host", "message":"卧底胜利!"})
                    pass
            st.rerun()
            pass
    human_live = False
    for player in state.players:
        if player["id"] == "H":
            human_live = True
            break
        pass
    if human_live:
        if des := st.chat_input():
            state.messages.append({"id":"H", "message":des})
            state.description.append("H:"+des)
            st.rerun()