fadvi.FADVAE
- class fadvi.FADVAE(n_input: int, n_batch: int = 0, n_labels: int = 0, n_hidden: int = 128, n_latent_b: int = 30, n_latent_l: int = 30, n_latent_r: int = 10, n_layers: int = 2, n_continuous_cov: int = 0, n_cats_per_cov: Iterable[int] | None = None, dropout_rate: float = 0.1, dispersion: Literal['gene', 'gene-batch', 'gene-label', 'gene-cell'] = 'gene', log_variational: bool = True, gene_likelihood: Literal['zinb', 'nb', 'poisson'] = 'zinb', use_observed_lib_size: bool = True, beta: float = 1.0, lambda_b: float = 50, lambda_l: float = 50, alpha_bl: float = 1.0, alpha_lb: float = 1.0, alpha_rb: float = 1.0, alpha_rl: float = 1.0, gamma: float = 1.0, use_batch_norm: Literal['encoder', 'decoder', 'none', 'both'] = 'both', use_layer_norm: Literal['encoder', 'decoder', 'none', 'both'] = 'none', unlabeled_category_id: int | None = None, **vae_kwargs)[source]
Factor Disentanglement Variational Autoencoder.
This model disentangles batch-related variation (z_b), label-related variation (z_l), and residual variation (z_r) using adversarial training and cross-correlation penalties.
- Parameters:
n_input – Number of input genes
n_batch – Number of batches
n_labels – Number of labels
n_hidden – Number of nodes per hidden layer
n_latent_b – Dimensionality of batch latent space
n_latent_l – Dimensionality of label latent space
n_latent_r – Dimensionality of residual latent space
n_layers – Number of hidden layers used for encoder and decoder NNs
n_continuous_cov – Number of continuous covariates
n_cats_per_cov – Number of categories for each extra categorical covariate
dropout_rate – Dropout rate for neural networks
dispersion – Dispersion parameter option
log_variational – Log(data+1) prior to encoding for numerical stability
gene_likelihood – One of ‘nb’, ‘zinb’, ‘poisson’
use_observed_lib_size – If True, use observed library size
beta – KL divergence weight
lambda_b – Weight for batch classification loss
lambda_l – Weight for label classification loss
alpha_bl – Weight for adversarial loss (label prediction from batch latents)
alpha_lb – Weight for adversarial loss (batch prediction from label latents)
alpha_rb – Weight for adversarial loss (batch prediction from residual latents)
alpha_rl – Weight for adversarial loss (label prediction from residual latents)
gamma – Weight for cross-correlation penalty
use_batch_norm – Whether to use batch norm in layers
use_layer_norm – Whether to use layer norm in layers
**vae_kwargs – Keyword args for VAE
- __init__(n_input: int, n_batch: int = 0, n_labels: int = 0, n_hidden: int = 128, n_latent_b: int = 30, n_latent_l: int = 30, n_latent_r: int = 10, n_layers: int = 2, n_continuous_cov: int = 0, n_cats_per_cov: Iterable[int] | None = None, dropout_rate: float = 0.1, dispersion: Literal['gene', 'gene-batch', 'gene-label', 'gene-cell'] = 'gene', log_variational: bool = True, gene_likelihood: Literal['zinb', 'nb', 'poisson'] = 'zinb', use_observed_lib_size: bool = True, beta: float = 1.0, lambda_b: float = 50, lambda_l: float = 50, alpha_bl: float = 1.0, alpha_lb: float = 1.0, alpha_rb: float = 1.0, alpha_rl: float = 1.0, gamma: float = 1.0, use_batch_norm: Literal['encoder', 'decoder', 'none', 'both'] = 'both', use_layer_norm: Literal['encoder', 'decoder', 'none', 'both'] = 'none', unlabeled_category_id: int | None = None, **vae_kwargs)[source]
- inference(x, batch_index=None, cat_covs=None, cont_covs=None, n_samples=1)[source]
Run inference to get latent representations.
- generative(z, library, batch_index=None, cat_covs=None, cont_covs=None)[source]
Run generative model to reconstruct data.
- cross_covariance_penalty(z_b, z_l, z_r)[source]
Compute cross-covariance penalty to enforce independence.
- loss(tensors: dict[str, torch.Tensor], inference_outputs: dict[str, torch.Tensor | Distribution | None], generative_outputs: dict[str, Distribution | None], kl_weight: float = 1.0, labelled_tensors: dict[str, torch.Tensor] | None = None, **kwargs)[source]
Compute the total loss including ELBO, supervised, adversarial, and decorrelation terms.