MrKustic commited on
Commit
5e18796
·
verified ·
1 Parent(s): d1e8d7b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -28
app.py CHANGED
@@ -8,46 +8,71 @@ print("Загружаем модель и токенизатор...")
8
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
10
 
11
- # Если в Spaces доступен GPU, переводим модель на него
12
- device = 0 if torch.cuda.is_available() else -1
13
- if device == 0:
14
- model = model.to("cuda")
15
 
 
 
 
16
  model.eval()
17
 
18
  def chat(user_input):
19
- # Дополнительно можно добавить обозначение конца строки для корректного завершения генерации
20
- input_with_eos = user_input + tokenizer.eos_token
21
 
22
- # Токенизируем входной текст
23
- inputs = tokenizer.encode(input_with_eos, return_tensors="pt")
24
- if device >= 0:
25
- inputs = inputs.to("cuda")
 
 
 
 
 
 
26
 
27
- # Генерация ответа
28
- outputs = model.generate(
29
- inputs,
30
- max_length=200, # можно изменить длину генерируемого текста
31
- pad_token_id=tokenizer.eos_token_id,
32
- do_sample=True,
33
- top_p=0.9,
34
- temperature=0.7
35
- )
36
- # Декодируем сгенерированный текст
37
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
38
-
39
- # Если модель возвращает и исходный текст, можно убрать его:
40
- if generated_text.startswith(user_input):
41
- generated_text = generated_text[len(user_input):].strip()
42
- return generated_text
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  # Создаем интерфейс Gradio
45
  iface = gr.Interface(
46
  fn=chat,
47
- inputs=gr.Textbox(lines=2, placeholder="Например: Привет, как дела?", label="Введите сообщение"),
 
 
 
 
48
  outputs=gr.Textbox(label="Ответ модели"),
49
  title="RuDialoGPT-small Chat",
50
- description="Диалоговый чат на базе модели t-bank-ai/RuDialoGPT-small от Hugging Face"
51
  )
52
 
53
  if __name__ == "__main__":
 
8
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
10
 
11
+ # Убедимся, что токенизатор и модель используют одинаковый словарь
12
+ tokenizer.pad_token = tokenizer.eos_token
13
+ model.config.pad_token_id = model.config.eos_token_id
 
14
 
15
+ # Если в Spaces доступен GPU
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ model = model.to(device)
18
  model.eval()
19
 
20
  def chat(user_input):
21
+ # Формируем промпт
22
+ prompt = f"User: {user_input}\nAssistant:"
23
 
24
+ try:
25
+ # Токенизируем с явным указанием параметров
26
+ inputs = tokenizer(
27
+ prompt,
28
+ return_tensors="pt",
29
+ padding=True,
30
+ truncation=True,
31
+ max_length=512, # Ограничиваем длину входного текста
32
+ add_special_tokens=True
33
+ )
34
 
35
+ # Переносим тензоры на нужное устройство
36
+ input_ids = inputs["input_ids"].to(device)
37
+ attention_mask = inputs["attention_mask"].to(device)
38
+
39
+ # Генерация с обработкой ошибок
40
+ with torch.no_grad():
41
+ outputs = model.generate(
42
+ input_ids=input_ids,
43
+ attention_mask=attention_mask,
44
+ max_length=200, # Ограничиваем длину выходного текста
45
+ pad_token_id=tokenizer.eos_token_id,
46
+ do_sample=True,
47
+ top_p=0.9,
48
+ temperature=0.7,
49
+ num_return_sequences=1,
50
+ no_repeat_ngram_size=3 # Избегаем повторений
51
+ )
52
+
53
+ # Декодируем результат
54
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
55
+
56
+ # Убираем исходный промпт из ответа
57
+ response = generated_text.split("Assistant:")[-1].strip()
58
+
59
+ return response if response else "Извините, не удалось сгенерировать ответ."
60
+
61
+ except Exception as e:
62
+ print(f"Ошибка при генерации: {str(e)}")
63
+ return f"Произошла ошибка при обработке запроса: {str(e)}"
64
 
65
  # Создаем интерфейс Gradio
66
  iface = gr.Interface(
67
  fn=chat,
68
+ inputs=gr.Textbox(
69
+ lines=2,
70
+ placeholder="Например: Привет, как дела?",
71
+ label="Введите сообщение"
72
+ ),
73
  outputs=gr.Textbox(label="Ответ модели"),
74
  title="RuDialoGPT-small Chat",
75
+ description="Диалоговый чат на базе модели t-bank-ai/RuDialoGPT-small"
76
  )
77
 
78
  if __name__ == "__main__":