random2222 commited on
Commit
c7c9218
·
verified ·
1 Parent(s): b0ad4e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -7
app.py CHANGED
@@ -5,7 +5,8 @@ from langchain_community.embeddings import HuggingFaceEmbeddings
5
  from langchain_community.document_loaders import PyMuPDFLoader
6
  from langchain_text_splitters import CharacterTextSplitter
7
  from langchain.chains import RetrievalQA
8
- from transformers import pipeline # Local model execution
 
9
 
10
  def create_qa_system():
11
  try:
@@ -34,19 +35,27 @@ def create_qa_system():
34
  # Build vector store
35
  db = FAISS.from_documents(texts, embeddings)
36
 
37
- # Local model pipeline
38
- qa_pipeline = pipeline(
 
 
 
 
39
  "text2text-generation",
40
- model="google/flan-t5-small", # Runs locally
41
- device=-1, # Use CPU
42
  max_length=128,
43
- temperature=0.2
 
44
  )
45
 
 
 
46
  return RetrievalQA.from_chain_type(
47
- llm=qa_pipeline,
48
  chain_type="stuff",
49
  retriever=db.as_retriever(search_kwargs={"k": 2}))
 
50
  except Exception as e:
51
  raise gr.Error(f"Initialization failed: {str(e)}")
52
 
 
5
  from langchain_community.document_loaders import PyMuPDFLoader
6
  from langchain_text_splitters import CharacterTextSplitter
7
  from langchain.chains import RetrievalQA
8
+ from langchain_community.llms import HuggingFacePipeline
9
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
10
 
11
  def create_qa_system():
12
  try:
 
35
  # Build vector store
36
  db = FAISS.from_documents(texts, embeddings)
37
 
38
+ # Initialize local model with LangChain wrapper
39
+ model_name = "google/flan-t5-small"
40
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
41
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
42
+
43
+ pipe = pipeline(
44
  "text2text-generation",
45
+ model=model,
46
+ tokenizer=tokenizer,
47
  max_length=128,
48
+ temperature=0.2,
49
+ device_map="auto"
50
  )
51
 
52
+ llm = HuggingFacePipeline(pipeline=pipe)
53
+
54
  return RetrievalQA.from_chain_type(
55
+ llm=llm,
56
  chain_type="stuff",
57
  retriever=db.as_retriever(search_kwargs={"k": 2}))
58
+ )
59
  except Exception as e:
60
  raise gr.Error(f"Initialization failed: {str(e)}")
61