CryoFM2: A Generative Foundation Model for Cryo-EM Densities
Overview
CryoFM2 is a flow-based generative foundation model for cryo-EM density maps. It is pretrained on curated EMDB half maps to learn general priors of high-quality cryo-EM densities and can be fine-tuned for downstream tasks.
The model learns a continuous mapping from a simple Gaussian distribution to the complex distribution of cryo-EM densities, enabling stable generation and flexible adaptation. CryoFM2 can also act as a Bayesian prior, integrating naturally with task-specific likelihoods to support applications such as anisotropy-aware refinement, non-uniform reconstruction, and controlled density modification.
Model Details
CryoFM2 is pretrained on curated EMDB half maps to learn general priors of high-quality cryo-EM densities. The model can be fine-tuned for various downstream tasks such as density map enhancement and post-processing.
Pre-training Architecture:
Fine-tuning Architecture (for EMhancer/EMReady style post-processing):
Architecture
- Architecture Type: 3D UNet
- Input Size: 64ร64ร64 voxels
- Input Channels: 2 for pre-trained model, 3 for fine-tuned model
- Output Channels: 1
- Down Blocks: DownBlock3D, DownBlock3D, AttnDownBlock3D, AttnDownBlock3D
- Up Blocks: AttnUpBlock3D, AttnUpBlock3D, UpBlock3D, UpBlock3D
- Block Output Channels: (64, 128, 256, 512)
- Layers per Block: 2
- Attention Head Dimension: 8
- Normalization: GroupNorm (32 groups)
- Activation: SiLU
- Time Embedding: Positional encoding
Model Variants
- cryofm2-pretrain: Unconditional pretrained model for general density map generation
- cryofm2-emhancer: Fine-tuned model for density map enhancement (EMhancer style)
- cryofm2-emready: Fine-tuned model for density map enhancement (EMReady style)
Play with CryoFM2
Unconditional Generation (Explore Training Data Distribution)
Generate samples from the pretrained model to explore the learned data distribution:
Pretrained Model:
import torch
from mmengine import Config
from cryofm.core.utils.mrc_io import save_mrc
from cryofm.core.utils.sampling_fm import sample_from_fm
from cryofm.projects.cryofm2.lit_modules import CryoFM2Uncond
# Update the path to your model directory
model_dir = "path/to/cryofm-v2/cryofm2-pretrain"
cfg = Config.fromfile(f"{model_dir}/config.yaml")
lit_model = CryoFM2Uncond.load_from_safetensors(f"{model_dir}/model.safetensors", cfg=cfg)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lit_model = lit_model.to(device)
lit_model.eval()
def v_xt_t(_xt, _t):
return lit_model(_xt, _t)
# Enable bfloat16 for faster inference if your GPU supports it
with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
out = sample_from_fm(
v_xt_t,
lit_model.noise_scheduler,
method="euler",
num_steps=200,
num_samples=3,
device=lit_model.device,
side_shape=64
)
# Apply normalization if configured
if hasattr(lit_model.cfg, "z_scale") and lit_model.cfg.z_scale.mean is not None:
out = out * lit_model.cfg.z_scale.std + lit_model.cfg.z_scale.mean
# Save generated samples
for i in range(3):
save_mrc(out[i].float().cpu().numpy(), f"sample-{i}.mrc", voxel_size=1.5)
Fine-tuned Models (EMhancer/EMReady):
import torch
from mmengine import Config
from cryofm.core.utils.mrc_io import save_mrc
from cryofm.core.utils.sampling_fm import sample_from_fm
from cryofm.projects.cryofm2.lit_modules import CryoFM2Cond
# Choose style: "emhancer" or "emready"
style = "emhancer"
model_dir = f"path/to/cryofm-v2/cryofm2-{style}"
cfg = Config.fromfile(f"{model_dir}/config.yaml")
lit_model = CryoFM2Cond.load_from_safetensors(f"{model_dir}/model.safetensors", cfg=cfg)
output_tag = 1 if style == "emhancer" else 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lit_model = lit_model.to(device)
lit_model.eval()
def v_xt_t(_xt, _t):
bs = _xt.shape[0]
unconditional_generation_conds = {
"input_cond": None,
"output_cond": torch.tensor([output_tag] * bs).to(device),
"vol_cond": None, # dimension should be [bs, d, h, w]
}
return lit_model(_xt, _t, generation_conds=unconditional_generation_conds)
# Enable bfloat16 for faster inference if your GPU supports it
with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
out = sample_from_fm(
v_xt_t,
lit_model.noise_scheduler,
method="euler",
num_steps=200,
num_samples=3,
device=lit_model.device,
side_shape=64
)
# Apply normalization if configured
if hasattr(lit_model.cfg, "z_scale") and lit_model.cfg.z_scale.mean is not None:
out = out * lit_model.cfg.z_scale.std + lit_model.cfg.z_scale.mean
# Save generated samples
for i in range(3):
save_mrc(out[i].float().cpu().numpy(), f"{style}-sample-{i}.mrc", voxel_size=1.5)
Density Map Modification
CryoFM2 supports various density map modification operations using the pretrained model as a Bayesian prior. Supported operators include:
- denoise: Remove noise from density maps
- inpaint: Fill missing regions (e.g., missing wedge)
- denoise inpaint: Combined denoising and inpainting
- non-uniform weight: Apply non-uniform weighting during reconstruction
Basic Usage:
python -m cryofm.projects.cryofm2.uncond_sampling \
-i1 half_map_1.mrc \
-i2 half_map_2.mrc \
-o ./output \
--model-dir path/to/cryofm-v2/cryofm2-pretrain \
--op denoise \
--norm-grad \
--use-lamb-w
For inpainting tasks, you need to provide a RELION starfile path:
python -m cryofm.projects.cryofm2.uncond_sampling \
-i1 half_map_1.mrc \
-i2 half_map_2.mrc \
-o ./output \
--model-dir path/to/cryofm-v2/cryofm2-pretrain \
--op inpaint \
--data-starfile-path path/to/relion_data.star \
--norm-grad \
--use-lamb-w
Density Map Post-Processing
CryoFM2 provides fine-tuned models for density map enhancement in different styles, similar to EMhancer and EMReady.
EMhancer Style Enhancement
python -m cryofm.projects.cryofm2.cond_sampling \
-i input_map.mrc \
-o ./output_emhancer \
--model-dir path/to/cryofm-v2/cryofm2-emhancer \
--output-tag 1
EMReady Style Enhancement
python -m cryofm.projects.cryofm2.cond_sampling \
-i input_map.mrc \
-o ./output_emready \
--model-dir path/to/cryofm-v2/cryofm2-emready \
--output-tag 0 \
--cfg-weight 0.5
Parameters:
-i: Input density map file (MRC format)-o: Output directory--model-dir: Path to the model directory containingconfig.yamlandmodel.safetensors--output-tag: Style tag (1 for EMhancer, 0 for EMReady)--cfg-weight: Classifier-free guidance weight (optional, default varies by model)
Performance Tips
- Multi-GPU Inference: Use
accelerate launchfor faster inference on multiple GPUs:NCCL_DEBUG=ERROR accelerate launch --num_processes=${NUM_GPUS} --main_process_port=8881 \ python -m cryofm.projects.cryofm2.cond_sampling ... - Mixed Precision: Use
--bf16flag when available to reduce memory usage and speed up inference. - Batch Processing: Adjust batch size based on your GPU memory capacity.
Limitations
- Input size is fixed at 64ร64ร64 voxels
- Model performance may vary depending on the input density map quality
- Fine-tuned models are optimized for specific enhancement styles
Ethical Considerations
This model is intended for scientific research and structural biology applications. Users should:
- Ensure proper attribution when using generated structures
- Validate generated structures through experimental verification
- Be aware of potential biases in the training data
- Use the model responsibly and in accordance with scientific best practices
Citation
TBA
License
This model is released under the Apache 2.0 License. See the LICENSE file for details.
Acknowledgments
This work is developed by the ByteDance Seed Team. For more information, visit: