train_vae {vmsae}R Documentation

Train VAE for CAR Prior

Description

Trains a Variational Autoencoder (VAE) to learn the spatial structure implied by the Conditional Autoregressive (CAR) prior. The trained VAE parameters are saved and can later be used as a generator within Hamiltonian Monte Carlo (HMC) sampling.

Usage

train_vae(
  W,
  GEOID,
  model_name,
  save_dir,
  n_samples = 10000,
  batch_size = 256,
  epoch = 10000,
  lr_init = 0.001,
  lr_min = 1e-07,
  verbose = TRUE,
  use_gpu = TRUE
)

Arguments

W

Matrix. A proximity or adjacency matrix representing spatial relationships.

GEOID

Character vector. Identifiers for spatial units (e.g., region or area codes).

model_name

Character. The name of the trained VAE model.

save_dir

Character. Directory to save the trained VAE model and associated metadata. Defaults to the current working directory.

n_samples

Integer. Number of samples to draw from the prior for training. Default is 10000.

batch_size

Integer. Batch size for VAE training. Default is 256.

epoch

Integer. Number of training epochs. Default is 10000.

lr_init

Numeric. Initial learning rate. Default is 0.001.

lr_min

Numeric. Minimum learning rate at the final epoch. Default is 1e-7.

verbose

Logical; if TRUE (default), prints progress.

use_gpu

Boolean. Use GPU if available. Default is TRUE.

Details

The function requires a configured Python environment via the reticulate interface, with VAE training implemented in Python. It uses py$train_vae() defined in the sourced Python modules (see load_environment).

Value

A named list containing:

loss

Total training loss

RCL

Reconstruction error

KLD

Kullback–Leibler divergence

Examples

## Not run: 
library(vmsae)
library(sf)
# this function is time consuming for the first run
install_environment()
load_environment()

acs_data <- read_sf(system.file("example", "mo_county.shp", package = "vmsae"))
W <- readRDS(system.file("example", "W.Rds", package = "vmsae"))

loss <- train_vae(W = W,
  GEOID = acs_data$GEOID,
  model_name = "test",
  save_dir = tempdir(),
  n_samples = 1000, # set to larger values in practice, e.g. 10000.
  batch_size = 256,
  epoch = 1000)     # set to larger values in practice, e.g. 10000.

## End(Not run)


[Package vmsae version 0.1.1 Index]