# genesis/visualization.py """ Visualization utilities for GENESIS-AI Generates interactive pathway and funding network graphs from Neo4j data. """ import os import logging from typing import List, Tuple, Optional import plotly.graph_objects as go # Optional Neo4j import try: from neo4j import GraphDatabase except ImportError: GraphDatabase = None # ========================= # CONFIGURATION # ========================= NEO4J_URI = os.getenv("NEO4J_URI") NEO4J_USER = os.getenv("NEO4J_USER") NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD") COLOR_PATHWAY_NODE = "#00FF99" # Biotech green COLOR_FUNDING_NODE = "#FF9900" # Orange COLOR_EDGE = "#AAAAAA" # ========================= # NEO4J CONNECTION # ========================= driver = None if GraphDatabase and NEO4J_URI and NEO4J_USER and NEO4J_PASSWORD: try: driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD)) logging.info("[Neo4j] Connected for visualization.") except Exception as e: logging.error(f"[Neo4j] Connection failed: {e}") driver = None else: logging.info("[Neo4j] No URI/user/password set — skipping graph_tools connection.") def run_neo4j_query(query: str, params: dict = None) -> list: """Run a Cypher query and return results.""" if not driver: logging.warning("[Neo4j] No active connection — returning empty result.") return [] with driver.session() as session: return list(session.run(query, params or {})) # ========================= # GRAPH UTILS # ========================= def create_plotly_graph(nodes: List[Tuple[str, str, str]], edges: List[Tuple[str, str]], title: str): """Creates an interactive Plotly network graph.""" import networkx as nx G = nx.Graph() for node_id, label, color in nodes: G.add_node(node_id, label=label, color=color) for src, dst in edges: G.add_edge(src, dst) pos = nx.spring_layout(G, seed=42, k=0.5) edge_x, edge_y = [], [] for src, dst in edges: x0, y0 = pos[src] x1, y1 = pos[dst] edge_x.extend([x0, x1, None]) edge_y.extend([y0, y1, None]) node_x, node_y, node_text, node_color = [], [], [], [] for node_id, label, color in nodes: x, y = pos[node_id] node_x.append(x) node_y.append(y) node_text.append(label) node_color.append(color) edge_trace = go.Scatter( x=edge_x, y=edge_y, line=dict(width=1, color=COLOR_EDGE), hoverinfo='none', mode='lines' ) node_trace = go.Scatter( x=node_x, y=node_y, mode='markers+text', text=node_text, textposition="top center", hoverinfo="text", marker=dict( color=node_color, size=18, line=dict(width=2, color='#000000') ) ) fig = go.Figure(data=[edge_trace, node_trace]) fig.update_layout( title=title, title_x=0.5, plot_bgcolor="#111111", paper_bgcolor="#111111", font=dict(color="#FFFFFF"), showlegend=False, margin=dict(l=10, r=10, t=40, b=10) ) return fig # ========================= # PATHWAY GRAPH # ========================= def generate_pathway_graph(pathway_name: str) -> Optional[go.Figure]: """ Generate an interactive graph for a given metabolic pathway. Only takes pathway_name — matches pipeline.py signature. """ query = """ MATCH (p:Pathway {name: $name})-[r:INVOLVES]->(m:Molecule) RETURN p.name AS pathway, m.name AS molecule """ results = run_neo4j_query(query, {"name": pathway_name}) if not results: return None nodes = [(pathway_name, pathway_name, COLOR_PATHWAY_NODE)] edges = [] seen_nodes = {pathway_name} for record in results: mol_name = record["molecule"] if mol_name not in seen_nodes: nodes.append((mol_name, mol_name, "#00BFFF")) # Blue for molecules seen_nodes.add(mol_name) edges.append((pathway_name, mol_name)) return create_plotly_graph(nodes, edges, f"Metabolic Pathway: {pathway_name}") # ========================= # FUNDING NETWORK # ========================= def generate_funding_network(industry_keyword: str) -> Optional[go.Figure]: """ Generate an interactive funding network graph for companies in a given biotech domain. """ query = """ MATCH (c:Company)-[f:FUNDED_BY]->(i:Investor) WHERE toLower(c.industry) CONTAINS toLower($keyword) RETURN c.name AS company, i.name AS investor """ results = run_neo4j_query(query, {"keyword": industry_keyword}) if not results: return None nodes = [] edges = [] seen_nodes = set() for record in results: comp = record["company"] inv = record["investor"] if comp not in seen_nodes: nodes.append((comp, comp, COLOR_FUNDING_NODE)) seen_nodes.add(comp) if inv not in seen_nodes: nodes.append((inv, inv, "#FFD700")) # Gold for investors seen_nodes.add(inv) edges.append((comp, inv)) return create_plotly_graph(nodes, edges, f"Funding Network: {industry_keyword}") # ========================= # CLEANUP # ========================= def close_driver(): if driver: driver.close()