pvyas96 commited on
Commit
96f9572
·
verified ·
1 Parent(s): ef4096a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -2
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
3
 
4
  # Create a class for the session state
5
  class SessionState:
@@ -33,10 +34,15 @@ if input_text:
33
  history_string = "\n".join(message for role, message in session_state.conversation_history)
34
 
35
  # Tokenize the input text and history
36
- inputs = tokenizer.encode_plus(history_string, input_text, return_tensors="pt")
 
 
 
 
 
37
 
38
  # Generate the response from the model with additional parameters
39
- outputs = model.generate(**inputs, max_length=max_length, do_sample=True ,temperature=temperature)
40
 
41
  # Decode the response
42
  response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
 
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import torch
4
 
5
  # Create a class for the session state
6
  class SessionState:
 
34
  history_string = "\n".join(message for role, message in session_state.conversation_history)
35
 
36
  # Tokenize the input text and history
37
+ inputs = tokenizer.encode_plus(history_string, return_tensors="pt")
38
+ inputs["input_ids"] = torch.cat([inputs["input_ids"], torch.tensor([tokenizer.sep_token_id])], dim=-1)
39
+ inputs["attention_mask"] = torch.cat([inputs["attention_mask"], torch.tensor([1])], dim=-1)
40
+ inputs = tokenizer.encode_plus(input_text, return_tensors="pt", add_special_tokens=False)
41
+ inputs["input_ids"] = torch.cat([inputs["input_ids"], inputs["input_ids"]], dim=-1)
42
+ inputs["attention_mask"] = torch.cat([inputs["attention_mask"], inputs["attention_mask"]], dim=-1)
43
 
44
  # Generate the response from the model with additional parameters
45
+ outputs = model.generate(**inputs, max_length=max_length, do_sample=True, temperature=temperature)
46
 
47
  # Decode the response
48
  response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()