|
from huggingface_hub import InferenceClient |
|
import gradio as gr |
|
from langchain_community.document_loaders import PyPDFLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.vectorstores import Chroma |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
import os |
|
|
|
|
|
css = ''' |
|
.gradio-container{max-width: 1000px !important} |
|
h1{text-align:center} |
|
footer {visibility: hidden} |
|
''' |
|
|
|
|
|
client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3") |
|
|
|
|
|
def initialize_retriever(file_objs): |
|
"""Carrega documentos PDFs e cria um retriever.""" |
|
if not file_objs: |
|
return None, "Nenhum documento carregado." |
|
|
|
|
|
documents = [] |
|
for file_obj in file_objs: |
|
loader = PyPDFLoader(file_obj.name) |
|
documents.extend(loader.load()) |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=2048, chunk_overlap=128) |
|
splits = text_splitter.split_documents(documents) |
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") |
|
vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings) |
|
retriever = vectorstore.as_retriever(search_kwargs={"k": 2}) |
|
|
|
return retriever, "Documentos processados com sucesso!" |
|
|
|
|
|
def format_prompt(message, history, retriever=None, system_prompt=None): |
|
prompt = "<s>" |
|
|
|
|
|
for user_prompt, bot_response in history: |
|
prompt += f"[INST] {user_prompt} [/INST]" |
|
prompt += f" {bot_response}</s> " |
|
|
|
|
|
if system_prompt: |
|
prompt += f"[SYS] {system_prompt} [/SYS]" |
|
|
|
|
|
if retriever: |
|
|
|
docs = retriever.get_relevant_documents(message) |
|
context = "\n".join([doc.page_content for doc in docs]) |
|
prompt += f"[CONTEXT] {context} [/CONTEXT]" |
|
|
|
|
|
prompt += f"[INST] {message} [/INST]" |
|
return prompt |
|
|
|
|
|
def generate( |
|
prompt, history, retriever=None, system_prompt=None, temperature=0.2, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0 |
|
): |
|
temperature = float(temperature) |
|
if temperature < 1e-2: |
|
temperature = 1e-2 |
|
top_p = float(top_p) |
|
|
|
generate_kwargs = dict( |
|
temperature=temperature, |
|
max_new_tokens=max_new_tokens, |
|
top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
do_sample=True, |
|
seed=42, |
|
) |
|
|
|
|
|
formatted_prompt = format_prompt(prompt, history, retriever, system_prompt) |
|
|
|
|
|
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) |
|
output = "" |
|
|
|
for response in stream: |
|
output += response.token.text |
|
yield output |
|
|
|
|
|
def create_demo(): |
|
with gr.Blocks(css=css) as demo: |
|
retriever_state = gr.State(value=None) |
|
status = gr.State(value="Nenhum documento carregado") |
|
|
|
|
|
gr.Markdown("<h1>RAG Chatbot</h1>") |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown("### Carregar Documentos") |
|
file_input = gr.Files(label="Upload PDFs", file_types=["pdf"], file_count="multiple") |
|
process_btn = gr.Button("Processar Documentos") |
|
status_output = gr.Textbox(label="Status", value="Nenhum documento carregado") |
|
|
|
|
|
chat_interface = gr.ChatInterface( |
|
fn=generate, |
|
additional_inputs=[ |
|
gr.State(value=retriever_state), |
|
gr.Textbox(label="System Prompt", placeholder="Digite um prompt de sistema (opcional)", value=None) |
|
], |
|
title="", |
|
) |
|
|
|
|
|
process_btn.click( |
|
fn=initialize_retriever, |
|
inputs=[file_input], |
|
outputs=[retriever_state, status_output] |
|
) |
|
|
|
return demo |
|
|
|
|
|
demo = create_demo() |
|
demo.queue().launch(share=False) |
|
|