LexAIcon / app.py
manuelcozar55's picture
Update app.py
ec25508 verified
raw
history blame
7.8 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from huggingface_hub import login
from PyPDF2 import PdfReader
from docx import Document
import csv
import json
import os
import torch
from langchain.document_loaders import JSONLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
huggingface_token = os.getenv('HUGGINGFACE_TOKEN')
# Realizar el inicio de sesi贸n de Hugging Face solo si el token est谩 disponible
if huggingface_token:
login(token=huggingface_token)
# Configuraci贸n del modelo de generaci贸n de texto
@st.cache_resource
def load_llm():
llm = HuggingFaceEndpoint(
repo_id="mistralai/Mistral-7B-Instruct-v0.3",
task="text-generation"
)
llm_engine_hf = ChatHuggingFace(llm=llm)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
return llm_engine_hf, tokenizer
llm_engine_hf, tokenizer = load_llm()
# Configuraci贸n del modelo de clasificaci贸n
@st.cache_resource
def load_classification_model():
tokenizer = AutoTokenizer.from_pretrained("mrm8488/legal-longformer-base-8192-spanish")
model = AutoModelForSequenceClassification.from_pretrained("mrm8488/legal-longformer-base-8192-spanish")
return model, tokenizer
classification_model, classification_tokenizer = load_classification_model()
id2label = {0: "multas", 1: "politicas_de_privacidad", 2: "contratos", 3: "denuncias", 4: "otros"}
# Cargar documentos JSON para cada categor铆a
@st.cache_resource
def load_json_documents():
documents = {}
categories = ["multas", "politicas_de_privacidad", "contratos", "denuncias", "otros"]
for category in categories:
with open(f"./{category}.json", "r", encoding="utf-8") as f:
data = json.load(f)["questions_and_answers"]
documents[category] = [entry["question"] + " " + entry["answer"] for entry in data]
return documents
json_documents = load_json_documents()
# Configuraci贸n de Embeddings y Vector Stores
@st.cache_resource
def create_vector_store():
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-l6-v2", model_kwargs={"device": "cpu"})
vector_stores = {}
for category, docs in json_documents.items():
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
split_docs = [doc for doc in text_splitter.split_text(docs)]
vector_stores[category] = FAISS.from_texts(split_docs, embeddings)
return vector_stores
vector_stores = create_vector_store()
def classify_text(text):
inputs = classification_tokenizer(text, return_tensors="pt", max_length=4096, truncation=True, padding="max_length")
classification_model.eval()
with torch.no_grad():
outputs = classification_model(**inputs)
logits = outputs.logits
predicted_class_id = logits.argmax(dim=-1).item()
predicted_label = id2label[predicted_class_id]
return predicted_label
def translate(text, target_language):
template = '''
Por favor, traduzca el siguiente documento al {LANGUAGE}:
<document>
{TEXT}
</document>
Aseg煤rese de que la traducci贸n sea precisa y conserve el significado original del documento.
'''
formatted_prompt = template.replace("{TEXT}", text).replace("{LANGUAGE}", target_language)
response = llm_engine_hf.invoke(formatted_prompt)
translated_text = response.content
return translated_text
def summarize(text, length):
template = f'''
Por favor, haga un resumen {length} del siguiente documento:
<document>
{text}
</document>
Aseg煤rese de que el resumen sea conciso y conserve el significado original del documento.
'''
response = llm_engine_hf.invoke(template)
summarized_text = response.content
return summarized_text
def handle_uploaded_file(uploaded_file):
try:
if uploaded_file.name.endswith(".txt"):
text = uploaded_file.read().decode("utf-8")
elif uploaded_file.name.endswith(".pdf"):
reader = PdfReader(uploaded_file)
text = ""
for page in range(len(reader.pages)):
text += reader.pages[page].extract_text()
elif uploaded_file.name.endswith(".docx"):
doc = Document(uploaded_file)
text = "\n".join([para.text for para in doc.paragraphs])
elif uploaded_file.name.endswith(".csv"):
text = ""
content = uploaded_file.read().decode("utf-8").splitlines()
reader = csv.reader(content)
text = " ".join([" ".join(row) for row in reader])
elif uploaded_file.name.endswith(".json"):
data = json.load(uploaded_file)
text = json.dumps(data, indent=4)
else:
text = "Tipo de archivo no soportado."
return text
except Exception as e:
return str(e)
def main():
st.image("./icon.jpg", width=100)
st.title("LexAIcon")
st.write("Puedes conversar con este chatbot basado en Mistral7B-Instruct y subir archivos para que el chatbot los procese.")
if "messages" not in st.session_state:
st.session_state["messages"] = [{"role": "assistant", "content": "驴C贸mo puedo ayudarte?"}]
with st.sidebar:
st.text_input("HuggingFace Token", value=huggingface_token, type="password", key="huggingface_token")
st.caption("[Consigue un HuggingFace Token](https://huggingface.co/settings/tokens)")
for msg in st.session_state.messages:
st.chat_message(msg["role"]).write(msg["content"])
if prompt := st.chat_input():
st.session_state.messages.append({"role": "user", "content": prompt})
st.chat_message("user").write(prompt)
operation = st.radio("Selecciona una operaci贸n", ["Resumir", "Traducir", "Explicar"])
target_language = None
summary_length = None
if operation == "Traducir":
target_language = st.selectbox("Selecciona el idioma de traducci贸n", ["espa帽ol", "ingl茅s", "franc茅s", "alem谩n"])
if operation == "Resumir":
summary_length = st.selectbox("Selecciona la longitud del resumen", ["corto", "medio", "largo"])
if uploaded_files := st.file_uploader("Sube un archivo", type=["txt", "pdf", "docx", "csv", "json"], accept_multiple_files=True):
for uploaded_file in uploaded_files:
file_content = handle_uploaded_file(uploaded_file)
classification = classify_text(file_content)
vector_store = vector_stores[classification]
search_docs = vector_store.similarity_search(prompt)
context = " ".join([doc.page_content for doc in search_docs])
prompt_with_context = f"Contexto: {context}\n\nPregunta: {prompt}"
response = llm_engine_hf.invoke(prompt_with_context)
msg = response.content
elif operation == "Resumir":
if summary_length == "corto":
length = "de aproximadamente 50 palabras"
elif summary_length == "medio":
length = "de aproximadamente 100 palabras"
elif summary_length == "largo":
length = "de aproximadamente 500 palabras"
msg = summarize(prompt, length)
elif operation == "Traducir":
msg = translate(prompt, target_language)
else:
msg = llm_engine_hf.invoke(prompt).content
st.session_state.messages.append({"role": "assistant", "content": msg})
st.chat_message("assistant").write(msg)
if __name__ == "__main__":
main()