vella-backend / _utils /gerar_relatorio_modelo_usuario /EnhancedDocumentSummarizer.py
luanpoppe
fix
fe21938
raw
history blame
10.3 kB
import os
from typing import List, Dict, Tuple, Optional
from langchain_community.vectorstores import Chroma
from langchain_community.chat_models import ChatOpenAI
from langchain.chains import create_extraction_chain
from langchain.prompts import PromptTemplate
from rank_bm25 import BM25Okapi
import logging
import requests
from _utils.gerar_relatorio_modelo_usuario.DocumentSummarizer_simples import (
DocumentSummarizer,
)
from _utils.models.gerar_relatorio import (
ContextualizedChunk,
RetrievalConfig,
)
from modelos_usuarios.serializer import ModeloUsuarioSerializer
from setup.environment import api_url
from rest_framework.response import Response
from _utils.gerar_relatorio_modelo_usuario.contextual_retriever import (
ContextualRetriever,
)
class EnhancedDocumentSummarizer(DocumentSummarizer):
def __init__(
self,
openai_api_key: str,
claude_api_key: str,
config: RetrievalConfig,
embedding_model,
chunk_size,
chunk_overlap,
num_k_rerank,
model_cohere_rerank,
claude_context_model,
prompt_relatorio,
gpt_model,
gpt_temperature,
id_modelo_do_usuario,
prompt_modelo,
reciprocal_rank_fusion,
):
super().__init__(
openai_api_key,
os.environ.get("COHERE_API_KEY"),
embedding_model,
chunk_size,
chunk_overlap,
num_k_rerank,
model_cohere_rerank,
)
self.config = config
self.contextual_retriever = ContextualRetriever(
config, claude_api_key, claude_context_model
)
self.logger = logging.getLogger(__name__)
self.prompt_relatorio = prompt_relatorio
self.gpt_model = gpt_model
self.gpt_temperature = gpt_temperature
self.id_modelo_do_usuario = id_modelo_do_usuario
self.prompt_modelo = prompt_modelo
self.reciprocal_rank_fusion = reciprocal_rank_fusion
def create_enhanced_vector_store(
self, chunks: List[ContextualizedChunk], is_contextualized_chunk
) -> Tuple[Chroma, BM25Okapi, List[str]]:
"""Create vector store and BM25 index with contextualized chunks"""
try:
# Prepare texts with context
if is_contextualized_chunk:
texts = [f"{chunk.context} {chunk.content}" for chunk in chunks]
else:
texts = [f"{chunk.content}" for chunk in chunks]
# Create vector store
metadatas = []
for chunk in chunks:
if is_contextualized_chunk:
context = chunk.context
else:
context = ""
metadatas.append(
{
"chunk_id": chunk.chunk_id,
"page": chunk.page_number,
"start_char": chunk.start_char,
"end_char": chunk.end_char,
"context": context,
}
)
vector_store = Chroma.from_texts(
texts=texts, metadatas=metadatas, embedding=self.embeddings
)
# Create BM25 index
tokenized_texts = [text.split() for text in texts]
bm25 = BM25Okapi(tokenized_texts)
# Get chunk IDs in order
chunk_ids = [chunk.chunk_id for chunk in chunks]
return vector_store, bm25, chunk_ids
except Exception as e:
self.logger.error(f"Error creating enhanced vector store: {str(e)}")
raise
def retrieve_with_rank_fusion(
self, vector_store: Chroma, bm25: BM25Okapi, chunk_ids: List[str], query: str
) -> List[Dict]:
"""Combine embedding and BM25 retrieval results"""
try:
# Get embedding results
embedding_results = vector_store.similarity_search_with_score(
query, k=self.config.num_chunks
)
# Convert embedding results to list of (chunk_id, score)
embedding_list = [
(doc.metadata["chunk_id"], 1 / (1 + score))
for doc, score in embedding_results
]
# Get BM25 results
tokenized_query = query.split()
bm25_scores = bm25.get_scores(tokenized_query)
# Convert BM25 scores to list of (chunk_id, score)
bm25_list = [
(chunk_ids[i], float(score)) for i, score in enumerate(bm25_scores)
]
# Sort bm25_list by score in descending order and limit to top N results
bm25_list = sorted(bm25_list, key=lambda x: x[1], reverse=True)[
: self.config.num_chunks
]
# Normalize BM25 scores
calculo_max = max(
[score for _, score in bm25_list]
) # Criei este max() pois em alguns momentos estava vindo valores 0, e reclamava que não podia dividir por 0
max_bm25 = calculo_max if bm25_list and calculo_max else 1
bm25_list = [(doc_id, score / max_bm25) for doc_id, score in bm25_list]
# Pass the lists to rank fusion
result_lists = [embedding_list, bm25_list]
weights = [self.config.embedding_weight, self.config.bm25_weight]
combined_results = self.reciprocal_rank_fusion(
result_lists, weights=weights
)
return combined_results
except Exception as e:
self.logger.error(f"Error in rank fusion retrieval: {str(e)}")
raise
def generate_enhanced_summary(
self,
vector_store: Chroma,
bm25: BM25Okapi,
chunk_ids: List[str],
query: str = "Summarize the main points of this document",
) -> List[Dict]:
"""Generate enhanced summary using both vector and BM25 retrieval"""
try:
# Get combined results using rank fusion
ranked_results = self.retrieve_with_rank_fusion(
vector_store, bm25, chunk_ids, query
)
# Prepare context and track sources
contexts = []
sources = []
# Get full documents for top results
for chunk_id, score in ranked_results[: self.config.num_chunks]:
results = vector_store.get(
where={"chunk_id": chunk_id}, include=["documents", "metadatas"]
)
if results["documents"]:
context = results["documents"][0]
metadata = results["metadatas"][0]
contexts.append(context)
sources.append(
{
"content": context,
"page": metadata["page"],
"chunk_id": chunk_id,
"relevance_score": score,
"context": metadata.get("context", ""),
}
)
url_request = f"{api_url}/modelo/{self.id_modelo_do_usuario}"
print("url_request: ", url_request)
resposta = requests.get(url_request)
print("resposta: ", resposta)
if resposta.status_code != 200:
return Response(
{
"error": "Ocorreu um problema. Pode ser que o modelo não tenha sido encontrado. Tente novamente e/ou entre em contato com a equipe técnica"
}
)
modelo_buscado = resposta.json()["modelo"]
# from modelos_usuarios.models import ModeloUsuarioModel
# # try:
# modelo_buscado = ModeloUsuarioModel.objects.get(
# pk=self.id_modelo_do_usuario
# )
# serializer = ModeloUsuarioSerializer(modelo_buscado)
# print("serializer.data: ", serializer.data)
# except:
# return Response(
# {
# "error": "Ocorreu um problema. Pode ser que o modelo não tenha sido encontrado. Tente novamente e/ou entre em contato com a equipe técnica"
# }
# )
# print("modelo_buscado: ", modelo_buscado)
llm = ChatOpenAI(
temperature=self.gpt_temperature,
model_name=self.gpt_model,
api_key=self.openai_api_key,
)
prompt_gerar_relatorio = PromptTemplate(
template=self.prompt_relatorio, input_variables=["context"]
)
relatorio_gerado = llm.predict(
prompt_gerar_relatorio.format(context="\n\n".join(contexts))
)
prompt_gerar_modelo = PromptTemplate(
template=self.prompt_modelo,
input_variables=["context", "modelo_usuario"],
)
modelo_gerado = llm.predict(
prompt_gerar_modelo.format(
context=relatorio_gerado, modelo_usuario=modelo_buscado
)
)
# Split the response into paragraphs
summaries = [p.strip() for p in modelo_gerado.split("\n\n") if p.strip()]
# Create structured output
structured_output = []
for idx, summary in enumerate(summaries):
source_idx = min(idx, len(sources) - 1)
structured_output.append(
{
"content": summary,
"source": {
"page": sources[source_idx]["page"],
"text": sources[source_idx]["content"][:200] + "...",
"context": sources[source_idx]["context"],
"relevance_score": sources[source_idx]["relevance_score"],
"chunk_id": sources[source_idx]["chunk_id"],
},
}
)
return structured_output
except Exception as e:
self.logger.error(f"Error generating enhanced summary: {str(e)}")
raise