fadvi.FADVI
- class fadvi.FADVI(adata: AnnData, registry: dict | None = None, n_hidden: int = 128, n_latent_b: int = 30, n_latent_l: int = 30, n_latent_r: int = 10, n_layers: int = 2, dropout_rate: float = 0.1, dispersion: Literal['gene', 'gene-batch', 'gene-label', 'gene-cell'] = 'gene', gene_likelihood: Literal['zinb', 'nb', 'poisson'] = 'zinb', use_observed_lib_size: bool = True, beta: float = 1.0, lambda_b: float = 50, lambda_l: float = 50, alpha_bl: float = 1.0, alpha_lb: float = 1.0, alpha_rb: float = 1.0, alpha_rl: float = 1.0, gamma: float = 1.0, **model_kwargs)[source]
Factor Disentanglement Variational Inference model.
This model disentangles batch-related variation, label-related variation, and residual variation using adversarial training and cross-correlation penalties.
- Parameters:
adata – AnnData object that has been registered via
setup_anndata().registry – Registry of the datamodule used to train FADVI model.
n_hidden – Number of nodes per hidden layer.
n_latent_b – Dimensionality of the batch latent space.
n_latent_l – Dimensionality of the label latent space.
n_latent_r – Dimensionality of the residual latent space.
n_layers – Number of hidden layers used for encoder and decoder NNs.
dropout_rate – Dropout rate for neural networks.
dispersion –
One of the following:
'gene'- dispersion parameter of NB is constant per gene across cells'gene-batch'- dispersion can differ between different batches'gene-label'- dispersion can differ between different labels'gene-cell'- dispersion can differ for every gene in every cell
gene_likelihood –
One of:
'nb'- Negative binomial distribution'zinb'- Zero-inflated negative binomial distribution'poisson'- Poisson distribution
use_observed_lib_size – If
True, use the observed library size for RNA as the scaling factor in the mean of the conditional distribution.beta – Weight for KL divergence in ELBO.
lambda_b – Weight for batch classification loss.
lambda_l – Weight for label classification loss.
alpha_bl – Weight for adversarial loss (label prediction from batch latents).
alpha_lb – Weight for adversarial loss (batch prediction from label latents).
alpha_rb – Weight for adversarial loss (batch prediction from residual latents).
alpha_rl – Weight for adversarial loss (label prediction from residual latents).
gamma – Weight for cross-correlation penalty.
**model_kwargs – Keyword args for
FADVAE
Examples
>>> adata = anndata.read_h5ad(path_to_anndata) >>> fadvi.FADVI.setup_anndata(adata, batch_key="batch", labels_key="labels") >>> model = fadvi.FADVI(adata) >>> model.train() >>> adata.obsm["X_fadvi_b"] = model.get_latent_representation(representation="b") >>> adata.obsm["X_fadvi_l"] = model.get_latent_representation(representation="l") >>> adata.obsm["X_fadvi_r"] = model.get_latent_representation(representation="r")
- _training_plan_cls
alias of
SemiSupervisedTrainingPlanFixed
- __init__(adata: AnnData, registry: dict | None = None, n_hidden: int = 128, n_latent_b: int = 30, n_latent_l: int = 30, n_latent_r: int = 10, n_layers: int = 2, dropout_rate: float = 0.1, dispersion: Literal['gene', 'gene-batch', 'gene-label', 'gene-cell'] = 'gene', gene_likelihood: Literal['zinb', 'nb', 'poisson'] = 'zinb', use_observed_lib_size: bool = True, beta: float = 1.0, lambda_b: float = 50, lambda_l: float = 50, alpha_bl: float = 1.0, alpha_lb: float = 1.0, alpha_rb: float = 1.0, alpha_rl: float = 1.0, gamma: float = 1.0, **model_kwargs)[source]
- _get_label_mapping_for_predictions()[source]
Get label mapping excluding unlabeled category for predictions.
- _get_code_to_label_for_predictions()[source]
Get code-to-label mapping excluding unlabeled category for predictions.
- classmethod setup_anndata(adata: AnnData, layer: str | None = None, batch_key: str | None = None, labels_key: str | None = None, unlabeled_category: str = 'Unknown', size_factor_key: str | None = None, categorical_covariate_keys: list[str] | None = None, continuous_covariate_keys: list[str] | None = None, **kwargs) AnnDataManager | None[source]
Set up AnnData object for FADVI model.
A mapping will be created between data fields used by FADVI and AnnData objects. None of the data in adata are modified. Only adds fields to uns.
- Parameters:
object. (AnnData) – Rows represent cells, columns represent features.
layer – If not None, uses this as the key in adata.layers for raw count data.
batch_key – key in adata.obs for batch information. Categories will automatically be converted into integer categories and saved to adata.obs[‘_scvi_batch’]. If None, assigns the same batch to all the data.
labels_key – key in adata.obs for label information. Categories will automatically be converted into integer categories and saved to adata.obs[‘_scvi_labels’]. If None, assigns the same label to all the data.
unlabeled_category – value in adata.obs[labels_key] that indicates unlabeled observations.
size_factor_key – key in adata.obs for size factor information. Instead of using library size as a size factor, the provided size factor column will be used as offset in the mean of the likelihood. Assumed to be on linear scale.
categorical_covariate_keys – keys in adata.obs that correspond to categorical data. These covariates can be added in addition to the batch covariate and are also treated as nuisance factors (i.e., the model tries to minimize their effects on the latent space). Thus, these should not be used for biologically-relevant factors that you do _not_ want to correct for.
continuous_covariate_keys – keys in adata.obs that correspond to continuous data. These covariates can be added in addition to the batch covariate and are also treated as nuisance factors (i.e., the model tries to minimize their effects on the latent space). Thus, these should not be used for biologically-relevant factors that you do _not_ want to correct for.
- Returns:
- .uns[‘_scvi’]
scvi setup dictionary
- .obs[‘_scvi_labels’]
labels encoded as integers
- .obs[‘_scvi_batch’]
batch encoded as integers
- Return type:
None. Adds the following fields
- get_latent_representation(adata: AnnData | None = None, indices: list[int] | None = None, give_mean: bool = True, mc_samples: int = 5000, batch_size: int | None = None, return_dist: bool = False, representation: Literal['full', 'b', 'batch', 'l', 'label', 'r', 'residual', 'lr', 'label_residual'] = 'label') dict[str, torch.Tensor] | torch.Tensor[source]
Return the latent representation for each cell.
- Parameters:
adata – AnnData object with equivalent structure to initial AnnData. If None, defaults to the AnnData object used to initialize the model.
indices – Indices of cells in adata to use. If None, all cells are used.
give_mean – Give mean of distribution or sample from it.
mc_samples – For distributions that have no closed-form mean (e.g. LogNormal), how many Monte Carlo samples to take for computing mean.
batch_size – Minibatch size for data loading into model. Defaults to scvi.settings.batch_size.
return_dist – If True, a mapping will be returned with key “dist” that contains distributional samples of the latent variables.
representation – Which latent representation to return: - “full”: concatenated representation [z_b, z_l, z_r] - “b” or “batch”: batch representation only - “l” or “label”: label representation only - “r” or “residual”: residual representation only - “lr” or “label_residual”: concatenated label and residual representation
- Returns:
Low-dimensional representation for each cell or dict of tensors if return_dist is True.
- Return type:
latent_representation
- get_normalized_expression(adata: AnnData | None = None, indices: list[int] | None = None, transform_batch: str | int | None = None, gene_list: list[str] | None = None, library_size: float | Literal['latent'] = 1, n_samples: int = 1, n_samples_overall: int | None = None, batch_size: int | None = None, return_mean: bool = True, return_numpy: bool | None = None)[source]
Return the normalized (decoded) gene expression.
This is denoted as \(\rho_n\) in the scVI paper.
- Parameters:
adata – AnnData object with equivalent structure to initial AnnData. If None, defaults to the AnnData object used to initialize the model.
indices – Indices of cells in adata to use. If None, all cells are used.
transform_batch –
Batch to condition on. If transform_batch is:
None, then real observed batch is used.
int, then batch transform_batch is used.
gene_list – Return frequencies of expression for a subset of genes. This can save memory when working with large datasets and few genes are of interest.
library_size – Scale the expression frequencies to a common library size. This allows gene expression levels to be interpreted on a common scale of relevant magnitude. If set to “latent”, use the latent library size.
n_samples – Number of posterior samples to use for estimation.
n_samples_overall – Number of posterior samples to use for estimation. Overrides n_samples.
batch_size – Minibatch size for data loading into model. Defaults to scvi.settings.batch_size.
return_mean – Whether to return the mean of the samples.
return_numpy – Return a
ndarrayinstead of aDataFrame. DataFrame includes gene names as columns. If either n_samples=1 or return_mean=True, defaults to False. Otherwise, it defaults to True.
- Returns:
If n_samples > 1 and return_mean is False, then the shape is (samples, cells, genes). Otherwise, shape is (cells, genes). In this case, return type is
DataFrameunless return_numpy is True.- Return type:
normalized_expression
- _batch_classifier_for_interpretability(x, batch_index=None, cat_covs=None, cont_covs=None)[source]
Classifier function for batch prediction with interpretability methods.
- _label_classifier_for_interpretability(x, batch_index=None, cat_covs=None, cont_covs=None)[source]
Classifier function for label prediction with interpretability methods.
- _compute_attributions(method: Literal['ig', 'gs'], x: torch.Tensor, predictions: torch.Tensor, soft: bool, batch_index: torch.Tensor | None, cat_covs: torch.Tensor | None, cont_covs: torch.Tensor | None, prediction_mode: str, method_args: dict | None = None) torch.Tensor[source]
Compute feature attributions using the specified interpretability method.
- Parameters:
method – Interpretability method: “ig” for Integrated Gradients or “gs” for GradientShap
x – Input tensor
predictions – Model predictions (soft or hard)
soft – Whether predictions are soft (probabilities) or hard (class indices)
batch_index – Additional forward arguments for the model
cat_covs – Additional forward arguments for the model
cont_covs – Additional forward arguments for the model
prediction_mode – Prediction mode (“batch” or “label”)
method_args – Additional arguments for the interpretability method
- Returns:
Feature attributions tensor
- Return type:
attributions
- _validate_interpretability_setup(method: str | None) None[source]
Validate interpretability method and dependencies.
- predict(adata: AnnData | None = None, indices: list[int] | None = None, prediction_mode: Literal['b', 'batch', 'l', 'label'] = 'label', soft: bool = False, batch_size: int | None = None, return_numpy: bool = True, interpretability: Literal['ig', 'gs'] | None = None, interpretability_args: dict | None = None, return_dict: bool = True) np.ndarray | pd.DataFrame | tuple[np.ndarray | pd.DataFrame, np.ndarray] | dict[source]
Predict batch or label categories using the supervised classification heads.
This method uses the trained encoders and classification heads to predict either batch categories (using batch latent b) or label categories (using label latent l).
- Parameters:
adata – AnnData object with equivalent structure to initial AnnData. If None, defaults to the AnnData object used to initialize the model.
indices – Indices of cells in adata to use. If None, all cells are used.
prediction_mode – What to predict: - “b” or “batch”: Predict batch categories using the batch latent (b) and batch classifier - “l” or “label”: Predict label categories using the label latent (l) and label classifier
soft – If True, return class probabilities. If False, return class predictions.
batch_size – Minibatch size for data loading into model. Defaults to scvi.settings.batch_size.
interpretability – If None, perform regular prediction. If “ig”, use Integrated Gradients to compute feature attributions for predictions. If “gs”, use GradientShap to compute feature attributions for predictions. Requires captum library.
interpretability_args – Keyword arguments for the interpretability method. Works with both IG and GradientShap.
return_dict – If True, return interpretability results as dict (new format). If False, return as tuple (backward compatibility). Only affects interpretability results.
- Returns:
- If interpretability is None:
If soft=True, returns class probabilities with shape (n_cells, n_classes). If soft=False, returns class predictions with shape (n_cells,). If return_numpy=True, returns numpy array, otherwise torch tensor.
- If interpretability in [“ig”, “gs”]:
- If return_dict=True (default): Returns dict with keys:
”predictions”: Model predictions (same format as above)
”attributions”: Feature attributions array with shape (n_cells, n_genes)
If return_dict=False: Returns tuple (predictions, attributions) for backward compatibility
- Return type:
predictions or dict/tuple with predictions and attributions
- _compute_predictions(x: Tensor, batch_index: Tensor | None, cat_covs: Tensor | None, cont_covs: Tensor | None, prediction_mode: str, soft: bool) Tensor[source]
Compute predictions for a batch of data.
- _format_predictions(predictions: torch.Tensor, prediction_mode: str, soft: bool, adata: AnnData, indices: list[int]) np.ndarray | pd.DataFrame[source]
Format predictions for output based on soft/hard and return type preferences.
- _get_unlabeled_category_index() int | None[source]
Get the index of the unlabeled category if it exists.
- get_ranked_features(adata: AnnData | None = None, attributions: np.ndarray | None = None, top_n: int | None = None) pd.DataFrame[source]
Get ranked gene list based on feature attribution scores.
This method takes attribution scores from interpretability methods (IG or GradientShap) and returns a ranked DataFrame of genes/features ordered by their importance.
- Parameters:
adata – AnnData object that has been registered via
setup_anndata(). If None, uses the AnnData object used to initialize the model.attributions – Attribution matrix from interpretability analysis with shape (n_cells, n_genes). Typically obtained from the predict method with interpretability=”ig” or “gs”.
top_n – Number of top features to return. If None, returns all features.
- Returns:
DataFrame with ranked features containing columns: - ‘feature’: Gene/feature name - ‘feature_idx’: Original feature index - ‘attribution_mean’: Mean attribution score across cells - ‘attribution_std’: Standard deviation of attribution scores - ‘attribution_abs_mean’: Mean absolute attribution score - ‘n_cells’: Number of cells in the analysis
- Return type:
pd.DataFrame
Examples
>>> # Get predictions with interpretability >>> result = model.predict(adata, interpretability="ig") >>> # Get ranked features >>> ranked_features = model.get_ranked_features(adata, result["attributions"]) >>> print(ranked_features.head())
- _abc_impl = <_abc._abc_data object>