Spaces:
Running
Running
import os | |
from typing import List, Dict, Tuple, Optional | |
from langchain_community.vectorstores import Chroma | |
from langchain_community.chat_models import ChatOpenAI | |
from langchain.chains import create_extraction_chain | |
from langchain.prompts import PromptTemplate | |
from rank_bm25 import BM25Okapi | |
import logging | |
import requests | |
from _utils.gerar_relatorio_modelo_usuario.DocumentSummarizer_simples import ( | |
DocumentSummarizer, | |
) | |
from _utils.models.gerar_relatorio import ( | |
ContextualizedChunk, | |
RetrievalConfig, | |
) | |
from modelos_usuarios.serializer import ModeloUsuarioSerializer | |
from setup.environment import api_url | |
from rest_framework.response import Response | |
from _utils.gerar_relatorio_modelo_usuario.contextual_retriever import ( | |
ContextualRetriever, | |
) | |
class EnhancedDocumentSummarizer(DocumentSummarizer): | |
def __init__( | |
self, | |
openai_api_key: str, | |
claude_api_key: str, | |
config: RetrievalConfig, | |
embedding_model, | |
chunk_size, | |
chunk_overlap, | |
num_k_rerank, | |
model_cohere_rerank, | |
claude_context_model, | |
prompt_relatorio, | |
gpt_model, | |
gpt_temperature, | |
id_modelo_do_usuario, | |
prompt_modelo, | |
reciprocal_rank_fusion, | |
): | |
super().__init__( | |
openai_api_key, | |
os.environ.get("COHERE_API_KEY"), | |
embedding_model, | |
chunk_size, | |
chunk_overlap, | |
num_k_rerank, | |
model_cohere_rerank, | |
) | |
self.config = config | |
self.contextual_retriever = ContextualRetriever( | |
config, claude_api_key, claude_context_model | |
) | |
self.logger = logging.getLogger(__name__) | |
self.prompt_relatorio = prompt_relatorio | |
self.gpt_model = gpt_model | |
self.gpt_temperature = gpt_temperature | |
self.id_modelo_do_usuario = id_modelo_do_usuario | |
self.prompt_modelo = prompt_modelo | |
self.reciprocal_rank_fusion = reciprocal_rank_fusion | |
def create_enhanced_vector_store( | |
self, chunks: List[ContextualizedChunk], is_contextualized_chunk | |
) -> Tuple[Chroma, BM25Okapi, List[str]]: | |
"""Create vector store and BM25 index with contextualized chunks""" | |
try: | |
# Prepare texts with context | |
if is_contextualized_chunk: | |
texts = [f"{chunk.context} {chunk.content}" for chunk in chunks] | |
else: | |
texts = [f"{chunk.content}" for chunk in chunks] | |
# Create vector store | |
metadatas = [] | |
for chunk in chunks: | |
if is_contextualized_chunk: | |
context = chunk.context | |
else: | |
context = "" | |
metadatas.append( | |
{ | |
"chunk_id": chunk.chunk_id, | |
"page": chunk.page_number, | |
"start_char": chunk.start_char, | |
"end_char": chunk.end_char, | |
"context": context, | |
} | |
) | |
vector_store = Chroma.from_texts( | |
texts=texts, metadatas=metadatas, embedding=self.embeddings | |
) | |
# Create BM25 index | |
tokenized_texts = [text.split() for text in texts] | |
bm25 = BM25Okapi(tokenized_texts) | |
# Get chunk IDs in order | |
chunk_ids = [chunk.chunk_id for chunk in chunks] | |
return vector_store, bm25, chunk_ids | |
except Exception as e: | |
self.logger.error(f"Error creating enhanced vector store: {str(e)}") | |
raise | |
def retrieve_with_rank_fusion( | |
self, vector_store: Chroma, bm25: BM25Okapi, chunk_ids: List[str], query: str | |
) -> List[Dict]: | |
"""Combine embedding and BM25 retrieval results""" | |
try: | |
# Get embedding results | |
embedding_results = vector_store.similarity_search_with_score( | |
query, k=self.config.num_chunks | |
) | |
# Convert embedding results to list of (chunk_id, score) | |
embedding_list = [ | |
(doc.metadata["chunk_id"], 1 / (1 + score)) | |
for doc, score in embedding_results | |
] | |
# Get BM25 results | |
tokenized_query = query.split() | |
bm25_scores = bm25.get_scores(tokenized_query) | |
# Convert BM25 scores to list of (chunk_id, score) | |
bm25_list = [ | |
(chunk_ids[i], float(score)) for i, score in enumerate(bm25_scores) | |
] | |
# Sort bm25_list by score in descending order and limit to top N results | |
bm25_list = sorted(bm25_list, key=lambda x: x[1], reverse=True)[ | |
: self.config.num_chunks | |
] | |
# Normalize BM25 scores | |
calculo_max = max( | |
[score for _, score in bm25_list] | |
) # Criei este max() pois em alguns momentos estava vindo valores 0, e reclamava que não podia dividir por 0 | |
max_bm25 = calculo_max if bm25_list and calculo_max else 1 | |
bm25_list = [(doc_id, score / max_bm25) for doc_id, score in bm25_list] | |
# Pass the lists to rank fusion | |
result_lists = [embedding_list, bm25_list] | |
weights = [self.config.embedding_weight, self.config.bm25_weight] | |
combined_results = self.reciprocal_rank_fusion( | |
result_lists, weights=weights | |
) | |
return combined_results | |
except Exception as e: | |
self.logger.error(f"Error in rank fusion retrieval: {str(e)}") | |
raise | |
def generate_enhanced_summary( | |
self, | |
vector_store: Chroma, | |
bm25: BM25Okapi, | |
chunk_ids: List[str], | |
query: str = "Summarize the main points of this document", | |
) -> List[Dict]: | |
"""Generate enhanced summary using both vector and BM25 retrieval""" | |
try: | |
# Get combined results using rank fusion | |
ranked_results = self.retrieve_with_rank_fusion( | |
vector_store, bm25, chunk_ids, query | |
) | |
# Prepare context and track sources | |
contexts = [] | |
sources = [] | |
# Get full documents for top results | |
for chunk_id, score in ranked_results[: self.config.num_chunks]: | |
results = vector_store.get( | |
where={"chunk_id": chunk_id}, include=["documents", "metadatas"] | |
) | |
if results["documents"]: | |
context = results["documents"][0] | |
metadata = results["metadatas"][0] | |
contexts.append(context) | |
sources.append( | |
{ | |
"content": context, | |
"page": metadata["page"], | |
"chunk_id": chunk_id, | |
"relevance_score": score, | |
"context": metadata.get("context", ""), | |
} | |
) | |
url_request = f"{api_url}/modelo/{self.id_modelo_do_usuario}" | |
print("url_request: ", url_request) | |
resposta = requests.get(url_request) | |
print("resposta: ", resposta) | |
if resposta.status_code != 200: | |
return Response( | |
{ | |
"error": "Ocorreu um problema. Pode ser que o modelo não tenha sido encontrado. Tente novamente e/ou entre em contato com a equipe técnica" | |
} | |
) | |
modelo_buscado = resposta.json()["modelo"] | |
# from modelos_usuarios.models import ModeloUsuarioModel | |
# # try: | |
# modelo_buscado = ModeloUsuarioModel.objects.get( | |
# pk=self.id_modelo_do_usuario | |
# ) | |
# serializer = ModeloUsuarioSerializer(modelo_buscado) | |
# print("serializer.data: ", serializer.data) | |
# except: | |
# return Response( | |
# { | |
# "error": "Ocorreu um problema. Pode ser que o modelo não tenha sido encontrado. Tente novamente e/ou entre em contato com a equipe técnica" | |
# } | |
# ) | |
# print("modelo_buscado: ", modelo_buscado) | |
llm = ChatOpenAI( | |
temperature=self.gpt_temperature, | |
model_name=self.gpt_model, | |
api_key=self.openai_api_key, | |
) | |
prompt_gerar_relatorio = PromptTemplate( | |
template=self.prompt_relatorio, input_variables=["context"] | |
) | |
relatorio_gerado = llm.predict( | |
prompt_gerar_relatorio.format(context="\n\n".join(contexts)) | |
) | |
prompt_gerar_modelo = PromptTemplate( | |
template=self.prompt_modelo, | |
input_variables=["context", "modelo_usuario"], | |
) | |
modelo_gerado = llm.predict( | |
prompt_gerar_modelo.format( | |
context=relatorio_gerado, modelo_usuario=modelo_buscado | |
) | |
) | |
# Split the response into paragraphs | |
summaries = [p.strip() for p in modelo_gerado.split("\n\n") if p.strip()] | |
# Create structured output | |
structured_output = [] | |
for idx, summary in enumerate(summaries): | |
source_idx = min(idx, len(sources) - 1) | |
structured_output.append( | |
{ | |
"content": summary, | |
"source": { | |
"page": sources[source_idx]["page"], | |
"text": sources[source_idx]["content"][:200] + "...", | |
"context": sources[source_idx]["context"], | |
"relevance_score": sources[source_idx]["relevance_score"], | |
"chunk_id": sources[source_idx]["chunk_id"], | |
}, | |
} | |
) | |
return structured_output | |
except Exception as e: | |
self.logger.error(f"Error generating enhanced summary: {str(e)}") | |
raise | |