r"""
DIRECTi, an deep learning model for semi-supervised parametric dimension
reduction and systematical bias removal, extended from scVI.
"""
import os
import tempfile
import time
import typing
from collections import OrderedDict
import anndata as ad
import numpy as np
import pandas as pd
import scipy
import torch
import torch.distributions as D
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from . import config, data, latent, prob, rmbatch, utils
from .config import DEVICE
from .rebuild import RMSprop
_TRAIN = 1
_TEST = 0
[docs]class DIRECTi(nn.Module):
r"""
DIRECTi model.
Parameters
----------
genes
Genes to use in the model.
latent_module
Module for latent variable (encoder module).
prob_module
Module for data generative modeling (decoder module).
batch_effect
Batch effects need to be corrected.
rmbatch_modules
List of modules for batch effect correction.
denoising
Whether to add noise to the input during training (source of randomness
in modeling the approximate posterior).
learning_rate
Learning rate.
path
Specifies a path where model configuration, checkpoints,
as well as the final model will be saved.
random_seed
Random seed. If not specified, :data:`config.RANDOM_SEED`
will be used, which defaults to 0.
Attributes
----------
genes
List of gene names the model is defined and fitted on
batch_effect_list
List of batch effect names need to be corrected.
Examples
--------
The :func:`fit_DIRECTi` function offers an easy to use wrapper of this
:class:`DIRECTi` model class, which is the preferred API and should satisfy most
needs. We suggest using the :func:`fit_DIRECTi` wrapper first.
"""
_TRAIN = 1
_TEST = 0
def __init__(
self,
genes: typing.List[str],
latent_module: "latent.Latent",
prob_module: "prob.ProbModel",
rmbatch_modules: typing.Tuple["rmbatch.RMBatch"],
denoising: bool = True,
learning_rate: float = 1e-3,
path: typing.Optional[str] = None,
random_seed: int = config._USE_GLOBAL,
_mode: int = _TRAIN,
) -> None:
super().__init__()
if path is None:
path = tempfile.mkdtemp()
random_seed = (
config.RANDOM_SEED if random_seed == config._USE_GLOBAL else random_seed
)
self.ensure_reproducibility(random_seed)
self.genes = genes
self.latent_module = latent_module
self.prob_module = prob_module
self.rmbatch_modules = rmbatch_modules
self.denoising = denoising
self.learning_rate = learning_rate
self.path = path
self.random_seed = random_seed
self._mode = _mode
self.opt_latent_reg = RMSprop(
self.latent_module.parameters_reg(), lr=learning_rate
)
self.opt_latent_fit = RMSprop(
self.latent_module.parameters_fit(), lr=learning_rate
)
self.opt_prob = RMSprop(self.prob_module.parameters(), lr=learning_rate)
self.opts_rmbatch = [
RMSprop(_rmbatch.parameters(), lr=learning_rate)
if _rmbatch._class
in ("Adversarial", "MNNAdversarial", "AdaptiveMNNAdversarial")
else None
for _rmbatch in self.rmbatch_modules
]
@staticmethod
def ensure_reproducibility(random_seed):
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
def get_config(self) -> typing.Mapping:
return {
"genes": self.genes,
"latent_module": self.latent_module.get_config(),
"prob_module": self.prob_module.get_config(),
"rmbatch_modules": [
_module.get_config() for _module in self.rmbatch_modules
],
"denoising": self.denoising,
"learning_rate": self.learning_rate,
"path": self.path,
"random_seed": self.random_seed,
"_mode": self._mode,
}
@staticmethod
def preprocess(
x: torch.Tensor, libs: torch.Tensor, noisy: bool = True
) -> torch.Tensor:
x = x / (libs / 10000)
if noisy:
x = D.Poisson(rate=x).sample()
x = x.log1p()
return x
def fit(
self,
dataset: data.Dataset,
batch_size: int = 128,
val_split: float = 0.1,
epoch: int = 1000,
patience: int = 30,
tolerance: float = 0.0,
progress_bar: bool = False,
):
os.makedirs(self.path, exist_ok=True)
utils.logger.info("Using model path: %s", self.path)
val_size = int(len(dataset) * val_split)
train_size = len(dataset) - val_size
train_dataset, val_dataset = torch.utils.data.random_split(
dataset,
[train_size, val_size],
generator=torch.Generator().manual_seed(self.random_seed),
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True,
generator=torch.Generator().manual_seed(self.random_seed),
)
val_dataloader = torch.utils.data.DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=True,
generator=torch.Generator().manual_seed(self.random_seed),
)
assert self._mode == _TRAIN
self.to(DEVICE)
self.ensure_reproducibility(self.random_seed)
self.save_weights(self.path)
self.latent_module.check_fine_tune()
self.prob_module.check_fine_tune()
patience_remain = patience
best_loss = 1e10
summarywriter = SummaryWriter(log_dir=os.path.join(self.path, "summary"))
for _epoch in range(epoch):
start_time = time.time()
if progress_bar:
train_dataloader = utils.smart_tqdm()(train_dataloader)
train_loss = self.train_epoch(train_dataloader, _epoch, summarywriter)
if progress_bar:
val_dataloader = utils.smart_tqdm()(val_dataloader)
val_loss = self.val_epoch(val_dataloader, _epoch, summarywriter)
report = f"[{self.__class__.__name__} epoch {_epoch}] "
report += f"train={train_loss:.3f}, "
report += f"val={val_loss:.3f}, "
report += f"time elapsed={time.time() - start_time:.1f}s"
if any([_epoch < _rmbatch.delay for _rmbatch in self.rmbatch_modules]):
if _epoch % 10 == 0:
report += " Regular save..."
self.save_weights(self.path)
elif val_loss < best_loss + tolerance:
report += " Best save..."
self.save_weights(self.path)
best_loss = val_loss
patience_remain = patience
else:
patience_remain -= 1
print(report)
if patience_remain < 0:
break
print("Restoring best model...")
self.load_weights(self.path)
self.save_weights(self.path)
def train_epoch(self, train_dataloader, epoch, summarywriter):
self.train()
loss_record = {}
self.latent_module.init_loss_record(loss_record)
self.prob_module.init_loss_record(loss_record)
for _rmbatch in self.rmbatch_modules:
_rmbatch.init_loss_record(loss_record)
loss_record["early_stop_loss"] = 0
loss_record["total_loss"] = 0
datasize = 0
for feed_dict in train_dataloader:
for key, value in feed_dict.items():
feed_dict[key] = value.to(DEVICE)
exprs = feed_dict["exprs"]
libs = feed_dict["library_size"]
datasize += libs.shape[0]
x = self.preprocess(exprs, libs, self.denoising)
l, l_components = self.latent_module(x)
latent_d_loss = self.latent_module.d_loss(
l_components, feed_dict, loss_record
)
self.opt_latent_reg.zero_grad()
latent_d_loss.backward()
self.opt_latent_reg.step()
l = l.detach()
for _rmbatch, _opt in zip(self.rmbatch_modules, self.opts_rmbatch):
if epoch >= _rmbatch.delay:
mask = _rmbatch.get_mask(l, feed_dict)
if mask.sum() > 0:
for _ in range(_rmbatch.n_steps):
pred = _rmbatch(l, mask)
rmbatch_d_loss = _rmbatch.d_loss(
pred, feed_dict, mask, loss_record
)
if not _opt is None:
_opt.zero_grad()
rmbatch_d_loss.backward()
_opt.step()
x = self.preprocess(exprs, libs, self.denoising)
l, l_components = self.latent_module(x)
latent_g_loss = self.latent_module.g_loss(
l_components, feed_dict, loss_record
)
full_l = [l]
for _rmbatch in self.rmbatch_modules:
full_l.append(feed_dict[_rmbatch.name])
d_components = self.prob_module(full_l, feed_dict)
prob_loss = self.prob_module.loss(d_components, feed_dict, loss_record)
loss = prob_loss + latent_g_loss
for _rmbatch in self.rmbatch_modules:
if epoch >= _rmbatch.delay:
mask = _rmbatch.get_mask(l, feed_dict)
if mask.sum() > 0:
pred = _rmbatch(l, mask)
rmbatch_g_loss = _rmbatch.g_loss(
pred, feed_dict, mask, loss_record
)
loss = loss + rmbatch_g_loss
self.opt_latent_fit.zero_grad()
self.opt_prob.zero_grad()
loss.backward()
self.opt_latent_fit.step()
self.opt_prob.step()
loss_record["early_stop_loss"] += prob_loss.item() * x.shape[0]
loss_record["total_loss"] += loss.item() * x.shape[0]
for key, value in loss_record.items():
summarywriter.add_scalar(key + ":0 (train)", value / datasize, epoch)
return loss_record["early_stop_loss"] / datasize
def val_epoch(self, val_dataloader, epoch, summarywriter):
self.eval()
loss_record = {}
self.latent_module.init_loss_record(loss_record)
self.prob_module.init_loss_record(loss_record)
for _rmbatch in self.rmbatch_modules:
_rmbatch.init_loss_record(loss_record)
loss_record["early_stop_loss"] = 0
loss_record["total_loss"] = 0
datasize = 0
for feed_dict in val_dataloader:
for key, value in feed_dict.items():
feed_dict[key] = value.to(DEVICE)
exprs = feed_dict["exprs"]
libs = feed_dict["library_size"]
datasize += libs.shape[0]
with torch.no_grad():
x = self.preprocess(exprs, libs, self.denoising)
l, l_components = self.latent_module(x)
_ = self.latent_module.d_loss(l_components, feed_dict, loss_record)
for _rmbatch in self.rmbatch_modules:
if epoch >= _rmbatch.delay:
mask = _rmbatch.get_mask(l, feed_dict)
if mask.sum() > 0:
for _ in range(_rmbatch.n_steps):
pred = _rmbatch(l, mask)
_ = _rmbatch.d_loss(pred, feed_dict, mask, loss_record)
x = self.preprocess(exprs, libs, self.denoising)
l, l_components = self.latent_module(x)
latent_g_loss = self.latent_module.g_loss(
l_components, feed_dict, loss_record
)
full_l = [l]
for _rmbatch in self.rmbatch_modules:
full_l.append(feed_dict[_rmbatch.name])
d_components = self.prob_module(full_l, feed_dict)
prob_loss = self.prob_module.loss(d_components, feed_dict, loss_record)
loss = prob_loss + latent_g_loss
for _rmbatch in self.rmbatch_modules:
if epoch >= _rmbatch.delay:
mask = _rmbatch.get_mask(l, feed_dict)
if mask.sum() > 0:
pred = _rmbatch(l, mask)
rmbatch_g_loss = _rmbatch.g_loss(
pred, feed_dict, mask, loss_record
)
loss = loss + rmbatch_g_loss
loss_record["early_stop_loss"] += prob_loss.item() * x.shape[0]
loss_record["total_loss"] += loss.item() * x.shape[0]
for key, value in loss_record.items():
summarywriter.add_scalar(key + ":0 (val)", value / datasize, epoch)
return loss_record["early_stop_loss"] / datasize
def save_weights(self, path: str, checkpoint: str = "checkpoint.pk"):
os.makedirs(path, exist_ok=True)
torch.save(self.state_dict(), os.path.join(path, checkpoint))
def load_weights(self, path: str, checkpoint: str = "checkpoint.pk"):
assert os.path.exists(path)
self.load_state_dict(torch.load(os.path.join(path, checkpoint), map_location=DEVICE))
@classmethod
def load_config(cls, configuration: typing.Mapping):
_class = configuration["latent_module"]["_class"]
latent_module = getattr(latent, _class)(**configuration["latent_module"])
_class = configuration["prob_module"]["_class"]
prob_module = getattr(prob, _class)(**configuration["prob_module"])
rmbatch_modules = nn.ModuleList()
for _conf in configuration["rmbatch_modules"]:
_class = _conf["_class"]
rmbatch_modules.append(getattr(rmbatch, _class)(**_conf))
configuration["latent_module"] = latent_module
configuration["prob_module"] = prob_module
configuration["rmbatch_modules"] = rmbatch_modules
model = cls(**configuration)
return model
[docs] def save(
self,
path: typing.Optional[str] = None,
config: str = "config.pk",
weights: str = "weights.pk",
):
r"""
Save model to files
Parameters
----------
path
Path to a directory where the model will be saved
config
Name of the configuration file
weights
Name of the weights file
"""
if path is None:
os.makedirs(self.path, exist_ok=True)
torch.save(self.get_config(), os.path.join(self.path, config))
torch.save(self.state_dict(), os.path.join(self.path, weights))
else:
os.makedirs(path, exist_ok=True)
configuration = self.get_config()
configuration["path"] = path
torch.save(configuration, os.path.join(path, config))
torch.save(self.state_dict(), os.path.join(path, weights))
[docs] @classmethod
def load(
cls,
path: str,
config: str = "config.pk",
weights: str = "weights.pk",
_mode: int = _TRAIN,
) -> None:
r"""
Load model from files
Parameters
----------
path
Path to a model directory to load from
config
Name of the configuration file
weights
Name of the weights file
"""
assert os.path.exists(path)
configuration = torch.load(os.path.join(path, config))
if configuration["_mode"] == _TEST and _mode == _TRAIN:
raise RuntimeError(
"The model was minimal, please use argument '_mode=Cell_BLAST.blast.MINIMAL'"
)
model = cls.load_config(configuration)
model.load_state_dict(torch.load(os.path.join(path, weights), map_location=DEVICE), strict=False)
return model
[docs] def inference(
self,
adata: ad.AnnData,
batch_size: int = 4096,
n_posterior: int = 0,
progress_bar: bool = False,
priority: str = "auto",
random_seed: typing.Optional[int] = config._USE_GLOBAL,
) -> np.ndarray:
r"""
Project expression profiles into the cell embedding space.
Parameters
----------
adata
Dataset for which to compute cell embeddings.
batch_size
Minibatch size.
Changing this may slighly affect speed, but not the result.
n_posterior
How many posterior samples to fetch.
If set to 0, the posterior point estimate is computed.
If greater than 0, produces ``n_posterior`` number of
posterior samples for each cell.
progress_bar
Whether to show progress bar duing projection.
priority
Should be among {"auto", "speed", "memory"}.
Controls which one of speed or memory should be prioritized, by
default "auto", meaning that data with more than 100,000 cells will
use "memory" mode and smaller data will use "speed" mode.
random_seed
Random seed used with noisy projection. If not specified,
:data:`config.RANDOM_SEED` will be used, which defaults to 0.
Returns
-------
latent
Coordinates in the latent space.
If ``n_posterior`` is 0, will be in shape :math:`cell \times latent\_dim`.
If ``n_posterior`` is greater than 0, will be in shape
:math:`cell \times noisy \times latent\_dim`.
"""
self.eval()
self.to(DEVICE)
random_seed = (
config.RANDOM_SEED
if random_seed is None or random_seed == config._USE_GLOBAL
else random_seed
)
x = data.select_vars(adata, self.genes).X
if "__libsize__" not in adata.obs.columns:
data.compute_libsize(adata)
l = adata.obs["__libsize__"].to_numpy().reshape((-1, 1))
if n_posterior > 0:
if priority == "auto":
priority = "memory" if x.shape[0] > 1e4 else "speed"
if priority == "speed":
if scipy.sparse.issparse(x):
xrep = x.tocsr()[np.repeat(np.arange(x.shape[0]), n_posterior)]
else:
xrep = np.repeat(x, n_posterior, axis=0)
lrep = np.repeat(l, n_posterior, axis=0)
data_dict = OrderedDict(exprs=xrep, library_size=lrep)
return (
self._fetch_latent(
data.Dataset(data_dict),
batch_size,
True,
progress_bar,
random_seed,
)
.astype(np.float32)
.reshape((x.shape[0], n_posterior, -1))
)
else: # priority == "memory":
data_dict = OrderedDict(exprs=x, library_size=l)
return np.stack(
[
self._fetch_latent(
data.Dataset(data_dict),
batch_size,
True,
progress_bar,
(random_seed + i) if random_seed is not None else None,
).astype(np.float32)
for i in range(n_posterior)
],
axis=1,
)
data_dict = OrderedDict(exprs=x, library_size=l)
return self._fetch_latent(
data.Dataset(data_dict), batch_size, False, progress_bar, random_seed
).astype(np.float32)
[docs] def clustering(
self,
adata: ad.AnnData,
batch_size: int = 4096,
return_confidence: bool = False,
progress_bar: bool = False,
) -> typing.Tuple[np.ndarray]:
r"""
Get model intrinsic clustering of the data.
Parameters
----------
adata
Dataset for which to obtain the intrinsic clustering.
batch_size
Minibatch size.
Changing this may slighly affect speed, but not the result.
return_confidence
Whether to return model intrinsic clustering confidence.
progress_bar
Whether to show progress bar during projection.
Returns
-------
idx
model intrinsic clustering index, 1 dimensional
confidence (if ``return_confidence`` is True)
model intrinsic clustering confidence, 1 dimensional
"""
self.eval()
self.to(DEVICE)
if not isinstance(self.latent_module, latent.CatGau):
raise Exception("Model has no intrinsic clustering")
x = data.select_vars(adata, self.genes).X
if "__libsize__" not in adata.obs.columns:
data.compute_libsize(adata)
l = adata.obs["__libsize__"].to_numpy().reshape((-1, 1))
data_dict = OrderedDict(exprs=x, library_size=l)
cat = self._fetch_cat(
data.Dataset(data_dict), batch_size, False, progress_bar
).astype(np.float32)
if return_confidence:
return cat.argmax(axis=1), cat.max(axis=1)
return cat.argmax(axis=1)
[docs] def gene_grad(
self,
adata: ad.AnnData,
latent_grad: np.ndarray,
batch_size: int = 4096,
progress_bar: bool = False,
) -> np.ndarray:
r"""
Fetch gene space gradients with regard to latent space gradients
Parameters
----------
dataset
Dataset for which to obtain gene gradients.
latent_grad
Latent space gradients.
batch_size
Minibatch size.
Changing this may slighly affect speed, but not the result.
progress_bar
Whether to show progress bar during projection.
Returns
-------
grad
Fetched gene-wise gradient
"""
self.eval()
self.to(DEVICE)
x = data.select_vars(adata, self.genes).X
if "__libsize__" not in adata.obs.columns:
data.compute_libsize(adata)
l = adata.obs["__libsize__"].to_numpy().reshape((-1, 1))
data_dict = OrderedDict(exprs=x, library_size=l, output_grad=latent_grad)
return self._fetch_grad(
data.Dataset(data_dict), batch_size=batch_size, progress_bar=progress_bar
)
def _fetch_latent(
self,
dataset: data.Dataset,
batch_size: int,
noisy: bool,
progress_bar: bool,
random_seed: int,
) -> np.ndarray:
self.ensure_reproducibility(random_seed)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=False
)
if progress_bar:
dataloader = utils.smart_tqdm()(dataloader)
with torch.no_grad():
latents = []
for feed_dict in dataloader:
for key, value in feed_dict.items():
feed_dict[key] = value.to(DEVICE)
exprs = feed_dict["exprs"]
libs = feed_dict["library_size"]
latents.append(
self.latent_module.fetch_latent(self.preprocess(exprs, libs, noisy))
)
return torch.cat(latents).cpu().numpy()
def _fetch_cat(
self, dataset: data.Dataset, batch_size: int, noisy: bool, progress_bar: bool
) -> typing.Tuple[np.ndarray]:
self.ensure_reproducibility(self.random_seed)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=False
)
if progress_bar:
dataloader = utils.smart_tqdm()(dataloader)
with torch.no_grad():
cats = []
for feed_dict in dataloader:
for key, value in feed_dict.items():
feed_dict[key] = value.to(DEVICE)
exprs = feed_dict["exprs"]
libs = feed_dict["library_size"]
cats.append(
self.latent_module.fetch_cat(self.preprocess(exprs, libs, noisy))
)
return torch.cat(cats).cpu().numpy()
def _fetch_grad(
self, dataset: data.Dataset, batch_size: int, progress_bar: bool
) -> np.ndarray:
self.ensure_reproducibility(self.random_seed)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=False
)
if progress_bar:
dataloader = utils.smart_tqdm()(dataloader)
grads = []
for feed_dict in dataloader:
for key, value in feed_dict.items():
feed_dict[key] = value.to(DEVICE)
exprs = feed_dict["exprs"]
libs = feed_dict["library_size"]
latent_grad = feed_dict["output_grad"]
grads.append(
self.latent_module.fetch_grad(
self.preprocess(exprs, libs, self.denoising), latent_grad
)
)
return torch.cat(grads).cpu().numpy()
[docs]def fit_DIRECTi(
adata: ad.AnnData,
genes: typing.Optional[typing.List[str]] = None,
supervision: typing.Optional[str] = None,
batch_effect: typing.Optional[typing.List[str]] = None,
latent_dim: int = 10,
cat_dim: typing.Optional[int] = None,
h_dim: int = 128,
depth: int = 1,
prob_module: str = "NB",
rmbatch_module: typing.Union[str, typing.List[str]] = "Adversarial",
latent_module_kwargs: typing.Optional[typing.Mapping] = None,
prob_module_kwargs: typing.Optional[typing.Mapping] = None,
rmbatch_module_kwargs: typing.Optional[
typing.Union[typing.Mapping, typing.List[typing.Mapping]]
] = None,
optimizer: str = "RMSPropOptimizer",
learning_rate: float = 1e-3,
batch_size: int = 128,
val_split: float = 0.1,
epoch: int = 1000,
patience: int = 30,
progress_bar: bool = False,
reuse_weights: typing.Optional[str] = None,
random_seed: int = config._USE_GLOBAL,
path: typing.Optional[str] = None,
) -> DIRECTi:
r"""
A convenient one-step function to build and fit DIRECTi models.
Should work well in most cases.
Parameters
----------
adata
Dataset to be fitted.
genes
Genes to fit on, should be a subset of :attr:`anndata.AnnData.var_names`.
If not specified, all genes are used.
supervision
Specifies a column in the :attr:`anndata.AnnData.obs` table for use as
(semi-)supervision. If value in the specified column is emtpy,
the corresponding cells will be treated as unsupervised.
batch_effect
Specifies one or more columns in the :attr:`anndata.AnnData.obs` table
for use as batch effect to be corrected.
latent_dim
Latent space (cell embedding) dimensionality.
cat_dim
Number of intrinsic clusters.
h_dim
Hidden layer dimensionality. It is used consistently across all MLPs
in the model.
depth
Hidden layer depth. It is used consistently across all MLPs in the model.
prob_module
Generative model to fit, should be among {"NB", "ZINB", "LN", "ZILN"}.
See the :mod:`prob` for details.
rmbatch_module
Batch effect correction method. If a list is provided, each element
specifies the method to use for a corresponding batch effect in
``batch_effect`` list (in this case the ``rmbatch_module`` list should
have the same length as the ``batch_effect`` list).
latent_module_kwargs
Keyword arguments to be passed to the latent module.
prob_module_kwargs
Keyword arguments to be passed to the prob module.
rmbatch_module_kwargs
Keyword arguments to be passed to the rmbatch module.
If a list is provided, each element specifies keyword arguments
for a corresponding batch effect correction module in the
``rmbatch_module`` list.
optimizer
Name of optimizer used in training.
learning_rate
Learning rate used in training.
batch_size
Size of minibatch used in training.
val_split
Fraction of data to use for validation.
epoch
Maximal training epochs.
patience
Early stop patience. Model training stops when best validation loss does
not decrease for a consecutive ``patience`` epochs.
progress_bar
Whether to show progress bars during training.
reuse_weights
Specifies a path where previously stored model weights can be reused.
random_seed
Random seed. If not specified, :data:`config.RANDOM_SEED`
will be used, which defaults to 0.
path
Specifies a path where model checkpoints as well as the final model
will be saved.
Returns
-------
model
A fitted DIRECTi model.
Examples
--------
See the DIRECTi ipython notebook (:ref:`vignettes`) for live examples.
"""
random_seed = (
config.RANDOM_SEED
if random_seed is None or random_seed == config._USE_GLOBAL
else random_seed
)
DIRECTi.ensure_reproducibility(random_seed)
if latent_module_kwargs is None:
latent_module_kwargs = {}
if prob_module_kwargs is None:
prob_module_kwargs = {}
if rmbatch_module_kwargs is None:
rmbatch_module_kwargs = {}
if genes is None:
genes = adata.var_names.values
if isinstance(genes, (pd.Series, pd.Index)):
genes = genes.to_numpy()
if isinstance(genes, np.ndarray):
genes = genes.tolist()
assert isinstance(genes, list)
if "__libsize__" not in adata.obs.columns:
data.compute_libsize(adata)
data_dict = OrderedDict(
library_size=adata.obs["__libsize__"].to_numpy().reshape((-1, 1)),
exprs=data.select_vars(adata, genes).X,
)
if batch_effect is None:
batch_effect = []
elif isinstance(batch_effect, str):
batch_effect = [batch_effect]
elif isinstance(batch_effect, pd.Series):
batch_effect = batch_effect.values
elif isinstance(batch_effect, np.ndarray):
batch_effect = batch_effect.tolist()
assert isinstance(batch_effect, list)
for _batch_effect in batch_effect:
data_dict[_batch_effect] = utils.encode_onehot(
adata.obs[_batch_effect], sort=True
) # sorting ensures batch order reproducibility for later tuning
if supervision is not None:
data_dict[supervision] = utils.encode_onehot(
adata.obs[supervision], sort=True
) # sorting ensures supervision order reproducibility for later tuning
if cat_dim is None:
cat_dim = data_dict[supervision].shape[1]
elif cat_dim > data_dict[supervision].shape[1]:
data_dict[supervision] = scipy.sparse.hstack(
[
data_dict[supervision].tocsc(),
scipy.sparse.csc_matrix(
(
data_dict[supervision].shape[0],
cat_dim - data_dict[supervision].shape[1],
)
),
]
).tocsr()
elif cat_dim < data_dict[supervision].shape[1]: # pragma: no cover
raise ValueError(
"`cat_dim` must be greater than or equal to "
"number of supervised classes!"
)
# else ==
kwargs = dict(input_dim=len(genes), latent_dim=latent_dim, h_dim=h_dim, depth=depth)
if cat_dim:
kwargs.update(dict(cat_dim=cat_dim))
if supervision:
kwargs.update(dict(name=supervision))
kwargs.update(latent_module_kwargs)
latent_module = latent.SemiSupervisedCatGau(**kwargs)
else:
kwargs.update(latent_module_kwargs)
latent_module = latent.CatGau(**kwargs)
else:
kwargs.update(latent_module_kwargs)
latent_module = latent.Gau(**kwargs)
if not isinstance(rmbatch_module, list):
rmbatch_module = [rmbatch_module] * len(batch_effect)
if not isinstance(rmbatch_module_kwargs, list):
rmbatch_module_kwargs = [rmbatch_module_kwargs] * len(batch_effect)
assert len(rmbatch_module_kwargs) == len(rmbatch_module) == len(batch_effect)
rmbatch_list = nn.ModuleList()
full_latent_dim = [latent_dim]
for _batch_effect, _rmbatch_module, _rmbatch_module_kwargs in zip(
batch_effect, rmbatch_module, rmbatch_module_kwargs
):
batch_dim = len(adata.obs[_batch_effect].dropna().unique())
full_latent_dim.append(batch_dim)
kwargs = dict(batch_dim=batch_dim, latent_dim=latent_dim, name=_batch_effect)
if _rmbatch_module in (
"Adversarial",
"MNNAdversarial",
"AdaptiveMNNAdversarial",
):
kwargs.update(dict(h_dim=h_dim, depth=depth))
kwargs.update(_rmbatch_module_kwargs)
elif _rmbatch_module not in ("RMBatch", "MNN"): # pragma: no cover
raise ValueError("Invalid rmbatch method!")
# else "RMBatch" or "MNN"
kwargs.update(_rmbatch_module_kwargs)
rmbatch_list.append(getattr(rmbatch, _rmbatch_module)(**kwargs))
kwargs = dict(
output_dim=len(genes), full_latent_dim=full_latent_dim, h_dim=h_dim, depth=depth
)
kwargs.update(prob_module_kwargs)
prob_module = getattr(prob, prob_module)(**kwargs)
model = DIRECTi(
genes=genes,
latent_module=latent_module,
prob_module=prob_module,
rmbatch_modules=rmbatch_list,
learning_rate=learning_rate,
path=path,
random_seed=random_seed,
)
if not reuse_weights is None:
model.load_state_dict(torch.load(reuse_weights, map_location=DEVICE))
if optimizer != "RMSPropOptimizer":
utils.logger.warning("Argument `optimizer` is no longer supported!")
model.fit(
dataset=data.Dataset(data_dict),
batch_size=batch_size,
val_split=val_split,
epoch=epoch,
patience=patience,
progress_bar=progress_bar,
)
return model
[docs]def align_DIRECTi(
model: DIRECTi,
original_adata: ad.AnnData,
new_adata: typing.Union[ad.AnnData, typing.Mapping[str, ad.AnnData]],
rmbatch_module: str = "MNNAdversarial",
rmbatch_module_kwargs: typing.Optional[typing.Mapping] = None,
deviation_reg: float = 0.01,
optimizer: str = "RMSPropOptimizer",
learning_rate: float = 1e-3,
batch_size: int = 256,
val_split: float = 0.1,
epoch: int = 100,
patience: int = 100,
tolerance: float = 0.0,
reuse_weights: bool = True,
progress_bar: bool = False,
random_seed: int = config._USE_GLOBAL,
path: typing.Optional[str] = None,
) -> DIRECTi:
r"""
Align datasets starting with an existing DIRECTi model (fine-tuning)
Parameters
----------
model
A pretrained DIRECTi model.
original_adata
The dataset that the model was originally trained on.
new_adata
A new dataset or a dictionary containing new datasets,
to be aligned with ``original_dataset``.
rmbatch_module
Specifies the batch effect correction method to use for aligning new
datasets.
rmbatch_module_kwargs
Keyword arguments to be passed to the rmbatch module.
deviation_reg
Regularization strength for the deviation from original model weights.
optimizer
Name of optimizer used in training.
learning_rate
Learning rate used in training.
batch_size
Size of minibatches used in training.
val_split
Fraction of data to use for validation.
epoch
Maximal training epochs.
patience
Early stop patience. Model training stops when best
validation loss does not decrease for a consecutive ``patience`` epochs.
tolerance
Tolerance of deviation from the lowest validation loss recorded for the
"patience countdown" to be reset. The "patience countdown" is reset if
current validation loss < lowest validation loss recorded + ``tolerance``.
reuse_weights
Whether to reuse weights of the original model.
progress_bar
Whether to show progress bar during training.
random_seed
Random seed. If not specified, :data:`config.RANDOM_SEED`
will be used, which defaults to 0.
path
Specifies a path where model checkpoints as well as the final model
is saved.
Returns
-------
aligned_model
Aligned model.
"""
random_seed = (
config.RANDOM_SEED if random_seed == config._USE_GLOBAL else random_seed
)
DIRECTi.ensure_reproducibility(random_seed)
if rmbatch_module_kwargs is None:
rmbatch_module_kwargs = {}
if rmbatch_module_kwargs is None:
rmbatch_module_kwargs = {}
if isinstance(new_adata, ad.AnnData):
new_adatas = {"__new__": new_adata}
elif isinstance(new_adata, dict):
assert (
"__original__" not in new_adata
), "Key `__original__` is now allowed in new datasets."
new_adatas = new_adata.copy() # shallow
else:
raise TypeError("Invalid type for argument `new_dataset`.")
_config = model.get_config()
for _rmbatch_module in _config["rmbatch_modules"]:
_rmbatch_module["delay"] = 0
kwargs = {
"batch_dim": len(new_adatas) + 1,
"latent_dim": model.latent_module.latent_dim,
"delay": 0,
"name": "__align__",
"_class": rmbatch_module,
}
if rmbatch_module in ("Adversarial", "MNNAdversarial", "AdaptiveMNNAdversarial"):
kwargs.update(
dict(
h_dim=model.latent_module.h_dim,
depth=model.latent_module.depth,
dropout=model.latent_module.dropout,
lambda_reg=0.01,
)
)
elif rmbatch_module not in ("RMBatch", "MNN"): # pragma: no cover
raise ValueError("Unknown rmbatch_module!")
# else "RMBatch" or "MNN"
kwargs.update(rmbatch_module_kwargs)
_config["rmbatch_modules"].append(kwargs)
_config["prob_module"]["full_latent_dim"].append(len(new_adatas) + 1)
_config["prob_module"]["fine_tune"] = True
_config["prob_module"]["deviation_reg"] = deviation_reg
_config["learning_rate"] = learning_rate
_config["path"] = path
aligned_model = DIRECTi.load_config(_config)
if reuse_weights:
aligned_model.load_state_dict(model.state_dict(), strict=False)
supervision = (
aligned_model.latent_module.name
if isinstance(aligned_model.latent_module, latent.SemiSupervisedCatGau)
else None
)
assert (
"__align__" not in original_adata.obs.columns
), "Please remove column `__align__` from obs of the original dataset."
original_adata = ad.AnnData(
X=original_adata.X,
obs=original_adata.obs.copy(deep=False),
var=original_adata.var.copy(deep=False),
)
if "__libsize__" not in original_adata.obs.columns:
data.compute_libsize(original_adata)
original_adata = data.select_vars(original_adata, model.genes)
for key in new_adatas.keys():
assert (
"__align__" not in new_adatas[key].obs.columns
), f"Please remove column `__align__` from new dataset {key}."
new_adatas[key] = ad.AnnData(
X=new_adatas[key].X,
obs=new_adatas[key].obs.copy(deep=False),
var=new_adatas[key].var.copy(deep=False),
)
new_adatas[key].obs = new_adatas[key].obs.loc[
:, new_adatas[key].obs.columns == "__libsize__"
] # All meta in new datasets are cleared to avoid interference
if "__libsize__" not in new_adatas[key].obs.columns:
data.compute_libsize(new_adatas[key])
new_adatas[key] = data.select_vars(new_adatas[key], model.genes)
adatas = {"__original__": original_adata, **new_adatas}
for key, val in adatas.items():
val.obs["__align__"] = key
adata = ad.concat(adatas, join="outer", fill_value=0)
data_dict = OrderedDict(
library_size=adata.obs["__libsize__"].to_numpy().reshape((-1, 1)),
exprs=data.select_vars(adata, model.genes).X, # Ensure order
)
for rmbatch_module in aligned_model.rmbatch_modules:
data_dict[rmbatch_module.name] = utils.encode_onehot(
adata.obs[rmbatch_module.name], sort=True
)
if isinstance(aligned_model.latent_module, latent.SemiSupervisedCatGau):
data_dict[supervision] = utils.encode_onehot(adata.obs[supervision], sort=True)
cat_dim = aligned_model.latent_module.cat_dim
if cat_dim > data_dict[supervision].shape[1]:
data_dict[supervision] = scipy.sparse.hstack(
[
data_dict[supervision].tocsc(),
scipy.sparse.csc_matrix(
(
data_dict[supervision].shape[0],
cat_dim - data_dict[supervision].shape[1],
)
),
]
).tocsr()
if optimizer != "RMSPropOptimizer":
utils.logger.warning("Argument `optimizer` is not supported!")
aligned_model.fit(
dataset=data.Dataset(data_dict),
batch_size=batch_size,
val_split=val_split,
epoch=epoch,
patience=patience,
tolerance=tolerance,
progress_bar=progress_bar,
)
return aligned_model