import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch MODEL_NAME = "t-bank-ai/RuDialoGPT-small" print("Загружаем модель и токенизатор...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) # Убедимся, что токенизатор и модель используют одинаковый словарь tokenizer.pad_token = tokenizer.eos_token model.config.pad_token_id = model.config.eos_token_id # Если в Spaces доступен GPU device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) model.eval() def chat(user_input): # Формируем промпт prompt = f"User: {user_input}\nAssistant:" try: # Токенизируем с явным указанием параметров inputs = tokenizer( prompt, return_tensors="pt", padding=True, truncation=True, max_length=512, # Ограничиваем длину входного текста add_special_tokens=True ) # Переносим тензоры на нужное устройство input_ids = inputs["input_ids"].to(device) attention_mask = inputs["attention_mask"].to(device) # Генерация с обработкой ошибок with torch.no_grad(): outputs = model.generate( input_ids=input_ids, attention_mask=attention_mask, max_length=200, # Ограничиваем длину выходного текста pad_token_id=tokenizer.eos_token_id, do_sample=True, top_p=0.9, temperature=0.7, num_return_sequences=1, no_repeat_ngram_size=3 # Избегаем повторений ) # Декодируем результат generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Убираем исходный промпт из ответа response = generated_text.split("Assistant:")[-1].strip() return response if response else "Извините, не удалось сгенерировать ответ." except Exception as e: print(f"Ошибка при генерации: {str(e)}") return f"Произошла ошибка при обработке запроса: {str(e)}" # Создаем интерфейс Gradio iface = gr.Interface( fn=chat, inputs=gr.Textbox( lines=2, placeholder="Например: Привет, как дела?", label="Введите сообщение" ), outputs=gr.Textbox(label="Ответ модели"), title="RuDialoGPT-small Chat", description="Диалоговый чат на базе модели t-bank-ai/RuDialoGPT-small" ) if __name__ == "__main__": iface.launch()