Spaces:
Running
Running
File size: 7,286 Bytes
3ec9224 5be8df6 ebc9208 ab26ada 93068c0 ebc9208 93068c0 ebc9208 93068c0 ab26ada ebc9208 93068c0 ab26ada ebc9208 93068c0 ebc9208 93068c0 ebc9208 ab26ada ebc9208 93068c0 ebc9208 93068c0 ebc9208 93068c0 ebc9208 93068c0 ebc9208 93068c0 ebc9208 93068c0 ebc9208 93068c0 ebc9208 93068c0 ebc9208 7f0656e ebc9208 7f0656e ebc9208 7f0656e ebc9208 7f0656e ebc9208 7f0656e ebc9208 7f0656e ebc9208 93068c0 ebc9208 ab26ada |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import gradio as gr
import os
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFaceEndpoint
from langchain.memory import ConversationBufferMemory
from pathlib import Path
import chromadb
from unidecode import unidecode
import re
list_llm = [
"mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1",
"google/gemma-7b-it", "google/gemma-2b-it",
"HuggingFaceH4/zephyr-7b-beta", "HuggingFaceH4/zephyr-7b-gemma-v0.1",
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", "tiiuae/falcon-7b-instruct"
]
list_llm_simple = [os.path.basename(llm) for llm in list_llm]
def load_doc(list_file_path, chunk_size, chunk_overlap):
loaders = [PyPDFLoader(x) for x in list_file_path]
pages = []
for loader in loaders:
pages.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
return text_splitter.split_documents(pages)
def create_db(splits, collection_name):
embedding = HuggingFaceEmbeddings()
new_client = chromadb.EphemeralClient()
return Chroma.from_documents(documents=splits, embedding=embedding, client=new_client, collection_name=collection_name)
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
progress(0.5, desc="Initializing HF Hub...")
llm = HuggingFaceEndpoint(
repo_id=llm_model,
temperature=temperature,
max_new_tokens=max_tokens,
top_k=top_k,
)
memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True)
retriever = vector_db.as_retriever()
return ConversationalRetrievalChain.from_llm(
llm,
retriever=retriever,
chain_type="stuff",
memory=memory,
return_source_documents=True,
verbose=False,
)
def create_collection_name(filepath):
collection_name = Path(filepath).stem
collection_name = unidecode(collection_name.replace(" ", "-"))
collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)[:50]
if len(collection_name) < 3:
collection_name += 'xyz'
if not collection_name[0].isalnum():
collection_name = 'A' + collection_name[1:]
if not collection_name[-1].isalnum():
collection_name = collection_name[:-1] + 'Z'
return collection_name
def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
list_file_path = [x.name for x in list_file_obj if x is not None]
collection_name = create_collection_name(list_file_path[0])
doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
vector_db = create_db(doc_splits, collection_name)
return vector_db, collection_name, "Completed!"
def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
llm_name = list_llm[llm_option]
qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
return qa_chain, "Completed!"
def format_chat_history(message, chat_history):
return [f"User: {user_message}\nAssistant: {bot_message}" for user_message, bot_message in chat_history]
def conversation(qa_chain, message, history):
formatted_chat_history = format_chat_history(message, history)
response = qa_chain({"question": message, "chat_history": formatted_chat_history})
response_answer = response["answer"].split("Helpful Answer:")[-1]
response_sources = response["source_documents"]
sources = [(source.page_content.strip(), source.metadata["page"] + 1) for source in response_sources[:3]]
new_history = history + [(message, response_answer)]
return qa_chain, gr.update(value=""), new_history, *[item for source in sources for item in source]
def demo():
with gr.Blocks(theme="base") as demo:
vector_db = gr.State()
qa_chain = gr.State()
collection_name = gr.State()
gr.Markdown("# Creatore di Chatbot basato su PDF")
with gr.Tab("Passo 1 - Carica PDF"):
document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Carica i tuoi documenti PDF")
with gr.Tab("Passo 2 - Elabora Documenti"):
db_btn = gr.Radio(["ChromaDB"], label="Tipo di database vettoriale", value="ChromaDB", type="index")
with gr.Accordion("Opzioni Avanzate - Divisione del testo del documento", open=False):
slider_chunk_size = gr.Slider(100, 1000, 1000, step=20, label="Dimensione del chunk")
slider_chunk_overlap = gr.Slider(10, 200, 100, step=10, label="Sovrapposizione del chunk")
db_progress = gr.Textbox(label="Inizializzazione del database vettoriale", value="Nessuna")
db_btn = gr.Button("Genera database vettoriale")
with gr.Tab("Passo 3 - Inizializza catena QA"):
llm_btn = gr.Radio(list_llm_simple, label="Modelli LLM", value=list_llm_simple[5], type="index")
with gr.Accordion("Opzioni avanzate - Modello LLM", open=False):
slider_temperature = gr.Slider(0.01, 1.0, 0.3, step=0.1, label="Temperatura")
slider_maxtokens = gr.Slider(224, 4096, 1024, step=32, label="Token massimi")
slider_topk = gr.Slider(1, 10, 3, step=1, label="Campioni top-k")
language_btn = gr.Radio(["Italiano", "Inglese"], label="Lingua", value="Italiano", type="index")
llm_progress = gr.Textbox(value="Nessuna", label="Inizializzazione catena QA")
qachain_btn = gr.Button("Inizializza catena di Domanda e Risposta")
with gr.Tab("Passo 4 - Chatbot"):
chatbot = gr.Chatbot(height=300)
with gr.Accordion("Opzioni avanzate - Riferimenti ai documenti", open=False):
doc_sources = [gr.Textbox(label=f"Riferimento {i+1}", lines=2, container=True, scale=20) for i in range(3)]
source_pages = [gr.Number(label="Pagina", scale=1) for _ in range(3)]
msg = gr.Textbox(placeholder="Inserisci il messaggio (es. 'Di cosa tratta questo documento?')", container=True)
submit_btn = gr.Button("Invia messaggio")
clear_btn = gr.ClearButton([msg, chatbot], value="Cancella conversazione")
db_btn.click(initialize_database, inputs=[document, slider_chunk_size, slider_chunk_overlap], outputs=[vector_db, collection_name, db_progress])
qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[qa_chain, llm_progress])
submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot] + doc_sources + source_pages)
msg.submit(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot] + doc_sources + source_pages)
demo.queue().launch(debug=True)
if __name__ == "__main__":
demo() |