Spaces:
Sleeping
Sleeping
# --- μλ² λ© κ΄λ ¨ ν¬νΌ ν¨μ --- | |
def save_embeddings(base_retriever, file_path): | |
"""μλ² λ© λ°μ΄ν°λ₯Ό μμΆνμ¬ νμΌμ μ μ₯""" | |
try: | |
# μ μ₯ λλ ν λ¦¬κ° μμΌλ©΄ μμ± | |
os.makedirs(os.path.dirname(file_path), exist_ok=True) | |
# νμμ€ν¬ν μΆκ° | |
save_data = { | |
'timestamp': datetime.now().isoformat(), | |
'retriever': base_retriever | |
} | |
# μμΆνμ¬ μ μ₯ (μ©λ μ€μ΄κΈ°) | |
with gzip.open(file_path, 'wb') as f: | |
pickle.dump(save_data, f) | |
logger.info(f"μλ² λ© λ°μ΄ν°λ₯Ό {file_path}μ μμΆνμ¬ μ μ₯νμ΅λλ€.") | |
return True | |
except Exception as e: | |
logger.error(f"μλ² λ© μ μ₯ μ€ μ€λ₯ λ°μ: {e}") | |
return False | |
def load_embeddings(file_path, max_age_days=30): | |
"""μ μ₯λ μλ² λ© λ°μ΄ν°λ₯Ό νμΌμμ λ‘λ""" | |
try: | |
if not os.path.exists(file_path): | |
logger.info(f"μ μ₯λ μλ² λ© νμΌ({file_path})μ΄ μμ΅λλ€.") | |
return None | |
# μμΆ νμΌ λ‘λ | |
with gzip.open(file_path, 'rb') as f: | |
data = pickle.load(f) | |
# νμμ€ν¬ν νμΈ (λ무 μ€λλ λ°μ΄ν°λ μ¬μ©νμ§ μμ) | |
saved_time = datetime.fromisoformat(data['timestamp']) | |
age = (datetime.now() - saved_time).days | |
if age > max_age_days: | |
logger.info(f"μ μ₯λ μλ² λ©μ΄ {age}μΌλ‘ λ무 μ€λλμμ΅λλ€. μλ‘ μμ±ν©λλ€.") | |
return None | |
logger.info(f"{file_path}μμ μλ² λ© λ°μ΄ν°λ₯Ό λ‘λνμ΅λλ€. (μμ±μΌ: {saved_time})") | |
return data['retriever'] | |
except Exception as e: | |
logger.error(f"μλ² λ© λ‘λ μ€ μ€λ₯ λ°μ: {e}") | |
return None | |
def init_retriever(): | |
"""κ²μκΈ° κ°μ²΄ μ΄κΈ°ν λλ λ‘λ""" | |
global base_retriever, retriever | |
# μλ² λ© μΊμ νμΌ κ²½λ‘ | |
cache_path = os.path.join(app.config['INDEX_PATH'], "cached_embeddings.gz") | |
# λ¨Όμ μ μ₯λ μλ² λ© λ°μ΄ν° λ‘λ μλ | |
cached_retriever = load_embeddings(cache_path) | |
if cached_retriever: | |
logger.info("μΊμλ μλ² λ© λ°μ΄ν°λ₯Ό μ±κ³΅μ μΌλ‘ λ‘λνμ΅λλ€.") | |
base_retriever = cached_retriever | |
else: | |
# μΊμλ λ°μ΄ν°κ° μμΌλ©΄ κΈ°μ‘΄ λ°©μμΌλ‘ μ΄κΈ°ν | |
index_path = app.config['INDEX_PATH'] | |
# VectorRetriever λ‘λ λλ μ΄κΈ°ν | |
if os.path.exists(os.path.join(index_path, "documents.json")): | |
try: | |
logger.info(f"κΈ°μ‘΄ λ²‘ν° μΈλ±μ€λ₯Ό '{index_path}'μμ λ‘λν©λλ€...") | |
base_retriever = VectorRetriever.load(index_path) | |
logger.info(f"{len(base_retriever.documents) if hasattr(base_retriever, 'documents') else 0}κ° λ¬Έμκ° λ‘λλμμ΅λλ€.") | |
except Exception as e: | |
logger.error(f"μΈλ±μ€ λ‘λ μ€ μ€λ₯ λ°μ: {e}. μ κ²μκΈ°λ₯Ό μ΄κΈ°νν©λλ€.") | |
base_retriever = VectorRetriever() | |
else: | |
logger.info("κΈ°μ‘΄ μΈλ±μ€λ₯Ό μ°Ύμ μ μμ΄ μ κ²μκΈ°λ₯Ό μ΄κΈ°νν©λλ€...") | |
base_retriever = VectorRetriever() | |
# λ°μ΄ν° ν΄λμ λ¬Έμ λ‘λ | |
data_path = app.config['DATA_FOLDER'] | |
if (not hasattr(base_retriever, 'documents') or not base_retriever.documents) and os.path.exists(data_path): | |
logger.info(f"{data_path}μμ λ¬Έμλ₯Ό λ‘λν©λλ€...") | |
try: | |
docs = DocumentProcessor.load_documents_from_directory( | |
data_path, | |
extensions=[".txt", ".md", ".csv"], | |
recursive=True | |
) | |
if docs and hasattr(base_retriever, 'add_documents'): | |
logger.info(f"{len(docs)}κ° λ¬Έμλ₯Ό κ²μκΈ°μ μΆκ°ν©λλ€...") | |
base_retriever.add_documents(docs) | |
if hasattr(base_retriever, 'save'): | |
logger.info(f"κ²μκΈ° μνλ₯Ό '{index_path}'μ μ μ₯ν©λλ€...") | |
try: | |
base_retriever.save(index_path) | |
logger.info("μΈλ±μ€ μ μ₯ μλ£") | |
# μλ‘ μμ±λ κ²μκΈ° μΊμ± | |
if hasattr(base_retriever, 'documents') and base_retriever.documents: | |
save_embeddings(base_retriever, cache_path) | |
logger.info(f"κ²μκΈ°λ₯Ό μΊμ νμΌ {cache_path}μ μ μ₯ μλ£") | |
except Exception as e: | |
logger.error(f"μΈλ±μ€ μ μ₯ μ€ μ€λ₯ λ°μ: {e}") | |
except Exception as e: | |
logger.error(f"DATA_FOLDERμμ λ¬Έμ λ‘λ μ€ μ€λ₯: {e}") | |
# μ¬μμν κ²μκΈ° μ΄κΈ°ν | |
logger.info("μ¬μμν κ²μκΈ°λ₯Ό μ΄κΈ°νν©λλ€...") | |
try: | |
# μ체 ꡬνλ μ¬μμν ν¨μ | |
def custom_rerank_fn(query, results): | |
query_terms = set(query.lower().split()) | |
for result in results: | |
if isinstance(result, dict) and "text" in result: | |
text = result["text"].lower() | |
term_freq = sum(1 for term in query_terms if term in text) | |
normalized_score = term_freq / (len(text.split()) + 1) * 10 | |
result["rerank_score"] = result.get("score", 0) * 0.7 + normalized_score * 0.3 | |
elif isinstance(result, dict): | |
result["rerank_score"] = result.get("score", 0) | |
results.sort(key=lambda x: x.get("rerank_score", 0) if isinstance(x, dict) else 0, reverse=True) | |
return results | |
# ReRanker ν΄λμ€ μ¬μ© | |
retriever = ReRanker( | |
base_retriever=base_retriever, | |
rerank_fn=custom_rerank_fn, | |
rerank_field="text" | |
) | |
logger.info("μ¬μμν κ²μκΈ° μ΄κΈ°ν μλ£") | |
except Exception as e: | |
logger.error(f"μ¬μμν κ²μκΈ° μ΄κΈ°ν μ€ν¨: {e}") | |
retriever = base_retriever # μ€ν¨ μ κΈ°λ³Έ κ²μκΈ° μ¬μ© | |
return retriever | |
def background_init(): | |
"""λ°±κ·ΈλΌμ΄λμμ κ²μκΈ° μ΄κΈ°ν μν""" | |
global app_ready, retriever, base_retriever | |
# μ¦μ μ± μ¬μ© κ°λ₯ μνλ‘ μ€μ | |
app_ready = True | |
logger.info("μ±μ μ¦μ μ¬μ© κ°λ₯ μνλ‘ μ€μ (app_ready=True)") | |
try: | |
# κΈ°λ³Έ κ²μκΈ° μ΄κΈ°ν (보ν) | |
if base_retriever is None: | |
base_retriever = MockComponent() | |
if hasattr(base_retriever, 'documents'): | |
base_retriever.documents = [] | |
# μμ retriever μ€μ | |
if retriever is None: | |
retriever = MockComponent() | |
if not hasattr(retriever, 'search'): | |
retriever.search = lambda query, **kwargs: [] | |
# μΊμλ μλ² λ© λ‘λ μλ | |
cache_path = os.path.join(app.config['INDEX_PATH'], "cached_embeddings.gz") | |
cached_retriever = load_embeddings(cache_path) | |
if cached_retriever: | |
# μΊμλ λ°μ΄ν°κ° μμΌλ©΄ λ°λ‘ μ¬μ© | |
base_retriever = cached_retriever | |
# κ°λ¨ν μ¬μμν ν¨μ | |
def simple_rerank(query, results): | |
if results: | |
for result in results: | |
if isinstance(result, dict): | |
result["rerank_score"] = result.get("score", 0) | |
results.sort(key=lambda x: x.get("rerank_score", 0) if isinstance(x, dict) else 0, reverse=True) | |
return results | |
# μ¬μμν κ²μκΈ° μ΄κΈ°ν | |
retriever = ReRanker( | |
base_retriever=base_retriever, | |
rerank_fn=simple_rerank, | |
rerank_field="text" | |
) | |
logger.info("μΊμλ μλ² λ©μΌλ‘ κ²μκΈ° μ΄κΈ°ν μλ£ (λΉ λ₯Έ μμ)") | |
else: | |
# μΊμλ λ°μ΄ν°κ° μμΌλ©΄ μ 체 μ΄κΈ°ν μ§ν | |
logger.info("μΊμλ μλ² λ©μ΄ μμ΄ μ 체 μ΄κΈ°ν μμ") | |
retriever = init_retriever() | |
logger.info("μ 체 μ΄κΈ°ν μλ£") | |
logger.info("μ± μ΄κΈ°ν μλ£ (λͺ¨λ μ»΄ν¬λνΈ μ€λΉλ¨)") | |
except Exception as e: | |
logger.error(f"μ± λ°±κ·ΈλΌμ΄λ μ΄κΈ°ν μ€ μ¬κ°ν μ€λ₯ λ°μ: {e}", exc_info=True) | |
# μ΄κΈ°ν μ€ν¨ μ κΈ°λ³Έ κ°μ²΄ μμ± | |
if base_retriever is None: | |
base_retriever = MockComponent() | |
if hasattr(base_retriever, 'documents'): | |
base_retriever.documents = [] | |
if retriever is None: | |
retriever = MockComponent() | |
if not hasattr(retriever, 'search'): | |
retriever.search = lambda query, **kwargs: [] | |
logger.warning("μ΄κΈ°ν μ€ μ€λ₯κ° μμ§λ§ μ±μ κ³μ μ¬μ© κ°λ₯ν©λλ€.") | |
# λ°±κ·ΈλΌμ΄λ μ€λ λ μμ | |
init_thread = threading.Thread(target=background_init) | |
init_thread.daemon = True | |
init_thread.start() | |