xalavapridi / app.py
MrKustic's picture
Update app.py
5e18796 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
MODEL_NAME = "t-bank-ai/RuDialoGPT-small"
print("Загружаем модель и токенизатор...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
# Убедимся, что токенизатор и модель используют одинаковый словарь
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
# Если в Spaces доступен GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()
def chat(user_input):
# Формируем промпт
prompt = f"User: {user_input}\nAssistant:"
try:
# Токенизируем с явным указанием параметров
inputs = tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512, # Ограничиваем длину входного текста
add_special_tokens=True
)
# Переносим тензоры на нужное устройство
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
# Генерация с обработкой ошибок
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=200, # Ограничиваем длину выходного текста
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
top_p=0.9,
temperature=0.7,
num_return_sequences=1,
no_repeat_ngram_size=3 # Избегаем повторений
)
# Декодируем результат
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Убираем исходный промпт из ответа
response = generated_text.split("Assistant:")[-1].strip()
return response if response else "Извините, не удалось сгенерировать ответ."
except Exception as e:
print(f"Ошибка при генерации: {str(e)}")
return f"Произошла ошибка при обработке запроса: {str(e)}"
# Создаем интерфейс Gradio
iface = gr.Interface(
fn=chat,
inputs=gr.Textbox(
lines=2,
placeholder="Например: Привет, как дела?",
label="Введите сообщение"
),
outputs=gr.Textbox(label="Ответ модели"),
title="RuDialoGPT-small Chat",
description="Диалоговый чат на базе модели t-bank-ai/RuDialoGPT-small"
)
if __name__ == "__main__":
iface.launch()