Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import streamlit as st
|
2 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
|
3 |
|
4 |
# Create a class for the session state
|
5 |
class SessionState:
|
@@ -33,10 +34,15 @@ if input_text:
|
|
33 |
history_string = "\n".join(message for role, message in session_state.conversation_history)
|
34 |
|
35 |
# Tokenize the input text and history
|
36 |
-
inputs = tokenizer.encode_plus(history_string,
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
# Generate the response from the model with additional parameters
|
39 |
-
outputs = model.generate(**inputs, max_length=max_length, do_sample=True
|
40 |
|
41 |
# Decode the response
|
42 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
|
|
1 |
import streamlit as st
|
2 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
3 |
+
import torch
|
4 |
|
5 |
# Create a class for the session state
|
6 |
class SessionState:
|
|
|
34 |
history_string = "\n".join(message for role, message in session_state.conversation_history)
|
35 |
|
36 |
# Tokenize the input text and history
|
37 |
+
inputs = tokenizer.encode_plus(history_string, return_tensors="pt")
|
38 |
+
inputs["input_ids"] = torch.cat([inputs["input_ids"], torch.tensor([tokenizer.sep_token_id])], dim=-1)
|
39 |
+
inputs["attention_mask"] = torch.cat([inputs["attention_mask"], torch.tensor([1])], dim=-1)
|
40 |
+
inputs = tokenizer.encode_plus(input_text, return_tensors="pt", add_special_tokens=False)
|
41 |
+
inputs["input_ids"] = torch.cat([inputs["input_ids"], inputs["input_ids"]], dim=-1)
|
42 |
+
inputs["attention_mask"] = torch.cat([inputs["attention_mask"], inputs["attention_mask"]], dim=-1)
|
43 |
|
44 |
# Generate the response from the model with additional parameters
|
45 |
+
outputs = model.generate(**inputs, max_length=max_length, do_sample=True, temperature=temperature)
|
46 |
|
47 |
# Decode the response
|
48 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|