File size: 6,303 Bytes
c1772f8
 
 
 
 
 
f0340cd
 
c1772f8
f0340cd
 
c1772f8
 
 
 
 
 
 
 
 
 
 
f0340cd
 
 
 
 
 
 
 
 
 
 
 
 
 
c1772f8
 
 
f0340cd
c1772f8
 
 
f0340cd
 
 
 
 
 
c1772f8
 
 
 
 
f0340cd
c1772f8
f0340cd
 
 
 
 
 
 
c1772f8
f0340cd
 
 
 
 
 
 
 
 
 
 
 
 
 
c1772f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0340cd
c1772f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0340cd
c1772f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0340cd
c1772f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0340cd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from huggingface_hub import InferenceClient
import gradio as gr
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
import os
import re
from unidecode import unidecode

# CSS para estilização
css = '''
.gradio-container{max-width: 1000px !important}
h1{text-align:center}
footer {visibility: hidden}
'''

# Inicializar o cliente de inferência
client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")

# Função de pré-processamento de texto
def preprocess_text(text):
    """Pré-processa o texto removendo ruídos e normalizando."""
    # Remover números de página (ex.: "Página 1", "Page 1 of 10")
    text = re.sub(r'(Página|Page)\s+\d+(?:\s+of\s+\d+)?', '', text, flags=re.IGNORECASE)
    # Remover múltiplos espaços e quebras de linha
    text = re.sub(r'\s+', ' ', text).strip()
    # Normalizar texto (remover acentos e converter para minúsculas)
    text = unidecode(text.lower())
    return text

# Configurar o retriever com pré-processamento e indexação avançada
def initialize_retriever(file_objs, persist_directory="chroma_db"):
    """Carrega documentos PDFs, pré-processa e cria um retriever híbrido."""
    if not file_objs:
        return None, "Nenhum documento carregado."
    
    # Carregar e pré-processar documentos
    documents = []
    for file_obj in file_objs:
        loader = PyPDFLoader(file_obj.name)
        raw_docs = loader.load()
        for doc in raw_docs:
            doc.page_content = preprocess_text(doc.page_content)
            # Adicionar metadados (exemplo: página e origem)
            doc.metadata.update({"source": os.path.basename(file_obj.name)})
        documents.extend(raw_docs)
    
    # Dividir em pedaços menores
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=2048, chunk_overlap=128)
    splits = text_splitter.split_documents(documents)
    
    # Criar embeddings e banco de vetores (Chroma)
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    try:
        # Tentar carregar um banco existente
        vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
        vectorstore.add_documents(splits)  # Adicionar novos documentos
    except:
        # Criar um novo banco se não existir
        vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings, persist_directory=persist_directory)
    
    # Configurar retriever semântico
    semantic_retriever = vectorstore.as_retriever(search_kwargs={"k": 2})
    
    # Configurar retriever lexical (BM25)
    bm25_retriever = BM25Retriever.from_documents(splits)
    bm25_retriever.k = 2
    
    # Combinar em um retriever híbrido
    ensemble_retriever = EnsembleRetriever(
        retrievers=[semantic_retriever, bm25_retriever],
        weights=[0.6, 0.4]  # Mais peso para busca semântica
    )
    
    return ensemble_retriever, "Documentos processados com sucesso!"

# Formatar o prompt para RAG
def format_prompt(message, history, retriever=None, system_prompt=None):
    prompt = "<s>"
    
    # Adicionar histórico
    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response}</s> "
    
    # Adicionar instrução do sistema, se fornecida
    if system_prompt:
        prompt += f"[SYS] {system_prompt} [/SYS]"
    
    # Adicionar contexto recuperado, se houver retriever
    if retriever:
        docs = retriever.get_relevant_documents(message)
        context = "\n".join([f"[{doc.metadata.get('source', 'Unknown')}, Page {doc.metadata.get('page', 'N/A')}] {doc.page_content}" for doc in docs])
        prompt += f"[CONTEXT] {context} [/CONTEXT]"
    
    # Adicionar a mensagem do usuário
    prompt += f"[INST] {message} [/INST]"
    return prompt

# Função de geração com RAG
def generate(
    prompt, history, retriever=None, system_prompt=None, temperature=0.2, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0
):
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    # Formatar o prompt com contexto RAG
    formatted_prompt = format_prompt(prompt, history, retriever, system_prompt)

    # Gerar resposta em streaming
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text
        yield output

# Interface Gradio com RAG
def create_demo():
    with gr.Blocks(css=css) as demo:
        retriever_state = gr.State(value=None)
        status = gr.State(value="Nenhum documento carregado")

        # Título
        gr.Markdown("<h1>RAG Chatbot</h1>")

        # Seção de upload de documentos
        with gr.Row():
            with gr.Column(scale=1):
                gr.Markdown("### Carregar Documentos")
                file_input = gr.Files(label="Upload PDFs", file_types=["pdf"], file_count="multiple")
                process_btn = gr.Button("Processar Documentos")
                status_output = gr.Textbox(label="Status", value="Nenhum documento carregado")

        # Interface de chat
        chat_interface = gr.ChatInterface(
            fn=generate,
            additional_inputs=[
                gr.State(value=retriever_state),
                gr.Textbox(label="System Prompt", placeholder="Digite um prompt de sistema (opcional)", value=None)
            ],
            title="",
        )

        # Evento para processar documentos
        process_btn.click(
            fn=initialize_retriever,
            inputs=[file_input],
            outputs=[retriever_state, status_output]
        )

    return demo

# Lançar a aplicação
demo = create_demo()
demo.queue().launch(share=False)