Model Implementation
- class fadvi._fadvae.GradientReversalFunction(*args, **kwargs)[source]
Bases:
FunctionGradient reversal layer for adversarial training.
- static forward(ctx, x, alpha)[source]
Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses. There are two ways to define forward:
Usage 1 (Combined forward and ctx):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
See Combined or separate forward() and setup_context() for more details
Usage 2 (Separate forward and ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
The forward no longer accepts a ctx argument.
Instead, you must also override the
torch.autograd.Function.setup_context()staticmethod to handle setting up thectxobject.outputis the output of the forward,inputsare a Tuple of inputs to the forward.See Extending torch.autograd for more details
The context can be used to store arbitrary data that can be then retrieved during the backward pass. Tensors should not be stored directly on ctx (though this is not currently enforced for backward compatibility). Instead, tensors should be saved either with
ctx.save_for_backward()if they are intended to be used inbackward(equivalently,vjp) orctx.save_for_forward()if they are intended to be used for injvp.
- static backward(ctx, grad_output)[source]
Define a formula for differentiating the operation with backward mode automatic differentiation.
This function is to be overridden by all subclasses. (Defining this function is equivalent to defining the
vjpfunction.)It must accept a context
ctxas the first argument, followed by as many outputs as theforward()returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs toforward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.The context can be used to retrieve tensors saved during the forward pass. It also has an attribute
ctx.needs_input_gradas a tuple of booleans representing whether each input needs gradient. E.g.,backward()will havectx.needs_input_grad[0] = Trueif the first input toforward()needs gradient computed w.r.t. the output.
- _backward_cls
alias of
GradientReversalFunctionBackward
- class fadvi._fadvae.FADVAE(n_input: int, n_batch: int = 0, n_labels: int = 0, n_hidden: int = 128, n_latent_b: int = 30, n_latent_l: int = 30, n_latent_r: int = 10, n_layers: int = 2, n_continuous_cov: int = 0, n_cats_per_cov: Iterable[int] | None = None, dropout_rate: float = 0.1, dispersion: Literal['gene', 'gene-batch', 'gene-label', 'gene-cell'] = 'gene', log_variational: bool = True, gene_likelihood: Literal['zinb', 'nb', 'poisson'] = 'zinb', use_observed_lib_size: bool = True, beta: float = 1.0, lambda_b: float = 50, lambda_l: float = 50, alpha_bl: float = 1.0, alpha_lb: float = 1.0, alpha_rb: float = 1.0, alpha_rl: float = 1.0, gamma: float = 1.0, use_batch_norm: Literal['encoder', 'decoder', 'none', 'both'] = 'both', use_layer_norm: Literal['encoder', 'decoder', 'none', 'both'] = 'none', unlabeled_category_id: int | None = None, **vae_kwargs)[source]
Bases:
SupervisedModuleClass,VAEFactor Disentanglement Variational Autoencoder.
This model disentangles batch-related variation (z_b), label-related variation (z_l), and residual variation (z_r) using adversarial training and cross-correlation penalties.
- Parameters:
n_input – Number of input genes
n_batch – Number of batches
n_labels – Number of labels
n_hidden – Number of nodes per hidden layer
n_latent_b – Dimensionality of batch latent space
n_latent_l – Dimensionality of label latent space
n_latent_r – Dimensionality of residual latent space
n_layers – Number of hidden layers used for encoder and decoder NNs
n_continuous_cov – Number of continuous covariates
n_cats_per_cov – Number of categories for each extra categorical covariate
dropout_rate – Dropout rate for neural networks
dispersion – Dispersion parameter option
log_variational – Log(data+1) prior to encoding for numerical stability
gene_likelihood – One of ‘nb’, ‘zinb’, ‘poisson’
use_observed_lib_size – If True, use observed library size
beta – KL divergence weight
lambda_b – Weight for batch classification loss
lambda_l – Weight for label classification loss
alpha_bl – Weight for adversarial loss (label prediction from batch latents)
alpha_lb – Weight for adversarial loss (batch prediction from label latents)
alpha_rb – Weight for adversarial loss (batch prediction from residual latents)
alpha_rl – Weight for adversarial loss (label prediction from residual latents)
gamma – Weight for cross-correlation penalty
use_batch_norm – Whether to use batch norm in layers
use_layer_norm – Whether to use layer norm in layers
**vae_kwargs – Keyword args for VAE
- __init__(n_input: int, n_batch: int = 0, n_labels: int = 0, n_hidden: int = 128, n_latent_b: int = 30, n_latent_l: int = 30, n_latent_r: int = 10, n_layers: int = 2, n_continuous_cov: int = 0, n_cats_per_cov: Iterable[int] | None = None, dropout_rate: float = 0.1, dispersion: Literal['gene', 'gene-batch', 'gene-label', 'gene-cell'] = 'gene', log_variational: bool = True, gene_likelihood: Literal['zinb', 'nb', 'poisson'] = 'zinb', use_observed_lib_size: bool = True, beta: float = 1.0, lambda_b: float = 50, lambda_l: float = 50, alpha_bl: float = 1.0, alpha_lb: float = 1.0, alpha_rb: float = 1.0, alpha_rl: float = 1.0, gamma: float = 1.0, use_batch_norm: Literal['encoder', 'decoder', 'none', 'both'] = 'both', use_layer_norm: Literal['encoder', 'decoder', 'none', 'both'] = 'none', unlabeled_category_id: int | None = None, **vae_kwargs)[source]
- inference(x, batch_index=None, cat_covs=None, cont_covs=None, n_samples=1)[source]
Run inference to get latent representations.
- generative(z, library, batch_index=None, cat_covs=None, cont_covs=None)[source]
Run generative model to reconstruct data.
- cross_covariance_penalty(z_b, z_l, z_r)[source]
Compute cross-covariance penalty to enforce independence.
- loss(tensors: dict[str, torch.Tensor], inference_outputs: dict[str, torch.Tensor | Distribution | None], generative_outputs: dict[str, Distribution | None], kl_weight: float = 1.0, labelled_tensors: dict[str, torch.Tensor] | None = None, **kwargs)[source]
Compute the total loss including ELBO, supervised, adversarial, and decorrelation terms.
FADVAE Class
- class fadvi._fadvae.FADVAE(n_input: int, n_batch: int = 0, n_labels: int = 0, n_hidden: int = 128, n_latent_b: int = 30, n_latent_l: int = 30, n_latent_r: int = 10, n_layers: int = 2, n_continuous_cov: int = 0, n_cats_per_cov: Iterable[int] | None = None, dropout_rate: float = 0.1, dispersion: Literal['gene', 'gene-batch', 'gene-label', 'gene-cell'] = 'gene', log_variational: bool = True, gene_likelihood: Literal['zinb', 'nb', 'poisson'] = 'zinb', use_observed_lib_size: bool = True, beta: float = 1.0, lambda_b: float = 50, lambda_l: float = 50, alpha_bl: float = 1.0, alpha_lb: float = 1.0, alpha_rb: float = 1.0, alpha_rl: float = 1.0, gamma: float = 1.0, use_batch_norm: Literal['encoder', 'decoder', 'none', 'both'] = 'both', use_layer_norm: Literal['encoder', 'decoder', 'none', 'both'] = 'none', unlabeled_category_id: int | None = None, **vae_kwargs)[source]
Bases:
SupervisedModuleClass,VAEFactor Disentanglement Variational Autoencoder.
This model disentangles batch-related variation (z_b), label-related variation (z_l), and residual variation (z_r) using adversarial training and cross-correlation penalties.
- Parameters:
n_input – Number of input genes
n_batch – Number of batches
n_labels – Number of labels
n_hidden – Number of nodes per hidden layer
n_latent_b – Dimensionality of batch latent space
n_latent_l – Dimensionality of label latent space
n_latent_r – Dimensionality of residual latent space
n_layers – Number of hidden layers used for encoder and decoder NNs
n_continuous_cov – Number of continuous covariates
n_cats_per_cov – Number of categories for each extra categorical covariate
dropout_rate – Dropout rate for neural networks
dispersion – Dispersion parameter option
log_variational – Log(data+1) prior to encoding for numerical stability
gene_likelihood – One of ‘nb’, ‘zinb’, ‘poisson’
use_observed_lib_size – If True, use observed library size
beta – KL divergence weight
lambda_b – Weight for batch classification loss
lambda_l – Weight for label classification loss
alpha_bl – Weight for adversarial loss (label prediction from batch latents)
alpha_lb – Weight for adversarial loss (batch prediction from label latents)
alpha_rb – Weight for adversarial loss (batch prediction from residual latents)
alpha_rl – Weight for adversarial loss (label prediction from residual latents)
gamma – Weight for cross-correlation penalty
use_batch_norm – Whether to use batch norm in layers
use_layer_norm – Whether to use layer norm in layers
**vae_kwargs – Keyword args for VAE
- __init__(n_input: int, n_batch: int = 0, n_labels: int = 0, n_hidden: int = 128, n_latent_b: int = 30, n_latent_l: int = 30, n_latent_r: int = 10, n_layers: int = 2, n_continuous_cov: int = 0, n_cats_per_cov: Iterable[int] | None = None, dropout_rate: float = 0.1, dispersion: Literal['gene', 'gene-batch', 'gene-label', 'gene-cell'] = 'gene', log_variational: bool = True, gene_likelihood: Literal['zinb', 'nb', 'poisson'] = 'zinb', use_observed_lib_size: bool = True, beta: float = 1.0, lambda_b: float = 50, lambda_l: float = 50, alpha_bl: float = 1.0, alpha_lb: float = 1.0, alpha_rb: float = 1.0, alpha_rl: float = 1.0, gamma: float = 1.0, use_batch_norm: Literal['encoder', 'decoder', 'none', 'both'] = 'both', use_layer_norm: Literal['encoder', 'decoder', 'none', 'both'] = 'none', unlabeled_category_id: int | None = None, **vae_kwargs)[source]
- inference(x, batch_index=None, cat_covs=None, cont_covs=None, n_samples=1)[source]
Run inference to get latent representations.
- generative(z, library, batch_index=None, cat_covs=None, cont_covs=None)[source]
Run generative model to reconstruct data.
- cross_covariance_penalty(z_b, z_l, z_r)[source]
Compute cross-covariance penalty to enforce independence.
- loss(tensors: dict[str, torch.Tensor], inference_outputs: dict[str, torch.Tensor | Distribution | None], generative_outputs: dict[str, Distribution | None], kl_weight: float = 1.0, labelled_tensors: dict[str, torch.Tensor] | None = None, **kwargs)[source]
Compute the total loss including ELBO, supervised, adversarial, and decorrelation terms.