Model Implementation

class fadvi._fadvae.GradientReversalFunction(*args, **kwargs)[source]

Bases: Function

Gradient reversal layer for adversarial training.

static forward(ctx, x, alpha)[source]

Define the forward of the custom autograd Function.

This function is to be overridden by all subclasses. There are two ways to define forward:

Usage 1 (Combined forward and ctx):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass

Usage 2 (Separate forward and ctx):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass


@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • The forward no longer accepts a ctx argument.

  • Instead, you must also override the torch.autograd.Function.setup_context() staticmethod to handle setting up the ctx object. output is the output of the forward, inputs are a Tuple of inputs to the forward.

  • See Extending torch.autograd for more details

The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with ctx.save_for_backward() if they are intended to be used in backward (equivalently, vjp) or ctx.save_for_forward() if they are intended to be used for in jvp.

static backward(ctx, grad_output)[source]

Define a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the vjp function.)

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computed w.r.t. the output.

_backward_cls

alias of GradientReversalFunctionBackward

fadvi._fadvae.gradient_reversal(x, alpha=1.0)[source]

Apply gradient reversal to input tensor.

class fadvi._fadvae.FADVAE(n_input: int, n_batch: int = 0, n_labels: int = 0, n_hidden: int = 128, n_latent_b: int = 30, n_latent_l: int = 30, n_latent_r: int = 10, n_layers: int = 2, n_continuous_cov: int = 0, n_cats_per_cov: Iterable[int] | None = None, dropout_rate: float = 0.1, dispersion: Literal['gene', 'gene-batch', 'gene-label', 'gene-cell'] = 'gene', log_variational: bool = True, gene_likelihood: Literal['zinb', 'nb', 'poisson'] = 'zinb', use_observed_lib_size: bool = True, beta: float = 1.0, lambda_b: float = 50, lambda_l: float = 50, alpha_bl: float = 1.0, alpha_lb: float = 1.0, alpha_rb: float = 1.0, alpha_rl: float = 1.0, gamma: float = 1.0, use_batch_norm: Literal['encoder', 'decoder', 'none', 'both'] = 'both', use_layer_norm: Literal['encoder', 'decoder', 'none', 'both'] = 'none', unlabeled_category_id: int | None = None, **vae_kwargs)[source]

Bases: SupervisedModuleClass, VAE

Factor Disentanglement Variational Autoencoder.

This model disentangles batch-related variation (z_b), label-related variation (z_l), and residual variation (z_r) using adversarial training and cross-correlation penalties.

Parameters:
  • n_input – Number of input genes

  • n_batch – Number of batches

  • n_labels – Number of labels

  • n_hidden – Number of nodes per hidden layer

  • n_latent_b – Dimensionality of batch latent space

  • n_latent_l – Dimensionality of label latent space

  • n_latent_r – Dimensionality of residual latent space

  • n_layers – Number of hidden layers used for encoder and decoder NNs

  • n_continuous_cov – Number of continuous covariates

  • n_cats_per_cov – Number of categories for each extra categorical covariate

  • dropout_rate – Dropout rate for neural networks

  • dispersion – Dispersion parameter option

  • log_variational – Log(data+1) prior to encoding for numerical stability

  • gene_likelihood – One of ‘nb’, ‘zinb’, ‘poisson’

  • use_observed_lib_size – If True, use observed library size

  • beta – KL divergence weight

  • lambda_b – Weight for batch classification loss

  • lambda_l – Weight for label classification loss

  • alpha_bl – Weight for adversarial loss (label prediction from batch latents)

  • alpha_lb – Weight for adversarial loss (batch prediction from label latents)

  • alpha_rb – Weight for adversarial loss (batch prediction from residual latents)

  • alpha_rl – Weight for adversarial loss (label prediction from residual latents)

  • gamma – Weight for cross-correlation penalty

  • use_batch_norm – Whether to use batch norm in layers

  • use_layer_norm – Whether to use layer norm in layers

  • **vae_kwargs – Keyword args for VAE

__init__(n_input: int, n_batch: int = 0, n_labels: int = 0, n_hidden: int = 128, n_latent_b: int = 30, n_latent_l: int = 30, n_latent_r: int = 10, n_layers: int = 2, n_continuous_cov: int = 0, n_cats_per_cov: Iterable[int] | None = None, dropout_rate: float = 0.1, dispersion: Literal['gene', 'gene-batch', 'gene-label', 'gene-cell'] = 'gene', log_variational: bool = True, gene_likelihood: Literal['zinb', 'nb', 'poisson'] = 'zinb', use_observed_lib_size: bool = True, beta: float = 1.0, lambda_b: float = 50, lambda_l: float = 50, alpha_bl: float = 1.0, alpha_lb: float = 1.0, alpha_rb: float = 1.0, alpha_rl: float = 1.0, gamma: float = 1.0, use_batch_norm: Literal['encoder', 'decoder', 'none', 'both'] = 'both', use_layer_norm: Literal['encoder', 'decoder', 'none', 'both'] = 'none', unlabeled_category_id: int | None = None, **vae_kwargs)[source]
_get_inference_input(tensors)[source]

Get input for inference.

_get_generative_input(tensors, inference_outputs)[source]

Get input for generative model.

inference(x, batch_index=None, cat_covs=None, cont_covs=None, n_samples=1)[source]

Run inference to get latent representations.

generative(z, library, batch_index=None, cat_covs=None, cont_covs=None)[source]

Run generative model to reconstruct data.

cross_covariance_penalty(z_b, z_l, z_r)[source]

Compute cross-covariance penalty to enforce independence.

loss(tensors: dict[str, torch.Tensor], inference_outputs: dict[str, torch.Tensor | Distribution | None], generative_outputs: dict[str, Distribution | None], kl_weight: float = 1.0, labelled_tensors: dict[str, torch.Tensor] | None = None, **kwargs)[source]

Compute the total loss including ELBO, supervised, adversarial, and decorrelation terms.

FADVAE Class

class fadvi._fadvae.FADVAE(n_input: int, n_batch: int = 0, n_labels: int = 0, n_hidden: int = 128, n_latent_b: int = 30, n_latent_l: int = 30, n_latent_r: int = 10, n_layers: int = 2, n_continuous_cov: int = 0, n_cats_per_cov: Iterable[int] | None = None, dropout_rate: float = 0.1, dispersion: Literal['gene', 'gene-batch', 'gene-label', 'gene-cell'] = 'gene', log_variational: bool = True, gene_likelihood: Literal['zinb', 'nb', 'poisson'] = 'zinb', use_observed_lib_size: bool = True, beta: float = 1.0, lambda_b: float = 50, lambda_l: float = 50, alpha_bl: float = 1.0, alpha_lb: float = 1.0, alpha_rb: float = 1.0, alpha_rl: float = 1.0, gamma: float = 1.0, use_batch_norm: Literal['encoder', 'decoder', 'none', 'both'] = 'both', use_layer_norm: Literal['encoder', 'decoder', 'none', 'both'] = 'none', unlabeled_category_id: int | None = None, **vae_kwargs)[source]

Bases: SupervisedModuleClass, VAE

Factor Disentanglement Variational Autoencoder.

This model disentangles batch-related variation (z_b), label-related variation (z_l), and residual variation (z_r) using adversarial training and cross-correlation penalties.

Parameters:
  • n_input – Number of input genes

  • n_batch – Number of batches

  • n_labels – Number of labels

  • n_hidden – Number of nodes per hidden layer

  • n_latent_b – Dimensionality of batch latent space

  • n_latent_l – Dimensionality of label latent space

  • n_latent_r – Dimensionality of residual latent space

  • n_layers – Number of hidden layers used for encoder and decoder NNs

  • n_continuous_cov – Number of continuous covariates

  • n_cats_per_cov – Number of categories for each extra categorical covariate

  • dropout_rate – Dropout rate for neural networks

  • dispersion – Dispersion parameter option

  • log_variational – Log(data+1) prior to encoding for numerical stability

  • gene_likelihood – One of ‘nb’, ‘zinb’, ‘poisson’

  • use_observed_lib_size – If True, use observed library size

  • beta – KL divergence weight

  • lambda_b – Weight for batch classification loss

  • lambda_l – Weight for label classification loss

  • alpha_bl – Weight for adversarial loss (label prediction from batch latents)

  • alpha_lb – Weight for adversarial loss (batch prediction from label latents)

  • alpha_rb – Weight for adversarial loss (batch prediction from residual latents)

  • alpha_rl – Weight for adversarial loss (label prediction from residual latents)

  • gamma – Weight for cross-correlation penalty

  • use_batch_norm – Whether to use batch norm in layers

  • use_layer_norm – Whether to use layer norm in layers

  • **vae_kwargs – Keyword args for VAE

__init__(n_input: int, n_batch: int = 0, n_labels: int = 0, n_hidden: int = 128, n_latent_b: int = 30, n_latent_l: int = 30, n_latent_r: int = 10, n_layers: int = 2, n_continuous_cov: int = 0, n_cats_per_cov: Iterable[int] | None = None, dropout_rate: float = 0.1, dispersion: Literal['gene', 'gene-batch', 'gene-label', 'gene-cell'] = 'gene', log_variational: bool = True, gene_likelihood: Literal['zinb', 'nb', 'poisson'] = 'zinb', use_observed_lib_size: bool = True, beta: float = 1.0, lambda_b: float = 50, lambda_l: float = 50, alpha_bl: float = 1.0, alpha_lb: float = 1.0, alpha_rb: float = 1.0, alpha_rl: float = 1.0, gamma: float = 1.0, use_batch_norm: Literal['encoder', 'decoder', 'none', 'both'] = 'both', use_layer_norm: Literal['encoder', 'decoder', 'none', 'both'] = 'none', unlabeled_category_id: int | None = None, **vae_kwargs)[source]
_get_inference_input(tensors)[source]

Get input for inference.

_get_generative_input(tensors, inference_outputs)[source]

Get input for generative model.

inference(x, batch_index=None, cat_covs=None, cont_covs=None, n_samples=1)[source]

Run inference to get latent representations.

generative(z, library, batch_index=None, cat_covs=None, cont_covs=None)[source]

Run generative model to reconstruct data.

cross_covariance_penalty(z_b, z_l, z_r)[source]

Compute cross-covariance penalty to enforce independence.

loss(tensors: dict[str, torch.Tensor], inference_outputs: dict[str, torch.Tensor | Distribution | None], generative_outputs: dict[str, Distribution | None], kl_weight: float = 1.0, labelled_tensors: dict[str, torch.Tensor] | None = None, **kwargs)[source]

Compute the total loss including ELBO, supervised, adversarial, and decorrelation terms.