File size: 4,598 Bytes
c1772f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from huggingface_hub import InferenceClient
import gradio as gr
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
import os

# CSS para estilização
css = '''
.gradio-container{max-width: 1000px !important}
h1{text-align:center}
footer {visibility: hidden}
'''

# Inicializar o cliente de inferência
client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")

# Configurar o retriever globalmente
def initialize_retriever(file_objs):
    """Carrega documentos PDFs e cria um retriever."""
    if not file_objs:
        return None, "Nenhum documento carregado."
    
    # Carregar e dividir documentos
    documents = []
    for file_obj in file_objs:
        loader = PyPDFLoader(file_obj.name)
        documents.extend(loader.load())
    
    # Dividir em pedaços menores
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=2048, chunk_overlap=128)
    splits = text_splitter.split_documents(documents)
    
    # Criar embeddings e banco de vetores
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings)
    retriever = vectorstore.as_retriever(search_kwargs={"k": 2})  # Retorna 2 documentos mais relevantes
    
    return retriever, "Documentos processados com sucesso!"

# Formatar o prompt para RAG
def format_prompt(message, history, retriever=None, system_prompt=None):
    prompt = "<s>"
    
    # Adicionar histórico
    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response}</s> "
    
    # Adicionar instrução do sistema, se fornecida
    if system_prompt:
        prompt += f"[SYS] {system_prompt} [/SYS]"
    
    # Adicionar contexto recuperado, se houver retriever
    if retriever:
        # Buscar documentos relevantes
        docs = retriever.get_relevant_documents(message)
        context = "\n".join([doc.page_content for doc in docs])
        prompt += f"[CONTEXT] {context} [/CONTEXT]"
    
    # Adicionar a mensagem do usuário
    prompt += f"[INST] {message} [/INST]"
    return prompt

# Função de geração com RAG
def generate(
    prompt, history, retriever=None, system_prompt=None, temperature=0.2, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0
):
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    # Formatar o prompt com contexto RAG, se disponível
    formatted_prompt = format_prompt(prompt, history, retriever, system_prompt)

    # Gerar resposta em streaming
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text
        yield output

# Interface Gradio com RAG
def create_demo():
    with gr.Blocks(css=css) as demo:
        retriever_state = gr.State(value=None)
        status = gr.State(value="Nenhum documento carregado")

        # Título
        gr.Markdown("<h1>RAG Chatbot</h1>")

        # Seção de upload de documentos
        with gr.Row():
            with gr.Column(scale=1):
                gr.Markdown("### Carregar Documentos")
                file_input = gr.Files(label="Upload PDFs", file_types=["pdf"], file_count="multiple")
                process_btn = gr.Button("Processar Documentos")
                status_output = gr.Textbox(label="Status", value="Nenhum documento carregado")

        # Interface de chat
        chat_interface = gr.ChatInterface(
            fn=generate,
            additional_inputs=[
                gr.State(value=retriever_state),  # Passa o retriever como entrada adicional
                gr.Textbox(label="System Prompt", placeholder="Digite um prompt de sistema (opcional)", value=None)
            ],
            title="",
        )

        # Evento para processar documentos
        process_btn.click(
            fn=initialize_retriever,
            inputs=[file_input],
            outputs=[retriever_state, status_output]
        )

    return demo

# Lançar a aplicação
demo = create_demo()
demo.queue().launch(share=False)