pdfchatbot / app.py
DHEIVER's picture
Update app.py
2073925 verified
raw
history blame
5.35 kB
import gradio as gr
import os
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFacePipeline
from langchain.memory import ConversationBufferMemory
from pathlib import Path
import chromadb
from unidecode import unidecode
from transformers import AutoTokenizer, pipeline
import transformers
import torch
import re
# Lista de modelos 100% abertos e gratuitos
list_llm = [
"google/flan-t5-xxl", # Modelo para tarefas text-to-text
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", # Modelo leve para diálogo
"microsoft/phi-2", # Modelo para raciocínio lógico
"facebook/opt-1.3b", # Modelo de geração de texto
"EleutherAI/gpt-neo-1.3B", # Versão open-source do GPT-3
"bigscience/bloom-1b7", # Modelo multilíngue
"RWKV/rwkv-4-169m-pile", # Modelo eficiente em RAM
"gpt2-medium", # Clássico modelo de GPT-2
"databricks/dolly-v2-3b", # Modelo para instruções
"mosaicml/mpt-7b-instruct" # Modelo para instruções
]
list_llm_simple = [os.path.basename(llm) for llm in list_llm]
# Função para carregar documentos PDF
def load_doc(list_file_path, chunk_size, chunk_overlap):
loaders = [PyPDFLoader(x) for x in list_file_path]
pages = []
for loader in loaders:
pages.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
return text_splitter.split_documents(pages)
# Função para criar banco de dados vetorial
def create_db(splits, collection_name):
embedding = HuggingFaceEmbeddings()
return Chroma.from_documents(
documents=splits,
embedding=embedding,
client=chromadb.EphemeralClient(),
collection_name=collection_name
)
# Função para inicializar o modelo LLM
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
progress(0.1, desc="Carregando tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(llm_model)
progress(0.4, desc="Inicializando pipeline...")
pipeline_obj = pipeline(
"text-generation",
model=llm_model,
tokenizer=tokenizer,
torch_dtype=torch.bfloat16,
device_map="auto",
max_new_tokens=max_tokens,
do_sample=True,
top_k=top_k,
temperature=temperature
)
llm = HuggingFacePipeline(pipeline=pipeline_obj)
progress(0.7, desc="Configurando memória...")
memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True
)
progress(0.8, desc="Criando cadeia...")
return ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=vector_db.as_retriever(),
memory=memory,
chain_type="stuff",
return_source_documents=True
)
# Interface Gradio
def demo():
with gr.Blocks(theme=gr.themes.Soft()) as demo:
vector_db = gr.State()
qa_chain = gr.State()
gr.Markdown("## 🤖 Chatbot para PDFs com Modelos Gratuitos")
with gr.Tab("📤 Upload PDF"):
pdf_input = gr.Files(label="Selecione seus PDFs", file_types=[".pdf"])
with gr.Tab("⚙️ Processamento"):
chunk_size = gr.Slider(100, 1000, value=500, label="Tamanho dos Chunks")
chunk_overlap = gr.Slider(0, 200, value=50, label="Sobreposição")
process_btn = gr.Button("Processar PDFs")
with gr.Tab("🧠 Modelo"):
model_selector = gr.Dropdown(list_llm_simple, label="Selecione o Modelo", value=list_llm_simple[0])
temperature = gr.Slider(0, 1, value=0.7, label="Criatividade")
load_model_btn = gr.Button("Carregar Modelo")
with gr.Tab("💬 Chat"):
chatbot = gr.Chatbot(height=400)
msg = gr.Textbox(label="Sua mensagem")
clear_btn = gr.ClearButton([msg, chatbot])
# Eventos
process_btn.click(
lambda files, cs, co: create_db(load_doc([f.name for f in files], cs, co), "docs"),
inputs=[pdf_input, chunk_size, chunk_overlap],
outputs=vector_db
)
load_model_btn.click(
lambda model, temp: initialize_llmchain(list_llm[list_llm_simple.index(model)], temp, 512, 3, vector_db.value),
inputs=[model_selector, temperature],
outputs=qa_chain
)
def respond(message, chat_history):
result = qa_chain.value({"question": message, "chat_history": chat_history})
response = result["answer"]
sources = "\n".join([f"📄 Página {doc.metadata['page']+1}: {doc.page_content[:50]}..."
for doc in result["source_documents"][:2]])
return f"{response}\n\n🔍 Fontes:\n{sources}"
msg.submit(respond, [msg, chatbot], chatbot)
demo.launch()
if __name__ == "__main__":
demo()