r"""
Cell BLAST based on DIRECTi models
"""
import collections
import os
import re
import tempfile
import typing
import anndata
import joblib
import numba
import numpy as np
import pandas as pd
import scipy.sparse
import scipy.stats
import sklearn.neighbors
from . import config, data, directi, metrics, utils
NORMAL = 1
MINIMAL = 0
def _wasserstein_distance_impl(x: np.ndarray, y: np.ndarray): # pragma: no cover
x_sorter = np.argsort(x)
y_sorter = np.argsort(y)
xy = np.concatenate((x, y))
xy.sort()
deltas = np.diff(xy)
x_cdf = np.searchsorted(x[x_sorter], xy[:-1], "right") / x.size
y_cdf = np.searchsorted(y[y_sorter], xy[:-1], "right") / y.size
return np.sum(np.multiply(np.abs(x_cdf - y_cdf), deltas))
def _energy_distance_impl(x: np.ndarray, y: np.ndarray): # pragma: no cover
x_sorter = np.argsort(x)
y_sorter = np.argsort(y)
xy = np.concatenate((x, y))
xy.sort()
deltas = np.diff(xy)
x_cdf = np.searchsorted(x[x_sorter], xy[:-1], "right") / x.size
y_cdf = np.searchsorted(y[y_sorter], xy[:-1], "right") / y.size
return np.sqrt(2 * np.sum(np.multiply(np.square(x_cdf - y_cdf), deltas)))
@numba.extending.overload(
scipy.stats.wasserstein_distance, jit_options={"nogil": True, "cache": True}
)
def _wasserstein_distance(x: np.ndarray, y: np.ndarray): # pragma: no cover
if (
x == numba.float32[::1] and y == numba.float32[::1]
) or (
x == numba.float64[::1] and y == numba.float64[::1]
):
return _wasserstein_distance_impl
@numba.extending.overload(
scipy.stats.energy_distance, jit_options={"nogil": True, "cache": True}
)
def _energy_distance(x: np.ndarray, y: np.ndarray): # pragma: no cover
if x == numba.float32[::1] and y == numba.float32[::1]:
return _energy_distance_impl
[docs]@numba.jit(nopython=True, nogil=True, cache=True)
def ed(x: np.ndarray, y: np.ndarray): # pragma: no cover
r"""
x : latent_dim
y : latent_dim
"""
return np.sqrt(np.square(x - y).sum())
@numba.jit(nopython=True, nogil=True, cache=True)
def _md(
x: np.ndarray, y: np.ndarray, x_posterior: np.ndarray
) -> np.ndarray: # pragma: no cover
r"""
x : latent_dim
y : latent_dim
x_posterior : n_posterior * latent_dim
"""
if np.all(x == y):
return 0.0
x_posterior = x_posterior - x
cov_x = np.dot(x_posterior.T, x_posterior) / x_posterior.shape[0]
dev = np.expand_dims(y - x, axis=1)
return np.sqrt(np.dot(dev.T, np.dot(np.linalg.inv(cov_x), dev)))[0, 0]
[docs]@numba.jit(nopython=True, nogil=True, cache=True)
def md(
x: np.ndarray, y: np.ndarray, x_posterior: np.ndarray, y_posterior: np.ndarray
) -> np.ndarray: # pragma: no cover
r"""
x : latent_dim
y : latent_dim
x_posterior : n_posterior * latent_dim
y_posterior : n_posterior * latent_dim
"""
if np.all(x == y):
return 0.0
return 0.5 * (_md(x, y, x_posterior) + _md(y, x, y_posterior))
@numba.jit(nopython=True, nogil=True, cache=True)
def _compute_pcasd(
x: np.ndarray, x_posterior: np.ndarray, eps: float
) -> np.ndarray: # pragma: no cover
r"""
x : latent_dim
x_posterior : n_posterior * latent_dim
"""
centered_x_posterior = (
x_posterior - np.sum(x_posterior, axis=0) / x_posterior.shape[0]
)
cov_x = np.dot(centered_x_posterior.T, centered_x_posterior)
v = np.real(
np.linalg.eig(cov_x.astype(np.complex64))[1]
) # Suppress domain change due to rounding errors
x_posterior = np.dot(x_posterior - x, v)
squared_x_posterior = np.square(x_posterior)
asd = np.empty((2, x_posterior.shape[1]), dtype=np.float32)
for p in range(x_posterior.shape[1]):
mask = x_posterior[:, p] < 0
asd[0, p] = (
np.sqrt((np.sum(squared_x_posterior[mask, p])) / max(np.sum(mask), 1)) + eps
)
asd[1, p] = (
np.sqrt((np.sum(squared_x_posterior[~mask, p])) / max(np.sum(~mask), 1))
+ eps
)
return np.concatenate((v, asd), axis=0)
@numba.jit(nopython=True, nogil=True, cache=True)
def _compute_pcasd_across_models(
x: np.ndarray, x_posterior: np.ndarray, eps: float = 1e-1
) -> np.ndarray: # pragma: no cover
r"""
x : n_models * latent_dim
x_posterior : n_models * n_posterior * latent_dim
"""
result = np.empty((x.shape[0], x.shape[-1] + 2, x.shape[-1]), dtype=np.float32)
for i in range(x.shape[0]):
result[i] = _compute_pcasd(x[i], x_posterior[i], eps)
return result
@numba.jit(nopython=True, nogil=True, cache=True)
def _amd(
x: np.ndarray,
y: np.ndarray,
x_posterior: np.ndarray,
eps: float,
x_is_pcasd: bool = False,
) -> np.ndarray: # pragma: no cover
r"""
x : latent_dim
y : latent_dim
x_posterior : n_posterior * latent_dim
"""
if np.all(x == y):
return 0.0
if not x_is_pcasd:
x_posterior = _compute_pcasd(x, x_posterior, eps)
v = x_posterior[:-2]
asd = x_posterior[-2:]
y = np.dot(y - x, v)
for p in range(y.size):
if y[p] < 0:
y[p] /= asd[0, p]
else:
y[p] /= asd[1, p]
return np.linalg.norm(y)
[docs]@numba.jit(nopython=True, nogil=True, cache=True)
def amd(
x: np.ndarray,
y: np.ndarray,
x_posterior: np.ndarray,
y_posterior: np.ndarray,
eps: float = 1e-1,
x_is_pcasd: bool = False,
y_is_pcasd: bool = False,
) -> np.ndarray: # pragma: no cover
r"""
x : latent_dim
y : latent_dim
x_posterior : n_posterior * latent_dim
y_posterior : n_posterior * latent_dim
"""
if np.all(x == y):
return 0.0
return 0.5 * (
_amd(x, y, x_posterior, eps, x_is_pcasd)
+ _amd(y, x, y_posterior, eps, y_is_pcasd)
)
[docs]@numba.jit(nopython=True, nogil=True, cache=True)
def npd_v1(
x: np.ndarray,
y: np.ndarray,
x_posterior: np.ndarray,
y_posterior: np.ndarray,
eps: float = 0.0,
) -> np.ndarray: # pragma: no cover
r"""
x : latent_dim
y : latent_dim
x_posterior : n_posterior * latent_dim
y_posterior : n_posterior * latent_dim
"""
projection = x - y # latent_dim
if np.all(projection == 0.0):
projection[...] = 1.0 # any projection is equivalent
projection /= np.linalg.norm(projection)
x_posterior = np.sum(x_posterior * projection, axis=1) # n_posterior_samples
y_posterior = np.sum(y_posterior * projection, axis=1) # n_posterior_samples
xy_posterior = np.concatenate((x_posterior, y_posterior))
xy_posterior1 = (xy_posterior - np.mean(x_posterior)) / (
np.std(x_posterior) + np.float32(eps)
)
xy_posterior2 = (xy_posterior - np.mean(y_posterior)) / (
np.std(y_posterior) + np.float32(eps)
)
return 0.5 * (
scipy.stats.wasserstein_distance(
xy_posterior1[: len(x_posterior)], xy_posterior1[-len(y_posterior) :]
)
+ scipy.stats.wasserstein_distance(
xy_posterior2[: len(x_posterior)], xy_posterior2[-len(y_posterior) :]
)
)
@numba.jit(nopython=True, nogil=True, cache=True)
def _npd_v2(
x: np.ndarray, y: np.ndarray, x_posterior: np.ndarray, eps: float
) -> np.ndarray: # pragma: no cover
r"""
x : latent_dim
y : latent_dim
x_posterior : n_posterior * latent_dim
"""
if np.all(x == y):
return 0.0
dev = y - x
udev = dev / np.linalg.norm(dev)
projected_noise = np.sum((x_posterior - x) * udev, axis=1)
projected_y = np.sum((y - x) * udev)
mask = (projected_noise * projected_y) >= 0
scaler = np.sqrt(np.sum(np.square(projected_noise[mask])) / max(np.sum(mask), 1))
return np.abs(projected_y) / (scaler + eps)
[docs]@numba.jit(nopython=True, nogil=True, cache=True)
def npd_v2(
x: np.ndarray,
y: np.ndarray,
x_posterior: np.ndarray,
y_posterior: np.ndarray,
eps: float = 1e-1,
) -> np.ndarray: # pragma: no cover
r"""
x : latent_dim
y : latent_dim
x_posterior : n_posterior * latent_dim
y_posterior : n_posterior * latent_dim
"""
if np.all(x == y):
return 0.0
return 0.5 * (_npd_v2(x, y, x_posterior, eps) + _npd_v2(y, x, y_posterior, eps))
@numba.jit(nopython=True, nogil=True, cache=True)
def _hit_ed_across_models(
query_latent: np.ndarray, ref_latent: np.ndarray
) -> np.ndarray: # pragma: no cover
r"""
query_latent : n_models * latent_dim
ref_latent : n_hits * n_models * latent_dim
returns : n_hits * n_models
"""
dist = np.empty(ref_latent.shape[:-1]) # n_hits * n_models
for i in range(dist.shape[1]): # model index
x = query_latent[i, ...] # latent_dim
for j in range(dist.shape[0]): # hit index
y = ref_latent[j, i, ...] # latent_dim
dist[j, i] = ed(x, y)
return dist
@numba.jit(nopython=True, nogil=True, cache=True)
def _hit_md_across_models(
query_latent: np.ndarray,
ref_latent: np.ndarray,
query_posterior: np.ndarray,
ref_posterior: np.ndarray,
) -> np.ndarray: # pragma: no cover
r"""
query_latent : n_models * latent_dim
ref_latent : n_hits * n_models * latent_dim
query_posterior : n_models * n_posterior * latent_dim
ref_posterior : n_hits * n_models * n_posterior * latent_dim
returns : n_hits * n_models
"""
dist = np.empty(ref_latent.shape[:-1]) # n_hits * n_models
for i in range(dist.shape[1]): # model index
x = query_latent[i, ...] # latent_dim
x_posterior = query_posterior[i, ...] # n_posterior * latent_dim
for j in range(dist.shape[0]): # hit index
y = ref_latent[j, i, ...] # latent_dim
y_posterior = ref_posterior[j, i, ...] # n_posterior * latent_dim
dist[j, i] = md(x, y, x_posterior, y_posterior)
return dist
@numba.jit(nopython=True, nogil=True, cache=True)
def _hit_amd_across_models(
query_latent: np.ndarray,
ref_latent: np.ndarray,
query_posterior: np.ndarray,
ref_posterior: np.ndarray,
eps: float = 1e-1,
) -> np.ndarray: # pragma: no cover
r"""
query_latent : n_models * latent_dim
ref_latent : n_hits * n_models * latent_dim
query_posterior : n_models * n_posterior * latent_dim
ref_posterior : n_hits * n_models * n_posterior * latent_dim
returns : n_hits * n_models
"""
dist = np.empty(ref_latent.shape[:-1]) # n_hits * n_models
for i in range(dist.shape[1]): # model index
x = query_latent[i, ...] # latent_dim
x_posterior = query_posterior[i, ...] # n_posterior * latent_dim
for j in range(dist.shape[0]): # hit index
y = ref_latent[j, i, ...] # latent_dim
y_posterior = ref_posterior[j, i, ...] # n_posterior * latent_dim
dist[j, i] = amd(
x,
y,
x_posterior,
y_posterior,
eps=eps,
x_is_pcasd=False,
y_is_pcasd=True,
)
return dist
@numba.jit(nopython=True, nogil=True, cache=True)
def _hit_npd_v1_across_models(
query_latent: np.ndarray,
ref_latent: np.ndarray,
query_posterior: np.ndarray,
ref_posterior: np.ndarray,
eps: float = 0.0,
) -> np.ndarray: # pragma: no cover
r"""
query_latent : n_models * latent_dim
ref_latent : n_hits * n_models * latent_dim
query_posterior : n_models * n_posterior * latent_dim
ref_posterior : n_hits * n_models * n_posterior * latent_dim
returns : n_hits * n_models
"""
dist = np.empty(ref_latent.shape[:-1]) # n_hits * n_models
for i in range(dist.shape[1]): # model index
x = query_latent[i, ...] # latent_dim
x_posterior = query_posterior[i, ...] # n_posterior * latent_dim
for j in range(dist.shape[0]): # hit index
y = ref_latent[j, i, ...] # latent_dim
y_posterior = ref_posterior[j, i, ...] # n_posterior * latent_dim
dist[j, i] = npd_v1(x, y, x_posterior, y_posterior, eps=eps)
return dist
@numba.jit(nopython=True, nogil=True, cache=True)
def _hit_npd_v2_across_models(
query_latent: np.ndarray,
ref_latent: np.ndarray,
query_posterior: np.ndarray,
ref_posterior: np.ndarray,
eps: float = 1e-1,
) -> np.ndarray: # pragma: no cover
r"""
query_latent : n_models * latent_dim
ref_latent : n_hits * n_models * latent_dim
query_posterior : n_models * n_posterior * latent_dim
ref_posterior : n_hits * n_models * n_posterior * latent_dim
returns : n_hits * n_models
"""
dist = np.empty(ref_latent.shape[:-1]) # n_hits * n_models
for i in range(dist.shape[1]): # model index
x = query_latent[i, ...] # latent_dim
x_posterior = query_posterior[i, ...] # n_posterior * latent_dim
for j in range(dist.shape[0]): # hit index
y = ref_latent[j, i, ...] # latent_dim
y_posterior = ref_posterior[j, i, ...] # n_posterior * latent_dim
dist[j, i] = npd_v2(x, y, x_posterior, y_posterior, eps=eps)
return dist
DISTANCE_METRIC_ACROSS_MODELS = {
ed: _hit_ed_across_models,
md: _hit_md_across_models,
amd: _hit_amd_across_models,
npd_v1: _hit_npd_v1_across_models,
npd_v2: _hit_npd_v2_across_models,
}
[docs]class BLAST(object):
r"""
Cell BLAST
Parameters
----------
models
A list of "DIRECTi" models.
ref
A reference dataset.
distance_metric
Cell-to-cell distance metric to use, should be among
{"npd_v1", "npd_v2", "md", "amd", "ed"}.
n_posterior
How many samples from the posterior distribution to use for
estimating posterior distance. Irrelevant for distance_metric="ed".
n_empirical
Number of random cell pairs to use when estimating empirical
distribution of cell-to-cell distance.
cluster_empirical
Whether to build an empirical distribution for each intrinsic cluster
independently.
eps
A small number added to the normalization factors used in certain
posterior-based distance metrics to improve numeric stability.
If not specified, a recommended value will be used according to the
specified distance metric.
force_components
Whether to compute all the necessary components upon initialization.
If set to False, necessary components will be computed
on the fly when performing queries.
Examples
--------
A typical BLAST pipeline is described below.
Assuming we have a list of :class:`directi.DIRECTi` models already fitted
on some reference data, we can construct a BLAST object by feeding the
pretrained models and the reference data to the
:class:`BLAST` constructor.
>>> blast = BLAST(models, reference)
We can efficiently query the reference and obtain initial hits via the
:meth:`BLAST.query` method:
>>> hits = blast.query(query)
Then we filter the initial hits by using more accurate metrics
e.g. empirical p-value based on posterior distance), and pooling together
information across multiple models.
>>> hits = hits.reconcile_models().filter(by="pval", cutoff=0.05)
Finally, we use the :meth:`BLAST.annotate` method to obtain predictions
based on reference annotations, e.g. "cell_ontology_class" in this case.
>>> annotation = hits.annotate("cell_ontology_class")
See the BLAST ipython notebook (:ref:`vignettes`) for live examples.
"""
def __init__(
self,
models: typing.List[directi.DIRECTi],
ref: anndata.AnnData,
distance_metric: str = "npd_v1",
n_posterior: int = 50,
n_empirical: int = 10000,
cluster_empirical: bool = False,
eps: typing.Optional[float] = None,
force_components: bool = True,
**kwargs,
) -> None:
self.models = models
self.ref = anndata.AnnData(
X=ref.X, obs=ref.obs.copy(), var=ref.var.copy(), uns=ref.uns.copy()
) # X and uns are shallow copied, obs and var are deep copied
self.latent = None
self.nearest_neighbors = None
self.cluster = None
self.posterior = np.array([None] * self.ref.shape[0])
self.empirical = None
self.distance_metric = (
globals()[distance_metric]
if isinstance(distance_metric, str)
else distance_metric
)
self.n_posterior = n_posterior if self.distance_metric is not ed else 0
self.n_empirical = n_empirical
self.cluster_empirical = cluster_empirical
self.eps = eps
if force_components:
self._force_components(**kwargs)
def __len__(self) -> int:
return len(self.models)
def __getitem__(self, s) -> "BLAST":
blast = BLAST(
np.array(self.models)[s].tolist(),
self.ref,
self.distance_metric,
self.n_posterior,
self.n_empirical,
self.cluster_empirical,
self.eps,
force_components=False,
)
blast.latent = self.latent[:, s, ...] if self.latent is not None else None
blast.cluster = self.cluster[:, s, ...] if self.cluster is not None else None
blast.nearest_neighbors = (
np.array(self.nearest_neighbors)[s].tolist()
if self.nearest_neighbors is not None
else None
)
if self.posterior is not None:
for i in range(self.posterior.size):
if self.posterior[i] is not None:
blast.posterior[i] = self.posterior[i][s, ...]
blast.empirical = (
[item for item in np.array(self.empirical)[s]]
if self.empirical is not None
else None
)
return blast
def _get_latent(self, n_jobs: int) -> np.ndarray: # n_cells * n_models * latent_dim
if self.latent is None:
utils.logger.info("Projecting to latent space...")
self.latent = np.stack(
joblib.Parallel(n_jobs=min(n_jobs, len(self)), backend="threading")(
joblib.delayed(model.inference)(self.ref) for model in self.models
),
axis=1,
)
return self.latent
def _get_cluster(self, n_jobs: int) -> np.ndarray: # n_cells * n_models
if self.cluster is None:
utils.logger.info("Obtaining intrinsic clustering...")
self.cluster = np.stack(
joblib.Parallel(n_jobs=min(n_jobs, len(self)), backend="threading")(
joblib.delayed(model.clustering)(self.ref) for model in self.models
),
axis=1,
)
return self.cluster
def _get_posterior(
self, n_jobs: int, random_seed: int, idx: typing.Optional[np.ndarray] = None
) -> np.ndarray: # n_cells * (n_models * n_posterior * latent_dim)
if idx is None:
idx = np.arange(self.ref.shape[0])
new_idx = np.intersect1d(
np.unique(idx),
np.where(np.vectorize(lambda x: x is None)(self.posterior))[0],
)
if new_idx.size:
utils.logger.info("Sampling from posteriors...")
new_ref = self.ref[new_idx, :]
new_posterior = np.stack(
joblib.Parallel(n_jobs=min(n_jobs, len(self)), backend="loky")(
joblib.delayed(model.inference)(
new_ref, n_posterior=self.n_posterior, random_seed=random_seed
)
for model in self.models
),
axis=1,
) # n_cells * n_models * n_posterior * latent_dim
# NOTE: Slow discontigous memcopy here, but that's necessary since
# we will be caching values by cells. It also makes values more
# contiguous and faster to access in later cell-based operations.
if self.distance_metric is amd:
dist_kws = {"eps": self.eps} if self.eps is not None else {}
new_latent = self._get_latent(n_jobs)[new_idx]
self.posterior[new_idx] = joblib.Parallel(
n_jobs=n_jobs, backend="threading"
)(
joblib.delayed(_compute_pcasd_across_models)(
_new_latent, _new_posterior, **dist_kws
)
for _new_latent, _new_posterior in zip(new_latent, new_posterior)
)
else:
self.posterior[new_idx] = [
item for item in new_posterior
] # NOTE: No memcopy here
return self.posterior[idx]
def _get_nearest_neighbors(
self, n_jobs: int
) -> typing.List[sklearn.neighbors.NearestNeighbors]: # n_models
if self.nearest_neighbors is None:
latent = self._get_latent(n_jobs).swapaxes(0, 1)
# NOTE: Makes cells discontiguous, but for nearest neighbor tree,
# there's no influence on performance
utils.logger.info("Fitting nearest neighbor trees...")
self.nearest_neighbors = joblib.Parallel(
n_jobs=min(n_jobs, len(self)), backend="loky"
)(
joblib.delayed(self._fit_nearest_neighbors)(_latent)
for _latent in latent
)
return self.nearest_neighbors
def _get_empirical(
self, n_jobs: int, random_seed: int
) -> np.ndarray: # n_models * [n_clusters * n_empirical]
if self.empirical is None:
utils.logger.info("Generating empirical null distributions...")
if not self.cluster_empirical:
self.cluster = np.zeros((self.ref.shape[0], len(self)), dtype=int)
latent = self._get_latent(n_jobs)
cluster = self._get_cluster(n_jobs)
rs = np.random.RandomState(random_seed)
bg = rs.choice(latent.shape[0], size=self.n_empirical)
if self.distance_metric is not ed:
bg_posterior = self._get_posterior(n_jobs, random_seed, idx=bg)
self.empirical = []
dist_kws = {"eps": self.eps} if self.eps is not None else {}
if self.distance_metric is amd:
dist_kws["x_is_pcasd"] = True
dist_kws["y_is_pcasd"] = True
for k in range(len(self)): # model_idx
empirical = np.zeros((np.max(cluster[:, k]) + 1, self.n_empirical))
for c in np.unique(cluster[:, k]): # cluster_idx
fg = rs.choice(
np.where(cluster[:, k] == c)[0], size=self.n_empirical
)
if self.distance_metric is ed:
empirical[c] = np.sort(
joblib.Parallel(n_jobs=n_jobs, backend="threading")(
joblib.delayed(self.distance_metric)(
latent[fg[i]], latent[bg[i]]
)
for i in range(self.n_empirical)
)
)
else:
fg_posterior = self._get_posterior(n_jobs, random_seed, idx=fg)
empirical[c] = np.sort(
joblib.Parallel(n_jobs=n_jobs, backend="threading")(
joblib.delayed(self.distance_metric)(
latent[fg[i], k],
latent[bg[i], k],
fg_posterior[i][k],
bg_posterior[i][k],
**dist_kws,
)
for i in range(self.n_empirical)
)
)
self.empirical.append(empirical)
return self.empirical
def _force_components(
self, n_jobs: int = config._USE_GLOBAL, random_seed: int = config._USE_GLOBAL
) -> None:
n_jobs = config.N_JOBS if n_jobs == config._USE_GLOBAL else n_jobs
random_seed = (
config.RANDOM_SEED if random_seed == config._USE_GLOBAL else random_seed
)
self._get_nearest_neighbors(n_jobs)
if self.distance_metric is not ed:
self._get_posterior(n_jobs, random_seed)
self._get_empirical(n_jobs, random_seed)
@staticmethod
def _fit_nearest_neighbors(x: np.ndarray) -> sklearn.neighbors.NearestNeighbors:
return sklearn.neighbors.NearestNeighbors().fit(x)
@staticmethod
def _nearest_neighbor_search(
nn: sklearn.neighbors.NearestNeighbors, query: np.ndarray, n_neighbors: int
) -> np.ndarray:
return nn.kneighbors(query, n_neighbors=n_neighbors)[1]
@staticmethod
@numba.jit(nopython=True, nogil=True, cache=True)
def _nearest_neighbor_merge(x: np.ndarray) -> np.ndarray: # pragma: no cover
return np.unique(x)
@staticmethod
@numba.jit(nopython=True, nogil=True, cache=True)
def _empirical_pvalue(
hits: np.ndarray, dist: np.ndarray, cluster: np.ndarray, empirical: np.ndarray
) -> np.ndarray: # pragma: no cover
r"""
hits : n_hits
dist : n_hits * n_models
cluster : n_cells * n_models
empirical : n_models * [n_cluster * n_empirical]
"""
pval = np.empty(dist.shape)
for i in range(dist.shape[1]): # model index
for j in range(dist.shape[0]): # hit index
pval[j, i] = (
np.searchsorted(empirical[i][cluster[hits[j], i]], dist[j, i])
/ empirical[i].shape[1]
)
return pval
[docs] def save(self, path: str, only_used_genes: bool = True) -> None:
r"""
Save BLAST object to a directory.
Parameters
----------
path
Specifies a path to save the BLAST object.
only_used_genes
Whether to preserve only the genes used by models.
"""
if not os.path.exists(path):
os.makedirs(path)
if self.ref is not None:
if only_used_genes:
if "__libsize__" not in self.ref.obs.columns:
data.compute_libsize(
self.ref
) # So that align will still work properly
ref = data.select_vars(
self.ref,
np.unique(np.concatenate([model.genes for model in self.models])),
)
ref.uns["distance_metric"] = self.distance_metric.__name__
ref.uns["n_posterior"] = self.n_posterior
ref.uns["n_empirical"] = self.n_empirical
ref.uns["cluster_empirical"] = self.cluster_empirical
ref.uns["eps"] = self.eps if self.eps is not None else "None"
if self.latent is not None:
ref.uns["latent"] = self.latent
if self.latent is not None:
ref.uns["cluster"] = self.cluster
if self.empirical is not None:
ref.uns["empirical"] = {
str(i): item for i, item in enumerate(self.empirical)
}
ref.uns["posterior"] = {
str(i): item
for i, item in enumerate(self.posterior)
if item is not None
}
ref.write(os.path.join(path, "ref.h5ad"))
for i in range(len(self)):
self.models[i].save(os.path.join(path, f"model_{i}"))
[docs] @classmethod
def load(cls, path: str, mode: int = NORMAL, **kwargs):
r"""
Load BLAST object from a directory.
Parameters
----------
path
Specifies a path to load from.
mode
If mode is set to MINIMAL, model loading will be accelerated by only
loading the encoders, but aligning BLAST (fine-tuning) would not be
available. Should be among {cb.blast.NORMAL, cb.blast.MINIMAL}
Returns
-------
blast
Loaded BLAST object.
"""
assert mode in (NORMAL, MINIMAL)
ref = anndata.read_h5ad(os.path.join(path, "ref.h5ad"))
models = []
model_paths = sorted(
os.path.join(path, d)
for d in os.listdir(path)
if re.fullmatch(r"model_[0-9]+", d) and os.path.isdir(os.path.join(path, d))
)
for model_path in model_paths:
models.append(directi.DIRECTi.load(model_path, _mode=mode))
blast = cls(
models,
ref,
ref.uns["distance_metric"],
ref.uns["n_posterior"],
ref.uns["n_empirical"],
ref.uns["cluster_empirical"],
None if ref.uns["eps"] == "None" else ref.uns["eps"],
force_components=False,
)
blast.latent = blast.ref.uns["latent"] if "latent" in blast.ref.uns else None
blast.cluster = blast.ref.uns["cluster"] if "latent" in blast.ref.uns else None
blast.empirical = (
[blast.ref.uns["empirical"][str(i)] for i in range(len(blast))]
if "empirical" in blast.ref.uns
else None
)
if "posterior" in blast.ref.uns:
for i in range(ref.shape[0]):
if str(i) in blast.ref.uns["posterior"]:
blast.posterior[i] = blast.ref.uns["posterior"][str(i)]
blast._force_components(**kwargs)
return blast
[docs] def query(
self,
query: anndata.AnnData,
n_neighbors: int = 5,
store_dataset: bool = False,
n_jobs: int = config._USE_GLOBAL,
random_seed: int = config._USE_GLOBAL,
) -> "Hits":
r"""
BLAST query
Parameters
----------
query
Query transcriptomes.
n_neighbors
Initial number of nearest neighbors to search in each model.
store_dataset
Whether to store query dataset in the returned hit object.
Note that this is necessary if :meth:`Hits.gene_deviation`
is to be used.
n_jobs
Number of parallel jobs to run when performing query. If not
specified, :data:`config.N_JOBS` will be used.
Note that each (tensorflow) job could be distributed on multiple
CPUs for a single "job".
random_seed
Random seed for posterior sampling. If not specified,
:data:`utils.RANDOM_SEED` will be used.
Returns
-------
hits
Query hits
"""
n_jobs = config.N_JOBS if n_jobs == config._USE_GLOBAL else n_jobs
random_seed = (
config.RANDOM_SEED if random_seed == config._USE_GLOBAL else random_seed
)
utils.logger.info("Projecting to latent space...")
query_latent = joblib.Parallel(
n_jobs=min(n_jobs, len(self)), backend="threading"
)(
joblib.delayed(model.inference)(query) for model in self.models
) # n_models * [n_cells * latent_dim]
utils.logger.info("Doing nearest neighbor search...")
nearest_neighbors = self._get_nearest_neighbors(n_jobs)
nni = np.stack(
joblib.Parallel(n_jobs=min(n_jobs, len(self)), backend="threading")(
joblib.delayed(self._nearest_neighbor_search)(
_nearest_neighbor, _query_latent, n_neighbors
)
for _nearest_neighbor, _query_latent in zip(
nearest_neighbors, query_latent
)
),
axis=2,
) # n_cells * n_neighbors * n_models
utils.logger.info("Merging hits across models...")
hits = joblib.Parallel(n_jobs=n_jobs, backend="threading")(
joblib.delayed(self._nearest_neighbor_merge)(_nni) for _nni in nni
) # n_cells * [n_hits]
hitsu, hitsi = np.unique(np.concatenate(hits), return_inverse=True)
hitsi = np.split(hitsi, np.cumsum([item.size for item in hits])[:-1])
query_latent = np.stack(query_latent, axis=1) # n_cell * n_model * latent_dim
ref_latent = self._get_latent(n_jobs) # n_cell * n_model * latent_dim
if self.distance_metric is ed:
utils.logger.info("Computing Euclidean distances...")
dist = joblib.Parallel(n_jobs=n_jobs, backend="threading")(
joblib.delayed(_hit_ed_across_models)(
query_latent[i], ref_latent[hits[i]]
)
for i in range(len(hits))
) # list of n_hits * n_models
else:
utils.logger.info("Computing posterior distribution distances...")
query_posterior = np.stack(
joblib.Parallel(n_jobs=min(n_jobs, len(self)), backend="loky")(
joblib.delayed(model.inference)(
query, n_posterior=self.n_posterior, random_seed=random_seed
)
for model in self.models
),
axis=1,
) # n_cells * n_models * n_posterior_samples * latent_dim
ref_posterior = np.stack(
self._get_posterior(n_jobs, random_seed, idx=hitsu)
) # n_cells * n_models * n_posterior_samples * latent_dim
distance_metric = DISTANCE_METRIC_ACROSS_MODELS[self.distance_metric]
dist_kws = {"eps": self.eps} if self.eps is not None else {}
dist = joblib.Parallel(n_jobs=n_jobs, backend="threading")(
joblib.delayed(distance_metric)(
query_latent[i],
ref_latent[hits[i]],
query_posterior[i],
ref_posterior[hitsi[i]],
**dist_kws,
)
for i in range(len(hits))
) # list of n_hits * n_models
utils.logger.info("Computing empirical p-values...")
empirical = self._get_empirical(n_jobs, random_seed)
cluster = self._get_cluster(n_jobs)
pval = joblib.Parallel(n_jobs=n_jobs, backend="threading")(
joblib.delayed(self._empirical_pvalue)(_hits, _dist, cluster, empirical)
for _hits, _dist in zip(hits, dist)
) # list of n_hits * n_models
return Hits(
self,
hits,
dist,
pval,
query
if store_dataset
else anndata.AnnData(
X=scipy.sparse.csr_matrix((query.shape[0], 0)),
obs=pd.DataFrame(index=query.obs.index),
var=pd.DataFrame(),
uns={},
),
)
[docs] def align(
self,
query: typing.Union[anndata.AnnData, typing.Mapping[str, anndata.AnnData]],
n_jobs: int = config._USE_GLOBAL,
random_seed: int = config._USE_GLOBAL,
path: typing.Optional[str] = None,
**kwargs,
) -> "BLAST":
r"""
Align internal DIRECTi models with query datasets (fine tuning).
Parameters
----------
query
A query dataset or a dict of query datasets, which will be aligned
to the reference.
n_jobs
Number of parallel jobs to run when building the BLAST index,
If not specified, :data:`config.N_JOBS` will be used.
Note that each (tensorflow) job could be distributed on multiple
CPUs for a single "job".
random_seed
Random seed for posterior sampling. If not specified,
:data:`config.RANDOM_SEED` will be used.
path
Specifies a path to store temporary files.
kwargs
Additional keyword parameters passed to
:meth:`directi.align_DIRECTi`.
Returns
-------
blast
A new BLAST object with aligned internal models.
"""
if any(
model._mode == directi._TEST for model in self.models
): # pragma: no cover
raise Exception("Align not available!")
n_jobs = config.N_JOBS if n_jobs == config._USE_GLOBAL else n_jobs
random_seed = (
config.RANDOM_SEED if random_seed == config._USE_GLOBAL else random_seed
)
path = path or tempfile.mkdtemp()
aligned_models = joblib.Parallel(n_jobs=n_jobs, backend="threading")(
joblib.delayed(directi.align_DIRECTi)(
self.models[i],
self.ref,
query,
random_seed=random_seed,
path=os.path.join(path, f"aligned_model_{i}"),
**kwargs,
)
for i in range(len(self))
)
return BLAST(
aligned_models,
self.ref,
distance_metric=self.distance_metric,
n_posterior=self.n_posterior,
n_empirical=self.n_empirical,
cluster_empirical=self.cluster_empirical,
eps=self.eps,
)
[docs]class Hits(object):
r"""
BLAST hits
Parameters
----------
blast
The :class:`BLAST` object producing the hits
hits
Indices of hit cell in the reference dataset.
Each list element contains hit cell indices for a query cell.
dist
Hit cell distances.
Each list element contains distances for a query cell.
Each list element is a :math:`n\_hits \times n\_models` matrix,
with matrix entries corresponding to the distance to
each hit cell under each model.
pval
Hit cell empirical p-values.
Each list element contains p-values for a query cell.
Each list element is a :math:`n\_hits \times n\_models` matrix,
with matrix entries corresponding to the empirical p-value of
each hit cell under each model.
query
Query dataset
"""
FILTER_BY_DIST = 0
FILTER_BY_PVAL = 1
def __init__(
self,
blast: BLAST,
hits: typing.List[np.ndarray],
dist: typing.List[np.ndarray],
pval: typing.List[np.ndarray],
query: anndata.AnnData,
) -> None:
self.blast = blast
self.hits = np.asarray(hits, dtype=object)
self.dist = np.asarray(dist, dtype=object)
self.pval = np.asarray(pval, dtype=object)
self.query = query
if (
not self.hits.shape[0]
== self.dist.shape[0]
== self.pval.shape[0]
== self.query.shape[0]
):
raise ValueError("Inconsistent shape!")
def __len__(self) -> int:
return self.query.shape[0]
def __iter__(self):
for idx, (_hits, _dist, _pval) in enumerate(
zip(self.hits, self.dist, self.pval)
):
yield Hits(self.blast, [_hits], [_dist], [_pval], self.query[idx, :])
def __getitem__(self, s):
s = [s] if isinstance(s, (int, np.integer)) else s
return Hits(
self.blast, self.hits[s], self.dist[s], self.pval[s], self.query[s, :]
)
[docs] def to_data_frames(self) -> typing.Mapping[str, pd.DataFrame]:
r"""
Construct hit data frames for query cells.
Note that only reconciled ``Hits`` objects are supported.
Returns
-------
data_frame_dicts
Each element is hit data frame for a cell
"""
if self.dist[0].ndim != 1 or self.pval[0].ndim != 1:
raise RuntimeError("Please call `reconcile_models` first!")
df_dict = collections.OrderedDict()
for i, name in enumerate(self.query.obs_names):
df_dict[name] = self.blast.ref.obs.iloc[self.hits[i], :].copy()
df_dict[name].loc[:, "hits"] = self.hits[i]
df_dict[name].loc[:, "dist"] = self.dist[i]
df_dict[name].loc[:, "pval"] = self.pval[i]
return df_dict
[docs] def reconcile_models(
self, dist_method: str = "mean", pval_method: str = "gmean"
) -> "Hits":
r"""
Integrate model-specific distances and empirical p-values.
Parameters
----------
dist_method
Specifies how to integrate distances across difference models.
Should be among {"mean", "gmean", "min", "max"}.
pval_method
Specifies how to integrate empirical p-values across different
models. Should be among {"mean", "gmean", "min", "max"}.
Returns
-------
reconciled_hits
Hit object containing reconciled
"""
dist_method = self._get_reconcile_method(dist_method)
dist = [dist_method(item, axis=1) for item in self.dist]
pval_method = self._get_reconcile_method(pval_method)
pval = [pval_method(item, axis=1) for item in self.pval]
return Hits(self.blast, self.hits, dist, pval, self.query)
@staticmethod
@numba.jit(nopython=True, nogil=True, cache=True)
def _filter_hits(
hits: np.ndarray,
dist: np.ndarray,
pval: np.ndarray,
by: int,
cutoff: float,
model_tolerance: int,
) -> typing.Tuple[np.ndarray, np.ndarray, np.ndarray]: # pragma: no cover
r"""
hits : n_hits
dist : n_hits * n_models
pval : n_hits * n_models
"""
if by == 0: # Hits.FILTER_BY_DIST
hit_mask = (dist.shape[1] - (dist <= cutoff).sum(axis=1)) <= model_tolerance
else: # Hits.FILTER_BY_PVAL
hit_mask = (pval.shape[1] - (pval <= cutoff).sum(axis=1)) <= model_tolerance
return hits[hit_mask], dist[hit_mask], pval[hit_mask]
@staticmethod
@numba.jit(nopython=True, nogil=True, cache=True)
def _filter_reconciled_hits(
hits: np.ndarray, dist: np.ndarray, pval: np.ndarray, by: int, cutoff: float
) -> typing.Tuple[np.ndarray, np.ndarray, np.ndarray]: # pragma: no cover
r"""
hits : n_hits
dist : n_hits
pval : n_hits
"""
if by == 0: # Hits.FILTER_BY_DIST
hit_mask = dist <= cutoff
else: # Hits.FILTER_BY_PVAL
hit_mask = pval <= cutoff
return hits[hit_mask], dist[hit_mask], pval[hit_mask]
[docs] def filter(
self,
by: str = "pval",
cutoff: float = 0.05,
model_tolerance: int = 0,
n_jobs: int = 1,
) -> "Hits":
r"""
Filter hits by posterior distance or p-value
Parameters
----------
by
Specifies a metric based on which to filter hits.
Should be among {"dist", "pval"}.
cutoff
Cutoff when filtering hits.
model_tolerance
Maximal number of models allowed in which the cutoff is not
satisfied, above which the query cell will be rejected.
Irrelevant for reconciled hits.
n_jobs
Number of parallel jobs to run.
Returns
-------
filtered_hits
Hit object containing remaining hits after filtering
"""
if by == "pval":
assert self.pval is not None
by = Hits.FILTER_BY_PVAL
else: # by == "dist"
by = Hits.FILTER_BY_DIST
if self.dist[0].ndim == 1:
hits, dist, pval = [
_
for _ in zip(
*joblib.Parallel(n_jobs=n_jobs, backend="threading")(
joblib.delayed(self._filter_reconciled_hits)(
_hits, _dist, _pval, by, cutoff
)
for _hits, _dist, _pval in zip(self.hits, self.dist, self.pval)
)
)
]
else:
hits, dist, pval = [
_
for _ in zip(
*joblib.Parallel(n_jobs=n_jobs, backend="threading")(
joblib.delayed(self._filter_hits)(
_hits, _dist, _pval, by, cutoff, model_tolerance
)
for _hits, _dist, _pval in zip(self.hits, self.dist, self.pval)
)
)
]
return Hits(self.blast, hits, dist, pval, self.query)
[docs] def annotate(
self,
field: str,
min_hits: int = 2,
majority_threshold: float = 0.5,
return_evidence: bool = False,
) -> pd.DataFrame:
r"""
Annotate query cells based on existing annotations of hit cells
via majority voting.
Parameters
----------
field
Specifies a meta column in `anndata.AnnData.obs`.
min_hits
Minimal number of hits required for annotating a query cell,
otherwise the query cell will be rejected.
majority_threshold
Minimal majority fraction (not inclusive) required for confident
annotation. Only effective when predicting
categorical variables. If the threshold is not met, annotation
will be "ambiguous".
return_evidence
Whether to return evidence level of the annotations.
Returns
-------
prediction
Each row contains the inferred annotation for a query cell.
If ``return_evidence`` is set to False, the data frame contains only
one column, i.e. the inferred annotation.
If ``return_evidence`` is set to True, the data frame also contains
the number of hits, as well as the majority fraction (only for
categorical annotations) for each query cell.
"""
ref = self.blast.ref.obs[field].to_numpy().ravel()
n_hits = np.repeat(0, len(self.hits))
if np.issubdtype(ref.dtype.type, np.character) or np.issubdtype(
ref.dtype.type, np.object_
):
prediction = np.repeat("rejected", len(self.hits)).astype(object)
majority_frac = np.repeat(np.nan, len(self.hits))
for i, _hits in enumerate(self.hits):
hits = ref[_hits.astype(int)]
hits = hits[~utils.isnan(hits)] if hits.size else hits
n_hits[i] = hits.size
if n_hits[i] < min_hits:
continue
label, count = np.unique(hits, return_counts=True)
best_idx = count.argmax()
majority_frac[i] = count[best_idx] / hits.size
if majority_frac[i] <= majority_threshold:
prediction[i] = "ambiguous"
continue
prediction[i] = label[best_idx]
prediction = prediction.astype(ref.dtype.type)
elif np.issubdtype(ref.dtype.type, np.number):
prediction = np.repeat(np.nan, len(self.hits))
for i, _hits in enumerate(self.hits):
hits = ref[_hits.astype(int)]
hits = hits[~utils.isnan(hits)] if hits.size else hits
n_hits[i] = hits.size
if n_hits[i] < min_hits:
continue
prediction[i] = hits.mean()
# np.array call is for 1-d mean that produces 0-d values
prediction = np.stack(prediction, axis=0)
else: # pragma: no cover
raise ValueError("Unsupported data type!")
result = collections.OrderedDict()
result[field] = prediction
if return_evidence:
result["n_hits"] = n_hits
if "majority_frac" in locals():
result["majority_frac"] = majority_frac
return pd.DataFrame(result, index=self.query.obs_names)
[docs] def blast2co(
self,
cl_dag: utils.CellTypeDAG,
cl_field: str = "cell_ontology_class",
min_hits: int = 2,
thresh: float = 0.5,
min_path: int = 4,
) -> pd.DataFrame:
r"""
Annotate query cells based on existing annotations of hit cells
via the cell-ontology-aware BLAST2CO method.
Parameters
----------
cl_dag
Cell ontology DAG
cl_field
Specify the ``obs`` column containing cell ontology annotation
in the reference dataset
min_hits
Minimal number of hits required for annotating a query cell,
otherwise the query cell will be rejected.
thresh
Scoring threshold based on 1 - pvalue.
min_path
Minimal allowed value of the maximal distance to root
for a prediction to be made.
Returns
-------
prediction
Each row contains the inferred annotation for a query cell.
"""
if self.dist[0].ndim != 1 or self.pval[0].ndim != 1:
raise RuntimeError("Please call `reconcile_models` first!")
prediction = np.repeat("rejected", len(self.hits)).astype(object)
ref = self.blast.ref.obs[cl_field].to_numpy().ravel()
for i, (_hits, _pval) in enumerate(zip(self.hits, self.pval)):
hits = ref[_hits.astype(int)]
if hits.size:
mask = ~utils.isnan(hits)
hits = hits[mask]
_pval = _pval[mask]
if hits.size < min_hits:
continue
cl_dag.value_reset()
for cl in np.unique(hits): # 1 - pval
cl_dag.value_set(cl, np.sum(1 - _pval[hits == cl]) / np.sum(1 - _pval))
cl_dag.value_update()
leaves = cl_dag.best_leaves(
thresh=thresh, min_path=min_path, retrieve=cl_field
)
if len(leaves) == 1:
prediction[i] = leaves[0]
elif len(leaves) > 1:
prediction[i] = "ambiguous"
return pd.DataFrame({cl_field: prediction}, index=self.query.obs_names)
@staticmethod
def _get_reconcile_method(method: str):
if method == "mean":
return np.mean
if method == "gmean":
return scipy.stats.gmean
if method == "min":
return np.min
if method == "max":
return np.max
raise ValueError("Unknown method!") # pragma: no cover
[docs] def gene_gradient(
self,
eval_point: str = "query",
normalize_deviation: bool = True,
avg_models: bool = True,
n_jobs: int = config._USE_GLOBAL,
) -> typing.List[np.ndarray]:
r"""
Compute gene-wise gradient for each pair of query-hit cells
based on query-hit deviation in the latent space. Useful for model
interpretation.
Parameters
----------
eval_point
At which point should the gradient be evaluated.
Valid options include: {"query", "ref", "both"}
normalize_deviation
Whether to normalize query-hit deivation in the latent space.
avg_models
Whether to average gene-wise gradients across different models
n_jobs
Number of parallel jobs to run when performing query. If not
specified, :data:`config.N_JOBS` will be used.
Note that each (tensorflow) job could be distributed on multiple
CPUs for a single "job".
Returns
-------
gene_gradient
A list with length equal to the number of query cells, where each
element is a :class:`np.ndarray` containing gene-wise gradient for
every hit cell of a query cell. The :class:`np.ndarray`s are of
shape :math:`n\_hits \times n\_genes` if ``avg_models`` is set to
True, or :math:`n\_hits \times n\_models \times n\_genes` if
``avg_models`` is set to False.
"""
n_jobs = config.N_JOBS if n_jobs == config._USE_GLOBAL else n_jobs
if self.query.shape[1] == 0:
raise RuntimeError(
'No query data available! Please set "store_dataset" to True '
"when calling BLAST.query()"
)
ref_idx = np.concatenate(self.hits)
query_idx = np.concatenate(
[idx * np.ones_like(_hits) for idx, _hits in enumerate(self.hits)]
)
ref = self.blast.ref[ref_idx, :]
query = self.query[query_idx, :]
query_latent = joblib.Parallel(
n_jobs=min(n_jobs, len(self)), backend="threading"
)(joblib.delayed(model.inference)(self.query) for model in self.blast.models)
query_latent = np.stack(query_latent)[
:, query_idx, :
] # n_models * sum(n_hits) * latent_dim
ref_latent = self.blast._get_latent(n_jobs) # n_cells * n_models * latent_dim
ref_latent = ref_latent[ref_idx].swapaxes(
0, 1
) # n_models * sum(n_hits) * latent_dim
deviation = query_latent - ref_latent # n_models * sum(n_hits) * latent_dim
if normalize_deviation:
deviation /= np.linalg.norm(deviation, axis=2, keepdims=True)
if eval_point in ("ref", "both"):
gene_dev_ref = joblib.Parallel(n_jobs=n_jobs, backend="threading")(
joblib.delayed(model.gene_grad)(ref, latent_grad=_deviation)
for model, _deviation in zip(self.blast.models, deviation)
) # n_models * [sum(n_hits) * n_genes]
gene_dev_ref = np.stack(
gene_dev_ref, axis=1
) # sum(n_hits) * n_models * n_genes
if eval_point in ("query", "both"):
gene_dev_query = joblib.Parallel(n_jobs=n_jobs, backend="threading")(
joblib.delayed(model.gene_grad)(query, latent_grad=_deviation)
for model, _deviation in zip(self.blast.models, deviation)
) # n_models * [sum(n_hits) * n_genes]
gene_dev_query = np.stack(
gene_dev_query, axis=1
) # sum(n_hits) * n_models * n_genes
if eval_point == "ref":
gene_dev = gene_dev_ref
elif eval_point == "query":
gene_dev = gene_dev_query
else: # eval_point == "both"
gene_dev = (gene_dev_ref + gene_dev_query) / 2
if avg_models:
gene_dev = np.mean(gene_dev, axis=1)
split_idx = np.cumsum([_hits.size for _hits in self.hits])[:-1]
gene_dev = np.split(
gene_dev, split_idx
) # n_queries * [n_hits * <n_models> * n_genes]
return gene_dev
[docs]def sankey(
query: np.ndarray,
ref: np.ndarray,
title: str = "Sankey",
width: int = 500,
height: int = 500,
tint_cutoff: int = 1,
font: str = "Arial",
font_size: float = 10.0,
suppress_plot: bool = False,
) -> dict: # pragma: no cover
r"""
Make a sankey diagram of query-reference mapping (only works in
ipython notebooks).
Parameters
----------
query
1-dimensional array of query annotation.
ref
1-dimensional array of BLAST prediction based on reference database.
title
Diagram title.
width
Graph width.
height
Graph height.
tint_cutoff
Cutoff below which sankey flows are shown in a tinter color.
font
Font family used for the plot.
font_size
Font size for the plot.
suppress_plot
Whether to suppress plotting and only return the figure dict.
Returns
-------
fig
Figure object fed to `iplot` of the `plotly` module to produce the plot.
"""
cf = metrics.confusion_matrix(query, ref)
cf["query"] = cf.index.to_numpy()
cf = cf.melt(
id_vars=["query"], var_name="reference", value_name="count"
).sort_values(by="count")
query_i, query_c = utils.encode_integer(cf["query"])
ref_i, ref_c = utils.encode_integer(cf["reference"])
sankey_data = dict(
type="sankey",
node=dict(
pad=15,
thickness=20,
line=dict(color="black", width=0.5),
label=np.concatenate([query_c, ref_c], axis=0),
color=["#E64B35"] * len(query_c) + ["#4EBBD5"] * len(ref_c),
),
link=dict(
source=query_i.tolist(),
target=(ref_i + len(query_c)).tolist(),
value=cf["count"].tolist(),
color=np.vectorize(lambda x: "#F0F0F0" if x <= tint_cutoff else "#CCCCCC")(
cf["count"]
),
),
)
sankey_layout = dict(
title=title, width=width, height=height, font=dict(family=font, size=font_size)
)
fig = dict(data=[sankey_data], layout=sankey_layout)
if not suppress_plot:
import plotly.offline
plotly.offline.init_notebook_mode()
plotly.offline.iplot(fig, validate=False)
return fig