Spaces:
Running
Running
File size: 6,631 Bytes
460bd69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
"""
RAG κ²μ μ±λ΄ - κ²μκΈ° μ΄κΈ°ν λͺ¨λ
"""
import os
import logging
import pickle
import gzip
from datetime import datetime
# λ‘κ±° κ°μ Έμ€κΈ°
logger = logging.getLogger(__name__)
def save_embeddings(base_retriever, file_path):
"""μλ² λ© λ°μ΄ν°λ₯Ό μμΆνμ¬ νμΌμ μ μ₯"""
try:
# μ μ₯ λλ ν λ¦¬κ° μμΌλ©΄ μμ±
os.makedirs(os.path.dirname(file_path), exist_ok=True)
# νμμ€ν¬ν μΆκ°
save_data = {
'timestamp': datetime.now().isoformat(),
'retriever': base_retriever
}
# μμΆνμ¬ μ μ₯ (μ©λ μ€μ΄κΈ°)
with gzip.open(file_path, 'wb') as f:
pickle.dump(save_data, f)
logger.info(f"μλ² λ© λ°μ΄ν°λ₯Ό {file_path}μ μμΆνμ¬ μ μ₯νμ΅λλ€.")
return True
except Exception as e:
logger.error(f"μλ² λ© μ μ₯ μ€ μ€λ₯ λ°μ: {e}")
return False
def load_embeddings(file_path, max_age_days=30):
"""μ μ₯λ μλ² λ© λ°μ΄ν°λ₯Ό νμΌμμ λ‘λ"""
try:
if not os.path.exists(file_path):
logger.info(f"μ μ₯λ μλ² λ© νμΌ({file_path})μ΄ μμ΅λλ€.")
return None
# μμΆ νμΌ λ‘λ
with gzip.open(file_path, 'rb') as f:
data = pickle.load(f)
# νμμ€ν¬ν νμΈ (λ무 μ€λλ λ°μ΄ν°λ μ¬μ©νμ§ μμ)
saved_time = datetime.fromisoformat(data['timestamp'])
age = (datetime.now() - saved_time).days
if age > max_age_days:
logger.info(f"μ μ₯λ μλ² λ©μ΄ {age}μΌλ‘ λ무 μ€λλμμ΅λλ€. μλ‘ μμ±ν©λλ€.")
return None
logger.info(f"{file_path}μμ μλ² λ© λ°μ΄ν°λ₯Ό λ‘λνμ΅λλ€. (μμ±μΌ: {saved_time})")
return data['retriever']
except Exception as e:
logger.error(f"μλ² λ© λ‘λ μ€ μ€λ₯ λ°μ: {e}")
return None
def init_retriever(app, base_retriever, retriever, ReRanker):
"""κ²μκΈ° κ°μ²΄ μ΄κΈ°ν λλ λ‘λ"""
from utils.document_processor import DocumentProcessor
from retrieval.vector_retriever import VectorRetriever
# μλ² λ© μΊμ νμΌ κ²½λ‘
cache_path = os.path.join(app.config['INDEX_PATH'], "cached_embeddings.gz")
# λ¨Όμ μ μ₯λ μλ² λ© λ°μ΄ν° λ‘λ μλ
cached_retriever = load_embeddings(cache_path)
if cached_retriever:
logger.info("μΊμλ μλ² λ© λ°μ΄ν°λ₯Ό μ±κ³΅μ μΌλ‘ λ‘λνμ΅λλ€.")
base_retriever = cached_retriever
else:
# μΊμλ λ°μ΄ν°κ° μμΌλ©΄ κΈ°μ‘΄ λ°©μμΌλ‘ μ΄κΈ°ν
index_path = app.config['INDEX_PATH']
# VectorRetriever λ‘λ λλ μ΄κΈ°ν
if os.path.exists(os.path.join(index_path, "documents.json")):
try:
logger.info(f"κΈ°μ‘΄ λ²‘ν° μΈλ±μ€λ₯Ό '{index_path}'μμ λ‘λν©λλ€...")
base_retriever = VectorRetriever.load(index_path)
logger.info(f"{len(base_retriever.documents) if hasattr(base_retriever, 'documents') else 0}κ° λ¬Έμκ° λ‘λλμμ΅λλ€.")
except Exception as e:
logger.error(f"μΈλ±μ€ λ‘λ μ€ μ€λ₯ λ°μ: {e}. μ κ²μκΈ°λ₯Ό μ΄κΈ°νν©λλ€.")
base_retriever = VectorRetriever()
else:
logger.info("κΈ°μ‘΄ μΈλ±μ€λ₯Ό μ°Ύμ μ μμ΄ μ κ²μκΈ°λ₯Ό μ΄κΈ°νν©λλ€...")
base_retriever = VectorRetriever()
# λ°μ΄ν° ν΄λμ λ¬Έμ λ‘λ
data_path = app.config['DATA_FOLDER']
if (not hasattr(base_retriever, 'documents') or not base_retriever.documents) and os.path.exists(data_path):
logger.info(f"{data_path}μμ λ¬Έμλ₯Ό λ‘λν©λλ€...")
try:
docs = DocumentProcessor.load_documents_from_directory(
data_path,
extensions=[".txt", ".md", ".csv"],
recursive=True
)
if docs and hasattr(base_retriever, 'add_documents'):
logger.info(f"{len(docs)}κ° λ¬Έμλ₯Ό κ²μκΈ°μ μΆκ°ν©λλ€...")
base_retriever.add_documents(docs)
if hasattr(base_retriever, 'save'):
logger.info(f"κ²μκΈ° μνλ₯Ό '{index_path}'μ μ μ₯ν©λλ€...")
try:
base_retriever.save(index_path)
logger.info("μΈλ±μ€ μ μ₯ μλ£")
# μλ‘ μμ±λ κ²μκΈ° μΊμ±
if hasattr(base_retriever, 'documents') and base_retriever.documents:
save_embeddings(base_retriever, cache_path)
logger.info(f"κ²μκΈ°λ₯Ό μΊμ νμΌ {cache_path}μ μ μ₯ μλ£")
except Exception as e:
logger.error(f"μΈλ±μ€ μ μ₯ μ€ μ€λ₯ λ°μ: {e}")
except Exception as e:
logger.error(f"DATA_FOLDERμμ λ¬Έμ λ‘λ μ€ μ€λ₯: {e}")
# μ¬μμν κ²μκΈ° μ΄κΈ°ν
logger.info("μ¬μμν κ²μκΈ°λ₯Ό μ΄κΈ°νν©λλ€...")
try:
# μ체 ꡬνλ μ¬μμν ν¨μ
def custom_rerank_fn(query, results):
query_terms = set(query.lower().split())
for result in results:
if isinstance(result, dict) and "text" in result:
text = result["text"].lower()
term_freq = sum(1 for term in query_terms if term in text)
normalized_score = term_freq / (len(text.split()) + 1) * 10
result["rerank_score"] = result.get("score", 0) * 0.7 + normalized_score * 0.3
elif isinstance(result, dict):
result["rerank_score"] = result.get("score", 0)
results.sort(key=lambda x: x.get("rerank_score", 0) if isinstance(x, dict) else 0, reverse=True)
return results
# ReRanker ν΄λμ€ μ¬μ©
retriever = ReRanker(
base_retriever=base_retriever,
rerank_fn=custom_rerank_fn,
rerank_field="text"
)
logger.info("μ¬μμν κ²μκΈ° μ΄κΈ°ν μλ£")
except Exception as e:
logger.error(f"μ¬μμν κ²μκΈ° μ΄κΈ°ν μ€ν¨: {e}")
retriever = base_retriever # μ€ν¨ μ κΈ°λ³Έ κ²μκΈ° μ¬μ©
return retriever
|