Update app.py
Browse files
app.py
CHANGED
@@ -1,10 +1,17 @@
|
|
1 |
-
from huggingface_hub import InferenceClient
|
2 |
import gradio as gr
|
3 |
|
|
|
4 |
client = InferenceClient(
|
5 |
"mistralai/Mistral-7B-Instruct-v0.3"
|
6 |
)
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
def format_prompt(message, history, genre):
|
9 |
prompt = f"Ты — персонаж в текстовой игре, тебя зовут Рассказчик, взаимодействующий с игроком на русском языке. " \
|
10 |
f"Ты находишься в жанре игры: {genre}. Каждый твой ответ должен продолжать сюжет игры, принимая во внимание прошлые действия игрока и текущую ситуацию. " \
|
@@ -20,6 +27,7 @@ def format_prompt(message, history, genre):
|
|
20 |
prompt += f"[INST] {message} [/INST]"
|
21 |
return prompt
|
22 |
|
|
|
23 |
def generate(
|
24 |
prompt, history, system_prompt, genre, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
|
25 |
):
|
@@ -46,20 +54,26 @@ def generate(
|
|
46 |
yield output
|
47 |
return output
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
additional_inputs = [
|
50 |
gr.Textbox(
|
51 |
label="System Prompt",
|
52 |
max_lines=1,
|
53 |
interactive=True,
|
54 |
-
visible=False #
|
55 |
),
|
56 |
gr.Radio(
|
57 |
label="Game Genre",
|
58 |
choices=["Investigation", "Fantasy", "Sci-Fi", "Horror", "Adventure", "Mystery"],
|
59 |
-
value="Horror", #
|
60 |
interactive=True,
|
61 |
info="Select the genre of the game",
|
62 |
-
visible=True #
|
63 |
),
|
64 |
gr.Slider(
|
65 |
label="Temperature",
|
@@ -69,7 +83,7 @@ additional_inputs = [
|
|
69 |
step=0.05,
|
70 |
interactive=True,
|
71 |
info="Higher values produce more diverse outputs",
|
72 |
-
visible=False #
|
73 |
),
|
74 |
gr.Slider(
|
75 |
label="Max new tokens",
|
@@ -79,7 +93,7 @@ additional_inputs = [
|
|
79 |
step=64,
|
80 |
interactive=True,
|
81 |
info="The maximum numbers of new tokens",
|
82 |
-
visible=False #
|
83 |
),
|
84 |
gr.Slider(
|
85 |
label="Top-p (nucleus sampling)",
|
@@ -89,7 +103,7 @@ additional_inputs = [
|
|
89 |
step=0.05,
|
90 |
interactive=True,
|
91 |
info="Higher values sample more low-probability tokens",
|
92 |
-
visible=False #
|
93 |
),
|
94 |
gr.Slider(
|
95 |
label="Repetition penalty",
|
@@ -99,12 +113,19 @@ additional_inputs = [
|
|
99 |
step=0.05,
|
100 |
interactive=True,
|
101 |
info="Penalize repeated tokens",
|
102 |
-
visible=False #
|
|
|
|
|
|
|
|
|
|
|
103 |
)
|
104 |
]
|
105 |
|
|
|
106 |
examples = [["Подробнее"], ["Варианты"]]
|
107 |
|
|
|
108 |
gr.ChatInterface(
|
109 |
fn=generate,
|
110 |
chatbot=gr.Chatbot(show_label=False, show_share_button=True, show_copy_button=True, likeable=True, layout="panel"),
|
@@ -113,3 +134,11 @@ gr.ChatInterface(
|
|
113 |
examples=examples,
|
114 |
concurrency_limit=20,
|
115 |
).launch(show_api=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import InferenceClient
|
2 |
import gradio as gr
|
3 |
|
4 |
+
# Подключаем клиента для текстовой модели
|
5 |
client = InferenceClient(
|
6 |
"mistralai/Mistral-7B-Instruct-v0.3"
|
7 |
)
|
8 |
|
9 |
+
# Подключаем клиента для генерации изображений
|
10 |
+
image_client = InferenceClient(
|
11 |
+
"stabilityai/stable-diffusion-2"
|
12 |
+
)
|
13 |
+
|
14 |
+
# Форматируем текстовый запрос
|
15 |
def format_prompt(message, history, genre):
|
16 |
prompt = f"Ты — персонаж в текстовой игре, тебя зовут Рассказчик, взаимодействующий с игроком на русском языке. " \
|
17 |
f"Ты находишься в жанре игры: {genre}. Каждый твой ответ должен продолжать сюжет игры, принимая во внимание прошлые действия игрока и текущую ситуацию. " \
|
|
|
27 |
prompt += f"[INST] {message} [/INST]"
|
28 |
return prompt
|
29 |
|
30 |
+
# Генерация текста
|
31 |
def generate(
|
32 |
prompt, history, system_prompt, genre, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
|
33 |
):
|
|
|
54 |
yield output
|
55 |
return output
|
56 |
|
57 |
+
# Генерация изображений
|
58 |
+
def generate_image(description):
|
59 |
+
image = image_client.text_to_image(description, width=512, height=512)
|
60 |
+
return image.url
|
61 |
+
|
62 |
+
# Дополнительные настройки для Gradio
|
63 |
additional_inputs = [
|
64 |
gr.Textbox(
|
65 |
label="System Prompt",
|
66 |
max_lines=1,
|
67 |
interactive=True,
|
68 |
+
visible=False # Скрыть поле ввода системного запроса
|
69 |
),
|
70 |
gr.Radio(
|
71 |
label="Game Genre",
|
72 |
choices=["Investigation", "Fantasy", "Sci-Fi", "Horror", "Adventure", "Mystery"],
|
73 |
+
value="Horror", # Жанр по умолчанию
|
74 |
interactive=True,
|
75 |
info="Select the genre of the game",
|
76 |
+
visible=True # Отображение выбора жанра
|
77 |
),
|
78 |
gr.Slider(
|
79 |
label="Temperature",
|
|
|
83 |
step=0.05,
|
84 |
interactive=True,
|
85 |
info="Higher values produce more diverse outputs",
|
86 |
+
visible=False # Скрыть слайдер температуры
|
87 |
),
|
88 |
gr.Slider(
|
89 |
label="Max new tokens",
|
|
|
93 |
step=64,
|
94 |
interactive=True,
|
95 |
info="The maximum numbers of new tokens",
|
96 |
+
visible=False # Скрыть слайдер максимального числа токенов
|
97 |
),
|
98 |
gr.Slider(
|
99 |
label="Top-p (nucleus sampling)",
|
|
|
103 |
step=0.05,
|
104 |
interactive=True,
|
105 |
info="Higher values sample more low-probability tokens",
|
106 |
+
visible=False # Скрыть слайдер top-p
|
107 |
),
|
108 |
gr.Slider(
|
109 |
label="Repetition penalty",
|
|
|
113 |
step=0.05,
|
114 |
interactive=True,
|
115 |
info="Penalize repeated tokens",
|
116 |
+
visible=False # Скрыть слайдер штрафа за повторения
|
117 |
+
),
|
118 |
+
gr.Textbox(
|
119 |
+
label="Image Description",
|
120 |
+
placeholder="Describe the scene or character you want to generate.",
|
121 |
+
interactive=True
|
122 |
)
|
123 |
]
|
124 |
|
125 |
+
# Пример выбора
|
126 |
examples = [["Подробнее"], ["Варианты"]]
|
127 |
|
128 |
+
# Интерфейс Gradio
|
129 |
gr.ChatInterface(
|
130 |
fn=generate,
|
131 |
chatbot=gr.Chatbot(show_label=False, show_share_button=True, show_copy_button=True, likeable=True, layout="panel"),
|
|
|
134 |
examples=examples,
|
135 |
concurrency_limit=20,
|
136 |
).launch(show_api=False)
|
137 |
+
|
138 |
+
# Интерфейс для генерации картинок
|
139 |
+
gr.Interface(
|
140 |
+
fn=generate_image,
|
141 |
+
inputs="text",
|
142 |
+
outputs="image",
|
143 |
+
title="Image Generator"
|
144 |
+
).launch()
|