fadvi.FADVI

class fadvi.FADVI(adata: AnnData, registry: dict | None = None, n_hidden: int = 128, n_latent_b: int = 30, n_latent_l: int = 30, n_latent_r: int = 10, n_layers: int = 2, dropout_rate: float = 0.1, dispersion: Literal['gene', 'gene-batch', 'gene-label', 'gene-cell'] = 'gene', gene_likelihood: Literal['zinb', 'nb', 'poisson'] = 'zinb', use_observed_lib_size: bool = True, beta: float = 1.0, lambda_b: float = 50, lambda_l: float = 50, alpha_bl: float = 1.0, alpha_lb: float = 1.0, alpha_rb: float = 1.0, alpha_rl: float = 1.0, gamma: float = 1.0, **model_kwargs)[source]

Factor Disentanglement Variational Inference model.

This model disentangles batch-related variation, label-related variation, and residual variation using adversarial training and cross-correlation penalties.

Parameters:
  • adata – AnnData object that has been registered via setup_anndata().

  • registry – Registry of the datamodule used to train FADVI model.

  • n_hidden – Number of nodes per hidden layer.

  • n_latent_b – Dimensionality of the batch latent space.

  • n_latent_l – Dimensionality of the label latent space.

  • n_latent_r – Dimensionality of the residual latent space.

  • n_layers – Number of hidden layers used for encoder and decoder NNs.

  • dropout_rate – Dropout rate for neural networks.

  • dispersion

    One of the following:

    • 'gene' - dispersion parameter of NB is constant per gene across cells

    • 'gene-batch' - dispersion can differ between different batches

    • 'gene-label' - dispersion can differ between different labels

    • 'gene-cell' - dispersion can differ for every gene in every cell

  • gene_likelihood

    One of:

    • 'nb' - Negative binomial distribution

    • 'zinb' - Zero-inflated negative binomial distribution

    • 'poisson' - Poisson distribution

  • use_observed_lib_size – If True, use the observed library size for RNA as the scaling factor in the mean of the conditional distribution.

  • beta – Weight for KL divergence in ELBO.

  • lambda_b – Weight for batch classification loss.

  • lambda_l – Weight for label classification loss.

  • alpha_bl – Weight for adversarial loss (label prediction from batch latents).

  • alpha_lb – Weight for adversarial loss (batch prediction from label latents).

  • alpha_rb – Weight for adversarial loss (batch prediction from residual latents).

  • alpha_rl – Weight for adversarial loss (label prediction from residual latents).

  • gamma – Weight for cross-correlation penalty.

  • **model_kwargs – Keyword args for FADVAE

Examples

>>> adata = anndata.read_h5ad(path_to_anndata)
>>> fadvi.FADVI.setup_anndata(adata, batch_key="batch", labels_key="labels")
>>> model = fadvi.FADVI(adata)
>>> model.train()
>>> adata.obsm["X_fadvi_b"] = model.get_latent_representation(representation="b")
>>> adata.obsm["X_fadvi_l"] = model.get_latent_representation(representation="l")
>>> adata.obsm["X_fadvi_r"] = model.get_latent_representation(representation="r")
_training_plan_cls

alias of SemiSupervisedTrainingPlanFixed

__init__(adata: AnnData, registry: dict | None = None, n_hidden: int = 128, n_latent_b: int = 30, n_latent_l: int = 30, n_latent_r: int = 10, n_layers: int = 2, dropout_rate: float = 0.1, dispersion: Literal['gene', 'gene-batch', 'gene-label', 'gene-cell'] = 'gene', gene_likelihood: Literal['zinb', 'nb', 'poisson'] = 'zinb', use_observed_lib_size: bool = True, beta: float = 1.0, lambda_b: float = 50, lambda_l: float = 50, alpha_bl: float = 1.0, alpha_lb: float = 1.0, alpha_rb: float = 1.0, alpha_rl: float = 1.0, gamma: float = 1.0, **model_kwargs)[source]
_set_batch_mapping()[source]

Set up batch mapping for converting codes to original batch labels.

_get_label_mapping_for_predictions()[source]

Get label mapping excluding unlabeled category for predictions.

_get_code_to_label_for_predictions()[source]

Get code-to-label mapping excluding unlabeled category for predictions.

classmethod setup_anndata(adata: AnnData, layer: str | None = None, batch_key: str | None = None, labels_key: str | None = None, unlabeled_category: str = 'Unknown', size_factor_key: str | None = None, categorical_covariate_keys: list[str] | None = None, continuous_covariate_keys: list[str] | None = None, **kwargs) AnnDataManager | None[source]

Set up AnnData object for FADVI model.

A mapping will be created between data fields used by FADVI and AnnData objects. None of the data in adata are modified. Only adds fields to uns.

Parameters:
  • object. (AnnData) – Rows represent cells, columns represent features.

  • layer – If not None, uses this as the key in adata.layers for raw count data.

  • batch_key – key in adata.obs for batch information. Categories will automatically be converted into integer categories and saved to adata.obs[‘_scvi_batch’]. If None, assigns the same batch to all the data.

  • labels_key – key in adata.obs for label information. Categories will automatically be converted into integer categories and saved to adata.obs[‘_scvi_labels’]. If None, assigns the same label to all the data.

  • unlabeled_category – value in adata.obs[labels_key] that indicates unlabeled observations.

  • size_factor_key – key in adata.obs for size factor information. Instead of using library size as a size factor, the provided size factor column will be used as offset in the mean of the likelihood. Assumed to be on linear scale.

  • categorical_covariate_keys – keys in adata.obs that correspond to categorical data. These covariates can be added in addition to the batch covariate and are also treated as nuisance factors (i.e., the model tries to minimize their effects on the latent space). Thus, these should not be used for biologically-relevant factors that you do _not_ want to correct for.

  • continuous_covariate_keys – keys in adata.obs that correspond to continuous data. These covariates can be added in addition to the batch covariate and are also treated as nuisance factors (i.e., the model tries to minimize their effects on the latent space). Thus, these should not be used for biologically-relevant factors that you do _not_ want to correct for.

Returns:

.uns[‘_scvi’]

scvi setup dictionary

.obs[‘_scvi_labels’]

labels encoded as integers

.obs[‘_scvi_batch’]

batch encoded as integers

Return type:

None. Adds the following fields

get_latent_representation(adata: AnnData | None = None, indices: list[int] | None = None, give_mean: bool = True, mc_samples: int = 5000, batch_size: int | None = None, return_dist: bool = False, representation: Literal['full', 'b', 'batch', 'l', 'label', 'r', 'residual', 'lr', 'label_residual'] = 'label') dict[str, torch.Tensor] | torch.Tensor[source]

Return the latent representation for each cell.

Parameters:
  • adata – AnnData object with equivalent structure to initial AnnData. If None, defaults to the AnnData object used to initialize the model.

  • indices – Indices of cells in adata to use. If None, all cells are used.

  • give_mean – Give mean of distribution or sample from it.

  • mc_samples – For distributions that have no closed-form mean (e.g. LogNormal), how many Monte Carlo samples to take for computing mean.

  • batch_size – Minibatch size for data loading into model. Defaults to scvi.settings.batch_size.

  • return_dist – If True, a mapping will be returned with key “dist” that contains distributional samples of the latent variables.

  • representation – Which latent representation to return: - “full”: concatenated representation [z_b, z_l, z_r] - “b” or “batch”: batch representation only - “l” or “label”: label representation only - “r” or “residual”: residual representation only - “lr” or “label_residual”: concatenated label and residual representation

Returns:

Low-dimensional representation for each cell or dict of tensors if return_dist is True.

Return type:

latent_representation

get_normalized_expression(adata: AnnData | None = None, indices: list[int] | None = None, transform_batch: str | int | None = None, gene_list: list[str] | None = None, library_size: float | Literal['latent'] = 1, n_samples: int = 1, n_samples_overall: int | None = None, batch_size: int | None = None, return_mean: bool = True, return_numpy: bool | None = None)[source]

Return the normalized (decoded) gene expression.

This is denoted as \(\rho_n\) in the scVI paper.

Parameters:
  • adata – AnnData object with equivalent structure to initial AnnData. If None, defaults to the AnnData object used to initialize the model.

  • indices – Indices of cells in adata to use. If None, all cells are used.

  • transform_batch

    Batch to condition on. If transform_batch is:

    • None, then real observed batch is used.

    • int, then batch transform_batch is used.

  • gene_list – Return frequencies of expression for a subset of genes. This can save memory when working with large datasets and few genes are of interest.

  • library_size – Scale the expression frequencies to a common library size. This allows gene expression levels to be interpreted on a common scale of relevant magnitude. If set to “latent”, use the latent library size.

  • n_samples – Number of posterior samples to use for estimation.

  • n_samples_overall – Number of posterior samples to use for estimation. Overrides n_samples.

  • batch_size – Minibatch size for data loading into model. Defaults to scvi.settings.batch_size.

  • return_mean – Whether to return the mean of the samples.

  • return_numpy – Return a ndarray instead of a DataFrame. DataFrame includes gene names as columns. If either n_samples=1 or return_mean=True, defaults to False. Otherwise, it defaults to True.

Returns:

If n_samples > 1 and return_mean is False, then the shape is (samples, cells, genes). Otherwise, shape is (cells, genes). In this case, return type is DataFrame unless return_numpy is True.

Return type:

normalized_expression

_batch_classifier_for_interpretability(x, batch_index=None, cat_covs=None, cont_covs=None)[source]

Classifier function for batch prediction with interpretability methods.

_label_classifier_for_interpretability(x, batch_index=None, cat_covs=None, cont_covs=None)[source]

Classifier function for label prediction with interpretability methods.

_compute_attributions(method: Literal['ig', 'gs'], x: torch.Tensor, predictions: torch.Tensor, soft: bool, batch_index: torch.Tensor | None, cat_covs: torch.Tensor | None, cont_covs: torch.Tensor | None, prediction_mode: str, method_args: dict | None = None) torch.Tensor[source]

Compute feature attributions using the specified interpretability method.

Parameters:
  • method – Interpretability method: “ig” for Integrated Gradients or “gs” for GradientShap

  • x – Input tensor

  • predictions – Model predictions (soft or hard)

  • soft – Whether predictions are soft (probabilities) or hard (class indices)

  • batch_index – Additional forward arguments for the model

  • cat_covs – Additional forward arguments for the model

  • cont_covs – Additional forward arguments for the model

  • prediction_mode – Prediction mode (“batch” or “label”)

  • method_args – Additional arguments for the interpretability method

Returns:

Feature attributions tensor

Return type:

attributions

_validate_interpretability_setup(method: str | None) None[source]

Validate interpretability method and dependencies.

predict(adata: AnnData | None = None, indices: list[int] | None = None, prediction_mode: Literal['b', 'batch', 'l', 'label'] = 'label', soft: bool = False, batch_size: int | None = None, return_numpy: bool = True, interpretability: Literal['ig', 'gs'] | None = None, interpretability_args: dict | None = None, return_dict: bool = True) np.ndarray | pd.DataFrame | tuple[np.ndarray | pd.DataFrame, np.ndarray] | dict[source]

Predict batch or label categories using the supervised classification heads.

This method uses the trained encoders and classification heads to predict either batch categories (using batch latent b) or label categories (using label latent l).

Parameters:
  • adata – AnnData object with equivalent structure to initial AnnData. If None, defaults to the AnnData object used to initialize the model.

  • indices – Indices of cells in adata to use. If None, all cells are used.

  • prediction_mode – What to predict: - “b” or “batch”: Predict batch categories using the batch latent (b) and batch classifier - “l” or “label”: Predict label categories using the label latent (l) and label classifier

  • soft – If True, return class probabilities. If False, return class predictions.

  • batch_size – Minibatch size for data loading into model. Defaults to scvi.settings.batch_size.

  • return_numpy – Return a ndarray instead of a Tensor.

  • interpretability – If None, perform regular prediction. If “ig”, use Integrated Gradients to compute feature attributions for predictions. If “gs”, use GradientShap to compute feature attributions for predictions. Requires captum library.

  • interpretability_args – Keyword arguments for the interpretability method. Works with both IG and GradientShap.

  • return_dict – If True, return interpretability results as dict (new format). If False, return as tuple (backward compatibility). Only affects interpretability results.

Returns:

If interpretability is None:

If soft=True, returns class probabilities with shape (n_cells, n_classes). If soft=False, returns class predictions with shape (n_cells,). If return_numpy=True, returns numpy array, otherwise torch tensor.

If interpretability in [“ig”, “gs”]:
If return_dict=True (default): Returns dict with keys:
  • ”predictions”: Model predictions (same format as above)

  • ”attributions”: Feature attributions array with shape (n_cells, n_genes)

If return_dict=False: Returns tuple (predictions, attributions) for backward compatibility

Return type:

predictions or dict/tuple with predictions and attributions

_compute_predictions(x: Tensor, batch_index: Tensor | None, cat_covs: Tensor | None, cont_covs: Tensor | None, prediction_mode: str, soft: bool) Tensor[source]

Compute predictions for a batch of data.

_format_predictions(predictions: torch.Tensor, prediction_mode: str, soft: bool, adata: AnnData, indices: list[int]) np.ndarray | pd.DataFrame[source]

Format predictions for output based on soft/hard and return type preferences.

_get_unlabeled_category_index() int | None[source]

Get the index of the unlabeled category if it exists.

get_ranked_features(adata: AnnData | None = None, attributions: np.ndarray | None = None, top_n: int | None = None) pd.DataFrame[source]

Get ranked gene list based on feature attribution scores.

This method takes attribution scores from interpretability methods (IG or GradientShap) and returns a ranked DataFrame of genes/features ordered by their importance.

Parameters:
  • adata – AnnData object that has been registered via setup_anndata(). If None, uses the AnnData object used to initialize the model.

  • attributions – Attribution matrix from interpretability analysis with shape (n_cells, n_genes). Typically obtained from the predict method with interpretability=”ig” or “gs”.

  • top_n – Number of top features to return. If None, returns all features.

Returns:

DataFrame with ranked features containing columns: - ‘feature’: Gene/feature name - ‘feature_idx’: Original feature index - ‘attribution_mean’: Mean attribution score across cells - ‘attribution_std’: Standard deviation of attribution scores - ‘attribution_abs_mean’: Mean absolute attribution score - ‘n_cells’: Number of cells in the analysis

Return type:

pd.DataFrame

Examples

>>> # Get predictions with interpretability
>>> result = model.predict(adata, interpretability="ig")
>>> # Get ranked features
>>> ranked_features = model.get_ranked_features(adata, result["attributions"])
>>> print(ranked_features.head())
_abc_impl = <_abc._abc_data object>