import argparse from open_flamingo.eval.models.mistral_model import EvalModel from open_flamingo.train.distributed import init_distributed_device, world_info_from_env parser = argparse.ArgumentParser() parser.add_argument( "--model", type=str, help="Model name. Currently only `OpenFlamingo` is supported.", default="open_flamingo", ) def main(): model_args = { "config_yaml": "configs/mlm_multi_source_v1_zephyr_ift_zero2.yaml", "checkpoint_path": "cruise_logs/zephyr_freeze_ift/mp_rank_00_model_states.pt", "precision": "bf16", } eval_model = EvalModel(model_args) tokenizer = eval_model.tokenizer # tokenizer.save_pretrained('hf_weights') if __name__ == "__main__": main()