gaur3009 commited on
Commit
622f41b
Β·
verified Β·
1 Parent(s): 6ac7d2a

Update llm.py

Browse files
Files changed (1) hide show
  1. llm.py +15 -2
llm.py CHANGED
@@ -1,10 +1,23 @@
 
1
  from transformers import AutoModelForCausalLM, AutoTokenizer
2
 
3
  tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
4
  model = AutoModelForCausalLM.from_pretrained("distilgpt2")
5
 
 
 
 
 
6
  def generate_answer(context, question):
7
  prompt = f"Context:\n{context}\n\nQuestion: {question}\nAnswer:"
8
- inputs = tokenizer.encode(prompt, return_tensors='pt', max_length=1024, truncation=True)
9
- outputs = model.generate(inputs, max_new_tokens=50, do_sample=True)
 
 
 
 
 
 
 
 
10
  return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
 
1
+ # llm.py
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
  tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
5
  model = AutoModelForCausalLM.from_pretrained("distilgpt2")
6
 
7
+ # Fix: add pad_token_id if missing
8
+ if tokenizer.pad_token_id is None:
9
+ tokenizer.pad_token_id = tokenizer.eos_token_id
10
+
11
  def generate_answer(context, question):
12
  prompt = f"Context:\n{context}\n\nQuestion: {question}\nAnswer:"
13
+ # Limit to last N chars if prompt is too long
14
+ prompt = prompt[-1000:]
15
+
16
+ inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=1024)
17
+ outputs = model.generate(
18
+ inputs["input_ids"],
19
+ max_new_tokens=50,
20
+ do_sample=True,
21
+ pad_token_id=tokenizer.eos_token_id # fix warning
22
+ )
23
  return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()