Train a U-Net model. Result files are placed in <site>/unet/<model>.
Usage
train(
model,
train = "train",
result = NULL,
resources = NULL,
local = FALSE,
trap = TRUE,
comment = NULL
)Arguments
- model
The base name of a
.ymlfile in<pars>/unet/with model parameters. This file contains parameters used in data prep as well as training. Note thatunet_prepmust be run after changing any of these parameters. Themodelfile must contain the following:site: the three-letter site code
years: the year(s) of field data to fit
orthos: file names of all orthophotos to include
patch: size in pixels
depth: number of of downsampling stages
classes: vector of target classes (in original classification)
holdout_col: holdout set to use (uses bypoly
). Holdout sets are created by gather, numbering each poly from 1 to 10, repeating if necessary. There are 5 sets to choose from.cv: number of cross-validations. Use 1 for a single model, up to 5 for five-fold cross-validation. Cross-validations are systematic, not random. Since there are only 10 sets in each bypoly, the number of cross-validations is limited by the values of val and test.
val: validation polys from
holdout_col. Use NULL to skip validation, or a vector of the validation polys for the first cross-validation (these will be incremented for subsequent validations). For 20% validation holdout, useval = c(1, 6). This will usebypoly01 %in% c(1,6)`` for the first cross-validation,c(2, 7)` for the second, and so on.test: test polys from
holdout_col, as withval.overlap: Proportion overlap of patches
upscale: number of cells to upscale (default = 1). Use 3 to upscale to 3x3, 5 for 5x5, etc.
smooth: number of cells to include in moving window mean (default = 1). Use 3 to smooth to 3x3, etc.
- train
The base name of a
.ymlfile in<pars>/unet/with training parameters. If present, this overrides parameters inmodel. This file contains parameters used only in the training phase. The following must be present either in themodelortrainfile:in_channels Number of input channels (8 for multispectral + NDVI + NDRE + DEM)
n_epochs Number of training epochs
encoder_name. Pre-trained encoder to use. Choices include
resnet10,resnet18,resnet34,resnet50,efficientnet-b0, and others. The lowerrestnetnumbers have fewer parameters, so may be likely to result in more stable training.encoder_weights.
imagenetstart with weights learned on ImageNet (natural images); gives faster convergence, but might bias toward RGB patterns. NULL starts with random initialization, thus learns everything from this dataset; no bias, but slower training.learning_rate Learning rate for optimizer
weight_decay. L2 regularization - penalizes large weights to prevent overfitting. Higher values (1e-3) = stronger regularization. Lower values (1e-5) = weaker.
class_weighting. One of
none,freq, orsqrt. Ifnone, all classes will be given the same weight;freqweights them by inverse frequency, andsqrtweights by the square root of the inverse frequency.batch_size. How many patches to process together. Larger (16, 32) uses parallelization on GPUs so trains faster, more stable gradients, uses more GPU memory. Smaller (4, 8) gives noisier gradients (good regularization), less memory, better for small datasets. Use 8; if overfitting is a problem, try 4.
gradient_clip_max_norm. Prevents exploding gradients by capping gradient magnitude. Range: 0.5 (aggressive clipping) to 5.0 (gentle); start with 1.0.
use_ordinal If TRUE, use ordinal regression U-Net
- result
Name for this training run's result subdirectory. If NULL (default), automatically increments to the next available
fitNNname (e.g."fit01","fit02"). Specify explicitly to overwrite an existing run.- resources
Slurm launch resources. See launch. These take priority over the function's defaults. Note that this function requires GPUs. By default, it requests 1 L40S (preferred), but will accept V100 or RTX 2080 Ti. To specify only L40S, use
resources = list(constraint = 'l40s').- local
If TRUE, run locally; otherwise, spawn a batch run on Unity
- trap
If TRUE, trap errors in local mode; if FALSE, use normal R error handling. Use this for debugging. If you get unrecovered errors, the job won't be added to the jobs database. Has no effect if local = FALSE.
- comment
Optional slurmcollie comment