Athspi commited on
Commit
aa37cb9
·
verified ·
1 Parent(s): 8c2acb9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -6
app.py CHANGED
@@ -11,6 +11,10 @@ model = AutoModelForCausalLM.from_pretrained(
11
  device_map="auto"
12
  )
13
 
 
 
 
 
14
  # System prompt
15
  system_prompt = "You are a friendly assistant named FastLlama."
16
 
@@ -25,21 +29,27 @@ def respond(message: str, history: list):
25
  # Format the prompt with chat history
26
  full_prompt = format_prompt(message, history)
27
 
28
- # Tokenize input
29
- inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
 
 
 
 
 
30
 
31
- # Generate response
32
  output = model.generate(
33
  inputs.input_ids,
 
34
  max_new_tokens=256,
35
  temperature=0.7,
36
  top_p=0.9,
37
  repetition_penalty=1.1,
38
  do_sample=True,
39
- pad_token_id=tokenizer.eos_token_id
40
  )
41
 
42
- # Decode response
43
  response = tokenizer.decode(
44
  output[0][inputs.input_ids.shape[-1]:],
45
  skip_special_tokens=True
@@ -60,6 +70,5 @@ chat = gr.ChatInterface(
60
  cache_examples=False
61
  )
62
 
63
- # Launch the app
64
  if __name__ == "__main__":
65
  chat.launch(server_name="0.0.0.0")
 
11
  device_map="auto"
12
  )
13
 
14
+ # Explicitly set padding token
15
+ if tokenizer.pad_token is None:
16
+ tokenizer.pad_token = tokenizer.eos_token
17
+
18
  # System prompt
19
  system_prompt = "You are a friendly assistant named FastLlama."
20
 
 
29
  # Format the prompt with chat history
30
  full_prompt = format_prompt(message, history)
31
 
32
+ # Tokenize input with attention mask
33
+ inputs = tokenizer(
34
+ full_prompt,
35
+ return_tensors="pt",
36
+ padding=True,
37
+ truncation=True
38
+ ).to(model.device)
39
 
40
+ # Generate response with attention mask
41
  output = model.generate(
42
  inputs.input_ids,
43
+ attention_mask=inputs.attention_mask,
44
  max_new_tokens=256,
45
  temperature=0.7,
46
  top_p=0.9,
47
  repetition_penalty=1.1,
48
  do_sample=True,
49
+ pad_token_id=tokenizer.pad_token_id
50
  )
51
 
52
+ # Decode response while skipping special tokens
53
  response = tokenizer.decode(
54
  output[0][inputs.input_ids.shape[-1]:],
55
  skip_special_tokens=True
 
70
  cache_examples=False
71
  )
72
 
 
73
  if __name__ == "__main__":
74
  chat.launch(server_name="0.0.0.0")