canonical_sets.gan.lucidgan.LUCIDGAN
- class LUCIDGAN(embedding_dim=128, generator_dim=(256, 256), discriminator_dim=(256, 256), generator_lr=0.0002, generator_decay=1e-06, discriminator_lr=0.0002, discriminator_decay=1e-06, batch_size=500, discriminator_steps=1, log_frequency=True, epochs=300, pac=10)[source]
Bases:
CTGANModel wrapping CTGAN model.
This class is based on the CTGAN class from the ctgan package. It has been modified to fix several bugs (see PRs on the ctgan GitHub page) and to allow for the extension of the conditional vector. Note that a part of the code and comments is identical to the original CTGAN class.
Initialize LUCIDGAN.
- Parameters
embedding_dim (int) – Size of the random noise passed to the generator. Defaults to 128.
generator_dim (tuple of int) – Size of the output samples for each one of the residuals. A residual Layer will be created for each one of the values provided. Defaults to (256, 256).
discriminator_dim (tuple of int) – Size of the output samples for each one of the discriminator layers. A linear Layer will be created for each one of the values provided. Defaults to (256, 256).
generator_lr (float) – Learning rate for the generator. Defaults to 2e-4.
generator_decay (float) – Generator weight decay for the Adam Optimizer. Defaults to 1e-6.
discriminator_lr (float) – Learning rate for the discriminator. Defaults to 2e-4.
discriminator_decay (float) – Discriminator weight decay for the Adam Optimizer. Defaults to 1e-6.
batch_size (int) – Number of data samples to process in each step.
discriminator_steps (int) – Number of discriminator updates to do for each generator update. From the WGAN paper: https://arxiv.org/abs/1701.07875. WGAN paper default is 5. Default used is 1 to match original CTGAN implementation.
log_frequency (bool) – Whether to use log frequency of categorical levels in conditional sampling. Defaults to
True.epochs (int) – Number of training epochs. Defaults to 300.
pac (int) – Number of samples to group together when applying the discriminator. Defaults to 10.
Methods
fitLoad the model stored in the passed path.
sampleSave the model in the passed path.
Set the device to be used ('GPU' or 'CPU).
Set the random state.
Attributes
random_states- classmethod load(path)
Load the model stored in the passed path.
- save(path)
Save the model in the passed path.
- set_device(device)
Set the device to be used (‘GPU’ or ‘CPU).