Source code for Cell_BLAST.prob

r"""
Probabilistic / decoder modules for DIRECTi
"""

import math
import typing

import torch
import torch.nn.functional as F
from torch import nn

from . import utils
from .rebuild import MLP, Linear


[docs]class ProbModel(nn.Module): r""" Abstract base class for generative model modules. """ def __init__( self, output_dim: int, full_latent_dim: typing.Tuple[int], h_dim: int = 128, depth: int = 1, dropout: float = 0.0, lambda_reg: float = 0.0, fine_tune: bool = False, deviation_reg: float = 0.0, name: str = "ProbModel", _class: str = "ProbModel", **kwargs, ) -> None: super().__init__() self.output_dim = output_dim self.full_latent_dim = full_latent_dim self.h_dim = h_dim self.depth = depth self.dropout = dropout self.lambda_reg = lambda_reg self.fine_tune = fine_tune self.deviation_reg = deviation_reg self.name = name self._class = _class self.record_prefix = "decoder" for key in kwargs.keys(): utils.logger.warning("Argument `%s` is no longer supported!" % key) i_dim = [full_latent_dim] + [h_dim] * (depth - 1) if depth > 0 else [] o_dim = [h_dim] * depth dropout = [dropout] * depth if depth > 0: dropout[0] = 0.0 self.mlp = MLP(i_dim, o_dim, dropout) def get_config(self) -> typing.Mapping: return { "output_dim": self.output_dim, "full_latent_dim": self.full_latent_dim, "h_dim": self.h_dim, "depth": self.depth, "dropout": self.dropout, "lambda_reg": self.lambda_reg, "fine_tune": self.fine_tune, "deviation_reg": self.deviation_reg, "name": self.name, "_class": self._class, }
[docs]class NB(ProbModel): # Negative binomial r""" Build a Negative Binomial generative module. Parameters ---------- output_dim Dimensionality of the output tensor. full_latent_dim Dimensionality of the latent variable and Numbers of batches. h_dim Dimensionality of the hidden layers in the decoder MLP. depth Number of hidden layers in the decoder MLP. dropout Dropout rate. lambda_reg Regularization strength for the generative model parameters. Here log-scale variance of the scale parameter is regularized to improve numerical stability. fine_tune Whether the module is used in fine-tuning. deviation_reg Regularization strength for the deviation from original model weights. name Name of the module. """ def __init__( self, output_dim: int, full_latent_dim: typing.Tuple[int], h_dim: int = 128, depth: int = 1, dropout: float = 0.0, lambda_reg: float = 0.0, fine_tune: bool = False, deviation_reg: float = 0.0, name: str = "NB", _class: str = "NB", **kwargs, ) -> None: super().__init__( output_dim, full_latent_dim, h_dim, depth, dropout, lambda_reg, fine_tune, deviation_reg, name, _class, **kwargs, ) self.mu = ( Linear(h_dim, output_dim) if depth > 0 else Linear(full_latent_dim, output_dim) ) self.softmax = nn.Softmax(dim=1) self.log_theta = ( Linear(h_dim, output_dim) if depth > 0 else Linear(full_latent_dim, output_dim) ) # fine-tune def save_origin_state(self) -> None: self.mlp.save_origin_state() self.mu.save_origin_state() self.log_theta.save_origin_state() # fine-tune def deviation_loss(self) -> torch.Tensor: return self.deviation_reg * ( self.mlp.deviation_loss() + self.mu.deviation_loss() + self.log_theta.deviation_loss() ) # fine_tune def check_fine_tune(self) -> None: if self.fine_tune: self.save_origin_state() @staticmethod def log_likelihood( x: torch.Tensor, mu: torch.Tensor, log_theta: torch.Tensor, eps: float = 1e-8 ) -> torch.Tensor: theta = torch.exp(log_theta) return ( theta * log_theta - theta * torch.log(theta + mu + eps) + x * torch.log(mu + eps) - x * torch.log(theta + mu + eps) + torch.lgamma(x + theta) - torch.lgamma(theta) - torch.lgamma(x + 1) )
[docs] def forward( self, full_x: typing.Tuple[torch.Tensor], feed_dict: typing.Mapping ) -> torch.Tensor: y = feed_dict["exprs"] x = self.mlp(full_x) softmax_mu = self.softmax(self.mu(x)) mu = softmax_mu * y.sum(dim=1, keepdim=True) log_theta = self.log_theta(x) return mu, log_theta
def loss( self, mu_theta: typing.Tuple[torch.Tensor], feed_dict: typing.Mapping, loss_record: typing.Mapping, ) -> torch.Tensor: y = feed_dict["exprs"] mu, log_theta = mu_theta raw_loss = -self.log_likelihood(y, mu, log_theta).mean() loss_record[self.record_prefix + "/" + self.name + "/raw_loss"] += ( raw_loss.item() * mu.shape[0] ) reg_loss = raw_loss + self.lambda_reg * log_theta.var() loss_record[self.record_prefix + "/" + self.name + "/regularized_loss"] += ( reg_loss.item() * mu.shape[0] ) if self.fine_tune: return reg_loss + self.deviation_reg * self.deviation_loss() else: return reg_loss def init_loss_record(self, loss_record: typing.Mapping) -> None: loss_record[self.record_prefix + "/" + self.name + "/raw_loss"] = 0 loss_record[self.record_prefix + "/" + self.name + "/regularized_loss"] = 0
[docs]class ZINB(NB): r""" Build a Zero-Inflated Negative Binomial generative module. Parameters ---------- output_dim Dimensionality of the output tensor. full_latent_dim Dimensionality of the latent variable and Numbers of batches. h_dim Dimensionality of the hidden layers in the decoder MLP. depth Number of hidden layers in the decoder MLP. dropout Dropout rate. lambda_reg Regularization strength for the generative model parameters. Here log-scale variance of the scale parameter is regularized to improve numerical stability. fine_tune Whether the module is used in fine-tuning. deviation_reg Regularization strength for the deviation from original model weights. name Name of the module. """ def __init__( self, output_dim: int, full_latent_dim: typing.Tuple[int], h_dim: int = 128, depth: int = 1, dropout: float = 0.0, lambda_reg: float = 0.0, fine_tune: bool = False, deviation_reg: float = 0.0, name: str = "ZINB", _class: str = "ZINB", **kwargs, ) -> None: super().__init__( output_dim, full_latent_dim, h_dim, depth, dropout, lambda_reg, fine_tune, deviation_reg, name, _class, **kwargs, ) self.pi = ( Linear(h_dim, output_dim) if depth > 0 else Linear(full_latent_dim, output_dim) ) # fine-tune def save_origin_state(self) -> None: self.mlp.save_origin_state() self.mu.save_origin_state() self.log_theta.save_origin_state() self.pi.save_origin_state() # fine-tune def deviation_loss(self) -> torch.Tensor: return self.deviation_reg * ( self.mlp.deviation_loss() + self.mu.deviation_loss() + self.log_theta.deviation_loss() + self.pi.deviation_loss() ) # fine_tune def check_fine_tune(self) -> None: if self.fine_tune: self.save_origin_state() @staticmethod def log_likelihood( x: torch.Tensor, mu: torch.Tensor, log_theta: torch.Tensor, pi: torch.tensor, eps: float = 1e-8, ) -> torch.Tensor: theta = torch.exp(log_theta) case_zero = F.softplus( -pi + theta * log_theta - theta * torch.log(theta + mu + eps) ) - F.softplus(-pi) case_non_zero = ( -pi - F.softplus(-pi) + theta * log_theta - theta * torch.log(theta + mu + eps) + x * torch.log(mu + eps) - x * torch.log(theta + mu + eps) + torch.lgamma(x + theta) - torch.lgamma(theta) - torch.lgamma(x + 1) ) mask = (x < eps).float() res = mask * case_zero + (1 - mask) * case_non_zero return res
[docs] def forward( self, full_x: typing.Tuple[torch.Tensor], feed_dict: typing.Mapping ) -> torch.Tensor: y = feed_dict["exprs"] x = self.mlp(full_x) softmax_mu = self.softmax(self.mu(x)) mu = softmax_mu * y.sum(dim=1, keepdim=True) log_theta = self.log_theta(x) pi = self.pi(x) return mu, log_theta, pi
def loss( self, mu_theta_pi: typing.Tuple[torch.Tensor], feed_dict: typing.Mapping, loss_record: typing.Mapping, ) -> torch.Tensor: y = feed_dict["exprs"] mu, log_theta, pi = mu_theta_pi raw_loss = -self.log_likelihood(y, mu, log_theta, pi).mean() loss_record[self.record_prefix + "/" + self.name + "/raw_loss"] += ( raw_loss.item() * mu.shape[0] ) reg_loss = raw_loss + self.lambda_reg * log_theta.var() loss_record[self.record_prefix + "/" + self.name + "/regularized_loss"] += ( reg_loss.item() * mu.shape[0] ) if self.fine_tune: return reg_loss + self.deviation_reg * self.deviation_loss() else: return reg_loss
[docs]class LN(ProbModel): r""" Build a Log Normal generative module. Parameters ---------- output_dim Dimensionality of the output tensor. full_latent_dim Dimensionality of the latent variable and Numbers of batches. h_dim Dimensionality of the hidden layers in the decoder MLP. depth Number of hidden layers in the decoder MLP. dropout Dropout rate. lambda_reg NOT USED. fine_tune Whether the module is used in fine-tuning. deviation_reg Regularization strength for the deviation from original model weights. name Name of the module. """ def __init__( self, output_dim: int, full_latent_dim: typing.Tuple[int], h_dim: int = 128, depth: int = 1, dropout: float = 0.0, lambda_reg: float = 0.0, fine_tune: bool = False, deviation_reg: float = 0.0, name: str = "LN", _class: str = "LN", **kwargs, ) -> None: super().__init__( output_dim, full_latent_dim, h_dim, depth, dropout, lambda_reg, fine_tune, deviation_reg, name, _class, **kwargs, ) self.mu = ( Linear(h_dim, output_dim) if depth > 0 else Linear(full_latent_dim, output_dim) ) self.log_var = ( Linear(h_dim, output_dim) if depth > 0 else Linear(full_latent_dim, output_dim) ) # fine-tune def save_origin_state(self) -> None: self.mlp.save_origin_state() self.mu.save_origin_state() self.log_var.save_origin_state() # fine-tune def deviation_loss(self) -> torch.Tensor: return self.deviation_reg * ( self.mlp.deviation_loss() + self.mu.deviation_loss() + self.log_var.deviation_loss() ) # fine_tune def check_fine_tune(self) -> None: if self.fine_tune: self.save_origin_state() @staticmethod def log_likelihood( x: torch.Tensor, mu: torch.Tensor, log_var: torch.Tensor ) -> torch.Tensor: return -0.5 * ( torch.square(x - mu) / torch.exp(log_var) + math.log(2 * math.pi) + log_var )
[docs] def forward( self, full_x: typing.Tuple[torch.Tensor], feed_dict: typing.Mapping ) -> torch.Tensor: x = self.mlp(full_x) mu = torch.expm1(self.mu(x)) log_var = self.log_var(x) return mu, log_var
def loss( self, mu_var: typing.Tuple[torch.Tensor], feed_dict: typing.Mapping, loss_record: typing.Mapping, ) -> torch.Tensor: y = feed_dict["exprs"] mu, log_var = mu_var raw_loss = -self.log_likelihood(torch.log1p(y), mu, log_var).mean() loss_record[self.record_prefix + "/" + self.name + "/raw_loss"] += ( raw_loss.item() * mu.shape[0] ) reg_loss = raw_loss loss_record[self.record_prefix + "/" + self.name + "/regularized_loss"] += ( reg_loss.item() * mu.shape[0] ) if self.fine_tune: return reg_loss + self.deviation_reg * self.deviation_loss() else: return reg_loss def init_loss_record(self, loss_record: typing.Mapping) -> None: loss_record[self.record_prefix + "/" + self.name + "/raw_loss"] = 0 loss_record[self.record_prefix + "/" + self.name + "/regularized_loss"] = 0
[docs]class ZILN(LN): r""" Build a Zero-Inflated Log Normal generative module. Parameters ---------- output_dim Dimensionality of the output tensor. full_latent_dim Dimensionality of the latent variable and Numbers of batches. h_dim Dimensionality of the hidden layers in the decoder MLP. depth Number of hidden layers in the decoder MLP. dropout Dropout rate. lambda_reg NOT USED. fine_tune Whether the module is used in fine-tuning. deviation_reg Regularization strength for the deviation from original model weights. name Name of the module. """ def __init__( self, output_dim: int, full_latent_dim: typing.Tuple[int], h_dim: int = 128, depth: int = 1, dropout: float = 0.0, lambda_reg: float = 0.0, fine_tune: bool = False, deviation_reg: float = 0.0, name: str = "ZILN", _class: str = "ZILN", **kwargs, ) -> None: super().__init__( output_dim, full_latent_dim, h_dim, depth, dropout, lambda_reg, fine_tune, deviation_reg, name, _class, **kwargs, ) self.pi = ( Linear(h_dim, output_dim) if depth > 0 else Linear(full_latent_dim, output_dim) ) # fine-tune def save_origin_state(self) -> None: self.mlp.save_origin_state() self.mu.save_origin_state() self.log_var.save_origin_state() self.pi.save_origin_state() # fine-tune def deviation_loss(self) -> torch.Tensor: return self.deviation_reg * ( self.mlp.deviation_loss() + self.mu.deviation_loss() + self.log_var.deviation_loss() + self.pi.deviation_loss() ) # fine_tune def check_fine_tune(self) -> None: if self.fine_tune: self.save_origin_state() @staticmethod def log_likelihood( x: torch.Tensor, mu: torch.Tensor, log_var: torch.Tensor, pi: torch.Tensor, eps: float = 1e-8, ) -> torch.Tensor: case_zero = -F.softplus(-pi) case_non_zero = ( -pi - F.softplus(-pi) - 0.5 * ( torch.square(x - mu) / torch.exp(log_var) + math.log(2 * math.pi) + log_var ) ) mask = (x < eps).float() res = mask * case_zero + (1 - mask) * case_non_zero return res
[docs] def forward( self, full_x: typing.Tuple[torch.Tensor], feed_dict: typing.Mapping ) -> torch.Tensor: x = self.mlp(full_x) mu = torch.expm1(self.mu(x)) log_var = self.log_var(x) pi = self.pi(x) return mu, log_var, pi
def loss( self, mu_var_pi: typing.Tuple[torch.Tensor], feed_dict: typing.Mapping, loss_record: typing.Mapping, ) -> torch.Tensor: y = feed_dict["exprs"] mu, log_var, pi = mu_var_pi raw_loss = -self.log_likelihood(torch.log1p(y), mu, log_var, pi).mean() loss_record[self.record_prefix + "/" + self.name + "/raw_loss"] += ( raw_loss.item() * mu.shape[0] ) reg_loss = raw_loss loss_record[self.record_prefix + "/" + self.name + "/regularized_loss"] += ( reg_loss.item() * mu.shape[0] ) if self.fine_tune: return reg_loss + self.deviation_reg * self.deviation_loss() else: return reg_loss
[docs]class MSE(ProbModel): def __init__(self, *args, **kwargs): utils.logger.warning( "Prob module `MSE` is no longer supported, running as `ProbModel`" ) super().__init__(*args, **kwargs)