Starchik1 commited on
Commit
cd198d1
·
verified ·
1 Parent(s): e974634

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -46
app.py CHANGED
@@ -3,56 +3,57 @@ import gradio as gr
3
 
4
  client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
5
 
6
- # Шаблоны промптов для каждого жанра
7
- GENRE_PROMPTS = {
8
- "Investigation": "Ты находишься в жанре игры Investigation...",
9
- # Остальные жанры остаются неизменными
10
  }
11
 
12
- def format_prompt(message, history, genre, players):
 
13
  genre_prompt = GENRE_PROMPTS.get(genre, "Жанр игры неизвестен.")
14
- prompt = f"Ты персонаж в текстовой игре, тебя зовут Рассказчик, взаимодействующий с игроками на русском языке. " \
15
- f"{genre_prompt} Учитывай, что в игре участвуют несколько игроков: {', '.join(players)}. " \
16
- f"Каждый твой ответ должен учитывать действия всех игроков, их взаимодействие и текущую ситуацию. " \
17
- f"Добавляй элементы кооперации: совместное решение задач, голосование за действия и влияние выборов одного игрока на другого."
18
  for user_prompt, bot_response in history:
19
- prompt += f"[INST] {user_prompt} [/INST]"
20
- prompt += f" {bot_response}</s> "
21
  prompt += f"[INST] {message} [/INST]"
22
  return prompt
23
 
24
- def generate(
25
- prompt, history, system_prompt, genre, players, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
26
- ):
27
- formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history, genre, players)
28
- stream = client.text_generation(formatted_prompt, stream=True, details=True, return_full_text=True)
29
- output = ""
30
- for response in stream:
31
- output += response.token.text
32
- yield output
33
- return output
34
-
35
- additional_inputs = [
36
- gr.Textbox(label="System Prompt", max_lines=1, interactive=True, visible=False),
37
- gr.Radio(label="Game Genre", choices=list(GENRE_PROMPTS.keys()), value="Horror", interactive=True),
38
- gr.Textbox(label="Players (comma-separated)", interactive=True, placeholder="Player1, Player2", visible=True),
39
- gr.Slider(label="Temperature", value=0.1, minimum=0.0, maximum=1.0, step=0.05, interactive=True, visible=False),
40
- gr.Slider(label="Max new tokens", value=1024, minimum=128, maximum=8192, step=64, interactive=True, visible=False),
41
- gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, visible=False),
42
- gr.Slider(label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, visible=False),
43
- ]
44
-
45
- examples = [["Подробнее"], ["Обстановка"]]
46
-
47
- def wrapped_generate(prompt, history, system_prompt, genre, temperature, max_new_tokens, top_p, repetition_penalty, players):
48
- player_list = [p.strip() for p in players.split(",") if p.strip()]
49
- return generate(prompt, history, system_prompt, genre, player_list, temperature, max_new_tokens, top_p, repetition_penalty)
50
-
51
- gr.ChatInterface(
52
- fn=wrapped_generate,
53
- chatbot=gr.Chatbot(show_label=False),
54
- additional_inputs=additional_inputs,
55
- title="theGame with Co-op",
56
- examples=examples,
57
- concurrency_limit=20,
58
- ).launch(show_api=False)
 
 
 
 
 
3
 
4
  client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
5
 
6
+ # Хранение состояния игры
7
+ game_state = {
8
+ "players": {},
9
+ "story_progress": "Начало игры..."
10
  }
11
 
12
+ def format_prompt(message, history, genre, player_id):
13
+ # Получаем промпт для выбранного жанра
14
  genre_prompt = GENRE_PROMPTS.get(genre, "Жанр игры неизвестен.")
15
+ prompt = f"Игрок {player_id} в жанре {genre}. {genre_prompt} История: {game_state['story_progress']}."
 
 
 
16
  for user_prompt, bot_response in history:
17
+ prompt += f"[INST] {user_prompt} [/INST] {bot_response}</s> "
 
18
  prompt += f"[INST] {message} [/INST]"
19
  return prompt
20
 
21
+ def generate(player_id, message, genre, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
22
+ if player_id not in game_state["players"]:
23
+ game_state["players"][player_id] = {"history": []}
24
+
25
+ history = game_state["players"][player_id]["history"]
26
+ formatted_prompt = format_prompt(message, history, genre, player_id)
27
+
28
+ generate_kwargs = {
29
+ "temperature": temperature,
30
+ "max_new_tokens": max_new_tokens,
31
+ "top_p": top_p,
32
+ "repetition_penalty": repetition_penalty,
33
+ "do_sample": True,
34
+ "seed": 42,
35
+ }
36
+
37
+ response = client.text_generation(formatted_prompt, **generate_kwargs)
38
+ game_state["players"][player_id]["history"].append((message, response))
39
+ game_state["story_progress"] += f"\n{response}"
40
+
41
+ return response
42
+
43
+ def gradio_interface():
44
+ with gr.Blocks() as demo:
45
+ player_id = gr.Textbox(label="Player ID", placeholder="Введите ваш уникальный идентификатор")
46
+ genre = gr.Radio(label="Game Genre", choices=list(GENRE_PROMPTS.keys()), value="Horror")
47
+ message = gr.Textbox(label="Ваше сообщение")
48
+ output = gr.Textbox(label="Ответ игры")
49
+
50
+ def play_game(player_id, message, genre):
51
+ return generate(player_id, message, genre)
52
+
53
+ submit = gr.Button("Отправить")
54
+ submit.click(play_game, inputs=[player_id, message, genre], outputs=output)
55
+
56
+ demo.launch()
57
+
58
+ if __name__ == "__main__":
59
+ gradio_interface()