Spaces:
Running
Running
| import numpy as np | |
| import torch | |
| from ase import Atoms | |
| from ase.calculators.calculator import Calculator, all_changes | |
| from ase.data import atomic_numbers | |
| from ase.stress import full_3x3_to_voigt_6_stress | |
| from torch_geometric.data.data import Data | |
| from .model import PosEGNN | |
| class PosEGNNCalculator(Calculator): | |
| def __init__(self, checkpoint: str, device: str, compute_stress: bool = True, **kwargs): | |
| Calculator.__init__(self, **kwargs) | |
| checkpoint_dict = torch.load(checkpoint, weights_only=True, map_location=device) | |
| self.model = PosEGNN(checkpoint_dict["config"]) | |
| self.model.load_state_dict(checkpoint_dict["state_dict"], strict=True) | |
| self.model.eval() | |
| self.model.to(device) | |
| self.model.eval() | |
| self.implemented_properties = ["energy", "forces"] | |
| self.implemented_properties += ["stress"] if compute_stress else [] | |
| self.device = device | |
| self.compute_stress = compute_stress | |
| def calculate(self, atoms=None, properties=None, system_changes=all_changes): | |
| Calculator.calculate(self, atoms) | |
| self.results = {} | |
| data = self._build_data(atoms) | |
| out = self.model.compute_properties(data, compute_stress=self.compute_stress) | |
| # Decoder Forward | |
| self.results = { | |
| "energy": out["total_energy"].cpu().detach().numpy(), | |
| "forces": out["force"].cpu().detach().numpy() | |
| } | |
| if self.compute_stress: | |
| self.results.update({ | |
| "stress": full_3x3_to_voigt_6_stress(out["stress"].cpu().detach().numpy()) | |
| }) | |
| def _build_data(self, atoms): | |
| z = torch.tensor(np.array([atomic_numbers[symbol] for symbol in atoms.symbols]), device=self.device) | |
| box = torch.tensor(atoms.get_cell().tolist(), device=self.device).unsqueeze(0).float() | |
| pos = torch.tensor(atoms.get_positions().tolist(), device=self.device).float() | |
| batch = torch.zeros(len(z), device=self.device).long() | |
| ptr = torch.zeros(1, device=self.device).long() | |
| return Data(z=z, pos=pos, box=box, batch=batch, num_graphs=1, ptr=ptr) | |
| def get_invariant_embeddings(self): | |
| if self.calc is None: | |
| raise RuntimeError("No calculator is set.") | |
| else: | |
| data = self.calc._build_data(self) | |
| with torch.no_grad(): | |
| embeddings = self.calc.model(data)["embedding_0"][..., 1].squeeze(2) | |
| return embeddings | |
| Atoms.get_invariant_embeddings = get_invariant_embeddings | |