Spaces:
Running
Running
import weaviate | |
from weaviate.connect import ConnectionParams | |
from weaviate.classes.init import AdditionalConfig, Timeout | |
from sentence_transformers import SentenceTransformer | |
from langchain_community.document_loaders import BSHTMLLoader | |
from pathlib import Path | |
from lxml import html | |
import logging | |
from semantic_text_splitter import HuggingFaceTextSplitter | |
from tokenizers import Tokenizer | |
import json | |
import os | |
import re | |
import llama_cpp | |
from llama_cpp import Llama | |
import streamlit as st | |
import subprocess | |
import time | |
try: | |
if 'logging' not in st.session_state: | |
weaviate_logger = logging.getLogger("httpx") | |
weaviate_logger.setLevel(logging.WARNING) | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO) | |
st.session_state.weaviate_logger = weaviate_logger | |
st.session_state.logger = logger | |
else: | |
weaviate_logger = st.session_state.weaviate_logger | |
logger = st.session_state.logger | |
def runStartup(): | |
logger.info("### Running startup.sh") | |
try: | |
#result = subprocess.run("/app/startup.sh",shell=False,capture_output=None, | |
# text=None,timeout=300) | |
#logger.info(f"startup.sh stdout: {result.stdout}") | |
#logger.info(f"startup.sh stderr: {result.stderr}") | |
#logger.info(f"Return code: {result.returncode}") | |
subprocess.Popen(["/app/startup.sh"]) | |
time.sleep(180) | |
except Exception as e: | |
emsg = str(e) | |
logger.ERROR(f"subprocess.run EXCEPTION. e: {emsg}") | |
try: | |
with open("/app/startup.log", "r") as file: | |
content = file.read() | |
print(content) | |
except Exception as e2: | |
emsg = str(e2) | |
logger.ERROR(f"#### Displaying startup.log EXCEPTION. e2: {emsg}") | |
logger.info("### Running startup.sh complete") | |
if 'runStartup' not in st.session_state: | |
st.session_state.runStartup = True | |
runStartup() | |
###################################################################### | |
# MAINLINE | |
# | |
logger.info("#### MAINLINE ENTERED.") | |
# Function to load the CSS file | |
def load_css(file_name): | |
logger.info("#### load_css entered.") | |
with open(file_name) as f: | |
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True) | |
logger.info("#### load_css exited.") | |
# Load the custom CSS | |
if 'load_css' not in st.session_state: | |
load_css(".streamlit/main.css") | |
st.session_state.load_css = True | |
st.markdown("<h1 style='text-align: center; color: #666666;'>Vector Database RAG Proof of Concept</h1>", \ | |
unsafe_allow_html=True) | |
st.markdown("<h6 style='text-align: center; color: #666666;'>V1</h6>", unsafe_allow_html=True) | |
#pathString = "/Users/660565/KPSAllInOne/ProgramFilesX86/WebCopy/DownloadedWebSites/LLMPOC_HTML" | |
pathString = "/app/inputDocs" | |
chunks = [] | |
webpageDocNames = [] | |
page_contentArray = [] | |
webpageChunks = [] | |
webpageTitles = [] | |
webpageChunksDocNames = [] | |
###################################################### | |
# Connect to the Weaviate vector database. | |
#if 'client' not in st.session_state: | |
if 'client' not in st.session_state: | |
logger.info("#### Create Weaviate db client connection.") | |
client = weaviate.WeaviateClient( | |
connection_params=ConnectionParams.from_params( | |
http_host="localhost", | |
http_port="8080", | |
http_secure=False, | |
grpc_host="localhost", | |
grpc_port="50051", | |
grpc_secure=False | |
), | |
additional_config=AdditionalConfig( | |
timeout=Timeout(init=60, query=1800, insert=1800), # Values in seconds | |
) | |
) | |
client.connect() | |
st.session_state.client = client | |
logger.info("#### Create Weaviate db client connection exited.") | |
else: | |
client = st.session_state.client | |
####################################################### | |
# Read each text input file, parse it into a document, | |
# chunk it, collect chunks and document name. | |
if not client.collections.exists("Documents") or not client.collections.exists("Chunks") : | |
logger.info("#### Read and chunk input text files.") | |
for filename in os.listdir(pathString): | |
logger.debug(filename) | |
path = Path(pathString + "/" + filename) | |
filename = filename.rstrip(".html") | |
webpageDocNames.append(filename) | |
htmlLoader = BSHTMLLoader(path,"utf-8") | |
htmlData = htmlLoader.load() | |
title = htmlData[0].metadata['title'] | |
page_content = htmlData[0].page_content | |
# Clean data. Remove multiple newlines, etc. | |
page_content = re.sub(r'\n+', '\n',page_content) | |
page_contentArray.append(page_content) | |
webpageTitles.append(title) | |
max_tokens = 1000 | |
tokenizer = Tokenizer.from_pretrained("bert-base-uncased") | |
logger.debug(f"### tokenizer: {tokenizer}") | |
splitter = HuggingFaceTextSplitter(tokenizer, trim_chunks=True) | |
chunksOnePage = splitter.chunks(page_content, chunk_capacity=50) | |
chunks = [] | |
for chnk in chunksOnePage: | |
logger.debug(f"#### chnk in file: {chnk}") | |
chunks.append(chnk) | |
logger.debug(f"chunks: {chunks}") | |
webpageChunks.append(chunks) | |
webpageChunksDocNames.append(filename + "Chunks") | |
logger.debug(f"### filename, title: {filename}, {title}") | |
logger.debug(f"### webpageDocNames: {webpageDocNames}") | |
logger.info("#### Read and chunk input text files exited.") | |
###################################################### | |
# Create database webpage and chunks collections. | |
#wpCollection = createWebpageCollection() | |
#wpChunksCollection = createChunksCollection() | |
if not client.collections.exists("Documents"): | |
logger.info("#### createWebpageCollection() entered.") | |
#client.collections.delete("Documents") | |
class_obj = { | |
"class": "Documents", | |
"description": "For first attempt at loading a Weviate database.", | |
"vectorizer": "text2vec-transformers", | |
"moduleConfig": { | |
"text2vec-transformers": { | |
"vectorizeClassName": False | |
} | |
}, | |
"vectorIndexType": "hnsw", | |
"vectorIndexConfig": { | |
"distance": "cosine", | |
}, | |
"properties": [ | |
{ | |
"name": "title", | |
"dataType": ["text"], | |
"description": "HTML doc title.", | |
"vectorizer": "text2vec-transformers", | |
"moduleConfig": { | |
"text2vec-transformers": { | |
"vectorizePropertyName": True, | |
"skip": False, | |
"tokenization": "lowercase" | |
} | |
}, | |
"invertedIndexConfig": { | |
"bm25": { | |
"b": 0.75, | |
"k1": 1.2 | |
}, | |
} | |
}, | |
{ | |
"name": "content", | |
"dataType": ["text"], | |
"description": "HTML page content.", | |
"moduleConfig": { | |
"text2vec-transformers": { | |
"vectorizePropertyName": True, | |
"tokenization": "whitespace" | |
} | |
} | |
} | |
] | |
} | |
wpCollection = client.collections.create_from_dict(class_obj) | |
st.session_state.wpCollection = wpCollection | |
logger.info("#### createWebpageCollection() exited.") | |
else: | |
wpCollection = client.collections.get("Documents") | |
st.session_state.wpCollection = wpCollection | |
if not client.collections.exists("Chunks"): | |
logger.info("#### createChunksCollection() entered.") | |
#client.collections.delete("Chunks") | |
class_obj = { | |
"class": "Chunks", | |
"description": "Collection for document chunks.", | |
"vectorizer": "text2vec-transformers", | |
"moduleConfig": { | |
"text2vec-transformers": { | |
"vectorizeClassName": True | |
} | |
}, | |
"vectorIndexType": "hnsw", | |
"vectorIndexConfig": { | |
"distance": "cosine" | |
}, | |
"properties": [ | |
{ | |
"name": "chunk", | |
"dataType": ["text"], | |
"description": "Single webpage chunk.", | |
"vectorizer": "text2vec-transformers", | |
"moduleConfig": { | |
"text2vec-transformers": { | |
"vectorizePropertyName": False, | |
"skip": False, | |
"tokenization": "lowercase" | |
} | |
} | |
}, | |
{ | |
"name": "chunk_index", | |
"dataType": ["int"] | |
}, | |
{ | |
"name": "webpage", | |
"dataType": ["Documents"], | |
"description": "Webpage content chunks.", | |
"invertedIndexConfig": { | |
"bm25": { | |
"b": 0.75, | |
"k1": 1.2 | |
} | |
} | |
} | |
] | |
} | |
wpChunksCollection = client.collections.create_from_dict(class_obj) | |
st.session_state.wpChunksCollection = wpChunksCollection | |
logger.info("#### createChunksCollection() exited.") | |
else: | |
wpChunksCollection = client.collections.get("Chunks") | |
st.session_state.wpChunksCollection = wpChunksCollection | |
########################################################### | |
# Create document and chunks objects in the database. | |
if not client.collections.exists("Documents") : | |
logger.info("#### Create page/doc db objects.") | |
for i, className in enumerate(webpageDocNames): | |
title = webpageTitles[i] | |
logger.debug(f"## className, title: {className}, {title}") | |
# Create Webpage Object | |
page_content = page_contentArray[i] | |
# Insert the document. | |
wpCollectionObj_uuid = wpCollection.data.insert( | |
{ | |
"name": className, | |
"title": title, | |
"content": page_content | |
} | |
) | |
logger.info("#### Create page/doc/db/objects exited.") | |
if not client.collections.exists("Chunks") : | |
logger.info("#### Create chunk db objects.") | |
# Insert the chunks for the document. | |
for i2, chunk in enumerate(webpageChunks): | |
chunk_uuid = wpChunksCollection.data.insert( | |
{ | |
"title": title, | |
"chunk": chunk, | |
"chunk_index": i2, | |
"references": | |
{ | |
"webpage": wpCollectionObj_uuid | |
} | |
} | |
) | |
logger.info("#### Create chunk db objects exited.") | |
################################################################# | |
# Initialize the LLM. | |
model_path = "/app/llama-2-7b-chat.Q4_0.gguf" | |
if 'llm' not in st.session_state: | |
logger.info("### Initializing LLM.") | |
llm = Llama(model_path, | |
#*, | |
n_gpu_layers=0, | |
split_mode=llama_cpp.LLAMA_SPLIT_MODE_LAYER, | |
main_gpu=0, | |
tensor_split=None, | |
vocab_only=False, | |
use_mmap=True, | |
use_mlock=False, | |
kv_overrides=None, | |
seed=llama_cpp.LLAMA_DEFAULT_SEED, | |
n_ctx=512, | |
n_batch=512, | |
n_threads=8, | |
n_threads_batch=16, | |
rope_scaling_type=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, | |
pooling_type=llama_cpp.LLAMA_POOLING_TYPE_UNSPECIFIED, | |
rope_freq_base=0.0, | |
rope_freq_scale=0.0, | |
yarn_ext_factor=-1.0, | |
yarn_attn_factor=1.0, | |
yarn_beta_fast=32.0, | |
yarn_beta_slow=1.0, | |
yarn_orig_ctx=0, | |
logits_all=False, | |
embedding=False, | |
offload_kqv=True, | |
last_n_tokens_size=64, | |
lora_base=None, | |
lora_scale=1.0, | |
lora_path=None, | |
numa=False, | |
chat_format=None, | |
chat_handler=None, | |
draft_model=None, | |
tokenizer=None, | |
type_k=None, | |
type_v=None, | |
verbose=True | |
) | |
st.session_state.llm = llm | |
logger.info("### Initializing LLM exited.") | |
else: | |
llm = st.session_state.llm | |
def getRagData(promptText): | |
logger.info("#### getRagData() entered.") | |
############################################################################### | |
# Initial the the sentence transformer and encode the query prompt. | |
logger.debug(f"#### Encode text query prompt to create vectors. {promptText}") | |
model = SentenceTransformer('/app/multi-qa-MiniLM-L6-cos-v1') | |
vector = model.encode(promptText) | |
vectorList = [] | |
logger.debug("#### Print vectors.") | |
for vec in vector: | |
vectorList.append(vec) | |
logger.debug(f"vectorList: {vectorList[2]}") | |
# Fetch chunks and print chunks. | |
logger.debug("#### Retrieve semchunks from db using vectors from prompt.") | |
wpChunksCollection = st.session_state.wpChunksCollection | |
semChunks = wpChunksCollection.query.near_vector( | |
near_vector=vectorList, | |
distance=0.7, | |
limit=3 | |
) | |
logger.debug(f"### semChunks[0]: {semChunks}") | |
# Print chunks, corresponding document and document title. | |
ragData = "" | |
logger.debug("#### Print individual retrieved chunks.") | |
wpCollection = st.session_state.wpCollection | |
for chunk in enumerate(semChunks.objects): | |
logger.info(f"#### chunk: {chunk}") | |
ragData = ragData + "\n" + chunk[0] | |
webpage_uuid = chunk[1].properties['references']['webpage'] | |
logger.info(f"webpage_uuid: {webpage_uuid}") | |
wpFromChunk = wpCollection.query.fetch_object_by_id(webpage_uuid) | |
logger.info(f"### wpFromChunk title: {wpFromChunk.properties['title']}") | |
#collection = client.collections.get("Chunks") | |
logger.info("#### getRagData() exited.") | |
return ragData | |
# Display UI | |
col1, col2 = st.columns(2) | |
with col1: | |
if "sysTA" not in st.session_state: | |
st.session_state.sysTA = st.text_area(label="sysTA",value="fdsaf fsdafdsa") | |
elif "sysTAtext" in st.session_state: | |
st.session_state.sysTA = st.text_area(label="sysTA",value=st.session_state.sysTAtext) | |
else: | |
st.session_state.sysTA = st.text_area(label="sysTA",value=st.session_state.sysTA) | |
if "userpTA" not in st.session_state: | |
st.session_state.userpTA = st.text_area(label="userpTA",value="fdsaf fsdafdsa") | |
elif "userpTAtext" in st.session_state: | |
st.session_state.userpTA = st.text_area (label="userpTA",value=st.session_state.userpTAtext) | |
else: | |
st.session_state.userpTA = st.text_area(label="userpTA",value=st.session_state.userpTA) | |
with col2: | |
if "ragpTA" not in st.session_state: | |
st.session_state.ragpTA = st.text_area(label="ragpTA",value="fdsaf fsdafdsa") | |
elif "ragpTAtext" in st.session_state: | |
st.session_state.ragpTA = st.text_area(label="ragpTA",value=st.session_state.ragpTAtext) | |
else: | |
st.session_state.ragpTA = st.text_area(label="ragpTA",value=st.session_state.ragpTA) | |
if "rspTA" not in st.session_state: | |
st.session_state.rspTA = st.text_area(label="rspTA",value="fdsaf fsdafdsa") | |
elif "rspTAtext" in st.session_state: | |
st.session_state.rspTA = st.text_area(label="rspTA",value=st.session_state.rspTAtext) | |
else: | |
st.session_state.rspTA = st.text_area(label="rspTA",value=st.session_state.rspTA) | |
def runLLM(prompt): | |
logger = st.session_state.logger | |
logger.info("### runLLM entered.") | |
max_tokens = 1000 | |
temperature = 0.3 | |
top_p = 0.1 | |
echoVal = True | |
stop = ["Q", "\n"] | |
modelOutput = llm( | |
prompt, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
echo=echoVal, | |
stop=stop, | |
) | |
result = modelOutput["choices"][0]["text"].strip() | |
logger.info(f"### llmResult: {result}") | |
logger.info("### runLLM exited.") | |
return result | |
def setPrompt(pprompt,ragFlag): | |
logger = st.session_state.logger | |
logger.info(f"\n### setPrompt() entered. ragFlag: {ragFlag}") | |
if ragFlag: | |
ragPrompt = getRagData(pprompt) | |
userPrompt = pprompt + "\n" + ragPrompt | |
prompt = userPrompt | |
userPrompt = "Using this information: " + ragPrompt \ | |
+ "process the following statement or question and produce a response. " \ | |
+ pprompt | |
else: | |
userPrompt = st.session_state.sysTA + " " + pprompt | |
#prompt = f""" <s> [INST] <<SYS>> {systemTextArea.value} </SYS>> Q: {userPrompt} A: [/INST]""" | |
logger.info("setPrompt exited.") | |
logger.info(f"### userPrompt: {userPrompt}") | |
return userPrompt | |
def on_submitButton_clicked(): | |
logger = st.session_state.logger | |
logger.info("### on_submitButton_clicked entered.") | |
st.session_state.sysTAtext = st.session_state.sysTA | |
logger.info(f"sysTAtext: {st.session_state.sysTAtext}") | |
#st.session_state.userpTAtext = st.session_state.userpTA | |
st.session_state.userpTAtext = setPrompt(st.session_state.userpTA,st.selectRag) | |
st.session_state.userpTA = st.session_state.userpTAtext | |
logger.info(f"userpTAtext: {st.session_state.userpTAtext}") | |
st.session_state.rspTAtext = runLLM(st.session_state.userpTAtext) | |
st.session_state.rspTA = st.session_state.rspTAtext | |
logger.info(f"rspTAtext: {st.session_state.rspTAtext}") | |
logger.info("### on_submitButton_clicked exited.") | |
with st.sidebar: | |
st.selectRag = st.checkbox("Enable Query With RAG",value=False,key="selectRag",help=None,on_change=None,args=None,kwargs=None,disabled=False,label_visibility="visible") | |
st.submitButton = st.button("Run LLM Query",key=None,help=None,on_click=on_submitButton_clicked,args=None,kwargs=None,type="secondary",disabled=False,use_container_width=False) | |
logger.info("#### semsearch.py end of code.") | |
except Exception as e: | |
try: | |
emsg = str(e) | |
logger.error(f"Program-wide EXCEPTION. e: {emsg}") | |
with open("/app/startup.log", "r") as file: | |
content = file.read() | |
logger.debug(content) | |
except Exception as e2: | |
emsg = str(e2) | |
logger.error(f"#### Displaying startup.log EXCEPTION. e2: {emsg}") | |