Module: Cell_BLAST.directi
DIRECTi, an deep learning model for semi-supervised parametric dimension reduction and systematical bias removal, extended from scVI.
Classes:
|
DIRECTi model. |
Functions:
|
Align datasets starting with an existing DIRECTi model (fine-tuning) |
|
A convenient one-step function to build and fit DIRECTi models. |
- class Cell_BLAST.directi.DIRECTi(genes, latent_module, prob_module, rmbatch_modules, denoising=True, learning_rate=0.001, path=None, random_seed='__UsE_gLoBaL__', _mode=1)[source]
DIRECTi model.
- Parameters:
latent_module (
Latent
) – Module for latent variable (encoder module).prob_module (
ProbModel
) – Module for data generative modeling (decoder module).batch_effect – Batch effects need to be corrected.
rmbatch_modules (
Tuple
[RMBatch
]) – List of modules for batch effect correction.denoising (
bool
) – Whether to add noise to the input during training (source of randomness in modeling the approximate posterior).learning_rate (
float
) – Learning rate.path (
Optional
[str
]) – Specifies a path where model configuration, checkpoints, as well as the final model will be saved.random_seed (
int
) – Random seed. If not specified,config.RANDOM_SEED
will be used, which defaults to 0.
- genes
List of gene names the model is defined and fitted on
- batch_effect_list
List of batch effect names need to be corrected.
Examples
The
fit_DIRECTi()
function offers an easy to use wrapper of thisDIRECTi
model class, which is the preferred API and should satisfy most needs. We suggest using thefit_DIRECTi()
wrapper first.Methods:
clustering
(adata[, batch_size, ...])Get model intrinsic clustering of the data.
gene_grad
(adata, latent_grad[, batch_size, ...])Fetch gene space gradients with regard to latent space gradients
inference
(adata[, batch_size, n_posterior, ...])Project expression profiles into the cell embedding space.
load
(path[, config, weights, _mode])Load model from files
save
([path, config, weights])Save model to files
- clustering(adata, batch_size=4096, return_confidence=False, progress_bar=False)[source]
Get model intrinsic clustering of the data.
- Parameters:
adata (
AnnData
) – Dataset for which to obtain the intrinsic clustering.batch_size (
int
) – Minibatch size. Changing this may slighly affect speed, but not the result.return_confidence (
bool
) – Whether to return model intrinsic clustering confidence.progress_bar (
bool
) – Whether to show progress bar during projection.
- Return type:
- Returns:
idx – model intrinsic clustering index, 1 dimensional
confidence (if
return_confidence
is True) – model intrinsic clustering confidence, 1 dimensional
- gene_grad(adata, latent_grad, batch_size=4096, progress_bar=False)[source]
Fetch gene space gradients with regard to latent space gradients
- Parameters:
- Returns:
Fetched gene-wise gradient
- Return type:
grad
- inference(adata, batch_size=4096, n_posterior=0, progress_bar=False, priority='auto', random_seed='__UsE_gLoBaL__')[source]
Project expression profiles into the cell embedding space.
- Parameters:
adata (
AnnData
) – Dataset for which to compute cell embeddings.batch_size (
int
) – Minibatch size. Changing this may slighly affect speed, but not the result.n_posterior (
int
) – How many posterior samples to fetch. If set to 0, the posterior point estimate is computed. If greater than 0, producesn_posterior
number of posterior samples for each cell.progress_bar (
bool
) – Whether to show progress bar duing projection.priority (
str
) – Should be among {“auto”, “speed”, “memory”}. Controls which one of speed or memory should be prioritized, by default “auto”, meaning that data with more than 100,000 cells will use “memory” mode and smaller data will use “speed” mode.random_seed (
Optional
[int
]) – Random seed used with noisy projection. If not specified,config.RANDOM_SEED
will be used, which defaults to 0.
- Returns:
Coordinates in the latent space. If
n_posterior
is 0, will be in shape \(cell \times latent\_dim\). Ifn_posterior
is greater than 0, will be in shape \(cell \times noisy \times latent\_dim\).- Return type:
latent
- Cell_BLAST.directi.align_DIRECTi(model, original_adata, new_adata, rmbatch_module='MNNAdversarial', rmbatch_module_kwargs=None, deviation_reg=0.01, optimizer='RMSPropOptimizer', learning_rate=0.001, batch_size=256, val_split=0.1, epoch=100, patience=100, tolerance=0.0, reuse_weights=True, progress_bar=False, random_seed='__UsE_gLoBaL__', path=None)[source]
Align datasets starting with an existing DIRECTi model (fine-tuning)
- Parameters:
model (
DIRECTi
) – A pretrained DIRECTi model.original_adata (
AnnData
) – The dataset that the model was originally trained on.new_adata (
Union
[AnnData
,Mapping
[str
,AnnData
]]) – A new dataset or a dictionary containing new datasets, to be aligned withoriginal_dataset
.rmbatch_module (
str
) – Specifies the batch effect correction method to use for aligning new datasets.rmbatch_module_kwargs (
Optional
[Mapping
]) – Keyword arguments to be passed to the rmbatch module.deviation_reg (
float
) – Regularization strength for the deviation from original model weights.optimizer (
str
) – Name of optimizer used in training.learning_rate (
float
) – Learning rate used in training.batch_size (
int
) – Size of minibatches used in training.val_split (
float
) – Fraction of data to use for validation.epoch (
int
) – Maximal training epochs.patience (
int
) – Early stop patience. Model training stops when best validation loss does not decrease for a consecutivepatience
epochs.tolerance (
float
) – Tolerance of deviation from the lowest validation loss recorded for the “patience countdown” to be reset. The “patience countdown” is reset if current validation loss < lowest validation loss recorded +tolerance
.reuse_weights (
bool
) – Whether to reuse weights of the original model.progress_bar (
bool
) – Whether to show progress bar during training.random_seed (
int
) – Random seed. If not specified,config.RANDOM_SEED
will be used, which defaults to 0.path (
Optional
[str
]) – Specifies a path where model checkpoints as well as the final model is saved.
- Returns:
Aligned model.
- Return type:
aligned_model
- Cell_BLAST.directi.fit_DIRECTi(adata, genes=None, supervision=None, batch_effect=None, latent_dim=10, cat_dim=None, h_dim=128, depth=1, prob_module='NB', rmbatch_module='Adversarial', latent_module_kwargs=None, prob_module_kwargs=None, rmbatch_module_kwargs=None, optimizer='RMSPropOptimizer', learning_rate=0.001, batch_size=128, val_split=0.1, epoch=1000, patience=30, progress_bar=False, reuse_weights=None, random_seed='__UsE_gLoBaL__', path=None)[source]
A convenient one-step function to build and fit DIRECTi models. Should work well in most cases.
- Parameters:
adata (
AnnData
) – Dataset to be fitted.genes (
Optional
[List
[str
]]) – Genes to fit on, should be a subset ofanndata.AnnData.var_names
. If not specified, all genes are used.supervision (
Optional
[str
]) – Specifies a column in theanndata.AnnData.obs
table for use as (semi-)supervision. If value in the specified column is emtpy, the corresponding cells will be treated as unsupervised.batch_effect (
Optional
[List
[str
]]) – Specifies one or more columns in theanndata.AnnData.obs
table for use as batch effect to be corrected.latent_dim (
int
) – Latent space (cell embedding) dimensionality.h_dim (
int
) – Hidden layer dimensionality. It is used consistently across all MLPs in the model.depth (
int
) – Hidden layer depth. It is used consistently across all MLPs in the model.prob_module (
str
) – Generative model to fit, should be among {“NB”, “ZINB”, “LN”, “ZILN”}. See theprob
for details.rmbatch_module (
Union
[str
,List
[str
]]) – Batch effect correction method. If a list is provided, each element specifies the method to use for a corresponding batch effect inbatch_effect
list (in this case thermbatch_module
list should have the same length as thebatch_effect
list).latent_module_kwargs (
Optional
[Mapping
]) – Keyword arguments to be passed to the latent module.prob_module_kwargs (
Optional
[Mapping
]) – Keyword arguments to be passed to the prob module.rmbatch_module_kwargs (
Union
[Mapping
,List
[Mapping
],None
]) – Keyword arguments to be passed to the rmbatch module. If a list is provided, each element specifies keyword arguments for a corresponding batch effect correction module in thermbatch_module
list.optimizer (
str
) – Name of optimizer used in training.learning_rate (
float
) – Learning rate used in training.batch_size (
int
) – Size of minibatch used in training.val_split (
float
) – Fraction of data to use for validation.epoch (
int
) – Maximal training epochs.patience (
int
) – Early stop patience. Model training stops when best validation loss does not decrease for a consecutivepatience
epochs.progress_bar (
bool
) – Whether to show progress bars during training.reuse_weights (
Optional
[str
]) – Specifies a path where previously stored model weights can be reused.random_seed (
int
) – Random seed. If not specified,config.RANDOM_SEED
will be used, which defaults to 0.path (
Optional
[str
]) – Specifies a path where model checkpoints as well as the final model will be saved.
- Returns:
A fitted DIRECTi model.
- Return type:
model
Examples
See the DIRECTi ipython notebook (Vignettes) for live examples.