gaur3009 commited on
Commit
cc8d953
Β·
verified Β·
1 Parent(s): fad0ef9

Update llm.py

Browse files
Files changed (1) hide show
  1. llm.py +16 -36
llm.py CHANGED
@@ -1,40 +1,20 @@
1
- from transformers import pipeline
2
- import torch
3
 
4
- qa_pipeline = pipeline(
5
- "text2text-generation",
6
- model="google/flan-t5-base",
7
- device=device
8
- )
9
 
10
  def generate_answer(context, question):
11
- # Handle empty context
12
- if context == "No relevant context found.":
13
- prompt = f"""
14
- You are a helpful AI assistant. Answer the question based on your general knowledge.
15
-
16
- Question: {question}
17
-
18
- Answer as a helpful paragraph:
19
- """
20
- else:
21
- prompt = f"""
22
- You are a helpful AI assistant. Use the context to answer the question.
23
-
24
- Context:
25
- {context}
26
-
27
- Question: {question}
28
-
29
- Answer as a comprehensive paragraph with key details:
30
- """
31
-
32
- result = qa_pipeline(
33
- prompt,
34
- max_length=400,
35
- do_sample=True,
36
- temperature=0.7,
37
- top_p=0.9
38
  )
39
-
40
- return result[0]['generated_text'].strip()
 
1
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 
2
 
3
+ tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
4
+ model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
 
 
 
5
 
6
  def generate_answer(context, question):
7
+ prompt = f"""
8
+ You are a helpful AI assistant.
9
+ Context:
10
+ {context}
11
+ Question: {question}
12
+ Answer as a helpful paragraph:"""
13
+
14
+ inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=512)
15
+ outputs = model.generate(
16
+ **inputs,
17
+ max_new_tokens=100,
18
+ do_sample=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  )
20
+ return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()