pdfchatbot / app.py
DHEIVER's picture
Update app.py
55cb274 verified
raw
history blame
6.53 kB
import gradio as gr
import os
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import HuggingFacePipeline
from langchain.memory import ConversationBufferMemory
from transformers import AutoTokenizer, pipeline
import torch
# Lista de modelos 100% abertos e gratuitos
list_llm = [
"google/flan-t5-xxl",
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"microsoft/phi-2",
"facebook/opt-1.3b",
"EleutherAI/gpt-neo-1.3B",
"bigscience/bloom-1b7",
"RWKV/rwkv-4-169m-pile",
"gpt2-medium",
"databricks/dolly-v2-3b",
"mosaicml/mpt-7b-instruct"
]
list_llm_simple = [name.split("/")[-1] for name in list_llm]
# Função para carregar documentos PDF
def load_doc(list_file_path, chunk_size, chunk_overlap):
loaders = [PyPDFLoader(file_path) for file_path in list_file_path]
pages = []
for loader in loaders:
pages.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
return text_splitter.split_documents(pages)
# Função para criar banco de dados vetorial
def create_db(splits, collection_name):
embedding = HuggingFaceEmbeddings()
return Chroma.from_documents(
documents=splits,
embedding=embedding,
persist_directory=f"./{collection_name}"
)
# Função para inicializar o modelo LLM
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
progress(0.1, desc="Carregando tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(llm_model)
progress(0.4, desc="Inicializando pipeline...")
# Define a tarefa correta para cada modelo
task = "text2text-generation" if "flan-t5" in llm_model.lower() else "text-generation"
# Configuração específica para dispositivos
device = 0 if torch.cuda.is_available() else -1
if "phi-2" in llm_model.lower() and device == 0:
device = "cuda"
pipeline_obj = pipeline(
task,
model=llm_model,
tokenizer=tokenizer,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device=device,
max_new_tokens=max_tokens,
do_sample=True,
top_k=top_k,
temperature=temperature
)
llm = HuggingFacePipeline(pipeline=pipeline_obj)
progress(0.7, desc="Configurando memória...")
memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True
)
progress(0.8, desc="Criando cadeia...")
return ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=vector_db.as_retriever(),
memory=memory,
return_source_documents=True
)
# Interface Gradio
def demo():
with gr.Blocks(theme=gr.themes.Soft()) as demo:
vector_db = gr.State(None)
qa_chain = gr.State(None)
gr.Markdown("## 🤖 Chatbot para PDFs com Modelos Gratuitos")
with gr.Tab("📤 Upload PDF"):
pdf_input = gr.Files(label="Selecione seus PDFs", file_types=[".pdf"])
with gr.Tab("⚙️ Processamento"):
chunk_size = gr.Slider(100, 1000, value=500, label="Tamanho dos Chunks")
chunk_overlap = gr.Slider(0, 200, value=50, label="Sobreposição")
process_btn = gr.Button("Processar PDFs")
process_status = gr.Textbox(label="Status do Processamento", interactive=False)
with gr.Tab("🧠 Modelo"):
model_selector = gr.Dropdown(list_llm_simple, label="Selecione o Modelo", value=list_llm_simple[1])
temperature = gr.Slider(0, 1, value=0.7, label="Criatividade")
load_model_btn = gr.Button("Carregar Modelo")
model_status = gr.Textbox(label="Status do Modelo", interactive=False)
with gr.Tab("💬 Chat"):
chatbot = gr.Chatbot(height=400)
msg = gr.Textbox(label="Sua mensagem")
clear_btn = gr.Button("Limpar Chat")
# Eventos
def process_documents(files, cs, co):
try:
file_paths = [f.name for f in files]
splits = load_doc(file_paths, cs, co)
db = create_db(splits, "docs")
return db, "Documentos processados!"
except Exception as e:
return None, f"Erro: {str(e)}"
process_btn.click(
process_documents,
inputs=[pdf_input, chunk_size, chunk_overlap],
outputs=[vector_db, process_status]
)
def load_model(model, temp, vector_db_state):
try:
if vector_db_state is None:
raise ValueError("Processe os documentos primeiro.")
model_name = list_llm[list_llm_simple.index(model)]
qa = initialize_llmchain(model_name, temp, 512, 3, vector_db_state)
return qa, "Modelo carregado!"
except Exception as e:
return None, f"Erro: {str(e)}"
load_model_btn.click(
load_model,
inputs=[model_selector, temperature, vector_db],
outputs=[qa_chain, model_status]
)
def respond(message, chat_history):
if not qa_chain.value:
return "Erro: Modelo não carregado ou documentos não processados!", chat_history
try:
result = qa_chain.value({"question": message, "chat_history": chat_history})
response = result["answer"]
sources = "\n".join([f"📄 Página {doc.metadata['page']+1}: {doc.page_content[:50]}..."
for doc in result.get("source_documents", [])[:2]])
chat_history.append((message, f"{response}\n\n🔍 Fontes:\n{sources}"))
return "", chat_history
except Exception as e:
return f"Erro na geração: {str(e)}", chat_history
msg.submit(respond, [msg, chatbot], [msg, chatbot])
clear_btn.click(lambda: [], outputs=[chatbot])
demo.launch()
if __name__ == "__main__":
demo()