r"""
Latent space / encoder modules for DIRECTi
"""
import itertools
import typing
import torch
import torch.distributions as D
import torch.nn.functional as F
from torch import nn
from . import config, utils
from .rebuild import MLP, Linear
[docs]class Regularizer(nn.Module):
def __init__(
self,
latent_dim: int,
h_dim: int = 128,
depth: int = 1,
dropout: float = 0.0,
name: str = "Reg",
_class: str = "Regularizer",
**kwargs,
) -> None:
super().__init__()
self.latent_dim = latent_dim
self.h_dim = h_dim
self.depth = depth
self.dropout = dropout
self.name = name
self._class = _class
for key in kwargs.keys():
utils.logger.warning("Argument `%s` is no longer supported!" % key)
i_dim = [latent_dim] + [h_dim] * (depth - 1) if depth > 0 else []
o_dim = [h_dim] * depth
dropout = [dropout] * depth
if depth > 0:
dropout[0] = 0.0
self.mlp = MLP(i_dim, o_dim, dropout)
self.output = Linear(h_dim, 1) if depth > 0 else Linear(latent_dim, 1)
self.sigmoid = nn.Sigmoid()
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.sigmoid(self.output(self.mlp(x)))
def get_config(self) -> typing.Mapping:
return {
"h_dim": self.h_dim,
"depth": self.depth,
"dropout": self.dropout,
"name": self.name,
"_class": self._class,
}
[docs]class Latent(nn.Module):
r"""
Abstract base class for latent variable modules.
"""
def __init__(
self,
input_dim: int,
latent_dim: int,
h_dim: int = 128,
depth: int = 1,
dropout: float = 0.0,
lambda_reg: float = 0.0,
fine_tune: bool = False,
deviation_reg: float = 0.0,
name: str = "Latent",
_class: str = "Latent",
**kwargs,
) -> None:
super().__init__()
self.input_dim = input_dim
self.latent_dim = latent_dim
self.h_dim = h_dim
self.depth = depth
self.dropout = dropout
self.lambda_reg = lambda_reg
self.fine_tune = fine_tune
self.deviation_reg = deviation_reg
self.name = name
self._class = _class
self.record_prefix = "discriminator"
for key in kwargs.keys():
utils.logger.warning("Argument `%s` is no longer supported!" % key)
@staticmethod
def gan_d_loss(
y: torch.Tensor, y_hat: torch.Tensor, eps: float = 1e-8
) -> torch.Tensor:
return -(torch.log(y_hat + eps) + torch.log(1 - y + eps)).mean()
@staticmethod
def gan_g_loss(y: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
return -torch.log(y + eps).mean()
def get_config(self) -> typing.Mapping:
return {
"input_dim": self.input_dim,
"latent_dim": self.latent_dim,
"h_dim": self.h_dim,
"depth": self.depth,
"dropout": self.dropout,
"lambda_reg": self.lambda_reg,
"fine_tune": self.fine_tune,
"deviation_reg": self.deviation_reg,
"name": self.name,
"_class": self._class,
}
[docs]class Gau(Latent):
r"""
Build a Gaussian latent module. The Gaussian latent variable is used as
cell embedding.
Parameters
----------
input_dim
Dimensionality of the input tensor.
latent_dim
Dimensionality of the latent variable.
h_dim
Dimensionality of the hidden layers in the encoder MLP.
depth
Number of hidden layers in the encoder MLP.
dropout
Dropout rate.
lambda_reg
Regularization strength on the latent variable.
name
Name of the module.
"""
def __init__(
self,
input_dim: int,
latent_dim: int,
h_dim: int = 128,
depth: int = 1,
dropout: float = 0.0,
lambda_reg: float = 0.001,
fine_tune: bool = False,
deviation_reg: float = 0.0,
name: str = "Gau",
_class: str = "Gau",
**kwargs,
) -> None:
super().__init__(
input_dim,
latent_dim,
h_dim,
depth,
dropout,
lambda_reg,
fine_tune,
deviation_reg,
name,
_class,
**kwargs,
)
self.gau_reg = Regularizer(latent_dim, h_dim, depth, dropout, name="gau")
self.gaup_sampler = D.Normal(loc=torch.tensor(0.0), scale=torch.tensor(1.0))
i_dim = [input_dim] + [h_dim] * (depth - 1) if depth > 0 else []
o_dim = [h_dim] * depth
dropout = [dropout] * depth
self.mlp = MLP(i_dim, o_dim, dropout, bias=False, batch_normalization=True)
self.gau = (
Linear(h_dim, latent_dim) if depth > 0 else Linear(input_dim, latent_dim)
)
# fine-tune
def save_origin_state(self) -> None:
self.mlp.save_origin_state()
self.mlp.first_layer_trainable = False
self.gau.save_origin_state()
# fine-tune
def deviation_loss(self) -> torch.Tensor:
return self.deviation_reg * (
self.mlp.deviation_loss() + self.gau.deviation_loss()
)
# fine_tune
def check_fine_tune(self) -> None:
if self.fine_tune:
self.save_origin_state()
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
gau = self.gau(self.mlp(x))
return gau, gau
def fetch_latent(self, x: torch.Tensor) -> torch.Tensor:
gau = self.gau(self.mlp(x))
return gau
def fetch_cat(self, x: torch.Tensor) -> torch.Tensor:
raise Exception("Model has no intrinsic clustering")
def fetch_grad(self, x: torch.Tensor, latent_grad: torch.Tensor) -> torch.Tensor:
x_with_grad = x.requires_grad_(True)
gau = self.gau(self.mlp(x))
gau.backward(latent_grad)
return x_with_grad.grad
def d_loss(
self, gau: torch.Tensor, feed_dict: typing.Mapping, loss_record: typing.Mapping
) -> typing.Tuple[torch.Tensor]:
gaup = self.gaup_sampler.sample((gau.shape[0], self.latent_dim)).to(
config.DEVICE
)
gau_pred = self.gau_reg(gau)
gaup_pred = self.gau_reg(gaup)
gau_d_loss = self.gan_d_loss(gau_pred, gaup_pred)
loss_record[
self.record_prefix
+ "/"
+ self.name
+ "/"
+ self.gau_reg.name
+ "/d_loss/d_loss"
] += (gau_d_loss.item() * gau.shape[0])
return self.lambda_reg * gau_d_loss
def g_loss(
self, gau: torch.Tensor, feed_dict: typing.Mapping, loss_record: typing.Mapping
) -> typing.Tuple[torch.Tensor]:
gau_pred = self.gau_reg(gau)
gau_g_loss = self.gan_g_loss(gau_pred)
loss_record[
self.record_prefix
+ "/"
+ self.name
+ "/"
+ self.gau_reg.name
+ "/g_loss/g_loss"
] += (gau_g_loss.item() * gau.shape[0])
if self.fine_tune:
return (
self.lambda_reg * gau_g_loss
+ self.deviation_reg * self.deviation_loss()
)
else:
return self.lambda_reg * gau_g_loss
def init_loss_record(self, loss_record: typing.Mapping) -> None:
loss_record[
self.record_prefix
+ "/"
+ self.name
+ "/"
+ self.gau_reg.name
+ "/d_loss/d_loss"
] = 0
loss_record[
self.record_prefix
+ "/"
+ self.name
+ "/"
+ self.gau_reg.name
+ "/g_loss/g_loss"
] = 0
def parameters_reg(self):
return self.gau_reg.parameters()
def parameters_fit(self):
return itertools.chain(
self.mlp.parameters(),
self.gau.parameters(),
)
def get_config(self) -> typing.Mapping:
return {**super().get_config()}
[docs]class CatGau(Latent):
r"""
Build a double latent module, with a continuous Gaussian latent variable
and a one-hot categorical latent variable for intrinsic clustering of
the data. These two latent variabels are then combined into a single
cell embedding vector.
Parameters
----------
input_dim
Dimensionality of the input tensor.
latent_dim
Dimensionality of the latent variable.
cat_dim
Number of intrinsic clusters.
h_dim
Dimensionality of the hidden layers in the encoder MLP.
depth
Number of hidden layers in the encoder MLP.
dropout
Dropout rate.
lambda_reg
Regularization strength on the latent variable.
name
Name of the module.
"""
def __init__(
self,
input_dim: int,
latent_dim: int,
cat_dim: int,
h_dim: int = 128,
depth: int = 1,
dropout: float = 0.0,
lambda_reg: float = 0.001,
fine_tune: bool = False,
deviation_reg: float = 0.0,
name: str = "CatGau",
_class: str = "CatGau",
**kwargs,
) -> None:
super().__init__(
input_dim,
latent_dim,
h_dim,
depth,
dropout,
lambda_reg,
fine_tune,
deviation_reg,
name,
_class,
**kwargs,
)
self.cat_dim = cat_dim
self.gau_reg = Regularizer(latent_dim, h_dim, depth, dropout, name="gau")
self.gaup_sampler = D.Normal(loc=torch.tensor(0.0), scale=torch.tensor(1.0))
self.cat_reg = Regularizer(cat_dim, h_dim, depth, dropout, name="cat")
self.catp_sampler = D.OneHotCategorical(probs=torch.ones(cat_dim) / cat_dim)
i_dim = [input_dim] + [h_dim] * (depth - 1) if depth > 0 else []
o_dim = [h_dim] * depth
dropout = [dropout] * depth
self.mlp = MLP(i_dim, o_dim, dropout, bias=False, batch_normalization=True)
self.gau = (
Linear(h_dim, latent_dim) if depth > 0 else Linear(input_dim, latent_dim)
)
self.cat = Linear(h_dim, cat_dim) if depth > 0 else Linear(input_dim, cat_dim)
self.softmax = nn.Softmax(dim=1)
self.mat = Linear(cat_dim, latent_dim, bias=False, init_std=0.1, trunc=False)
# fine-tune
def save_origin_state(self) -> None:
self.mlp.save_origin_state()
self.mlp.first_layer_trainable = False
self.gau.save_origin_state()
self.cat.save_origin_state()
self.mat.save_origin_state()
# fine-tune
def deviation_loss(self) -> torch.Tensor:
return self.deviation_reg * (
self.mlp.deviation_loss()
+ self.gau.deviation_loss()
+ self.cat.deviation_loss()
+ self.mat.deviation_loss()
)
# fine_tune
def check_fine_tune(self) -> None:
if self.fine_tune:
self.save_origin_state()
[docs] def forward(self, x: torch.Tensor) -> typing.Tuple[torch.Tensor]:
x = self.mlp(x)
gau = self.gau(x)
cat = self.softmax(self.cat(x))
latent = gau + self.mat(cat)
return latent, (gau, cat)
def fetch_latent(self, x: torch.Tensor) -> torch.Tensor:
x = self.mlp(x)
gau = self.gau(x)
cat = self.softmax(self.cat(x))
latent = gau + self.mat(cat)
return latent
def fetch_cat(self, x: torch.Tensor) -> torch.Tensor:
x = self.mlp(x)
cat = self.softmax(self.cat(x))
return cat
def fetch_grad(self, x: torch.Tensor, latent_grad: torch.Tensor) -> torch.Tensor:
x_with_grad = x.requires_grad_(True)
x = self.mlp(x)
gau = self.gau(x)
cat = self.softmax(self.cat(x))
latent = gau + self.mat(cat)
latent.backward(latent_grad)
return x_with_grad.grad
def d_loss(
self,
catgau: typing.Tuple[torch.Tensor],
feed_dict: typing.Mapping,
loss_record: typing.Mapping,
) -> torch.Tensor:
gau, cat = catgau
gaup = self.gaup_sampler.sample((gau.shape[0], self.latent_dim)).to(
config.DEVICE
)
gau_pred = self.gau_reg(gau)
gaup_pred = self.gau_reg(gaup)
gau_d_loss = self.gan_d_loss(gau_pred, gaup_pred)
loss_record[
self.record_prefix
+ "/"
+ self.name
+ "/"
+ self.gau_reg.name
+ "/d_loss/d_loss"
] += (gau_d_loss.item() * gau.shape[0])
catp = self.catp_sampler.sample((cat.shape[0],)).to(config.DEVICE)
cat_pred = self.cat_reg(cat)
catp_pred = self.cat_reg(catp)
cat_d_loss = self.gan_d_loss(cat_pred, catp_pred)
loss_record[
self.record_prefix
+ "/"
+ self.name
+ "/"
+ self.cat_reg.name
+ "/d_loss/d_loss"
] += (cat_d_loss.item() * cat.shape[0])
return self.lambda_reg * (gau_d_loss + cat_d_loss)
def g_loss(
self,
catgau: typing.Tuple[torch.Tensor],
feed_dict: typing.Mapping,
loss_record: typing.Mapping,
) -> typing.Tuple[torch.Tensor]:
gau, cat = catgau
gau_pred = self.gau_reg(gau)
gau_g_loss = self.gan_g_loss(gau_pred)
loss_record[
self.record_prefix
+ "/"
+ self.name
+ "/"
+ self.gau_reg.name
+ "/g_loss/g_loss"
] += (gau_g_loss.item() * gau.shape[0])
cat_pred = self.cat_reg(cat)
cat_g_loss = self.gan_g_loss(cat_pred)
loss_record[
self.record_prefix
+ "/"
+ self.name
+ "/"
+ self.cat_reg.name
+ "/g_loss/g_loss"
] += (cat_g_loss.item() * cat.shape[0])
if self.fine_tune:
return (
self.lambda_reg * (gau_g_loss + cat_g_loss)
+ self.deviation_reg * self.deviation_loss()
)
else:
return self.lambda_reg * (gau_g_loss + cat_g_loss)
def init_loss_record(self, loss_record: typing.Mapping) -> None:
loss_record[
self.record_prefix
+ "/"
+ self.name
+ "/"
+ self.gau_reg.name
+ "/d_loss/d_loss"
] = 0
loss_record[
self.record_prefix
+ "/"
+ self.name
+ "/"
+ self.gau_reg.name
+ "/g_loss/g_loss"
] = 0
loss_record[
self.record_prefix
+ "/"
+ self.name
+ "/"
+ self.cat_reg.name
+ "/d_loss/d_loss"
] = 0
loss_record[
self.record_prefix
+ "/"
+ self.name
+ "/"
+ self.cat_reg.name
+ "/g_loss/g_loss"
] = 0
def parameters_reg(self):
return itertools.chain(self.gau_reg.parameters(), self.cat_reg.parameters())
def parameters_fit(self):
return itertools.chain(
self.mlp.parameters(),
self.gau.parameters(),
self.cat.parameters(),
self.mat.parameters(),
)
def get_config(self) -> typing.Mapping:
return {"cat_dim": self.cat_dim, **super().get_config()}
[docs]class SemiSupervisedCatGau(CatGau):
r"""
Build a double latent module, with a continuous Gaussian latent variable
and a one-hot categorical latent variable for intrinsic clustering of
the data. The categorical latent supports semi-supervision. The two latent
variables are then combined into a single cell embedding vector.
Parameters
----------
input_dim
Dimensionality of the input tensor.
latent_dim
Dimensionality of the Gaussian latent variable.
cat_dim
Number of intrinsic clusters.
h_dim
Dimensionality of the hidden layers in the encoder MLP.
depth
Number of hidden layers in the encoder MLP.
dropout
Dropout rate.
lambda_sup
Supervision strength.
background_catp
Unnormalized background prior distribution of the intrinsic
clustering latent.
For each supervised cell in a minibatch, unnormalized prior
probability of the corresponding cluster will increase by 1,
so this parameter determines how much to trust supervision class
frequency, and it balances between supervision and identifying new
clusters.
lambda_reg
Regularization strength on the latent variables.
name
Name of latent module.
"""
def __init__(
self,
input_dim: int,
latent_dim: int,
cat_dim: int,
h_dim: int = 128,
depth: int = 1,
dropout: float = 0.0,
lambda_sup: float = 10.0,
background_catp: float = 1e-3,
lambda_reg: float = 0.001,
fine_tune: bool = False,
deviation_reg: float = 0.0,
name: str = "SemiSupervisedCatGau",
_class: str = "SemiSupervisedCatGau",
**kwargs,
) -> None:
super().__init__(
input_dim,
latent_dim,
cat_dim,
h_dim,
depth,
dropout,
lambda_reg,
fine_tune,
deviation_reg,
name,
_class,
**kwargs,
)
self.lambda_sup = lambda_sup
self.background_catp = background_catp
[docs] def forward(self, x: torch.Tensor) -> typing.Tuple[torch.Tensor]:
x = self.mlp(x)
gau = self.gau(x)
cat_logit = self.cat(x)
cat = self.softmax(cat_logit)
latent = gau + self.mat(cat)
return latent, (gau, cat, cat_logit)
def d_loss(
self,
catgau: typing.Tuple[torch.Tensor],
feed_dict: typing.Mapping,
loss_record: typing.Mapping,
) -> torch.Tensor:
gau, cat, _ = catgau
cats = feed_dict[self.name]
gaup = self.gaup_sampler.sample((gau.shape[0], self.latent_dim)).to(
config.DEVICE
)
gau_pred = self.gau_reg(gau)
gaup_pred = self.gau_reg(gaup)
gau_d_loss = self.gan_d_loss(gau_pred, gaup_pred)
loss_record[
self.record_prefix
+ "/"
+ self.name
+ "/"
+ self.gau_reg.name
+ "/d_loss/d_loss"
] += (gau_d_loss.item() * gau.shape[0])
cat_prob = torch.ones(self.cat_dim) * self.background_catp + cats.cpu().sum(
dim=0
)
catp_sampler = D.OneHotCategorical(probs=cat_prob / cat_prob.sum())
catp = catp_sampler.sample((cat.shape[0],)).to(config.DEVICE)
cat_pred = self.cat_reg(cat)
catp_pred = self.cat_reg(catp)
cat_d_loss = self.gan_d_loss(cat_pred, catp_pred)
loss_record[
self.record_prefix
+ "/"
+ self.name
+ "/"
+ self.cat_reg.name
+ "/d_loss/d_loss"
] += (cat_d_loss.item() * cat.shape[0])
return self.lambda_reg * (gau_d_loss + cat_d_loss)
def g_loss(
self,
catgau: typing.Tuple[torch.Tensor],
feed_dict: typing.Mapping,
loss_record: typing.Mapping,
) -> typing.Tuple[torch.Tensor]:
gau, cat, cat_logit = catgau
cats = feed_dict[self.name]
mask = cat.sum(dim=1) > 0
if mask.sum() > 0:
sup_loss = F.cross_entropy(cat_logit[mask], cats[mask].argmax(dim=1))
else:
sup_loss = torch.tensor(0)
loss_record["semi_supervision/" + self.name + "/supervised_loss"] += (
sup_loss.item() * cats.shape[0]
)
gau_pred = self.gau_reg(gau)
gau_g_loss = self.gan_g_loss(gau_pred)
loss_record[
self.record_prefix
+ "/"
+ self.name
+ "/"
+ self.gau_reg.name
+ "/g_loss/g_loss"
] += (gau_g_loss.item() * gau.shape[0])
cat_pred = self.cat_reg(cat)
cat_g_loss = self.gan_g_loss(cat_pred)
loss_record[
self.record_prefix
+ "/"
+ self.name
+ "/"
+ self.cat_reg.name
+ "/g_loss/g_loss"
] += (cat_g_loss.item() * cat.shape[0])
if self.fine_tune:
return (
self.lambda_sup * sup_loss
+ self.lambda_reg * (gau_g_loss + cat_g_loss)
+ self.deviation_reg * self.deviation_loss()
)
else:
return self.lambda_sup * sup_loss + self.lambda_reg * (
gau_g_loss + cat_g_loss
)
def init_loss_record(self, loss_record: typing.Mapping) -> None:
loss_record[
self.record_prefix
+ "/"
+ self.name
+ "/"
+ self.gau_reg.name
+ "/d_loss/d_loss"
] = 0
loss_record[
self.record_prefix
+ "/"
+ self.name
+ "/"
+ self.gau_reg.name
+ "/g_loss/g_loss"
] = 0
loss_record[
self.record_prefix
+ "/"
+ self.name
+ "/"
+ self.cat_reg.name
+ "/d_loss/d_loss"
] = 0
loss_record[
self.record_prefix
+ "/"
+ self.name
+ "/"
+ self.cat_reg.name
+ "/g_loss/g_loss"
] = 0
loss_record["semi_supervision/" + self.name + "/supervised_loss"] = 0
def get_config(self) -> typing.Mapping:
return {
"lambda_sup": self.lambda_sup,
"background_catp": self.background_catp,
**super().get_config(),
}