Spaces:
Running
Running
# genesis/graph_tools.py | |
import os | |
import logging | |
from neo4j import GraphDatabase, basic_auth | |
# ========================= | |
# CONFIG | |
# ========================= | |
NEO4J_URI = os.getenv("NEO4J_URI") | |
NEO4J_USER = os.getenv("NEO4J_USER") | |
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD") | |
driver = None | |
# ========================= | |
# INIT CONNECTION | |
# ========================= | |
def init_driver(): | |
"""Initialize Neo4j driver if credentials are set.""" | |
global driver | |
if NEO4J_URI and NEO4J_USER and NEO4J_PASSWORD: | |
try: | |
driver = GraphDatabase.driver( | |
NEO4J_URI, | |
auth=basic_auth(NEO4J_USER, NEO4J_PASSWORD) | |
) | |
logging.info("[Neo4j] Connected successfully.") | |
except Exception as e: | |
logging.error(f"[Neo4j] Connection failed: {e}") | |
driver = None | |
else: | |
logging.info("[Neo4j] No URI/user/password set β skipping connection.") | |
driver = None | |
# Call on import | |
init_driver() | |
def is_connected(): | |
"""Check if driver is active and ready.""" | |
return driver is not None and hasattr(driver, "session") | |
# ========================= | |
# QUERY FUNCTIONS | |
# ========================= | |
def run_query(cypher, params=None): | |
"""Run a read/write Cypher query safely.""" | |
if not is_connected(): | |
logging.warning("[Neo4j] No active connection β returning empty result.") | |
return [] | |
try: | |
with driver.session() as session: | |
return list(session.run(cypher, params or {})) | |
except Exception as e: | |
logging.error(f"[Neo4j] Query failed: {e}") | |
return [] | |
def write_data(cypher, params=None): | |
"""Write data to Neo4j (CREATE/MERGE).""" | |
if not is_connected(): | |
logging.warning("[Neo4j] No active connection β skipping write.") | |
return False | |
try: | |
with driver.session() as session: | |
session.run(cypher, params or {}) | |
return True | |
except Exception as e: | |
logging.error(f"[Neo4j] Write failed: {e}") | |
return False | |
# ========================= | |
# BULK GRAPH CREATION | |
# ========================= | |
def save_graph_data(nodes, edges): | |
""" | |
Save nodes and edges to Neo4j. | |
nodes: list of dicts {id, label, type} | |
edges: list of dicts {source, target, type} | |
""" | |
if not is_connected(): | |
logging.warning("[Neo4j] No active connection β skipping graph save.") | |
return False | |
try: | |
with driver.session() as session: | |
# Create nodes | |
for node in nodes: | |
session.run( | |
""" | |
MERGE (n:Entity {id: $id}) | |
SET n.label = $label, n.type = $type | |
""", | |
node | |
) | |
# Create edges | |
for edge in edges: | |
session.run( | |
""" | |
MATCH (a:Entity {id: $source}) | |
MATCH (b:Entity {id: $target}) | |
MERGE (a)-[r:RELATION {type: $type}]->(b) | |
""", | |
edge | |
) | |
logging.info("[Neo4j] Graph data saved successfully.") | |
return True | |
except Exception as e: | |
logging.error(f"[Neo4j] Error saving graph data: {e}") | |
return False | |
# ========================= | |
# CLEANUP | |
# ========================= | |
def close_driver(): | |
"""Close Neo4j driver.""" | |
global driver | |
if driver: | |
driver.close() | |
driver = None | |
logging.info("[Neo4j] Connection closed.") | |