Spaces:
Sleeping
Sleeping
""" | |
๋ฒกํฐ ๊ฒ์ ๊ตฌํ ๋ชจ๋ | |
""" | |
import os | |
import numpy as np | |
from typing import List, Dict, Any, Optional, Union, Tuple | |
import logging | |
from sentence_transformers import SentenceTransformer | |
from .base_retriever import BaseRetriever | |
logger = logging.getLogger(__name__) | |
class VectorRetriever(BaseRetriever): | |
""" | |
์๋ฒ ๋ฉ ๊ธฐ๋ฐ ๋ฒกํฐ ๊ฒ์ ๊ตฌํ | |
""" | |
def __init__( | |
self, | |
embedding_model: Optional[Union[str, SentenceTransformer]] = "paraphrase-multilingual-MiniLM-L12-v2", | |
documents: Optional[List[Dict[str, Any]]] = None, | |
embedding_field: str = "text", | |
embedding_device: str = "cpu" | |
): | |
""" | |
VectorRetriever ์ด๊ธฐํ | |
Args: | |
embedding_model: ์๋ฒ ๋ฉ ๋ชจ๋ธ ์ด๋ฆ ๋๋ SentenceTransformer ์ธ์คํด์ค | |
documents: ์ด๊ธฐ ๋ฌธ์ ๋ชฉ๋ก (์ ํ ์ฌํญ) | |
embedding_field: ์๋ฒ ๋ฉํ ๋ฌธ์ ํ๋ ์ด๋ฆ | |
embedding_device: ์๋ฒ ๋ฉ ๋ชจ๋ธ ์คํ ์ฅ์น ('cpu' ๋๋ 'cuda') | |
""" | |
self.embedding_field = embedding_field | |
self.model_name = None | |
# ์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ก๋ | |
if isinstance(embedding_model, str): | |
logger.info(f"์๋ฒ ๋ฉ ๋ชจ๋ธ '{embedding_model}' ๋ก๋ ์ค...") | |
self.model_name = embedding_model | |
self.embedding_model = SentenceTransformer(embedding_model, device=embedding_device) | |
else: | |
self.embedding_model = embedding_model | |
# ๋ชจ๋ธ์ด ์ด๋ฏธ ๋ก๋๋ ์ธ์คํด์ค์ผ ๊ฒฝ์ฐ ์ด๋ฆ ์ถ์ถ | |
if hasattr(embedding_model, '_modules') and 'modules' in embedding_model._modules: | |
self.model_name = "loaded_sentence_transformer" | |
# ๋ฌธ์ ์ ์ฅ์ ์ด๊ธฐํ | |
self.documents = [] | |
self.document_embeddings = None | |
# ์ด๊ธฐ ๋ฌธ์๊ฐ ์ ๊ณต๋ ๊ฒฝ์ฐ ์ถ๊ฐ | |
if documents: | |
self.add_documents(documents) | |
def add_documents(self, documents: List[Dict[str, Any]]) -> None: | |
""" | |
๊ฒ์๊ธฐ์ ๋ฌธ์๋ฅผ ์ถ๊ฐํ๊ณ ์๋ฒ ๋ฉ ์์ฑ | |
Args: | |
documents: ์ถ๊ฐํ ๋ฌธ์ ๋ชฉ๋ก | |
""" | |
if not documents: | |
logger.warning("์ถ๊ฐํ ๋ฌธ์๊ฐ ์์ต๋๋ค.") | |
return | |
# ๋ฌธ์ ์ถ๊ฐ | |
document_texts = [] | |
for doc in documents: | |
if self.embedding_field not in doc: | |
logger.warning(f"๋ฌธ์์ ํ๋ '{self.embedding_field}'๊ฐ ์์ต๋๋ค. ๊ฑด๋๋๋๋ค.") | |
continue | |
self.documents.append(doc) | |
document_texts.append(doc[self.embedding_field]) | |
if not document_texts: | |
logger.warning(f"์๋ฒ ๋ฉํ ํ ์คํธ๊ฐ ์์ต๋๋ค. ๋ชจ๋ ๋ฌธ์์ '{self.embedding_field}' ํ๋๊ฐ ์๋์ง ํ์ธํ์ธ์.") | |
return | |
# ๋ฌธ์ ์๋ฒ ๋ฉ ์์ฑ | |
logger.info(f"{len(document_texts)}๊ฐ ๋ฌธ์์ ์๋ฒ ๋ฉ ์์ฑ ์ค...") | |
new_embeddings = self.embedding_model.encode(document_texts, show_progress_bar=True) | |
# ๊ธฐ์กด ์๋ฒ ๋ฉ๊ณผ ๋ณํฉ | |
if self.document_embeddings is None: | |
self.document_embeddings = new_embeddings | |
else: | |
self.document_embeddings = np.vstack([self.document_embeddings, new_embeddings]) | |
logger.info(f"์ด {len(self.documents)}๊ฐ ๋ฌธ์, {self.document_embeddings.shape[0]}๊ฐ ์๋ฒ ๋ฉ ์ ์ฅ๋จ") | |
def search(self, query: str, top_k: int = 5, **kwargs) -> List[Dict[str, Any]]: | |
""" | |
์ฟผ๋ฆฌ์ ๋ํ ๋ฒกํฐ ๊ฒ์ ์ํ | |
Args: | |
query: ๊ฒ์ ์ฟผ๋ฆฌ | |
top_k: ๋ฐํํ ์์ ๊ฒฐ๊ณผ ์ | |
**kwargs: ์ถ๊ฐ ๊ฒ์ ๋งค๊ฐ๋ณ์ | |
Returns: | |
๊ด๋ จ์ฑ ์ ์์ ํจ๊ป ๊ฒ์๋ ๋ฌธ์ ๋ชฉ๋ก | |
""" | |
if not self.documents or self.document_embeddings is None: | |
logger.warning("๊ฒ์ํ ๋ฌธ์๊ฐ ์์ต๋๋ค.") | |
return [] | |
# ์ฟผ๋ฆฌ ์๋ฒ ๋ฉ ์์ฑ | |
query_embedding = self.embedding_model.encode(query) | |
# ์ฝ์ฌ์ธ ์ ์ฌ๋ ๊ณ์ฐ | |
scores = np.dot(self.document_embeddings, query_embedding) / ( | |
np.linalg.norm(self.document_embeddings, axis=1) * np.linalg.norm(query_embedding) | |
) | |
# ์์ ๊ฒฐ๊ณผ ์ ํ | |
top_indices = np.argsort(scores)[-top_k:][::-1] | |
# ๊ฒฐ๊ณผ ํ์ํ | |
results = [] | |
for idx in top_indices: | |
doc = self.documents[idx].copy() | |
doc["score"] = float(scores[idx]) | |
results.append(doc) | |
return results | |
def save(self, directory: str) -> None: | |
""" | |
๊ฒ์๊ธฐ ์ํ๋ฅผ ๋์คํฌ์ ์ ์ฅ | |
Args: | |
directory: ์ ์ฅํ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก | |
""" | |
import pickle | |
import json | |
os.makedirs(directory, exist_ok=True) | |
# ๋ฌธ์ ์ ์ฅ | |
with open(os.path.join(directory, "documents.json"), "w", encoding="utf-8") as f: | |
json.dump(self.documents, f, ensure_ascii=False, indent=2) | |
# ์๋ฒ ๋ฉ ์ ์ฅ | |
if self.document_embeddings is not None: | |
np.save(os.path.join(directory, "embeddings.npy"), self.document_embeddings) | |
# ๋ชจ๋ธ ์ ๋ณด ์ ์ฅ | |
model_info = { | |
"model_name": self.model_name or "paraphrase-multilingual-MiniLM-L12-v2", # ๊ธฐ๋ณธ๊ฐ ์ค์ | |
"embedding_dim": self.embedding_model.get_sentence_embedding_dimension() if hasattr(self.embedding_model, 'get_sentence_embedding_dimension') else 384 | |
} | |
with open(os.path.join(directory, "model_info.json"), "w") as f: | |
json.dump(model_info, f) | |
logger.info(f"๊ฒ์๊ธฐ ์ํ๋ฅผ '{directory}'์ ์ ์ฅํ์ต๋๋ค.") | |
def load(cls, directory: str, embedding_model: Optional[Union[str, SentenceTransformer]] = None) -> "VectorRetriever": | |
""" | |
๋์คํฌ์์ ๊ฒ์๊ธฐ ์ํ๋ฅผ ๋ก๋ | |
Args: | |
directory: ๋ก๋ํ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก | |
embedding_model: ์ฌ์ฉํ ์๋ฒ ๋ฉ ๋ชจ๋ธ (์ ๊ณต๋์ง ์์ผ๋ฉด ์ ์ฅ๋ ์ ๋ณด ์ฌ์ฉ) | |
Returns: | |
๋ก๋๋ VectorRetriever ์ธ์คํด์ค | |
""" | |
import json | |
# ๋ชจ๋ธ ์ ๋ณด ๋ก๋ | |
with open(os.path.join(directory, "model_info.json"), "r") as f: | |
model_info = json.load(f) | |
# ์๋ฒ ๋ฉ ๋ชจ๋ธ ์ธ์คํด์คํ | |
if embedding_model is None: | |
# ๋ชจ๋ธ ์ด๋ฆ์ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ ์ธ์คํด์คํ | |
if "model_name" in model_info and isinstance(model_info["model_name"], str): | |
embedding_model = model_info["model_name"] | |
else: | |
# ์์ ์ฅ์น: ๋ชจ๋ธ ์ด๋ฆ์ด ์๊ฑฐ๋ ์ ์์ธ ๊ฒฝ์ฐ(์ด์ ๋ฒ์ ํธํ์ฑ) ๊ธฐ๋ณธ ๋ชจ๋ธ ์ฌ์ฉ | |
logger.warning("์ ํจํ ๋ชจ๋ธ ์ด๋ฆ์ ์ฐพ์ ์ ์์ต๋๋ค. ๊ธฐ๋ณธ ๋ชจ๋ธ์ ์ฌ์ฉํฉ๋๋ค.") | |
embedding_model = "paraphrase-multilingual-MiniLM-L12-v2" | |
# ๊ฒ์๊ธฐ ์ธ์คํด์ค ์์ฑ (๋ฌธ์ ์์ด) | |
retriever = cls(embedding_model=embedding_model) | |
# ๋ฌธ์ ๋ก๋ | |
with open(os.path.join(directory, "documents.json"), "r", encoding="utf-8") as f: | |
retriever.documents = json.load(f) | |
# ์๋ฒ ๋ฉ ๋ก๋ | |
embeddings_path = os.path.join(directory, "embeddings.npy") | |
if os.path.exists(embeddings_path): | |
retriever.document_embeddings = np.load(embeddings_path) | |
logger.info(f"๊ฒ์๊ธฐ ์ํ๋ฅผ '{directory}'์์ ๋ก๋ํ์ต๋๋ค.") | |
return retriever | |