+import pickle
+import typing
+from abc import ABC, abstractmethod
+from configparser import ConfigParser
+
+import torch
+
+from gans.causal_gan import CausalGAN
+from gans.conditional_gan_cat import ConditionalCatGAN
+from gans.conditional_gan_proj import ConditionalProjGAN
+from gans.gan import GAN
+
+
+
+
[docs]
+
def parse_list(str_list: str, type_: type) -> list:
+
return list(map(type_, str.split(str_list)))
+
+
+
+
+
[docs]
+
class IGANFactory(ABC):
+
"""
+
Factory that represents a GAN.
+
This factory does not keep of created references.
+
"""
+
+
+
[docs]
+
def __init__(self, parser: ConfigParser) -> None:
+
"""
+
Initialize the factory.
+
+
Parameters
+
----------
+
parser : ConfigParser
+
Parser for config file containing GAN model and training params.
+
"""
+
self.parser = parser
+
+
+
+
[docs]
+
@abstractmethod
+
def get_gan(self) -> GAN:
+
"""
+
Returns a GAN instance
+
+
Returns
+
-------
+
GAN
+
GAN instance.
+
"""
+
pass
+
+
+
+
[docs]
+
@abstractmethod
+
def get_trainer(self) -> typing.Callable:
+
"""
+
Returns the GAN train function.
+
+
Returns
+
-------
+
typing.Callable
+
GAN train() function.
+
"""
+
pass
+
+
+
+
+
+
[docs]
+
class GANFactory(IGANFactory):
+
+
[docs]
+
def get_gan(self) -> GAN:
+
return GAN(
+
genes_no=self.parser.getint("Data", "number of genes"),
+
batch_size=self.parser.getint("Training", "batch size"),
+
latent_dim=self.parser.getint("Model", "latent dim"),
+
gen_layers=parse_list(self.parser["Model"]["generator layers"], int),
+
crit_layers=parse_list(self.parser["Model"]["critic layers"], int),
+
device=self.parser.get("EXPERIMENT", "device", fallback=None),
+
library_size=self.parser.getint("Preprocessing", "library size"),
+
)
+
+
+
+
[docs]
+
def get_trainer(self) -> typing.Callable:
+
gan = self.get_gan()
+
return lambda: gan.train(
+
train_files=self.parser.get("Data", "train"),
+
valid_files=self.parser.get("Data", "validation"),
+
critic_iter=self.parser.getint("Training", "critic iterations"),
+
max_steps=self.parser.getint("Training", "maximum steps"),
+
c_lambda=self.parser.getfloat("Model", "lambda"),
+
beta1=self.parser.getfloat("Optimizer", "beta1"),
+
beta2=self.parser.getfloat("Optimizer", "beta2"),
+
gen_alpha_0=self.parser.getfloat("Learning Rate", "generator initial"),
+
gen_alpha_final=self.parser.getfloat("Learning Rate", "generator final"),
+
crit_alpha_0=self.parser.getfloat("Learning Rate", "critic initial"),
+
crit_alpha_final=self.parser.getfloat("Learning Rate", "critic final"),
+
checkpoint=self.parser.get("EXPERIMENT", "checkpoint", fallback=None),
+
summary_freq=self.parser.getint("Logging", "summary frequency"),
+
plt_freq=self.parser.getint("Logging", "plot frequency"),
+
save_feq=self.parser.getint("Logging", "save frequency"),
+
output_dir=self.parser.get("EXPERIMENT", "output directory"),
+
)
+
+
+
+
+
+
[docs]
+
class ConditionalCatGANFactory(IGANFactory):
+
+
[docs]
+
def get_gan(self) -> ConditionalCatGAN:
+
return ConditionalCatGAN(
+
genes_no=self.parser.getint("Data", "number of genes"),
+
batch_size=self.parser.getint("Training", "batch size"),
+
latent_dim=self.parser.getint("Model", "latent dim"),
+
gen_layers=parse_list(self.parser["Model"]["generator layers"], int),
+
crit_layers=parse_list(self.parser["Model"]["critic layers"], int),
+
num_classes=self.parser.getint("Data", "number of classes"),
+
label_ratios=torch.Tensor(
+
parse_list(self.parser["Data"]["label ratios"], float)
+
),
+
device=self.parser.get("EXPERIMENT", "device", fallback=None),
+
library_size=self.parser.getint("Preprocessing", "library size"),
+
)
+
+
+
+
[docs]
+
def get_trainer(self) -> typing.Callable:
+
gan = self.get_gan()
+
return lambda: gan.train(
+
train_files=self.parser.get("Data", "train"),
+
valid_files=self.parser.get("Data", "validation"),
+
critic_iter=self.parser.getint("Training", "critic iterations"),
+
max_steps=self.parser.getint("Training", "maximum steps"),
+
c_lambda=self.parser.getfloat("Model", "lambda"),
+
beta1=self.parser.getfloat("Optimizer", "beta1"),
+
beta2=self.parser.getfloat("Optimizer", "beta2"),
+
gen_alpha_0=self.parser.getfloat("Learning Rate", "generator initial"),
+
gen_alpha_final=self.parser.getfloat("Learning Rate", "generator final"),
+
crit_alpha_0=self.parser.getfloat("Learning Rate", "critic initial"),
+
crit_alpha_final=self.parser.getfloat("Learning Rate", "critic final"),
+
checkpoint=self.parser.get("EXPERIMENT", "checkpoint", fallback=None),
+
summary_freq=self.parser.getint("Logging", "summary frequency"),
+
plt_freq=self.parser.getint("Logging", "plot frequency"),
+
save_feq=self.parser.getint("Logging", "save frequency"),
+
output_dir=self.parser.get("EXPERIMENT", "output directory"),
+
)
+
+
+
+
+
+
[docs]
+
class ConditionalProjGANFactory(IGANFactory):
+
+
[docs]
+
def get_gan(self) -> ConditionalProjGAN:
+
return ConditionalProjGAN(
+
genes_no=self.parser.getint("Data", "number of genes"),
+
batch_size=self.parser.getint("Training", "batch size"),
+
latent_dim=self.parser.getint("Model", "latent dim"),
+
gen_layers=parse_list(self.parser["Model"]["generator layers"], int),
+
crit_layers=parse_list(self.parser["Model"]["critic layers"], int),
+
num_classes=self.parser.getint("Data", "number of classes"),
+
label_ratios=torch.Tensor(
+
parse_list(self.parser["Data"]["label ratios"], float)
+
),
+
device=self.parser.get("EXPERIMENT", "device", fallback=None),
+
library_size=self.parser.getint("Preprocessing", "library size"),
+
)
+
+
+
+
[docs]
+
def get_trainer(self) -> typing.Callable:
+
gan = self.get_gan()
+
return lambda: gan.train(
+
train_files=self.parser.get("Data", "train"),
+
valid_files=self.parser.get("Data", "validation"),
+
critic_iter=self.parser.getint("Training", "critic iterations"),
+
max_steps=self.parser.getint("Training", "maximum steps"),
+
c_lambda=self.parser.getfloat("Model", "lambda"),
+
beta1=self.parser.getfloat("Optimizer", "beta1"),
+
beta2=self.parser.getfloat("Optimizer", "beta2"),
+
gen_alpha_0=self.parser.getfloat("Learning Rate", "generator initial"),
+
gen_alpha_final=self.parser.getfloat("Learning Rate", "generator final"),
+
crit_alpha_0=self.parser.getfloat("Learning Rate", "critic initial"),
+
crit_alpha_final=self.parser.getfloat("Learning Rate", "critic final"),
+
checkpoint=self.parser.get("EXPERIMENT", "checkpoint", fallback=None),
+
summary_freq=self.parser.getint("Logging", "summary frequency"),
+
plt_freq=self.parser.getint("Logging", "plot frequency"),
+
save_feq=self.parser.getint("Logging", "save frequency"),
+
output_dir=self.parser.get("EXPERIMENT", "output directory"),
+
)
+
+
+
+
+
+
[docs]
+
class CausalGANFactory(IGANFactory):
+
+
[docs]
+
def get_cc(self) -> GAN:
+
return GAN(
+
genes_no=self.parser.getint("Data", "number of genes"),
+
batch_size=self.parser.getint("CC Training", "batch size"),
+
latent_dim=self.parser.getint("CC Model", "latent dim"),
+
gen_layers=parse_list(self.parser["CC Model"]["generator layers"], int),
+
crit_layers=parse_list(self.parser["CC Model"]["critic layers"], int),
+
device=self.parser.get("EXPERIMENT", "device", fallback=None),
+
library_size=self.parser.getint("Preprocessing", "library size"),
+
)
+
+
+
+
[docs]
+
def get_gan(self) -> CausalGAN:
+
with open(self.parser.get("Data", "causal graph"), "rb") as fp:
+
causal_graph = pickle.load(fp)
+
+
return CausalGAN(
+
genes_no=self.parser.getint("Data", "number of genes"),
+
batch_size=self.parser.getint("Training", "batch size"),
+
latent_dim=self.parser.getint("Model", "latent dim"),
+
noise_per_gene=self.parser.getint("Model", "noise per gene"),
+
depth_per_gene=self.parser.getint("Model", "depth per gene"),
+
width_per_gene=self.parser.getint("Model", "width per gene"),
+
cc_latent_dim=self.parser.getint("CC Model", "latent dim"),
+
cc_layers=parse_list(self.parser["CC Model"]["generator layers"], int),
+
cc_pretrained_checkpoint=self.parser.get("EXPERIMENT", "output directory")
+
+ f"_CC/checkpoints/step_{self.parser.getint('CC Training', 'maximum steps')}.pth",
+
crit_layers=parse_list(self.parser["Model"]["critic layers"], int),
+
causal_graph=causal_graph,
+
labeler_layers=parse_list(self.parser["Model"]["labeler layers"], int),
+
device=self.parser.get("EXPERIMENT", "device", fallback=None),
+
library_size=self.parser.getint("Preprocessing", "library size"),
+
)
+
+
+
+
[docs]
+
def get_trainer(self) -> typing.Callable:
+
cc = self.get_cc()
+
+
# the following lambda will train the causal controller for maximum steps
+
# specified in the CC Training section of the config file
+
# after training the causal controller, the causal GAN will be instantiated
+
# with the pretrained causal controller and training will start.
+
return lambda: (
+
cc.train(
+
train_files=self.parser.get("Data", "train"),
+
valid_files=self.parser.get("Data", "validation"),
+
critic_iter=self.parser.getint("CC Training", "critic iterations"),
+
max_steps=self.parser.getint("CC Training", "maximum steps"),
+
c_lambda=self.parser.getfloat("CC Model", "lambda"),
+
beta1=self.parser.getfloat("CC Optimizer", "beta1"),
+
beta2=self.parser.getfloat("CC Optimizer", "beta2"),
+
gen_alpha_0=self.parser.getfloat(
+
"CC Learning Rate", "generator initial"
+
),
+
gen_alpha_final=self.parser.getfloat(
+
"CC Learning Rate", "generator final"
+
),
+
crit_alpha_0=self.parser.getfloat("CC Learning Rate", "critic initial"),
+
crit_alpha_final=self.parser.getfloat(
+
"CC Learning Rate", "critic final"
+
),
+
checkpoint=self.parser.get("EXPERIMENT", "checkpoint", fallback=None),
+
summary_freq=self.parser.getint("CC Logging", "summary frequency"),
+
plt_freq=self.parser.getint("CC Logging", "plot frequency"),
+
save_feq=self.parser.getint("CC Logging", "save frequency"),
+
output_dir=self.parser.get("EXPERIMENT", "output directory") + "_CC",
+
),
+
self.get_gan().train(
+
train_files=self.parser.get("Data", "train"),
+
valid_files=self.parser.get("Data", "validation"),
+
critic_iter=self.parser.getint("Training", "critic iterations"),
+
max_steps=self.parser.getint("Training", "maximum steps"),
+
c_lambda=self.parser.getfloat("Model", "lambda"),
+
beta1=self.parser.getfloat("Optimizer", "beta1"),
+
beta2=self.parser.getfloat("Optimizer", "beta2"),
+
gen_alpha_0=self.parser.getfloat("Learning Rate", "generator initial"),
+
gen_alpha_final=self.parser.getfloat(
+
"Learning Rate", "generator final"
+
),
+
crit_alpha_0=self.parser.getfloat("Learning Rate", "critic initial"),
+
crit_alpha_final=self.parser.getfloat("Learning Rate", "critic final"),
+
labeler_alpha=self.parser.getfloat("Learning Rate", "labeler"),
+
antilabeler_alpha=self.parser.getfloat("Learning Rate", "antilabeler"),
+
labeler_training_interval=self.parser.getfloat(
+
"Training", "labeler and antilabeler training intervals"
+
),
+
checkpoint=self.parser.get("EXPERIMENT", "checkpoint", fallback=None),
+
summary_freq=self.parser.getint("Logging", "summary frequency"),
+
plt_freq=self.parser.getint("Logging", "plot frequency"),
+
save_feq=self.parser.getint("Logging", "save frequency"),
+
output_dir=self.parser.get("EXPERIMENT", "output directory"),
+
),
+
)[0]
+
+
+
+
+
+
[docs]
+
def get_factory(cfg: ConfigParser) -> IGANFactory:
+
"""
+
Return the factory for the GAN type based on 'model' key in the parser.
+
+
Parameters
+
----------
+
cfg : ConfigParser
+
Parser for config file containing GAN model and training params.
+
+
Returns
+
-------
+
IGANFactory
+
Factory for the specified GAN.
+
+
Raises
+
------
+
ValueError
+
If the model is unknown or not implemented.
+
"""
+
# read the desired GAN
+
model = cfg.get("Model", "type")
+
factories = {
+
"GAN": GANFactory(cfg),
+
"proj conditional GAN": ConditionalProjGANFactory(cfg),
+
"cat conditional GAN": ConditionalCatGANFactory(cfg),
+
"causal GAN": CausalGANFactory(cfg),
+
}
+
+
if model in factories:
+
return factories[model]
+
raise ValueError(f"model '{model}' type is invalid")
+
+
+
+