cloudwalk_swarm / app.py
k3ybladewielder's picture
Update app.py
c08d114 verified
raw
history blame
14.6 kB
# libs
from huggingface_hub import hf_hub_download
from langchain.agents import initialize_agent, Tool, AgentType
from langchain.chains import RetrievalQA, LLMChain
from langchain_community.llms import HuggingFaceHub
from langchain.prompts import PromptTemplate
from langchain_community.vectorstores import FAISS
from langchain_community.utilities import SerpAPIWrapper
from langchain_huggingface import HuggingFacePipeline, HuggingFaceEmbeddings
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig, AutoModelForImageTextToText
import logging
import os
import torch
import yaml
import traceback
# ----------- SETUP -----------
import warnings
from dotenv import load_dotenv
from langchain_text_splitters import CharacterTextSplitter
from functions import fn_rebuild_vector_store
logging.getLogger("langchain.text_splitter").setLevel(logging.ERROR)
warnings.filterwarnings("ignore")
logging.basicConfig(format="%(asctime)s | %(levelname)s | %(message)s", level=logging.INFO)
logger = logging.getLogger(__name__)
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
SERPAPI_API_KEY = os.getenv("SERPAPI_API_KEY")
with open('./config.yaml', 'r', encoding='utf-8') as file:
config = yaml.safe_load(file)
EMBEDDING_MODEL = config.get('EMBEDDING_MODEL')
LLM_MODEL = config.get('LLM_MODEL')
LLM_MODEL_GGUF = config.get('LLM_MODEL_GGUF')
LLM_MODEL_FILE = config.get('LLM_MODEL_FILE')
REBUILD_VECTOR_STORE= config.get('REBUILD_VECTOR_STORE', False)
CHUNK_SIZE = config.get('CHUNK_SIZE', 500)
CHUNK_OVERLAP = config.get('CHUNK_OVERLAP', 50)
CACHE_FOLDER = config.get('CACHE_FOLDER', './cache')
URL_LIST = config.get('URL_LIST', [])
VS_BASE = config.get('VS_BASE', './vs')
# ----------- VECTOR STORE CREATION -----------
# executando fn para veirficacao True/False de criação de vector store
fn_rebuild_vector_store(REBUILD_VECTOR_STORE, URL_LIST, VS_BASE, EMBEDDING_MODEL, CACHE_FOLDER, CHUNK_SIZE, CHUNK_OVERLAP)
# ----------- SWARM -----------
def get_llm():
logger.info(f"Carregando modelo do HuggingFace: {LLM_MODEL}")
tokenizer = AutoTokenizer.from_pretrained(
LLM_MODEL,
cache_dir=CACHE_FOLDER)
model = AutoModelForCausalLM.from_pretrained(
LLM_MODEL,
cache_dir=CACHE_FOLDER,
device_map="auto",
torch_dtype=torch.float16
)
text_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=550,
temperature=0.6,
eos_token_id=tokenizer.eos_token_id,
return_full_text=False
)
return HuggingFacePipeline(pipeline=text_pipeline)
def get_llm():
"""Carrega o LLM quantizado localmente usando bitsandbytes e um pipeline com chat template."""
# Atualizando o nome do modelo para carregar o tokenizer correto,
# mesmo que o arquivo GGUF seja de outro repo.
# O tokenizer ainda deve ser compatível com o modelo base "google/gemma-2b-it".
TOKENIZER_MODEL = LLM_MODEL
model_path = os.path.join(CACHE_FOLDER, LOCAL_MODEL_FILE)
logger.info(f"Carregando LLM quantizado localmente: {LOCAL_MODEL_FILE}")
try:
# Configuração da quantização com bitsandbytes
# `load_in_4bit=True` habilita a quantização de 4 bits
# `bnb_4bit_quant_type` define o tipo de quantização (fp4 ou nf4)
# `bnb_4bit_compute_dtype` define o tipo de dados para o cálculo (float16 é comum para GPUs)
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
llm_int8_threshold=6.0, # Adicionado para compatibilidade
llm_int8_skip_modules=None, # Adicionado para compatibilidade
llm_int8_enable_fp32_cpu_offload=False # Adicionado para compatibilidade
)
# Carregando o tokenizer e o modelo, aplicando a quantização
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_MODEL, cache_dir=CACHE_FOLDER)
# O modelo Gemma-2b-it possui um template de chat embutido
if tokenizer.chat_template is None:
logger.warning("O modelo não tem um template de chat. Usando o template padrão.")
tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ '[User]: ' + message['content'] + '\n\n' }}{% else %}{{ '[Assistant]: ' + message['content'] + '\n\n' }}{% endif %}{% endfor %}"
# O modelo a ser carregado agora é o arquivo GGUF local.
# No entanto, bitsandbytes é geralmente usado para carregar modelos transformers
# não GGUF. Para GGUF, você geralmente usaria uma biblioteca como `ctransformers`
# ou `llama-cpp-python`.
# Como o código original usava bitsandbytes com um nome de arquivo GGUF,
# assumirei que a intenção era carregar um modelo compatível com transformers/bitsandbytes,
# talvez com um nome de arquivo .safetensors ou .bin.
# Vou reverter para carregar o modelo diretamente do repo original
# "google/gemma-2b-it" usando bitsandbytes, já que o arquivo GGUF
# não é o formato esperado para bitsandbytes/transformers.
# Se a intenção REALMENTE for usar o arquivo GGUF, a abordagem de carregamento
# precisará ser completamente reescrita usando uma biblioteca apropriada (ex: ctransformers).
logger.warning("Detectado uso de bitsandbytes com nome de arquivo .gguf. Bitsandbytes é para modelos transformers (ex: .safetensors, .bin). Revertendo para carregar o modelo diretamente do repo original 'google/gemma-2b-it' com bitsandbytes.")
model = AutoModelForCausalLM.from_pretrained(
TOKENIZER_MODEL, # Carregando do repo original para usar bitsandbytes
cache_dir=CACHE_FOLDER,
device_map="auto", # Tenta usar a GPU, se disponível
quantization_config=quantization_config # Adicionando a configuração de quantização
)
# Criando a pipeline de geração de texto
text_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=520,
temperature=0.3,
# Parâmetros de parada para evitar que o modelo continue a gerar após a resposta
# Note: stop_sequence pode não ser universalmente suportado por todas as pipelines/modelos
# dependendo da implementação específica.
eos_token_id=tokenizer.eos_token_id # Usar EOS token é mais robusto
)
# Retornando a LLM da LangChain que usa a pipeline
return HuggingFacePipeline(pipeline=text_pipeline)
except Exception as e:
logger.error(f"Erro ao carregar o modelo. Erro: {e}")
# Informar o usuário sobre a incompatibilidade potencial
if "bitsandbytes" in str(e).lower() and ".gguf" in LOCAL_MODEL_FILE.lower():
logger.error("Possível erro de incompatibilidade: bitsandbytes é usado para modelos transformers, não GGUF. Considere usar uma biblioteca como 'ctransformers' ou 'llama-cpp-python' para arquivos GGUF.")
raise e
def get_embedding_model():
return HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL, cache_folder=CACHE_FOLDER)
def load_vector_store():
logger.info("Loading FAISS vector store...")
embedding_model = get_embedding_model()
faiss_file = os.path.join(VS_BASE, "index.faiss")
pkl_file = os.path.join(VS_BASE, "index.pkl")
if not os.path.exists(faiss_file) or not os.path.exists(pkl_file):
raise FileNotFoundError(f"Arquivos .faiss e .pkl não encontrados em {VS_BASE}")
return FAISS.load_local(VS_BASE, embedding_model, allow_dangerous_deserialization=True)
def build_specialist_agents(vectorstore, llm):
template_base = (
"Você é um especialista da InfinityPay. Use o contexto abaixo para responder à pergunta de forma clara e direta.\n\n"
"Contexto: {context}\n\nPergunta: {question}\n\nResposta:")
prompt_template = PromptTemplate(template=template_base, input_variables=["context", "question"])
def make_agent():
return RetrievalQA.from_chain_type(
llm=llm,
retriever=vectorstore.as_retriever(),
chain_type_kwargs={"prompt": prompt_template}
)
return {
"GENERIC": Tool(name="GENERIC", func=make_agent().run, description="Agente genérico sobre a InfinityPay."),
"MAQUININHA": Tool(name="MAQUININHA", func=make_agent().run, description="Especialista em maquininhas."),
"COBRANCA_ONLINE": Tool(name="COBRANCA_ONLINE", func=make_agent().run, description="Especialista em cobranças online."),
"PDV_ECOMMERCE": Tool(name="PDV_ECOMMERCE", func=make_agent().run, description="Especialista em PDV e ecommerce."),
"CONTA_DIGITAL": Tool(name="CONTA_DIGITAL", func=make_agent().run, description="Especialista em conta digital, Pix, boleto, cartão, etc.")
}
def load_react_agent(llm):
if not SERPAPI_API_KEY or SERPAPI_API_KEY == "sua_serpapi_key":
return None
try:
react_tool = Tool(
name="WebSearch",
func=SerpAPIWrapper(serpapi_api_key=SERPAPI_API_KEY).run,
description="Busca na web."
)
return initialize_agent(
tools=[react_tool],
llm=llm,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=False,
handle_parsing_errors=True
)
except Exception as e:
logger.error(f"Erro no ReAct: {e}")
return None
def fallback_fn(input_text: str, llm) -> str:
prompt_text = (
"A seguinte pergunta do usuário não pode ser direcionada para um agente específico.\n"
"Responda de forma geral e amigável, informando que a equipe de suporte pode ajudar.\n"
f"\n\nPergunta: {input_text}"
)
try:
response = llm.invoke(prompt_text)
clean_response = response.strip().split("<eos>")[0].strip()
return clean_response.replace("[Assistant]:", "").strip()
except Exception as e:
return "Desculpe, não consegui processar sua solicitação agora."
def build_router_chain(llm, tokenizer):
return None # Roteador baseado em palavras-chave substitui LLMChain
def keyword_router(input_text: str) -> str:
keywords_map = {
"MAQUININHA": ["maquininha", "máquina", "POS", "pagamento físico"],
"COBRANCA_ONLINE": ["link de pagamento", "cobrança online", "pagamento online", "checkout"],
"PDV_ECOMMERCE": ["PDV", "ecommerce", "venda online", "loja virtual"],
"CONTA_DIGITAL": ["conta digital", "pix", "boleto", "transferência", "cartão"]
}
input_lower = input_text.lower()
for agent, keywords in keywords_map.items():
if any(keyword.lower() in input_lower for keyword in keywords):
return agent
return "GENERIC"
def keyword_router(input_text: str) -> str:
keywords_map = {
"MAQUININHA": ["maquininha", "máquina", "POS", "pagamento físico", "taxa", "%"],
"COBRANCA_ONLINE": ["pagamento", "link de pagamento", "cobrança online", "pagamento online", "checkout"],
"PDV_ECOMMERCE": ["PDV", "ecommerce", "venda online", "loja virtual"],
"CONTA_DIGITAL": ["conta digital", "pix", "boleto", "transferência", "cartão"]
}
input_lower = input_text.lower()
for agent, keywords in keywords_map.items():
if any(keyword in input_lower for keyword in keywords):
return agent
return "GENERIC" # ou "Fallback" se quiser forçar atendimento humano
# def swarm_router(input_text: str, tools: dict, router_chain, llm) -> str:
# try:
# agent_name = keyword_router(input_text)
# selected_tool = tools.get(agent_name, tools["Fallback"])
# if agent_name == "Fallback":
# return selected_tool.func(input_text, llm)
# elif selected_tool.func:
# return selected_tool.run(input_text)
# else:
# return fallback_fn(input_text, llm)
# except Exception as e:
# return fallback_fn(input_text, llm)
def swarm_router(input_text: str, tools: dict, router_chain, llm) -> str:
try:
agent_name = keyword_router(input_text)
selected_tool = tools.get(agent_name)
if selected_tool and selected_tool.func:
return selected_tool.run(input_text)
else:
return fallback_fn(input_text, llm)
except Exception as e:
return fallback_fn(input_text, llm)
import gradio as gr
# Variáveis globais para reuso no Gradio
llm = None
tokenizer = None
tools = None
router_chain = None
def setup():
global llm, tokenizer, tools, router_chain
logger.info("Inicializando Swarm via Gradio...")
try:
llm = get_llm()
tokenizer = llm.pipeline.tokenizer
except Exception as e:
logger.error("Erro ao carregar LLM.")
print(traceback.print_exc())
return "Erro ao carregar o modelo."
try:
vectorstore = load_vector_store()
except Exception as e:
logger.error("Erro ao carregar vectorstore.")
print(traceback.print_exc())
vectorstore = None
specialists = build_specialist_agents(vectorstore, llm) if vectorstore else {}
react_agent = load_react_agent(llm)
router_chain = build_router_chain(llm, tokenizer)
tools_local = {}
tools_local.update(specialists)
if react_agent:
tools_local["ReAct"] = Tool(name="ReAct", func=react_agent.run, description="Busca externa na web.")
tools_local["Fallback"] = Tool(name="Fallback", func=lambda x: fallback_fn(x, llm), description="Fallback generalista.")
tools = tools_local
def gradio_response(user_input):
if not tools:
return "Agentes ainda não estão prontos. Aguarde o carregamento."
return swarm_router(user_input, tools, router_chain, llm)
# Inicializa o sistema
setup()
# Interface Gradio
gr.Interface(
fn=gradio_response,
inputs=gr.Textbox(label="Sua pergunta", placeholder="Digite sua dúvida aqui..."),
outputs=gr.Textbox(label="Resposta do Swarm"),
title="Assistente InfinityPay",
description="Digite uma pergunta relacionada à InfinityPay e receba uma resposta especializada.",
theme="default",
examples=[
["Quais serviços a infinite pay oferece?"],
["Quais as taxas da maquininha?"],
["Como pedir uma maquininha?"],
]
).launch(share=True)