gaur3009 commited on
Commit
5a36579
Β·
verified Β·
1 Parent(s): 82957ca

Update llm.py

Browse files
Files changed (1) hide show
  1. llm.py +13 -16
llm.py CHANGED
@@ -1,23 +1,20 @@
1
- # llm.py
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
- tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
5
- model = AutoModelForCausalLM.from_pretrained("distilgpt2")
6
-
7
- # Fix: add pad_token_id if missing
8
- if tokenizer.pad_token_id is None:
9
- tokenizer.pad_token_id = tokenizer.eos_token_id
10
 
11
  def generate_answer(context, question):
12
- prompt = f"Context:\n{context}\n\nQuestion: {question}\nAnswer:"
13
- # Limit to last N chars if prompt is too long
14
- prompt = prompt[-1000:]
 
 
 
15
 
16
- inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=1024)
17
  outputs = model.generate(
18
- inputs["input_ids"],
19
- max_new_tokens=50,
20
- do_sample=True,
21
- pad_token_id=tokenizer.eos_token_id # fix warning
22
  )
23
  return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
 
1
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 
2
 
3
+ tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
4
+ model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
 
 
 
 
5
 
6
  def generate_answer(context, question):
7
+ prompt = f"""Context:
8
+ {context}
9
+
10
+ Based on the above context, answer the question:
11
+ Question: {question}
12
+ Answer:"""
13
 
14
+ inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=512)
15
  outputs = model.generate(
16
+ **inputs,
17
+ max_new_tokens=80,
18
+ do_sample=False # deterministic
 
19
  )
20
  return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()