RAG_AgenticServer_Small / retrieval /vector_retriever.py
jeongsoo's picture
init
6575706
"""
๋ฒกํ„ฐ ๊ฒ€์ƒ‰ ๊ตฌํ˜„ ๋ชจ๋“ˆ
"""
import os
import numpy as np
from typing import List, Dict, Any, Optional, Union, Tuple
import logging
from sentence_transformers import SentenceTransformer
from .base_retriever import BaseRetriever
logger = logging.getLogger(__name__)
class VectorRetriever(BaseRetriever):
"""
์ž„๋ฒ ๋”ฉ ๊ธฐ๋ฐ˜ ๋ฒกํ„ฐ ๊ฒ€์ƒ‰ ๊ตฌํ˜„
"""
def __init__(
self,
embedding_model: Optional[Union[str, SentenceTransformer]] = "paraphrase-multilingual-MiniLM-L12-v2",
documents: Optional[List[Dict[str, Any]]] = None,
embedding_field: str = "text",
embedding_device: str = "cpu"
):
"""
VectorRetriever ์ดˆ๊ธฐํ™”
Args:
embedding_model: ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ์ด๋ฆ„ ๋˜๋Š” SentenceTransformer ์ธ์Šคํ„ด์Šค
documents: ์ดˆ๊ธฐ ๋ฌธ์„œ ๋ชฉ๋ก (์„ ํƒ ์‚ฌํ•ญ)
embedding_field: ์ž„๋ฒ ๋”ฉํ•  ๋ฌธ์„œ ํ•„๋“œ ์ด๋ฆ„
embedding_device: ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ์‹คํ–‰ ์žฅ์น˜ ('cpu' ๋˜๋Š” 'cuda')
"""
self.embedding_field = embedding_field
self.model_name = None
# ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋กœ๋“œ
if isinstance(embedding_model, str):
logger.info(f"์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ '{embedding_model}' ๋กœ๋“œ ์ค‘...")
self.model_name = embedding_model
self.embedding_model = SentenceTransformer(embedding_model, device=embedding_device)
else:
self.embedding_model = embedding_model
# ๋ชจ๋ธ์ด ์ด๋ฏธ ๋กœ๋“œ๋œ ์ธ์Šคํ„ด์Šค์ผ ๊ฒฝ์šฐ ์ด๋ฆ„ ์ถ”์ถœ
if hasattr(embedding_model, '_modules') and 'modules' in embedding_model._modules:
self.model_name = "loaded_sentence_transformer"
# ๋ฌธ์„œ ์ €์žฅ์†Œ ์ดˆ๊ธฐํ™”
self.documents = []
self.document_embeddings = None
# ์ดˆ๊ธฐ ๋ฌธ์„œ๊ฐ€ ์ œ๊ณต๋œ ๊ฒฝ์šฐ ์ถ”๊ฐ€
if documents:
self.add_documents(documents)
def add_documents(self, documents: List[Dict[str, Any]]) -> None:
"""
๊ฒ€์ƒ‰๊ธฐ์— ๋ฌธ์„œ๋ฅผ ์ถ”๊ฐ€ํ•˜๊ณ  ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ
Args:
documents: ์ถ”๊ฐ€ํ•  ๋ฌธ์„œ ๋ชฉ๋ก
"""
if not documents:
logger.warning("์ถ”๊ฐ€ํ•  ๋ฌธ์„œ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.")
return
# ๋ฌธ์„œ ์ถ”๊ฐ€
document_texts = []
for doc in documents:
if self.embedding_field not in doc:
logger.warning(f"๋ฌธ์„œ์— ํ•„๋“œ '{self.embedding_field}'๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค. ๊ฑด๋„ˆ๋œ๋‹ˆ๋‹ค.")
continue
self.documents.append(doc)
document_texts.append(doc[self.embedding_field])
if not document_texts:
logger.warning(f"์ž„๋ฒ ๋”ฉํ•  ํ…์ŠคํŠธ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค. ๋ชจ๋“  ๋ฌธ์„œ์— '{self.embedding_field}' ํ•„๋“œ๊ฐ€ ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.")
return
# ๋ฌธ์„œ ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ
logger.info(f"{len(document_texts)}๊ฐœ ๋ฌธ์„œ์˜ ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ ์ค‘...")
new_embeddings = self.embedding_model.encode(document_texts, show_progress_bar=True)
# ๊ธฐ์กด ์ž„๋ฒ ๋”ฉ๊ณผ ๋ณ‘ํ•ฉ
if self.document_embeddings is None:
self.document_embeddings = new_embeddings
else:
self.document_embeddings = np.vstack([self.document_embeddings, new_embeddings])
logger.info(f"์ด {len(self.documents)}๊ฐœ ๋ฌธ์„œ, {self.document_embeddings.shape[0]}๊ฐœ ์ž„๋ฒ ๋”ฉ ์ €์žฅ๋จ")
def search(self, query: str, top_k: int = 5, **kwargs) -> List[Dict[str, Any]]:
"""
์ฟผ๋ฆฌ์— ๋Œ€ํ•œ ๋ฒกํ„ฐ ๊ฒ€์ƒ‰ ์ˆ˜ํ–‰
Args:
query: ๊ฒ€์ƒ‰ ์ฟผ๋ฆฌ
top_k: ๋ฐ˜ํ™˜ํ•  ์ƒ์œ„ ๊ฒฐ๊ณผ ์ˆ˜
**kwargs: ์ถ”๊ฐ€ ๊ฒ€์ƒ‰ ๋งค๊ฐœ๋ณ€์ˆ˜
Returns:
๊ด€๋ จ์„ฑ ์ ์ˆ˜์™€ ํ•จ๊ป˜ ๊ฒ€์ƒ‰๋œ ๋ฌธ์„œ ๋ชฉ๋ก
"""
if not self.documents or self.document_embeddings is None:
logger.warning("๊ฒ€์ƒ‰ํ•  ๋ฌธ์„œ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.")
return []
# ์ฟผ๋ฆฌ ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ
query_embedding = self.embedding_model.encode(query)
# ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ
scores = np.dot(self.document_embeddings, query_embedding) / (
np.linalg.norm(self.document_embeddings, axis=1) * np.linalg.norm(query_embedding)
)
# ์ƒ์œ„ ๊ฒฐ๊ณผ ์„ ํƒ
top_indices = np.argsort(scores)[-top_k:][::-1]
# ๊ฒฐ๊ณผ ํ˜•์‹ํ™”
results = []
for idx in top_indices:
doc = self.documents[idx].copy()
doc["score"] = float(scores[idx])
results.append(doc)
return results
def save(self, directory: str) -> None:
"""
๊ฒ€์ƒ‰๊ธฐ ์ƒํƒœ๋ฅผ ๋””์Šคํฌ์— ์ €์žฅ
Args:
directory: ์ €์žฅํ•  ๋””๋ ‰ํ† ๋ฆฌ ๊ฒฝ๋กœ
"""
import pickle
import json
os.makedirs(directory, exist_ok=True)
# ๋ฌธ์„œ ์ €์žฅ
with open(os.path.join(directory, "documents.json"), "w", encoding="utf-8") as f:
json.dump(self.documents, f, ensure_ascii=False, indent=2)
# ์ž„๋ฒ ๋”ฉ ์ €์žฅ
if self.document_embeddings is not None:
np.save(os.path.join(directory, "embeddings.npy"), self.document_embeddings)
# ๋ชจ๋ธ ์ •๋ณด ์ €์žฅ
model_info = {
"model_name": self.model_name or "paraphrase-multilingual-MiniLM-L12-v2", # ๊ธฐ๋ณธ๊ฐ’ ์„ค์ •
"embedding_dim": self.embedding_model.get_sentence_embedding_dimension() if hasattr(self.embedding_model, 'get_sentence_embedding_dimension') else 384
}
with open(os.path.join(directory, "model_info.json"), "w") as f:
json.dump(model_info, f)
logger.info(f"๊ฒ€์ƒ‰๊ธฐ ์ƒํƒœ๋ฅผ '{directory}'์— ์ €์žฅํ–ˆ์Šต๋‹ˆ๋‹ค.")
@classmethod
def load(cls, directory: str, embedding_model: Optional[Union[str, SentenceTransformer]] = None) -> "VectorRetriever":
"""
๋””์Šคํฌ์—์„œ ๊ฒ€์ƒ‰๊ธฐ ์ƒํƒœ๋ฅผ ๋กœ๋“œ
Args:
directory: ๋กœ๋“œํ•  ๋””๋ ‰ํ† ๋ฆฌ ๊ฒฝ๋กœ
embedding_model: ์‚ฌ์šฉํ•  ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ (์ œ๊ณต๋˜์ง€ ์•Š์œผ๋ฉด ์ €์žฅ๋œ ์ •๋ณด ์‚ฌ์šฉ)
Returns:
๋กœ๋“œ๋œ VectorRetriever ์ธ์Šคํ„ด์Šค
"""
import json
# ๋ชจ๋ธ ์ •๋ณด ๋กœ๋“œ
with open(os.path.join(directory, "model_info.json"), "r") as f:
model_info = json.load(f)
# ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ์ธ์Šคํ„ด์Šคํ™”
if embedding_model is None:
# ๋ชจ๋ธ ์ด๋ฆ„์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ ์ธ์Šคํ„ด์Šคํ™”
if "model_name" in model_info and isinstance(model_info["model_name"], str):
embedding_model = model_info["model_name"]
else:
# ์•ˆ์ „์žฅ์น˜: ๋ชจ๋ธ ์ด๋ฆ„์ด ์—†๊ฑฐ๋‚˜ ์ •์ˆ˜์ธ ๊ฒฝ์šฐ(์ด์ „ ๋ฒ„์ „ ํ˜ธํ™˜์„ฑ) ๊ธฐ๋ณธ ๋ชจ๋ธ ์‚ฌ์šฉ
logger.warning("์œ ํšจํ•œ ๋ชจ๋ธ ์ด๋ฆ„์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ๊ธฐ๋ณธ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.")
embedding_model = "paraphrase-multilingual-MiniLM-L12-v2"
# ๊ฒ€์ƒ‰๊ธฐ ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ (๋ฌธ์„œ ์—†์ด)
retriever = cls(embedding_model=embedding_model)
# ๋ฌธ์„œ ๋กœ๋“œ
with open(os.path.join(directory, "documents.json"), "r", encoding="utf-8") as f:
retriever.documents = json.load(f)
# ์ž„๋ฒ ๋”ฉ ๋กœ๋“œ
embeddings_path = os.path.join(directory, "embeddings.npy")
if os.path.exists(embeddings_path):
retriever.document_embeddings = np.load(embeddings_path)
logger.info(f"๊ฒ€์ƒ‰๊ธฐ ์ƒํƒœ๋ฅผ '{directory}'์—์„œ ๋กœ๋“œํ–ˆ์Šต๋‹ˆ๋‹ค.")
return retriever