Ferhan taha commited on
Commit
c4d1b84
·
verified ·
1 Parent(s): 1d555fd

Upload api.py

Browse files
Files changed (1) hide show
  1. api.py +148 -0
api.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """api.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1XRryfVWG4d_ScN5ADvlZpKmREvTJN3mg
8
+ """
9
+
10
+ import gradio as gr
11
+ import os
12
+
13
+ from langchain_community.document_loaders import PyPDFLoader
14
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
15
+ from langchain_community.vectorstores import Chroma
16
+ from langchain.chains import ConversationalRetrievalChain
17
+ from langchain_community.embeddings import HuggingFaceEmbeddings
18
+ from langchain_community.llms import HuggingFacePipeline
19
+ from langchain.chains import ConversationChain
20
+ from langchain.memory import ConversationBufferMemory
21
+ from langchain_community.llms import HuggingFaceEndpoint
22
+
23
+ from pathlib import Path
24
+ import chromadb
25
+ from unidecode import unidecode
26
+
27
+ from transformers import AutoTokenizer
28
+ import transformers
29
+ import torch
30
+ import tqdm
31
+ import accelerate
32
+
33
+ def load_doc(file_path):
34
+ loader = PyPDFLoader(file_path)
35
+ pages = loader.load()
36
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size = 600, chunk_overlap = 50)
37
+ doc_splits = text_splitter.split_documents(pages)
38
+ return doc_splits
39
+
40
+ !pip install fpdf
41
+
42
+ splt = load_doc('data.pdf')
43
+
44
+ def initialize_database(file_path):
45
+ # Create list of documents (when valid)
46
+ collection_name = Path(file_path).stem
47
+ # Fix potential issues from naming convention
48
+ ## Remove space
49
+ collection_name = collection_name.replace(" ","-")
50
+ ## Limit lenght to 50 characters
51
+ collection_name = collection_name[:50]
52
+ ## Enforce start and end as alphanumeric character
53
+ if not collection_name[0].isalnum():
54
+ collection_name[0] = 'A'
55
+ if not collection_name[-1].isalnum():
56
+ collection_name[-1] = 'Z'
57
+ # print('list_file_path: ', list_file_path)
58
+ print('Collection name: ', collection_name)
59
+ # Load document and create splits
60
+ doc_splits = load_doc(file_path)
61
+ # Create or load vector database
62
+ # global vector_db
63
+ vector_db = create_db(doc_splits, collection_name)
64
+ return vector_db, collection_name, "Complete!"
65
+
66
+ def create_db(splits, collection_name):
67
+ embedding = HuggingFaceEmbeddings()
68
+ new_client = chromadb.EphemeralClient()
69
+ vectordb = Chroma.from_documents(
70
+ documents=splits,
71
+ embedding=embedding,
72
+ client=new_client,
73
+ collection_name=collection_name,
74
+ # persist_directory=default_persist_directory
75
+ )
76
+ return vectordb
77
+
78
+ vec = initialize_database('data.pdf')
79
+
80
+ vec_cre = create_db(splt, 'data')
81
+ vec_cre
82
+
83
+ vec
84
+
85
+ def initialize_llmchain(temperature, max_tokens, top_k, vector_db):
86
+ memory = ConversationBufferMemory(
87
+ memory_key="chat_history",
88
+ output_key='answer',
89
+ return_messages=True
90
+ )
91
+
92
+ llm = HuggingFaceEndpoint(
93
+ repo_id='mistralai/Mixtral-8x7B-Instruct-v0.1',
94
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True},
95
+ temperature = temperature,
96
+ max_new_tokens = max_tokens,
97
+ top_k = top_k,
98
+ load_in_8bit = True
99
+ )
100
+ retriever=vector_db.as_retriever()
101
+ qa_chain = ConversationalRetrievalChain.from_llm(
102
+ llm,
103
+ retriever=retriever,
104
+ chain_type="stuff",
105
+ memory=memory,
106
+ # combine_docs_chain_kwargs={"prompt": your_prompt})
107
+ return_source_documents=True,
108
+ #return_generated_question=False,
109
+ verbose=False,
110
+ )
111
+ return qa_chain
112
+
113
+ qa = initialize_llmchain(0.7, 1024, 1, vec_cre)
114
+
115
+ def format_chat_history(message, chat_history):
116
+ formatted_chat_history = []
117
+ for user_message, bot_message in chat_history:
118
+ formatted_chat_history.append(f"User: {user_message}")
119
+ formatted_chat_history.append(f"Assistant: {bot_message}")
120
+ return formatted_chat_history
121
+
122
+ def conversation(message, history):
123
+ formatted_chat_history = format_chat_history(message, history)
124
+ #print("formatted_chat_history",formatted_chat_history)
125
+
126
+ # Generate response using QA chain
127
+ response = qa({"question": message, "chat_history": formatted_chat_history})
128
+ response_answer = response["answer"]
129
+ if response_answer.find("Helpful Answer:") != -1:
130
+ response_answer = response_answer.split("Helpful Answer:")[-1]
131
+ response_sources = response["source_documents"]
132
+ response_source1 = response_sources[0].page_content.strip()
133
+ response_source2 = response_sources[1].page_content.strip()
134
+ response_source3 = response_sources[2].page_content.strip()
135
+ # Langchain sources are zero-based
136
+ response_source1_page = response_sources[0].metadata["page"] + 1
137
+ response_source2_page = response_sources[1].metadata["page"] + 1
138
+ response_source3_page = response_sources[2].metadata["page"] + 1
139
+ # print ('chat response: ', response_answer)
140
+ # print('DB source', response_sources)
141
+
142
+ # Append user message and response to chat history
143
+ # return gr.update(value=""), new_history, response_sources[0], response_sources[1]
144
+ return response_answer
145
+
146
+ conversation("what is dat gov ma", "")
147
+
148
+ gr.ChatInterface(conversation).launch()