Spaces:
Runtime error
Runtime error
| import torch | |
| import os | |
| import imageio | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import glob | |
| import random | |
| from sklearn.decomposition import PCA | |
| from src import const | |
| from src.molecule_builder import get_bond_order | |
| def save_xyz_file(path, one_hot, positions, node_mask, names, is_geom, suffix=''): | |
| idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM | |
| for batch_i in range(one_hot.size(0)): | |
| mask = node_mask[batch_i].squeeze() | |
| n_atoms = mask.sum() | |
| atom_idx = torch.where(mask)[0] | |
| f = open(os.path.join(path, f'{names[batch_i]}_{suffix}.xyz'), "w") | |
| f.write("%d\n\n" % n_atoms) | |
| atoms = torch.argmax(one_hot[batch_i], dim=1) | |
| for atom_i in atom_idx: | |
| atom = atoms[atom_i].item() | |
| atom = idx2atom[atom] | |
| f.write("%s %.9f %.9f %.9f\n" % ( | |
| atom, positions[batch_i, atom_i, 0], positions[batch_i, atom_i, 1], positions[batch_i, atom_i, 2] | |
| )) | |
| f.close() | |
| def load_xyz_files(path, suffix=''): | |
| files = [] | |
| for fname in os.listdir(path): | |
| if fname.endswith(f'_{suffix}.xyz'): | |
| files.append(fname) | |
| files = sorted(files, key=lambda f: -int(f.replace(f'_{suffix}.xyz', '').split('_')[-1])) | |
| return [os.path.join(path, fname) for fname in files] | |
| def load_molecule_xyz(file, is_geom): | |
| atom2idx = const.GEOM_ATOM2IDX if is_geom else const.ATOM2IDX | |
| idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM | |
| with open(file, encoding='utf8') as f: | |
| n_atoms = int(f.readline()) | |
| one_hot = torch.zeros(n_atoms, len(idx2atom)) | |
| charges = torch.zeros(n_atoms, 1) | |
| positions = torch.zeros(n_atoms, 3) | |
| f.readline() | |
| atoms = f.readlines() | |
| for i in range(n_atoms): | |
| atom = atoms[i].split(' ') | |
| atom_type = atom[0] | |
| one_hot[i, atom2idx[atom_type]] = 1 | |
| position = torch.Tensor([float(e) for e in atom[1:]]) | |
| positions[i, :] = position | |
| return positions, one_hot, charges | |
| def draw_sphere(ax, x, y, z, size, color, alpha): | |
| u = np.linspace(0, 2 * np.pi, 100) | |
| v = np.linspace(0, np.pi, 100) | |
| xs = size * np.outer(np.cos(u), np.sin(v)) | |
| ys = size * np.outer(np.sin(u), np.sin(v)) #* 0.8 | |
| zs = size * np.outer(np.ones(np.size(u)), np.cos(v)) | |
| ax.plot_surface(x + xs, y + ys, z + zs, rstride=2, cstride=2, color=color, alpha=alpha) | |
| def plot_molecule(ax, positions, atom_type, alpha, spheres_3d, hex_bg_color, is_geom, fragment_mask=None): | |
| x = positions[:, 0] | |
| y = positions[:, 1] | |
| z = positions[:, 2] | |
| # Hydrogen, Carbon, Nitrogen, Oxygen, Flourine | |
| idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM | |
| colors_dic = np.array(const.COLORS) | |
| radius_dic = np.array(const.RADII) | |
| area_dic = 1500 * radius_dic ** 2 | |
| areas = area_dic[atom_type] | |
| radii = radius_dic[atom_type] | |
| colors = colors_dic[atom_type] | |
| if fragment_mask is None: | |
| fragment_mask = torch.ones(len(x)) | |
| for i in range(len(x)): | |
| for j in range(i + 1, len(x)): | |
| p1 = np.array([x[i], y[i], z[i]]) | |
| p2 = np.array([x[j], y[j], z[j]]) | |
| dist = np.sqrt(np.sum((p1 - p2) ** 2)) | |
| atom1, atom2 = idx2atom[atom_type[i]], idx2atom[atom_type[j]] | |
| draw_edge_int = get_bond_order(atom1, atom2, dist) | |
| line_width = (3 - 2) * 2 * 2 | |
| draw_edge = draw_edge_int > 0 | |
| if draw_edge: | |
| if draw_edge_int == 4: | |
| linewidth_factor = 1.5 | |
| else: | |
| linewidth_factor = 1 | |
| linewidth_factor *= 0.5 | |
| ax.plot( | |
| [x[i], x[j]], [y[i], y[j]], [z[i], z[j]], | |
| linewidth=line_width * linewidth_factor * 2, | |
| c=hex_bg_color, | |
| alpha=alpha | |
| ) | |
| # from pdb import set_trace | |
| # set_trace() | |
| if spheres_3d: | |
| # idx = torch.where(fragment_mask[:len(x)] == 0)[0] | |
| # ax.scatter( | |
| # x[idx], | |
| # y[idx], | |
| # z[idx], | |
| # alpha=0.9 * alpha, | |
| # edgecolors='#FCBA03', | |
| # facecolors='none', | |
| # linewidths=2, | |
| # s=900 | |
| # ) | |
| for i, j, k, s, c, f in zip(x, y, z, radii, colors, fragment_mask): | |
| if f == 1: | |
| alpha = 1.0 | |
| draw_sphere(ax, i.item(), j.item(), k.item(), 0.5 * s, c, alpha) | |
| else: | |
| ax.scatter(x, y, z, s=areas, alpha=0.9 * alpha, c=colors) | |
| def plot_data3d(positions, atom_type, is_geom, camera_elev=0, camera_azim=0, save_path=None, spheres_3d=False, | |
| bg='black', alpha=1., fragment_mask=None): | |
| black = (0, 0, 0) | |
| white = (1, 1, 1) | |
| hex_bg_color = '#FFFFFF' if bg == 'black' else '#000000' #'#666666' | |
| fig = plt.figure(figsize=(10, 10)) | |
| ax = fig.add_subplot(projection='3d') | |
| ax.set_aspect('auto') | |
| ax.view_init(elev=camera_elev, azim=camera_azim) | |
| if bg == 'black': | |
| ax.set_facecolor(black) | |
| else: | |
| ax.set_facecolor(white) | |
| ax.xaxis.pane.set_alpha(0) | |
| ax.yaxis.pane.set_alpha(0) | |
| ax.zaxis.pane.set_alpha(0) | |
| ax._axis3don = False | |
| if bg == 'black': | |
| ax.w_xaxis.line.set_color("black") | |
| else: | |
| ax.w_xaxis.line.set_color("white") | |
| plot_molecule( | |
| ax, positions, atom_type, alpha, spheres_3d, hex_bg_color, is_geom=is_geom, fragment_mask=fragment_mask | |
| ) | |
| max_value = positions.abs().max().item() | |
| axis_lim = min(40, max(max_value / 1.5 + 0.3, 3.2)) | |
| ax.set_xlim(-axis_lim, axis_lim) | |
| ax.set_ylim(-axis_lim, axis_lim) | |
| ax.set_zlim(-axis_lim, axis_lim) | |
| dpi = 120 if spheres_3d else 50 | |
| if save_path is not None: | |
| plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi) | |
| # plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi, transparent=True) | |
| if spheres_3d: | |
| img = imageio.imread(save_path) | |
| img_brighter = np.clip(img * 1.4, 0, 255).astype('uint8') | |
| imageio.imsave(save_path, img_brighter) | |
| else: | |
| plt.show() | |
| plt.close() | |
| def visualize_chain( | |
| path, spheres_3d=False, bg="black", alpha=1.0, wandb=None, mode="chain", is_geom=False, fragment_mask=None | |
| ): | |
| files = load_xyz_files(path) | |
| save_paths = [] | |
| # Fit PCA to the final molecule – to obtain the best orientation for visualization | |
| positions, one_hot, charges = load_molecule_xyz(files[-1], is_geom=is_geom) | |
| pca = PCA(n_components=3) | |
| pca.fit(positions) | |
| for i in range(len(files)): | |
| file = files[i] | |
| positions, one_hot, charges = load_molecule_xyz(file, is_geom=is_geom) | |
| atom_type = torch.argmax(one_hot, dim=1).numpy() | |
| # Transform positions of each frame according to the best orientation of the last frame | |
| positions = pca.transform(positions) | |
| positions = torch.tensor(positions) | |
| fn = file[:-4] + '.png' | |
| plot_data3d( | |
| positions, atom_type, | |
| save_path=fn, | |
| spheres_3d=spheres_3d, | |
| alpha=alpha, | |
| bg=bg, | |
| camera_elev=90, | |
| camera_azim=90, | |
| is_geom=is_geom, | |
| fragment_mask=fragment_mask, | |
| ) | |
| save_paths.append(fn) | |
| imgs = [imageio.imread(fn) for fn in save_paths] | |
| dirname = os.path.dirname(save_paths[0]) | |
| gif_path = dirname + '/output.gif' | |
| imageio.mimsave(gif_path, imgs, subrectangles=True) | |
| if wandb is not None: | |
| wandb.log({mode: [wandb.Video(gif_path, caption=gif_path)]}) | |