fadvi.FADVAE

class fadvi.FADVAE(n_input: int, n_batch: int = 0, n_labels: int = 0, n_hidden: int = 128, n_latent_b: int = 30, n_latent_l: int = 30, n_latent_r: int = 10, n_layers: int = 2, n_continuous_cov: int = 0, n_cats_per_cov: Iterable[int] | None = None, dropout_rate: float = 0.1, dispersion: Literal['gene', 'gene-batch', 'gene-label', 'gene-cell'] = 'gene', log_variational: bool = True, gene_likelihood: Literal['zinb', 'nb', 'poisson'] = 'zinb', use_observed_lib_size: bool = True, beta: float = 1.0, lambda_b: float = 50, lambda_l: float = 50, alpha_bl: float = 1.0, alpha_lb: float = 1.0, alpha_rb: float = 1.0, alpha_rl: float = 1.0, gamma: float = 1.0, use_batch_norm: Literal['encoder', 'decoder', 'none', 'both'] = 'both', use_layer_norm: Literal['encoder', 'decoder', 'none', 'both'] = 'none', unlabeled_category_id: int | None = None, **vae_kwargs)[source]

Factor Disentanglement Variational Autoencoder.

This model disentangles batch-related variation (z_b), label-related variation (z_l), and residual variation (z_r) using adversarial training and cross-correlation penalties.

Parameters:
  • n_input – Number of input genes

  • n_batch – Number of batches

  • n_labels – Number of labels

  • n_hidden – Number of nodes per hidden layer

  • n_latent_b – Dimensionality of batch latent space

  • n_latent_l – Dimensionality of label latent space

  • n_latent_r – Dimensionality of residual latent space

  • n_layers – Number of hidden layers used for encoder and decoder NNs

  • n_continuous_cov – Number of continuous covariates

  • n_cats_per_cov – Number of categories for each extra categorical covariate

  • dropout_rate – Dropout rate for neural networks

  • dispersion – Dispersion parameter option

  • log_variational – Log(data+1) prior to encoding for numerical stability

  • gene_likelihood – One of ‘nb’, ‘zinb’, ‘poisson’

  • use_observed_lib_size – If True, use observed library size

  • beta – KL divergence weight

  • lambda_b – Weight for batch classification loss

  • lambda_l – Weight for label classification loss

  • alpha_bl – Weight for adversarial loss (label prediction from batch latents)

  • alpha_lb – Weight for adversarial loss (batch prediction from label latents)

  • alpha_rb – Weight for adversarial loss (batch prediction from residual latents)

  • alpha_rl – Weight for adversarial loss (label prediction from residual latents)

  • gamma – Weight for cross-correlation penalty

  • use_batch_norm – Whether to use batch norm in layers

  • use_layer_norm – Whether to use layer norm in layers

  • **vae_kwargs – Keyword args for VAE

__init__(n_input: int, n_batch: int = 0, n_labels: int = 0, n_hidden: int = 128, n_latent_b: int = 30, n_latent_l: int = 30, n_latent_r: int = 10, n_layers: int = 2, n_continuous_cov: int = 0, n_cats_per_cov: Iterable[int] | None = None, dropout_rate: float = 0.1, dispersion: Literal['gene', 'gene-batch', 'gene-label', 'gene-cell'] = 'gene', log_variational: bool = True, gene_likelihood: Literal['zinb', 'nb', 'poisson'] = 'zinb', use_observed_lib_size: bool = True, beta: float = 1.0, lambda_b: float = 50, lambda_l: float = 50, alpha_bl: float = 1.0, alpha_lb: float = 1.0, alpha_rb: float = 1.0, alpha_rl: float = 1.0, gamma: float = 1.0, use_batch_norm: Literal['encoder', 'decoder', 'none', 'both'] = 'both', use_layer_norm: Literal['encoder', 'decoder', 'none', 'both'] = 'none', unlabeled_category_id: int | None = None, **vae_kwargs)[source]
_get_inference_input(tensors)[source]

Get input for inference.

_get_generative_input(tensors, inference_outputs)[source]

Get input for generative model.

inference(x, batch_index=None, cat_covs=None, cont_covs=None, n_samples=1)[source]

Run inference to get latent representations.

generative(z, library, batch_index=None, cat_covs=None, cont_covs=None)[source]

Run generative model to reconstruct data.

cross_covariance_penalty(z_b, z_l, z_r)[source]

Compute cross-covariance penalty to enforce independence.

loss(tensors: dict[str, torch.Tensor], inference_outputs: dict[str, torch.Tensor | Distribution | None], generative_outputs: dict[str, Distribution | None], kl_weight: float = 1.0, labelled_tensors: dict[str, torch.Tensor] | None = None, **kwargs)[source]

Compute the total loss including ELBO, supervised, adversarial, and decorrelation terms.