from __future__ import annotations
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F
from scvi import REGISTRY_KEYS
from scvi.module._classifier import Classifier
from scvi.module._vae import VAE
from scvi.module.base import LossOutput, SupervisedModuleClass
from scvi.nn import Encoder
from torch.distributions import Normal
from torch.distributions import kl_divergence as kl
if TYPE_CHECKING:
from collections.abc import Iterable
from typing import Literal
from torch.distributions import Distribution
[docs]
class GradientReversalFunction(torch.autograd.Function):
"""Gradient reversal layer for adversarial training."""
[docs]
@staticmethod
def forward(ctx, x, alpha):
ctx.alpha = alpha
return x.view_as(x)
[docs]
@staticmethod
def backward(ctx, grad_output):
return grad_output.neg() * ctx.alpha, None
[docs]
def gradient_reversal(x, alpha=1.0):
"""Apply gradient reversal to input tensor."""
return GradientReversalFunction.apply(x, alpha)
[docs]
class FADVAE(SupervisedModuleClass, VAE):
"""Factor Disentanglement Variational Autoencoder.
This model disentangles batch-related variation (z_b), label-related variation (z_l),
and residual variation (z_r) using adversarial training and cross-correlation penalties.
Parameters
----------
n_input
Number of input genes
n_batch
Number of batches
n_labels
Number of labels
n_hidden
Number of nodes per hidden layer
n_latent_b
Dimensionality of batch latent space
n_latent_l
Dimensionality of label latent space
n_latent_r
Dimensionality of residual latent space
n_layers
Number of hidden layers used for encoder and decoder NNs
n_continuous_cov
Number of continuous covariates
n_cats_per_cov
Number of categories for each extra categorical covariate
dropout_rate
Dropout rate for neural networks
dispersion
Dispersion parameter option
log_variational
Log(data+1) prior to encoding for numerical stability
gene_likelihood
One of 'nb', 'zinb', 'poisson'
use_observed_lib_size
If True, use observed library size
beta
KL divergence weight
lambda_b
Weight for batch classification loss
lambda_l
Weight for label classification loss
alpha_bl
Weight for adversarial loss (label prediction from batch latents)
alpha_lb
Weight for adversarial loss (batch prediction from label latents)
alpha_rb
Weight for adversarial loss (batch prediction from residual latents)
alpha_rl
Weight for adversarial loss (label prediction from residual latents)
gamma
Weight for cross-correlation penalty
use_batch_norm
Whether to use batch norm in layers
use_layer_norm
Whether to use layer norm in layers
**vae_kwargs
Keyword args for VAE
"""
[docs]
def __init__(
self,
n_input: int,
n_batch: int = 0,
n_labels: int = 0,
n_hidden: int = 128,
n_latent_b: int = 30,
n_latent_l: int = 30,
n_latent_r: int = 10,
n_layers: int = 2,
n_continuous_cov: int = 0,
n_cats_per_cov: Iterable[int] | None = None,
dropout_rate: float = 0.1,
dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene",
log_variational: bool = True,
gene_likelihood: Literal["zinb", "nb", "poisson"] = "zinb",
use_observed_lib_size: bool = True,
beta: float = 1.0,
lambda_b: float = 50,
lambda_l: float = 50,
alpha_bl: float = 1.0,
alpha_lb: float = 1.0,
alpha_rb: float = 1.0,
alpha_rl: float = 1.0,
gamma: float = 1.0,
use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both",
use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "none",
unlabeled_category_id: int | None = None,
**vae_kwargs,
):
# Calculate total latent dimension for VAE parent
n_latent_total = n_latent_b + n_latent_l + n_latent_r
super().__init__(
n_input,
n_hidden=n_hidden,
n_latent=n_latent_total,
n_layers=n_layers,
n_continuous_cov=n_continuous_cov,
n_cats_per_cov=n_cats_per_cov,
dropout_rate=dropout_rate,
n_batch=n_batch,
dispersion=dispersion,
log_variational=log_variational,
gene_likelihood=gene_likelihood,
use_observed_lib_size=use_observed_lib_size,
use_batch_norm=use_batch_norm,
use_layer_norm=use_layer_norm,
**vae_kwargs,
)
# Store dimensions
self.n_latent_b = n_latent_b
self.n_latent_l = n_latent_l
self.n_latent_r = n_latent_r
self.n_labels = n_labels
self.n_cats_per_cov = n_cats_per_cov
self._unlabeled_category_id = unlabeled_category_id
# Store loss weights
self.beta = beta
self.lambda_b = lambda_b
self.lambda_l = lambda_l
self.alpha_bl = alpha_bl
self.alpha_lb = alpha_lb
self.alpha_rb = alpha_rb
self.alpha_rl = alpha_rl
self.gamma = gamma
use_batch_norm_encoder = use_batch_norm == "encoder" or use_batch_norm == "both"
use_layer_norm_encoder = use_layer_norm == "encoder" or use_layer_norm == "both"
# Construct categorical list for encoders (following VAE pattern)
cat_list = [n_batch] + list([] if n_cats_per_cov is None else n_cats_per_cov)
# Separate encoders for each latent factor
self.encoder_b = Encoder(
n_input,
n_latent_b,
n_cat_list=cat_list,
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
use_batch_norm=use_batch_norm_encoder,
use_layer_norm=use_layer_norm_encoder,
return_dist=True,
)
self.encoder_l = Encoder(
n_input,
n_latent_l,
n_cat_list=cat_list,
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
use_batch_norm=use_batch_norm_encoder,
use_layer_norm=use_layer_norm_encoder,
return_dist=True,
)
self.encoder_r = Encoder(
n_input,
n_latent_r,
n_cat_list=cat_list,
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
use_batch_norm=use_batch_norm_encoder,
use_layer_norm=use_layer_norm_encoder,
return_dist=True,
)
# Supervised classification heads
self.head_batch = Classifier(
n_latent_b,
n_labels=n_batch,
n_hidden=n_hidden,
n_layers=n_layers,
dropout_rate=dropout_rate,
use_batch_norm=use_batch_norm_encoder,
use_layer_norm=use_layer_norm_encoder,
logits=True,
)
self.head_label = Classifier(
n_latent_l,
n_labels=n_labels,
n_hidden=n_hidden,
n_layers=n_layers,
dropout_rate=dropout_rate,
use_batch_norm=use_batch_norm_encoder,
use_layer_norm=use_layer_norm_encoder,
logits=True,
)
# Adversarial classification heads
self.adv_head_batch_from_l = Classifier(
n_latent_l,
n_labels=n_batch,
n_hidden=n_hidden,
n_layers=n_layers,
dropout_rate=dropout_rate,
use_batch_norm=use_batch_norm_encoder,
use_layer_norm=use_layer_norm_encoder,
logits=True,
)
self.adv_head_label_from_b = Classifier(
n_latent_b,
n_labels=n_labels,
n_hidden=n_hidden,
n_layers=n_layers,
dropout_rate=dropout_rate,
use_batch_norm=use_batch_norm_encoder,
use_layer_norm=use_layer_norm_encoder,
logits=True,
)
self.adv_head_batch_from_r = Classifier(
n_latent_r,
n_labels=n_batch,
n_hidden=n_hidden,
n_layers=n_layers,
dropout_rate=dropout_rate,
use_batch_norm=use_batch_norm_encoder,
use_layer_norm=use_layer_norm_encoder,
logits=True,
)
self.adv_head_label_from_r = Classifier(
n_latent_r,
n_labels=n_labels,
n_hidden=n_hidden,
n_layers=n_layers,
dropout_rate=dropout_rate,
use_batch_norm=use_batch_norm_encoder,
use_layer_norm=use_layer_norm_encoder,
logits=True,
)
[docs]
def inference(
self, x, batch_index=None, cat_covs=None, cont_covs=None, n_samples=1
):
"""Run inference to get latent representations."""
# Prepare input
x_ = x
if self.log_variational:
x_ = torch.log1p(x_)
# Ensure all tensors are on the same device as x
device = x.device
# Move batch_index to correct device if provided
if batch_index is not None:
batch_index = batch_index.to(device)
# Move categorical covariates to correct device if provided
if cat_covs is not None:
cat_covs = [
cov.to(device) if isinstance(cov, torch.Tensor) else cov
for cov in cat_covs
]
categorical_input = tuple(cat_covs)
else:
categorical_input = ()
# Move continuous covariates to correct device if provided
if cont_covs is not None:
cont_covs = [
cov.to(device) if isinstance(cov, torch.Tensor) else cov
for cov in cont_covs
]
# Encode to separate latent factors
qb, z_b = self.encoder_b(x_, batch_index, *categorical_input)
ql, z_l = self.encoder_l(x_, batch_index, *categorical_input)
qr, z_r = self.encoder_r(x_, batch_index, *categorical_input)
# Sample from distributions if needed
if n_samples > 1:
z_b = qb.rsample((n_samples,))
z_l = ql.rsample((n_samples,))
z_r = qr.rsample((n_samples,))
# Library size encoding (from parent VAE)
if not self.use_observed_lib_size:
ql_lib, library_encoded = self.l_encoder(
x_, batch_index, *categorical_input
)
if n_samples > 1:
library = ql_lib.rsample((n_samples,))
else:
library = library_encoded
else:
library = torch.log(x.sum(1)).unsqueeze(1)
ql_lib = None
outputs = {
"z_b": z_b,
"z_l": z_l,
"z_r": z_r,
"qb": qb,
"ql": ql,
"qr": qr,
"library": library,
}
if not self.use_observed_lib_size:
outputs["ql_lib"] = ql_lib
return outputs
[docs]
def generative(self, z, library, batch_index=None, cat_covs=None, cont_covs=None):
"""Run generative model to reconstruct data."""
# Ensure all tensors are on the same device as z
device = z.device
# Move library to correct device
if library is not None:
library = library.to(device)
# Move batch_index to correct device if provided
if batch_index is not None:
batch_index = batch_index.to(device)
# Move categorical covariates to correct device if provided
if cat_covs is not None:
cat_covs = [
cov.to(device) if isinstance(cov, torch.Tensor) else cov
for cov in cat_covs
]
# Move continuous covariates to correct device if provided
if cont_covs is not None:
cont_covs = [
cov.to(device) if isinstance(cov, torch.Tensor) else cov
for cov in cont_covs
]
# Use the parent VAE's decoder with concatenated latents
return super().generative(z, library, batch_index, cat_covs, cont_covs)
[docs]
def cross_covariance_penalty(self, z_b, z_l, z_r):
"""Compute cross-covariance penalty to enforce independence."""
# Center the latent representations
z_b_centered = z_b - z_b.mean(dim=0, keepdim=True)
z_l_centered = z_l - z_l.mean(dim=0, keepdim=True)
z_r_centered = z_r - z_r.mean(dim=0, keepdim=True)
n = z_b.size(0)
# Compute cross-covariances
cov_bl = torch.mm(z_b_centered.t(), z_l_centered) / (n - 1)
cov_br = torch.mm(z_b_centered.t(), z_r_centered) / (n - 1)
cov_lr = torch.mm(z_l_centered.t(), z_r_centered) / (n - 1)
# Return sum of squared cross-covariances (Frobenius norm)
penalty = (cov_bl**2).sum() + (cov_br**2).sum() + (cov_lr**2).sum()
return penalty
[docs]
def loss(
self,
tensors: dict[str, torch.Tensor],
inference_outputs: dict[str, torch.Tensor | Distribution | None],
generative_outputs: dict[str, Distribution | None],
kl_weight: float = 1.0,
labelled_tensors: dict[str, torch.Tensor] | None = None,
**kwargs, # Accept additional keyword arguments for compatibility
):
"""Compute the total loss including ELBO, supervised, adversarial, and decorrelation terms."""
px: Distribution = generative_outputs["px"]
x: torch.Tensor = tensors[REGISTRY_KEYS.X_KEY]
batch_index: torch.Tensor = tensors[REGISTRY_KEYS.BATCH_KEY]
labels: torch.Tensor = tensors[REGISTRY_KEYS.LABELS_KEY]
# Get latent representations and distributions
z_b = inference_outputs["z_b"]
z_l = inference_outputs["z_l"]
z_r = inference_outputs["z_r"]
qb = inference_outputs["qb"]
ql = inference_outputs["ql"]
qr = inference_outputs["qr"]
# Reconstruction loss - handle different gene likelihoods properly
if self.gene_likelihood == "zinb" or self.gene_likelihood == "nb":
# For negative binomial distributions, ensure x is non-negative integers
x_safe = torch.clamp(torch.round(x), min=0.0)
elif self.gene_likelihood == "poisson":
# For Poisson, ensure x is non-negative
x_safe = torch.clamp(x, min=0.0)
else:
# For other distributions, use general clamping
x_safe = torch.clamp(x, min=1e-8)
reconst_loss = -px.log_prob(x_safe).sum(-1)
# KL divergences for each latent factor
kl_divergence_b = kl(
qb, Normal(torch.zeros_like(qb.loc), torch.ones_like(qb.scale))
).sum(dim=-1)
kl_divergence_l = kl(
ql, Normal(torch.zeros_like(ql.loc), torch.ones_like(ql.scale))
).sum(dim=-1)
kl_divergence_r = kl(
qr, Normal(torch.zeros_like(qr.loc), torch.ones_like(qr.scale))
).sum(dim=-1)
kl_divergence = kl_divergence_b + kl_divergence_l + kl_divergence_r
# Library size KL if not using observed library size
if not self.use_observed_lib_size:
ql_lib = inference_outputs["ql_lib"]
(
local_library_log_means,
local_library_log_vars,
) = self._compute_local_library_params(batch_index)
kl_divergence_lib = kl(
ql_lib,
Normal(local_library_log_means, torch.sqrt(local_library_log_vars)),
).sum(dim=1)
kl_divergence += kl_divergence_lib
# ELBO loss
elbo_loss = torch.mean(reconst_loss + self.beta * kl_divergence * kl_weight)
# Supervised classification losses
pred_b = self.head_batch(z_b)
pred_l = self.head_label(z_l)
sup_loss_b = F.cross_entropy(pred_b, batch_index.squeeze())
# For label prediction, handle potential unlabeled category
# If the classifier has reduced outputs (n_labels < original),
# the labels might need remapping to exclude unlabeled category
labels_squeezed = labels.squeeze()
if (
hasattr(self, "_unlabeled_category_id")
and self._unlabeled_category_id is not None
):
# Filter out unlabeled data for classification loss
labeled_mask = labels_squeezed != self._unlabeled_category_id
if labeled_mask.sum() > 0: # Only compute if there are labeled samples
pred_l_labeled = pred_l[labeled_mask]
labels_labeled = labels_squeezed[labeled_mask]
# If labels need remapping (classifier outputs < total categories)
if pred_l_labeled.size(1) < (self._unlabeled_category_id + 1):
# Remap labels: all indices >= unlabeled_category_id should be shifted down
labels_remapped = labels_labeled.clone()
labels_remapped[labels_labeled > self._unlabeled_category_id] -= 1
sup_loss_l = F.cross_entropy(pred_l_labeled, labels_remapped)
else:
sup_loss_l = F.cross_entropy(pred_l_labeled, labels_labeled)
else:
sup_loss_l = torch.tensor(0.0, device=pred_l.device, requires_grad=True)
else:
sup_loss_l = F.cross_entropy(pred_l, labels_squeezed)
sup_loss = self.lambda_b * sup_loss_b + self.lambda_l * sup_loss_l
# Adversarial losses with gradient reversal
adv_pred_bl = self.adv_head_label_from_b(gradient_reversal(z_b, self.alpha_bl))
adv_pred_lb = self.adv_head_batch_from_l(gradient_reversal(z_l, self.alpha_lb))
adv_pred_rb = self.adv_head_batch_from_r(gradient_reversal(z_r, self.alpha_rb))
adv_pred_rl = self.adv_head_label_from_r(gradient_reversal(z_r, self.alpha_rl))
# Apply same logic to adversarial losses
if (
hasattr(self, "_unlabeled_category_id")
and self._unlabeled_category_id is not None
):
labeled_mask = labels_squeezed != self._unlabeled_category_id
if labeled_mask.sum() > 0:
adv_pred_bl_labeled = adv_pred_bl[labeled_mask]
adv_pred_rl_labeled = adv_pred_rl[labeled_mask]
labels_labeled = labels_squeezed[labeled_mask]
if adv_pred_bl_labeled.size(1) < (self._unlabeled_category_id + 1):
labels_remapped = labels_labeled.clone()
labels_remapped[labels_labeled > self._unlabeled_category_id] -= 1
adv_loss_bl = F.cross_entropy(adv_pred_bl_labeled, labels_remapped)
adv_loss_rl = F.cross_entropy(adv_pred_rl_labeled, labels_remapped)
else:
adv_loss_bl = F.cross_entropy(adv_pred_bl_labeled, labels_labeled)
adv_loss_rl = F.cross_entropy(adv_pred_rl_labeled, labels_labeled)
else:
adv_loss_bl = torch.tensor(
0.0, device=adv_pred_bl.device, requires_grad=True
)
adv_loss_rl = torch.tensor(
0.0, device=adv_pred_rl.device, requires_grad=True
)
else:
adv_loss_bl = F.cross_entropy(adv_pred_bl, labels_squeezed)
adv_loss_rl = F.cross_entropy(adv_pred_rl, labels_squeezed)
adv_loss_lb = F.cross_entropy(adv_pred_lb, batch_index.squeeze())
adv_loss_rb = F.cross_entropy(adv_pred_rb, batch_index.squeeze())
adv_loss = adv_loss_bl + adv_loss_lb + adv_loss_rb + adv_loss_rl
# Cross-covariance penalty
xcov_penalty = self.cross_covariance_penalty(z_b, z_l, z_r)
xcov_loss = self.gamma * xcov_penalty
# Total loss
total_loss = elbo_loss + sup_loss + adv_loss + xcov_loss
# For training plan metrics, filter out unlabeled data to match classifier outputs
if (
hasattr(self, "_unlabeled_category_id")
and self._unlabeled_category_id is not None
):
labels_squeezed = labels.squeeze()
labeled_mask = labels_squeezed != self._unlabeled_category_id
if labeled_mask.sum() > 0:
# Filter logits and labels for metrics computation
filtered_logits = pred_l[labeled_mask]
filtered_labels = labels_squeezed[labeled_mask]
# Remap labels if classifier has reduced outputs
if filtered_logits.size(1) < (self._unlabeled_category_id + 1):
filtered_labels_remapped = filtered_labels.clone()
filtered_labels_remapped[
filtered_labels > self._unlabeled_category_id
] -= 1
metrics_labels = filtered_labels_remapped
else:
metrics_labels = filtered_labels
metrics_logits = filtered_logits
metrics_classification_loss = sup_loss
else:
# No labeled data in this batch - set classification_loss to None to skip metrics
# metrics_logits = pred_l[:0] # Empty tensor with correct dimensions
# metrics_labels = labels.squeeze()[:0] # Empty tensor
# metrics_classification_loss = None
return LossOutput(
loss=total_loss,
reconstruction_loss=reconst_loss,
kl_local=kl_divergence,
extra_metrics={
"elbo_loss": elbo_loss,
"sup_loss": sup_loss,
"sup_loss_b": sup_loss_b,
"sup_loss_l": sup_loss_l,
"adv_loss": adv_loss,
"adv_loss_bl": adv_loss_bl,
"adv_loss_lb": adv_loss_lb,
"adv_loss_rb": adv_loss_rb,
"adv_loss_rl": adv_loss_rl,
"xcov_loss": xcov_loss,
"z_b": z_b,
"z_l": z_l,
"z_r": z_r,
},
)
else:
# No unlabeled category handling needed
metrics_logits = pred_l
metrics_labels = labels.squeeze()
metrics_classification_loss = sup_loss
return LossOutput(
loss=total_loss,
reconstruction_loss=reconst_loss,
kl_local=kl_divergence,
classification_loss=metrics_classification_loss,
true_labels=metrics_labels,
logits=metrics_logits,
extra_metrics={
"elbo_loss": elbo_loss,
"sup_loss": sup_loss,
"sup_loss_b": sup_loss_b,
"sup_loss_l": sup_loss_l,
"adv_loss": adv_loss,
"adv_loss_bl": adv_loss_bl,
"adv_loss_lb": adv_loss_lb,
"adv_loss_rb": adv_loss_rb,
"adv_loss_rl": adv_loss_rl,
"xcov_loss": xcov_loss,
"z_b": z_b,
"z_l": z_l,
"z_r": z_r,
},
)