zoya23 commited on
Commit
58fd569
·
verified ·
1 Parent(s): e96fbbe

Update agents/language_agent.py

Browse files
Files changed (1) hide show
  1. agents/language_agent.py +3 -14
agents/language_agent.py CHANGED
@@ -1,21 +1,10 @@
1
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
2
  from langchain.llms import HuggingFacePipeline
3
  from langchain.chains import RetrievalQA
4
  from agents.retriever_agent import create_vectorstore
5
 
6
  def generate_brief(question):
7
- model_id = "google/flan-t5-small"
8
-
9
- tokenizer = AutoTokenizer.from_pretrained(model_id)
10
- model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
11
-
12
- pipe = pipeline(
13
- "text2text-generation",
14
- model=model,
15
- tokenizer=tokenizer,
16
- max_length=512,
17
- temperature=0.7
18
- )
19
 
20
  llm = HuggingFacePipeline(pipeline=pipe)
21
 
@@ -23,4 +12,4 @@ def generate_brief(question):
23
  retriever = vectordb.as_retriever()
24
 
25
  qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever)
26
- return qa_chain.run(question)
 
1
+ from transformers import pipeline
2
  from langchain.llms import HuggingFacePipeline
3
  from langchain.chains import RetrievalQA
4
  from agents.retriever_agent import create_vectorstore
5
 
6
  def generate_brief(question):
7
+ pipe = pipeline("text2text-generation", model="google/flan-t5-small")
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  llm = HuggingFacePipeline(pipeline=pipe)
10
 
 
12
  retriever = vectordb.as_retriever()
13
 
14
  qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever)
15
+ return qa_chain.run(question)