import weaviate from weaviate.connect import ConnectionParams from weaviate.classes.init import AdditionalConfig, Timeout from sentence_transformers import SentenceTransformer from langchain_community.document_loaders import BSHTMLLoader from pathlib import Path from lxml import html import logging from semantic_text_splitter import HuggingFaceTextSplitter from tokenizers import Tokenizer import json import os import re import llama_cpp from llama_cpp import Llama import streamlit as st import subprocess import time import pprint import io import torch import time from datetime import datetime, timedelta import threading #from huggingface_hub import InferenceClient try: ############################################# # Logging setup including weaviate logging. # ############################################# if 'logging' not in st.session_state: weaviate_logger = logging.getLogger("httpx") #weaviate_logger.setLevel(logging.WARNING) weaviate_logger.setLevel(logging.DEBUG) logger = logging.getLogger(__name__) #logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',level=logging.INFO) logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',level=logging.DEBUG) st.session_state.weaviate_logger = weaviate_logger st.session_state.logger = logger else: weaviate_logger = st.session_state.weaviate_logger logger = st.session_state.logger # Set long session timeout for space. #inference = InferenceClient(repo_id="MVPilgrim/SemanticSearch", timeout=1800) logger.info("###################### Program Entry ############################") logger.info(f"CUDA available: {torch.cuda.is_available()}") logger.info(f"CUDA device count: {torch.cuda.device_count()}") if torch.cuda.is_available(): logger.info(f"CUDA device name: {torch.cuda.get_device_name(0)}") ########################################################################## # Asynchonously run startup.sh which run text2vec-transformers # # asynchronously and the Weaviate Vector Database server asynchronously. # ########################################################################## def runStartup(): logger.info("### Running startup.sh") try: subprocess.Popen(["/app/startup.sh"]) # Wait for text2vec-transformers and Weaviate DB to initialize. time.sleep(120) #subprocess.run(["/app/cmd.sh 'ps -ef'"]) displayStartupshLog() except Exception as e: emsg = str(e) logger.error(f"### subprocess.run or displayStartup.shLog EXCEPTION. e: {emsg}") logger.info("### Running startup.sh complete") def displayStartupshLog(): logger.info("### Displaying /app/startup.log") with open("/app/startup.log", "r") as file: line = file.readline().rstrip() while line: logger.info(line) line = file.readline().rstrip() logger.info("### End of /app/startup.log display.") if 'runStartup' not in st.session_state: st.session_state.runStartup = False if 'runStartup' not in st.session_state: logger.info("### runStartup still not in st.session_state after setting variable.") with st.spinner('Restarting...'): runStartup() try: displayStartupshLog() except Exception as e2: emsg = str(e2) logger.error(f"#### Displaying startup.log EXCEPTION. e2: {emsg}") ######################################### # Function to load the CSS syling file. # ######################################### def load_css(file_name): logger.info("#### load_css entered.") with open(file_name) as f: st.markdown(f'', unsafe_allow_html=True) logger.info("#### load_css exited.") if 'load_css' not in st.session_state: load_css(".streamlit/main.css") st.session_state.load_css = True # Display UI heading. st.markdown("

LLM with RAG Prompting
Proof of Concept

", unsafe_allow_html=True) pathString = "/app/inputDocs" chunks = [] webpageDocNames = [] page_contentArray = [] webpageChunks = [] webpageTitles = [] webpageChunksDocNames = [] ############################################ # Connect to the Weaviate vector database. # ############################################ if 'client' not in st.session_state: logger.info("#### Create Weaviate db client connection.") client = weaviate.WeaviateClient( connection_params=ConnectionParams.from_params( http_host="localhost", http_port="8080", http_secure=False, grpc_host="localhost", grpc_port="50051", grpc_secure=False ), additional_config=AdditionalConfig( timeout=Timeout(init=60, query=1800, insert=1800), # Values in seconds ) ) for i in range(3): try: client.connect() st.session_state.client = client logger.info("#### Create Weaviate db client connection exited.") break except Exception as e: emsg = str(e) logger.error(f"### client.connect() EXCEPTION. e2: {emsg}") time.sleep(45) if i >= 3: raise Exception("client.connect retries exhausted.") else: client = st.session_state.client ######################################################## # Read each text input file, parse it into a document, # # chunk it, collect chunks and document names. # ######################################################## if not client.collections.exists("Documents") or not client.collections.exists("Chunks") : logger.info("#### Read and chunk input RAG document files.") for filename in os.listdir(pathString): logger.debug(filename) path = Path(pathString + "/" + filename) filename = filename.rstrip(".html") webpageDocNames.append(filename) htmlLoader = BSHTMLLoader(path,"utf-8") htmlData = htmlLoader.load() title = htmlData[0].metadata['title'] page_content = htmlData[0].page_content # Clean data. Remove multiple newlines, etc. page_content = re.sub(r'\n+', '\n',page_content) page_contentArray.append(page_content) webpageTitles.append(title) max_tokens = 1000 tokenizer = Tokenizer.from_pretrained("bert-base-uncased") logger.info(f"### tokenizer: {tokenizer}") splitter = HuggingFaceTextSplitter(tokenizer, trim_chunks=True) chunksOnePage = splitter.chunks(page_content, chunk_capacity=50) chunks = [] for chnk in chunksOnePage: logger.debug(f"#### chnk in file: {chnk}") chunks.append(chnk) logger.debug(f"chunks: {chunks}") webpageChunks.append(chunks) webpageChunksDocNames.append(filename + "Chunks") logger.info(f"### filename, title: {filename}, {title}") logger.info(f"### webpageDocNames: {webpageDocNames}") logger.info("#### Read and chunk input RAG document files.") ############################################################# # Create database documents and chunks schemas/collections. # # Each chunk schema points to its corresponding document. # ############################################################# if not client.collections.exists("Documents"): logger.info("#### Create documents schema/collection started.") class_obj = { "class": "Documents", "description": "For first attempt at loading a Weviate database.", "vectorizer": "text2vec-transformers", "moduleConfig": { "text2vec-transformers": { "vectorizeClassName": False } }, "vectorIndexType": "hnsw", "vectorIndexConfig": { "distance": "cosine", }, "properties": [ { "name": "title", "dataType": ["text"], "description": "HTML doc title.", "vectorizer": "text2vec-transformers", "moduleConfig": { "text2vec-transformers": { "vectorizePropertyName": True, "skip": False, "tokenization": "lowercase" } }, "invertedIndexConfig": { "bm25": { "b": 0.75, "k1": 1.2 }, } }, { "name": "content", "dataType": ["text"], "description": "HTML page content.", "moduleConfig": { "text2vec-transformers": { "vectorizePropertyName": True, "tokenization": "whitespace" } } } ] } wpCollection = client.collections.create_from_dict(class_obj) st.session_state.wpCollection = wpCollection logger.info("#### Create documents schema/collection ended.") else: wpCollection = client.collections.get("Documents") st.session_state.wpCollection = wpCollection # Create chunks in db. if not client.collections.exists("Chunks"): logger.info("#### create document chunks schema/collection started.") #client.collections.delete("Chunks") class_obj = { "class": "Chunks", "description": "Collection for document chunks.", "vectorizer": "text2vec-transformers", "moduleConfig": { "text2vec-transformers": { "vectorizeClassName": True } }, "vectorIndexType": "hnsw", "vectorIndexConfig": { "distance": "cosine" }, "properties": [ { "name": "chunk", "dataType": ["text"], "description": "Single webpage chunk.", "vectorizer": "text2vec-transformers", "moduleConfig": { "text2vec-transformers": { "vectorizePropertyName": False, "skip": False, "tokenization": "lowercase" } } }, { "name": "chunk_index", "dataType": ["int"] }, { "name": "webpage", "dataType": ["Documents"], "description": "Webpage content chunks.", "invertedIndexConfig": { "bm25": { "b": 0.75, "k1": 1.2 } } } ] } wpChunksCollection = client.collections.create_from_dict(class_obj) st.session_state.wpChunksCollection = wpChunksCollection logger.info("#### create document chunks schedma/collection ended.") else: wpChunksCollection = client.collections.get("Chunks") st.session_state.wpChunksCollection = wpChunksCollection ################################################################## # Create the actual document and chunks objects in the database. # ################################################################## if 'dbObjsCreated' not in st.session_state: logger.info("#### Create db document and chunk objects started.") st.session_state.dbObjsCreated = True for i, className in enumerate(webpageDocNames): logger.info("#### Creating document object.") title = webpageTitles[i] logger.debug(f"## className, title: {className}, {title}") # Create Webpage Object page_content = page_contentArray[i] # Insert the document. wpCollectionObj_uuid = wpCollection.data.insert( { "name": className, "title": title, "content": page_content } ) logger.info("#### Document object created.") logger.info("#### Create chunk db objects.") st.session_state.wpChunksCollection = wpChunksCollection # Insert the chunks for the document. for i2, chunk in enumerate(webpageChunks[i]): chunk_uuid = wpChunksCollection.data.insert( { "title": title, "chunk": chunk, "chunk_index": i2, "references": { "webpage": wpCollectionObj_uuid } } ) logger.debug(f"Inserting chunk. title,chunk: {title}, {chunk}") logger.info("#### Create chunk db objects created.") logger.info("#### Create db document and chunk objects ended.") ####################### # Initialize the LLM. # ####################### model_path = "/app/llama-2-7b-chat.Q4_0.gguf" #model_path = "/app/Llama-3.2-3B-Instruct-Q4_0.gguf" #model_path = "Meta-Llama-3.1-8B-Instruct-Q8_0.gguf" if 'llm' not in st.session_state: logger.info("### Initializing LLM.") llm = Llama(model_path, #*, n_gpu_layers=-1, split_mode=llama_cpp.LLAMA_SPLIT_MODE_LAYER, main_gpu=0, tensor_split=None, vocab_only=False, use_mmap=True, use_mlock=False, kv_overrides=None, seed=llama_cpp.LLAMA_DEFAULT_SEED, n_ctx=2048, n_batch=512, n_threads=8, n_threads_batch=16, rope_scaling_type=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, pooling_type=llama_cpp.LLAMA_POOLING_TYPE_UNSPECIFIED, rope_freq_base=0.0, rope_freq_scale=0.0, yarn_ext_factor=-1.0, yarn_attn_factor=1.0, yarn_beta_fast=32.0, yarn_beta_slow=1.0, yarn_orig_ctx=0, logits_all=False, embedding=False, offload_kqv=True, last_n_tokens_size=64, lora_base=None, lora_scale=1.0, lora_path=None, numa=False, chat_format="llama-2", chat_handler=None, draft_model=None, tokenizer=None, type_k=None, type_v=None, verbose=True ) st.session_state.llm = llm logger.info("### Initializing LLM completed.") else: llm = st.session_state.llm ##################################################### # Get RAG data from vector db based on user prompt. # ##################################################### def getRagData(promptText): logger.info("#### getRagData() entered.") ############################################################################### # Initial the the sentence transformer and encode the query prompt. logger.debug(f"#### Encode text query prompt to create vectors. {promptText}") model = SentenceTransformer('/app/multi-qa-MiniLM-L6-cos-v1') vector = model.encode(promptText) logLevel = logger.getEffectiveLevel() if logLevel >= logging.DEBUG: wrks = str(vector) logger.debug(f"### vector: {wrks}") vectorList = [] for vec in vector: vectorList.append(vec) if logLevel >= logging.DEBUG: logger.debug("#### Print vectors.") wrks = str(vectorList) logger.debug(f"vectorList: {wrks}") # Fetch chunks and print chunks. logger.debug("#### Retrieve semchunks from db using vectors from prompt.") wpChunksCollection = st.session_state.wpChunksCollection semChunks = wpChunksCollection.query.near_vector( near_vector=vectorList, distance=2.0, #certainty=0.7, limit=5 ) if logLevel >= logging.DEBUG: wrks = str(semChunks) logger.debug(f"### semChunks[0]: {wrks}") # Print chunks, corresponding document and document title. ragData = "" logger.debug("#### Print individual retrieved chunks.") wpCollection = st.session_state.wpCollection for chunk in enumerate(semChunks.objects): logger.debug(f"#### chunk: {chunk}") ragData = ragData + chunk[1].properties['chunk'] + "\n" webpage_uuid = chunk[1].properties['references']['webpage'] logger.debug(f"webpage_uuid: {webpage_uuid}") wpFromChunk = wpCollection.query.fetch_object_by_id(webpage_uuid) logger.debug(f"### wpFromChunk title: {wpFromChunk.properties['title']}") #collection = client.collections.get("Chunks") if ragData == "" or ragData == None: ragData = "None found." logger.debug("#### ragData: {ragData}") logger.info("#### getRagData() exited.") return ragData ################################################# # Retrieve all RAG data for the user to review. # ################################################# def getAllRagData(): logger.info("#### getAllRagData() entered.") chunksCollection = client.collections.get("Chunks") response = chunksCollection.query.fetch_objects() wstrObjs = str(response.objects) logger.debug(f"### response.objects: {wstrObjs}") for o in response.objects: wstr = o.properties logger.debug(f"### o.properties: {wstr}") logger.info("#### getAllRagData() exited.") return wstrObjs #################################################################### # Prompt the LLM with the user's input and return the completion. # #################################################################### def runLLM(prompt): logger = st.session_state.logger logger.info("### runLLM entered.") max_tokens = 1000 temperature = 0.3 top_p = 0.1 echoVal = True stop = ["Q", "\n"] modelOutput = "" #with st.spinner('Generating Completion (but slowly. 40+ seconds.)...'): #with st.markdown("

LLM with RAG Prompting
Proof of Concept

", # unsafe_allow_html=True): st.session_state.spinGenMsg = True modelOutput = llm.create_chat_completion( prompt #max_tokens=max_tokens, #temperature=temperature, #top_p=top_p, #echo=echoVal, #stop=stop, ) st.session_state.spinGenMsg = False if modelOutput != "": result = modelOutput["choices"][0]["message"]["content"] else: result = "No result returned." #result = str(modelOutput) logger.debug(f"### llmResult: {result}") logger.info("### runLLM exited.") return result ########################################################################## # Build a llama-2 prompt from the user prompt and RAG input if selected. # ########################################################################## def setPrompt(pprompt,ragFlag): logger = st.session_state.logger logger.info(f"### setPrompt() entered. ragFlag: {ragFlag}") if ragFlag: ragPrompt = getRagData(pprompt) st.session_state.ragpTA = ragPrompt if ragFlag != "None found.": userPrompt = pprompt + " " \ + "Also, combine the following information with information in the LLM itself. " \ + "Use the combined information to generate the response. " \ + ragPrompt + " " else: userPrompt = pprompt else: userPrompt = pprompt fullPrompt = [ {"role": "system", "content": st.session_state.sysTA}, {"role": "user", "content": userPrompt} ] #fullPrompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" #fullPrompt += st.session_state.sysTA #fullPrompt += "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n" #fullPrompt += userPrompt #fullPrompt += "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" st.session_state.userpTA = userPrompt logger.debug(f"### fullPrompt: {fullPrompt}") logger.info("setPrompt exited.") return fullPrompt ########################## # Display UI text areas. # ########################## col1, col2 = st.columns(2) with col1: if 'spinGenMsg' not in st.session_state or st.session_state.spinGenMsg == False: placeHolder = st.empty() else: elapsedTime = 0.00 startTime = datetime.now() st.session_state.spinGenMsg = False; with st.spinner(f"Generating Completion..."): st.session_state.sysTAtext = st.session_state.sysTA logger.debug(f"sysTAtext: {st.session_state.sysTAtext}") fullPrompt = setPrompt(st.session_state.userpTA,st.selectRag) #fullPrompt = setPrompt(st.session_state.userpTA,st.selectRag) #st.session_state.userpTA = wrklist[1]["content"] logger.debug(f"userpTAtext: {st.session_state.userpTA}") #rsp = runLLM(wrklist) rsp = runLLM(fullPrompt) st.session_state.rspTA = rsp logger.debug(f"rspTAtext: {st.session_state.rspTA}") #if "sysTA" not in st.session_state: # st.session_state.sysTA = st.text_area(label="System Prompt",placeholder="You are a helpful AI assistant", help="Instruct the LLM about how to handle the user prompt.") #elif "sysTAtext" in st.session_state: # st.session_state.sysTA = st.text_area(label="System Prompt",value=st.session_state.sysTAtext,placeholder="You are a helpful AI assistant", help="Instruct the LLM about how to handle the user prompt.") #else: # st.session_state.sysTA = st.text_area(label="System Prompt",value=st.session_state.sysTA,placeholder="You are a helpful AI assistant", help="Instruct the LLM about how to handle the user prompt.") elapsedTime = datetime.now() - startTime logger.info(f"#### elapsedTime: {elapsedTime}") if "sysTA" not in st.session_state: st.session_state.sysTA = st.text_area(label="System Prompt",placeholder="You are a helpful AI assistant", help="Instruct the LLM about how to handle the user prompt.") elif "sysTAtext" in st.session_state: st.session_state.sysTA = st.text_area(label="System Prompt",value=st.session_state.sysTAtext,placeholder="You are a helpful AI assistant", help="Instruct the LLM about how to handle the user prompt.") else: st.session_state.sysTA = st.text_area(label="System Prompt",value=st.session_state.sysTA,placeholder="You are a helpful AI assistant", help="Instruct the LLM about how to handle the user prompt.") if "userpTA" not in st.session_state: st.session_state.userpTA = st.text_area(label="User Prompt",placeholder="Prompt the LLM with a question or instruction.", \ help="Enter a prompt for the LLM. No special characters needed.") elif "userpTAtext" in st.session_state: st.session_state.userpTA = st.text_area (label="User Prompt",value=st.session_state.userpTAtext,placeholder="Prompt the LLM with a question or instruction.", \ help="Enter a prompt for the LLM. No special characters needed.") else: st.session_state.userpTA = st.text_area(label="User Prompt",value=st.session_state.userpTA,placeholder="Prompt the LLM with a question or instruction.", \ help="Enter a prompt for the LLM. No special characters needed.") with col2: if "ragpTA" not in st.session_state: st.session_state.ragpTA = st.text_area(label="RAG Response",placeholder="Output if RAG selected.",help="RAG output if enabled.") elif "ragpTAtext" in st.session_state: st.session_state.ragpTA = st.text_area(label="RAG Response",value=st.session_state.ragpTAtext,placeholder="Output if RAG selected.",help="RAG output if enabled.") else: st.session_state.ragpTA = st.text_area(label="RAG Response",value=st.session_state.ragpTA,placeholder="Output if RAG selected.",help="RAG output if enabled.") if "rspTA" not in st.session_state: st.session_state.rspTA = st.text_area(label="LLM Completion",placeholder="LLM completion.",help="Output area for LLM completion (response).") elif "rspTAtext" in st.session_state: st.session_state.rspTA = st.text_area(label="LLM Completion",value=st.session_state.rspTAtext,placeholder="LLM completion.",help="Output area for LLM completion (response).") else: st.session_state.rspTA = st.text_area(label="LLM Completion",value=st.session_state.rspTA,placeholder="LLM completion.",help="Output area for LLM completion (response).") ##################################### # Run the LLM with the user prompt. # ##################################### def on_runLLMButton_Clicked(): logger = st.session_state.logger logger.info("### on_runLLMButton_Clicked entered.") st.session_state.spinGenMsg = True logger.info("### on_runLLMButton_Clicked exited.") ######################################### # Get all the RAG data for user review. # ######################################### def on_getAllRagDataButton_Clicked(): logger = st.session_state.logger logger.info("### on_getAllRagButton_Clicked entered.") st.session_state.ragpTA = getAllRagData(); logger.info("### on_getAllRagButton_Clicked exited.") ####################################### # Reset all the input, output fields. # ####################################### def on_resetButton_Clicked(): logger = st.session_state.logger logger.info("### on_resetButton_Clicked entered.") st.session_state.sysTA = "" st.session_state.userpTA = "" st.session_state.ragpTA = "" st.session_state.rspTA = "" st.selectRag .value = False logger.info("### on_resetButton_Clicked exited.") ########################################### # Display the sidebar with a checkbox and # # text areas. # ########################################### with st.sidebar: st.selectRag = st.checkbox("Enable RAG",value=False,key="selectRag",help=None,on_change=None,args=None,kwargs=None,disabled=False,label_visibility="visible") st.runLLMButton = st.button("Run LLM Prompt",key=None,help=None,on_click=on_runLLMButton_Clicked,args=None,kwargs=None,type="secondary",disabled=False,use_container_width=False) st.getAllRagDataButton = st.button("Get All Rag Data",key=None,help=None,on_click=on_getAllRagDataButton_Clicked,args=None,kwargs=None,type="secondary",disabled=False,use_container_width=False) st.resetButton = st.button("Reset",key=None,help=None,on_click=on_resetButton_Clicked,args=None,kwargs=None,type="secondary",disabled=False,use_container_width=False) logger.info("#### Program End Execution.") except Exception as e: try: emsg = str(e) logger.error(f"Program-wide EXCEPTION. e: {emsg}") with open("/app/startup.log", "r") as file: content = file.read() logger.debug(content) except Exception as e2: emsg = str(e2) logger.error(f"#### Displaying startup.log EXCEPTION. e2: {emsg}")