Spaces:
Sleeping
Sleeping
from chromadb import PersistentClient | |
from dotenv import load_dotenv | |
from enum import Enum | |
import plotly.graph_objects as go | |
from langchain.document_loaders import DirectoryLoader, TextLoader | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.schema import Document | |
from langchain_openai import OpenAIEmbeddings, ChatOpenAI | |
from langchain_chroma import Chroma | |
from langchain.memory import ConversationBufferMemory | |
from langchain.chains import ConversationalRetrievalChain | |
import numpy as np | |
import os | |
from pathlib import Path | |
from sklearn.manifold import TSNE | |
from typing import Any, List, Tuple, Generator | |
cur_path = Path(__file__) | |
env_path = cur_path.parent.parent.parent.parent / '.env' | |
assert env_path.exists(), f"Please add an .env to the root project path" | |
load_dotenv(dotenv_path=env_path) | |
class Rag(Enum): | |
GPT_MODEL = "gpt-4o-mini" | |
HUG_MODEL = "sentence-transformers/all-MiniLM-L6-v2" | |
EMBED_MODEL = OpenAIEmbeddings() | |
DB_NAME = "vector_db" | |
def add_metadata(doc: Document, doc_type: str) -> Document: | |
""" | |
Add metadata to a Document object. | |
:param doc: The Document object to add metadata to. | |
:type doc: Document | |
:param doc_type: The type of document to be added as metadata. | |
:type doc_type: str | |
:return: The Document object with added metadata. | |
:rtype: Document | |
""" | |
doc.metadata["doc_type"] = doc_type | |
return doc | |
def get_chunks(folders: Generator[Path, None, None], file_ext='.txt') -> List[Document]: | |
""" | |
Load documents from specified folders, add metadata, and split them into chunks. | |
:param folders: List of folder paths containing documents. | |
:type folders: List[str] | |
:param file_ext: | |
The file extension to get from a local knowledge base (e.g. '.txt') | |
:type file_ext: str | |
:return: List of document chunks. | |
:rtype: List[Document] | |
""" | |
text_loader_kwargs = {'encoding': 'utf-8'} | |
documents = [] | |
for folder in folders: | |
doc_type = os.path.basename(folder) | |
loader = DirectoryLoader( | |
folder, glob=f"**/*{file_ext}", loader_cls=TextLoader, loader_kwargs=text_loader_kwargs | |
) | |
folder_docs = loader.load() | |
documents.extend([add_metadata(doc, doc_type) for doc in folder_docs]) | |
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
chunks = text_splitter.split_documents(documents) | |
return chunks | |
def create_vector_db(db_name: str, chunks: List[Document], embeddings: Any) -> Any: | |
""" | |
Create a vector database from document chunks. | |
:param db_name: Name of the database to create. | |
:type db_name: str | |
:param chunks: List of document chunks. | |
:type chunks: List[Document] | |
:param embeddings: Embedding function to use. | |
:type embeddings: Any | |
:return: Created vector store. | |
:rtype: Any | |
""" | |
# Delete if already exists | |
if os.path.exists(db_name): | |
Chroma(persist_directory=db_name, embedding_function=embeddings).delete_collection() | |
# Create vectorstore | |
vectorstore = Chroma.from_documents(documents=chunks, embedding=embeddings, persist_directory=db_name) | |
return vectorstore | |
def get_local_vector_db(path: str) -> Any: | |
""" | |
Get a local vector database. | |
:param path: Path to the local vector database. | |
:type path: str | |
:return: Persistent client for the vector database. | |
:rtype: Any | |
""" | |
return PersistentClient(path=path) | |
def get_vector_db_info(vector_store: Any) -> None: | |
""" | |
Print information about the vector database. | |
:param vector_store: Vector store to get information from. | |
:type vector_store: Any | |
""" | |
collection = vector_store._collection | |
count = collection.count() | |
sample_embedding = collection.get(limit=1, include=["embeddings"])["embeddings"][0] | |
dimensions = len(sample_embedding) | |
print(f"There are {count:,} vectors with {dimensions:,} dimensions in the vector store") | |
def get_plot_data(collection: Any) -> Tuple[np.ndarray, List[str], List[str], List[str]]: | |
""" | |
Get plot data from a collection. | |
:param collection: Collection to get data from. | |
:type collection: Any | |
:return: Tuple containing vectors, colors, document types, and documents. | |
:rtype: Tuple[np.ndarray, List[str], List[str], List[str]] | |
""" | |
result = collection.get(include=['embeddings', 'documents', 'metadatas']) | |
vectors = np.array(result['embeddings']) | |
documents = result['documents'] | |
metadatas = result['metadatas'] | |
doc_types = [metadata['doc_type'] for metadata in metadatas] | |
colors = [['blue', 'green', 'red', 'orange'][['products', 'employees', 'contracts', 'company'].index(t)] for t in | |
doc_types] | |
return vectors, colors, doc_types, documents | |
def get_2d_plot(collection: Any) -> go.Figure: | |
""" | |
Generate a 2D plot of the vector store. | |
:param collection: Collection to generate plot from. | |
:type collection: Any | |
:return: 2D scatter plot figure. | |
:rtype: go.Figure | |
""" | |
vectors, colors, doc_types, documents = get_plot_data(collection) | |
tsne = TSNE(n_components=2, random_state=42) | |
reduced_vectors = tsne.fit_transform(vectors) | |
fig = go.Figure(data=[go.Scatter( | |
x=reduced_vectors[:, 0], | |
y=reduced_vectors[:, 1], | |
mode='markers', | |
marker=dict(size=5, color=colors, opacity=0.8), | |
text=[f"Type: {t}<br>Text: {d[:100]}..." for t, d in zip(doc_types, documents)], | |
hoverinfo='text' | |
)]) | |
fig.update_layout( | |
title='2D Chroma Vector Store Visualization', | |
scene=dict(xaxis_title='x', yaxis_title='y'), | |
width=800, | |
height=600, | |
margin=dict(r=20, b=10, l=10, t=40) | |
) | |
return fig | |
def get_3d_plot(collection: Any) -> go.Figure: | |
""" | |
Generate a 3D plot of the vector store. | |
:param collection: Collection to generate plot from. | |
:type collection: Any | |
:return: 3D scatter plot figure. | |
:rtype: go.Figure | |
""" | |
vectors, colors, doc_types, documents = get_plot_data(collection) | |
tsne = TSNE(n_components=3, random_state=42) | |
reduced_vectors = tsne.fit_transform(vectors) | |
fig = go.Figure(data=[go.Scatter3d( | |
x=reduced_vectors[:, 0], | |
y=reduced_vectors[:, 1], | |
z=reduced_vectors[:, 2], | |
mode='markers', | |
marker=dict(size=5, color=colors, opacity=0.8), | |
text=[f"Type: {t}<br>Text: {d[:100]}..." for t, d in zip(doc_types, documents)], | |
hoverinfo='text' | |
)]) | |
fig.update_layout( | |
title='3D Chroma Vector Store Visualization', | |
scene=dict(xaxis_title='x', yaxis_title='y', zaxis_title='z'), | |
width=900, | |
height=700, | |
margin=dict(r=20, b=10, l=10, t=40) | |
) | |
return fig | |
def get_conversation_chain(vectorstore: Any) -> ConversationalRetrievalChain: | |
""" | |
Create a conversation chain using the vector store. | |
:param vectorstore: Vector store to use in the conversation chain. | |
:type vectorstore: Any | |
:return: Conversational retrieval chain. | |
:rtype: ConversationalRetrievalChain | |
""" | |
llm = ChatOpenAI(temperature=0.7, model_name=Rag.GPT_MODEL.value) | |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key='answer') | |
retriever = vectorstore.as_retriever(search_kwargs={"k": 25}) | |
conversation_chain = ConversationalRetrievalChain.from_llm( | |
llm=llm, | |
retriever=retriever, | |
memory=memory, | |
return_source_documents=True, | |
) | |
return conversation_chain | |
def get_lang_doc(document_text, doc_id, metadata=None, encoding='utf-8'): | |
""" | |
Build a langchain Document that can be used to create a chroma database | |
:type document_text: str | |
:param document_text: | |
The text to add to a document object | |
:type doc_id: str | |
:param doc_id: | |
The document id to include. | |
:type metadata: dict | |
:param metadata: | |
A dictionary of metadata to associate to the document object. This will help filter an item from a | |
vector database. | |
:type encoding: string | |
:param encoding: | |
The type of encoding to use for loading the text. | |
""" | |
return Document( | |
page_content=document_text, | |
id=doc_id, | |
metadata=metadata, | |
encoding=encoding, | |
) | |