Skip to content

Schema for Fine-tuning

The schema to configure the input file for fine-turning.

Schema:

### Schema for fine-tuning configuration file

train:                      ### ANCHOR: Trainning ML model
  type: dict
  required: True
  schema:
    num_models:               # Number of models to train. Default is 1
      type: integer
    init_data_paths:          # List of paths to initial data.
      type: list
      required: True

    trainset_ratio:           # Ratio of training set. Default is 0.9
      type: float
    validset_ratio:           # Ratio of validation set. Default is 0.1
      type: float
    num_cores_buildgraph:     # number of cores for building graph data
      type: integer

    init_checkpoints:           # list of checkpoint files, each for each model
      type: list

    num_grad_updates:         # Maximum number of updates to guess num_epochs. Default is None
      type: integer

    distributed:
      type: dict
      schema:
        distributed_backend:  # choices: 'mpi' or 'nccl'  'gloo'
          type: string
        cluster_type:         # choices: 'slurm' or 'sge'
          type: string
        gpu_per_node:         # only need in SGE batch type. Default is 1
          type: integer

    mlp_engine:               # ML engine. Default is 'sevenn'. Choices: 'sevenn'
      type: string
    sevenn_args:              ### See: https://github.com/MDIL-SNU/SevenNet/blob/main/example_inputs/training/input_full.yml
      type: dict
      schema:
        model:
          type: dict
        train:
          type: dict
        data:
          type: dict

Example config 1:

### Example configuration file for fine-tuning with ALFF

train:
  mlp_engine: sevenn
  num_models: 1
  init_data_paths:
    - ../1_gendata/bulk_*/*/02_gendata/data_label.extxyz
    - ../1_gendata/monolayer_*/*/02_gendata/data_label.extxyz
    - ../1_gendata/1_iteration_data/*.extxyz

  trainset_ratio: 0.9
  validset_ratio: 0.1
  num_cores_buildgraph: 16       # number of cores for building graph data

  # init_checkpoints:      # list of checkpoint files, each for each model
  #   - _init_checkpoint/checkpoint_best.pth

  num_grad_updates: 100000   # 100000       # Maximum number of updates to guess num_epochs. Default is None

  distributed:
    distributed_backend: 'nccl'     # choices: 'mpi' or 'nccl'  'gloo'
    cluster_type: 'slurm'           # choices: 'slurm' or 'sge'
    # gpu_per_node: 1               # only need in sge

  sevenn_args:  # Updated: Dec 17, 2024. See: https://github.com/MDIL-SNU/SevenNet/blob/main/example_inputs/training/input_full.yaml
    model:
      chemical_species: ['Mo', 'W', 'S', 'Se', 'Te']  # Elements model should know. [ 'Univ' | 'Auto' | manual_user_input ]
      cutoff: 5.0                                     # Cutoff radius in Angstroms. If two atoms are within the cutoff, they are connected.
      channel: 32                                   # The multiplicity(channel) of node features.
      lmax: 2                                       # Maximum order of irreducible representations (rotation order).
      num_convolution_layer: 4                      # The number of message passing layers.

      # irreps_manual:                               # Manually set irreps of the model in each layer (e.g., 128 channels + 5 layers)
        #- "128x0e"
        #- "128x0e+64x1e+32x2e"
        #- "128x0e+64x1e+32x2e"
        #- "128x0e+64x1e+32x2e"
        #- "128x0e+64x1e+32x2e"
        #- "128x0e"

      weight_nn_hidden_neurons: [64, 64]            # Hidden neurons in convolution weight neural network
      radial_basis:                                 # Function and its parameters to encode radial distance
        radial_basis_name: 'bessel'               # Only 'bessel' is currently supported
        bessel_basis_num: 8
      cutoff_function:                              # Envelop function, multiplied to radial_basis functions to init edge featrues
        cutoff_function_name: 'poly_cut'          # {'poly_cut' and 'poly_cut_p_value'} or {'XPLOR' and 'cutoff_on'}
        poly_cut_p_value: 6

      act_gate: {'e': 'silu', 'o': 'tanh'}          # Equivalent to 'nonlinearity_gates' in nequip
      act_scalar: {'e': 'silu', 'o': 'tanh'}        # Equivalent to 'nonlinearity_scalars' in nequip

      is_parity: False                              # Pairy True (E(3) group) or False (to SE(3) group)

      self_connection_type: linear                # Default is 'nequip'. 'linear' is used for SevenNet-0.
      interaction_type: nequip

      conv_denominator: "avg_num_neigh"             # Valid options are "avg_num_neigh*", "sqrt_avg_num_neigh", or float
      train_denominator: False                      # Enable training for denominator in convolution layer
      train_shift_scale: False                      # Enable training for shift & scale in output layer

    train:
      random_seed: 1
      train_shuffle: True
      is_train_stress: True                         # Includes stress in the loss function
      epoch: 3                                    # Ends training after this number of epochs
      per_epoch:  20                                # Generate checkpoints every this epoch

      # loss: 'Huber'                                # Default is 'mse' (mean squared error)
      # loss_param:
      #     delta: 0.01

      # Each optimizer and scheduler have different available parameters.
      # You can refer to sevenn/train/optim.py for supporting optimizer & schedulers
      optimizer: 'adam'                             # Options available are 'sgd', 'adagrad', 'adam', 'adamw', 'radam'
      optim_param:
        lr: 5.0e-4

      scheduler: linearlr
      scheduler_param:
        start_factor: 1.0
        total_iters: 3               #   {..epoch}
        end_factor: 1.0e-7

      # scheduler: 'reducelronplateau'            # One of 'steplr', 'multisteplr', 'exponentiallr', 'cosineannealinglr', 'reducelronplateau', 'linearlr'
      # scheduler_param:
      #     factor: 0.75
      #     patience: 2
      #     threshold: 5.0e-5   # only changes large than this value will be considered as a change
      #     min_lr: 1.0e-12      # minimum learning rate

      # scheduler: exponentiallr
      # scheduler_param:
      #     gamma: 0.95        # large gamma means slower decay

      force_loss_weight: 1.0                                  # Coefficient for force loss
      stress_loss_weight: 1.0e-4       #  1.0e-3  1.0e-6      # Coefficient for stress loss (to kbar unit), kbar = 0.1 GPa

      # ['target y', 'metric']
      # Target y: TotalEnergy, Energy, Force, Stress, Stress_GPa, TotalLoss
      # Metric  : RMSE, MAE, or Loss
      error_record:
          - ['Energy', 'RMSE']
          - ['Force', 'RMSE']
          - ['Stress', 'RMSE']
          # - ['Stress_GPa', 'RMSE']
          - ['Energy', 'Loss']
          - ['Force', 'Loss']
          - ['Stress', 'Loss']
          - ['TotalLoss', 'None']
      best_metric: TotalLoss


      ### THANG: do not use this, just set `init_checkpoints` above
      # Continue training model from given checkpoint, or pre-trained model checkpoint for fine-tuning
      #continue:
          #checkpoint: 'checkpoint_best.pth'       # Checkpoint of pre-trained model or a model want to continue training.
          #reset_optimizer: False                  # Set True for fine-tuning
          #reset_scheduler: False                  # Set True for fine-tuning

    data:
      batch_size: 280    # 250                     # Per GPU batch size.

      shift: 'per_atom_energy_mean'                # One of 'per_atom_energy_mean*', 'elemwise_reference_energies', float
      scale: 'force_rms'                           # One of 'force_rms*', 'per_atom_energy_std', 'elemwise_force_rms', float
      data_format: 'ase'                   # Default is 'ase'. Choices are 'ase', 'structure_list', '.sevenn_data'
      # data_format_args:                    # Paramaters, will be passed to ase.io.read
      #   energy_key: 'ref_energy'                 # Key for energy in extxyz file
      #   force_key: 'ref_forces'                  # Key for force in extxyz file
      #   stress_key: 'ref_stress'                 # Key for stress in extxyz file