hkoppen commited on
Commit
bb97980
·
verified ·
1 Parent(s): 7dc110c

Update document_qa_engine.py

Browse files
Files changed (1) hide show
  1. document_qa_engine.py +142 -141
document_qa_engine.py CHANGED
@@ -1,141 +1,142 @@
1
- from typing import List
2
-
3
- from haystack.dataclasses import ChatMessage
4
- from pypdf import PdfReader
5
- from haystack.utils import Secret
6
- from haystack import Pipeline, Document, component
7
-
8
- from haystack.components.preprocessors import DocumentCleaner, DocumentSplitter
9
- from haystack.components.writers import DocumentWriter
10
- from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder
11
- from haystack.document_stores.in_memory import InMemoryDocumentStore
12
- from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever
13
- from haystack.components.builders import DynamicChatPromptBuilder
14
- from haystack.components.generators.chat import OpenAIChatGenerator, HuggingFaceTGIChatGenerator
15
- from haystack.document_stores.types import DuplicatePolicy
16
-
17
- SENTENCE_RETREIVER_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
18
-
19
- MAX_TOKENS = 500
20
-
21
- template = """
22
- As a professional HR recruiter given the following information, answer the question shortly and concisely in 1 or 2 sentences.
23
-
24
- Context:
25
- {% for document in documents %}
26
- {{ document.content }}
27
- {% endfor %}
28
-
29
- Question: {{question}}
30
- Answer:
31
- """
32
-
33
-
34
- @component
35
- class UploadedFileConverter:
36
- """
37
- A component to convert uploaded PDF files to Documents
38
- """
39
-
40
- @component.output_types(documents=List[Document])
41
- def run(self, uploaded_file):
42
- pdf = PdfReader(uploaded_file)
43
- documents = []
44
- # uploaded file name without .pdf at the end and with _ and page number at the end
45
- name = uploaded_file.name.rstrip('.PDF') + '_'
46
- for page in pdf.pages:
47
- documents.append(
48
- Document(
49
- content=page.extract_text(),
50
- meta={'name': name + f"_{page.page_number}"}))
51
- return {"documents": documents}
52
-
53
-
54
- def create_ingestion_pipeline(document_store):
55
- doc_embedder = SentenceTransformersDocumentEmbedder(model=SENTENCE_RETREIVER_MODEL)
56
- doc_embedder.warm_up()
57
-
58
- pipeline = Pipeline()
59
- pipeline.add_component("converter", UploadedFileConverter())
60
- pipeline.add_component("cleaner", DocumentCleaner())
61
- pipeline.add_component("splitter",
62
- DocumentSplitter(split_by="passage", split_length=100, split_overlap=10))
63
- pipeline.add_component("embedder", doc_embedder)
64
- pipeline.add_component("writer",
65
- DocumentWriter(document_store=document_store, policy=DuplicatePolicy.OVERWRITE))
66
-
67
- pipeline.connect("converter", "cleaner")
68
- pipeline.connect("cleaner", "splitter")
69
- pipeline.connect("splitter", "embedder")
70
- pipeline.connect("embedder", "writer")
71
- return pipeline
72
-
73
-
74
- def create_inference_pipeline(document_store, model_name, api_key):
75
- if model_name == "local LLM":
76
- generator = OpenAIChatGenerator(api_key=Secret.from_token("<local LLM doesn't need an API key>"),
77
- model=model_name,
78
- api_base_url="http://localhost:1234/v1",
79
- generation_kwargs={"max_tokens": MAX_TOKENS}
80
- )
81
- elif "gpt" in model_name:
82
- generator = OpenAIChatGenerator(api_key=Secret.from_token(api_key), model=model_name,
83
- generation_kwargs={"max_tokens": MAX_TOKENS, "stream": False}
84
- )
85
- else:
86
- generator = HuggingFaceTGIChatGenerator(token=Secret.from_token(api_key), model=model_name,
87
- generation_kwargs={"max_new_tokens": MAX_TOKENS}
88
- )
89
- pipeline = Pipeline()
90
- pipeline.add_component("text_embedder",
91
- SentenceTransformersTextEmbedder(model=SENTENCE_RETREIVER_MODEL))
92
- pipeline.add_component("retriever", InMemoryEmbeddingRetriever(document_store, top_k=3))
93
- pipeline.add_component("prompt_builder",
94
- DynamicChatPromptBuilder(runtime_variables=["query", "documents"]))
95
- pipeline.add_component("llm", generator)
96
- pipeline.connect("text_embedder.embedding", "retriever.query_embedding")
97
- pipeline.connect("retriever.documents", "prompt_builder.documents")
98
- pipeline.connect("prompt_builder.prompt", "llm.messages")
99
-
100
- return pipeline
101
-
102
-
103
- class DocumentQAEngine:
104
- def __init__(self,
105
- model_name,
106
- api_key=None
107
- ):
108
- self.api_key = api_key
109
- self.model_name = model_name
110
- document_store = InMemoryDocumentStore()
111
- self.chunks = []
112
- self.inference_pipeline = create_inference_pipeline(document_store, model_name, api_key)
113
- self.pdf_ingestion_pipeline = create_ingestion_pipeline(document_store)
114
-
115
- def ingest_pdf(self, uploaded_file):
116
- self.pdf_ingestion_pipeline.run({"converter": {"uploaded_file": uploaded_file}})
117
-
118
- def inference(self, query, input_messages: List[dict]):
119
- system_message = ChatMessage.from_system(
120
- "You are a professional HR recruiter that answers questions based on the content of the uploaded CV. in 1 or 2 sentences.")
121
- messages = [system_message]
122
- for message in input_messages:
123
- if message["role"] == "user":
124
- messages.append(ChatMessage.from_system(message["content"]))
125
- else:
126
- messages.append(
127
- ChatMessage.from_user(message["content"]))
128
- messages.append(ChatMessage.from_user("""
129
- Relevant information from the uploaded CV:
130
- {% for doc in documents %}
131
- {{ doc.content }}
132
- {% endfor %}
133
-
134
- \nQuestion: {{query}}
135
- \nAnswer:
136
- """))
137
- res = self.inference_pipeline.run(data={"text_embedder": {"text": query},
138
- "prompt_builder": {"prompt_source": messages,
139
- "query": query
140
- }})
141
- return res["llm"]["replies"][0].content
 
 
1
+ from typing import List
2
+
3
+ from haystack.dataclasses import ChatMessage
4
+ from pypdf import PdfReader
5
+ from haystack.utils import Secret
6
+ from haystack import Pipeline, Document, component
7
+
8
+ from haystack.components.preprocessors import DocumentCleaner, DocumentSplitter
9
+ from haystack.components.writers import DocumentWriter
10
+ from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder
11
+ from haystack.document_stores.in_memory import InMemoryDocumentStore
12
+ from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever
13
+ from haystack.components.builders import DynamicChatPromptBuilder
14
+ from haystack.components.generators.chat import OpenAIChatGenerator, HuggingFaceTGIChatGenerator
15
+ from haystack.document_stores.types import DuplicatePolicy
16
+
17
+ SENTENCE_RETREIVER_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
18
+
19
+ MAX_TOKENS = 500
20
+
21
+ template = """
22
+ As a professional HR recruiter given the following information, answer the question shortly and concisely in 1 or 2 sentences.
23
+
24
+ Context:
25
+ {% for document in documents %}
26
+ {{ document.content }}
27
+ {% endfor %}
28
+
29
+ Question: {{question}}
30
+ Answer:
31
+ """
32
+
33
+
34
+ @component
35
+ class UploadedFileConverter:
36
+ """
37
+ A component to convert uploaded PDF files to Documents
38
+ """
39
+
40
+ @component.output_types(documents=List[Document])
41
+ def run(self, uploaded_file):
42
+ pdf = PdfReader(uploaded_file)
43
+ documents = []
44
+ # uploaded file name without .pdf at the end and with _ and page number at the end
45
+ name = uploaded_file.name.rstrip('.PDF') + '_'
46
+ for page in pdf.pages:
47
+ documents.append(
48
+ Document(
49
+ content=page.extract_text(),
50
+ meta={'name': name + f"_{page.page_number}"}))
51
+ return {"documents": documents}
52
+
53
+
54
+ def create_ingestion_pipeline(document_store):
55
+ doc_embedder = SentenceTransformersDocumentEmbedder(model=SENTENCE_RETREIVER_MODEL)
56
+ doc_embedder.warm_up()
57
+
58
+ pipeline = Pipeline()
59
+ pipeline.add_component("converter", UploadedFileConverter())
60
+ pipeline.add_component("cleaner", DocumentCleaner())
61
+ pipeline.add_component("splitter",
62
+ DocumentSplitter(split_by="passage", split_length=100, split_overlap=10))
63
+ pipeline.add_component("embedder", doc_embedder)
64
+ pipeline.add_component("writer",
65
+ DocumentWriter(document_store=document_store, policy=DuplicatePolicy.OVERWRITE))
66
+
67
+ pipeline.connect("converter", "cleaner")
68
+ pipeline.connect("cleaner", "splitter")
69
+ pipeline.connect("splitter", "embedder")
70
+ pipeline.connect("embedder", "writer")
71
+ return pipeline
72
+
73
+
74
+ def create_inference_pipeline(document_store, model_name, api_key):
75
+ if model_name == "local LLM":
76
+ generator = OpenAIChatGenerator(api_key=Secret.from_token("<local LLM doesn't need an API key>"),
77
+ model=model_name,
78
+ api_base_url="http://localhost:1234/v1",
79
+ generation_kwargs={"max_tokens": MAX_TOKENS}
80
+ )
81
+ elif "gpt" in model_name:
82
+ generator = OpenAIChatGenerator(api_key=Secret.from_token(api_key), model=model_name,
83
+ generation_kwargs={"max_tokens": MAX_TOKENS},
84
+ streaming_callback=lambda x: print(x),
85
+ )
86
+ else:
87
+ generator = HuggingFaceTGIChatGenerator(token=Secret.from_token(api_key), model=model_name,
88
+ generation_kwargs={"max_new_tokens": MAX_TOKENS}
89
+ )
90
+ pipeline = Pipeline()
91
+ pipeline.add_component("text_embedder",
92
+ SentenceTransformersTextEmbedder(model=SENTENCE_RETREIVER_MODEL))
93
+ pipeline.add_component("retriever", InMemoryEmbeddingRetriever(document_store, top_k=3))
94
+ pipeline.add_component("prompt_builder",
95
+ DynamicChatPromptBuilder(runtime_variables=["query", "documents"]))
96
+ pipeline.add_component("llm", generator)
97
+ pipeline.connect("text_embedder.embedding", "retriever.query_embedding")
98
+ pipeline.connect("retriever.documents", "prompt_builder.documents")
99
+ pipeline.connect("prompt_builder.prompt", "llm.messages")
100
+
101
+ return pipeline
102
+
103
+
104
+ class DocumentQAEngine:
105
+ def __init__(self,
106
+ model_name,
107
+ api_key=None
108
+ ):
109
+ self.api_key = api_key
110
+ self.model_name = model_name
111
+ document_store = InMemoryDocumentStore()
112
+ self.chunks = []
113
+ self.inference_pipeline = create_inference_pipeline(document_store, model_name, api_key)
114
+ self.pdf_ingestion_pipeline = create_ingestion_pipeline(document_store)
115
+
116
+ def ingest_pdf(self, uploaded_file):
117
+ self.pdf_ingestion_pipeline.run({"converter": {"uploaded_file": uploaded_file}})
118
+
119
+ def inference(self, query, input_messages: List[dict]):
120
+ system_message = ChatMessage.from_system(
121
+ "You are a professional HR recruiter that answers questions based on the content of the uploaded CV. in 1 or 2 sentences.")
122
+ messages = [system_message]
123
+ for message in input_messages:
124
+ if message["role"] == "user":
125
+ messages.append(ChatMessage.from_system(message["content"]))
126
+ else:
127
+ messages.append(
128
+ ChatMessage.from_user(message["content"]))
129
+ messages.append(ChatMessage.from_user("""
130
+ Relevant information from the uploaded CV:
131
+ {% for doc in documents %}
132
+ {{ doc.content }}
133
+ {% endfor %}
134
+
135
+ \nQuestion: {{query}}
136
+ \nAnswer:
137
+ """))
138
+ res = self.inference_pipeline.run(data={"text_embedder": {"text": query},
139
+ "prompt_builder": {"prompt_source": messages,
140
+ "query": query
141
+ }})
142
+ return res["llm"]["replies"][0].content