Source code for canonical_sets.gan.lucidgan

"""LUCID-GAN."""
import warnings
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
import torch
import tqdm
from ctgan import CTGAN
from ctgan.data_transformer import DataTransformer
from ctgan.synthesizers.base import random_state
from ctgan.synthesizers.ctgan import Discriminator, Generator
from torch import optim

from canonical_sets.gan.sampler import _Sampler


[docs]class LUCIDGAN(CTGAN): """Model wrapping `CTGAN` model. This class is based on the `CTGAN` class from the `ctgan` package. It has been modified to fix several bugs (see PRs on the `ctgan` GitHub page) and to allow for the extension of the conditional vector. Note that a part of the code and comments is identical to the original `CTGAN` class. """ def __init__( self, embedding_dim: int = 128, generator_dim: Tuple[int, int] = (256, 256), discriminator_dim: Tuple[int, int] = (256, 256), generator_lr: float = 2e-4, generator_decay: float = 1e-6, discriminator_lr: float = 2e-4, discriminator_decay: float = 1e-6, batch_size: int = 500, discriminator_steps: int = 1, log_frequency: bool = True, epochs: int = 300, pac: int = 10, ): """Initialize LUCIDGAN. Parameters ---------- embedding_dim : int Size of the random noise passed to the generator. Defaults to 128. generator_dim : tuple of int Size of the output samples for each one of the residuals. A residual Layer will be created for each one of the values provided. Defaults to (256, 256). discriminator_dim : tuple of int Size of the output samples for each one of the discriminator layers. A linear Layer will be created for each one of the values provided. Defaults to (256, 256). generator_lr : float Learning rate for the generator. Defaults to 2e-4. generator_decay : float Generator weight decay for the Adam Optimizer. Defaults to 1e-6. discriminator_lr : float Learning rate for the discriminator. Defaults to 2e-4. discriminator_decay : float Discriminator weight decay for the Adam Optimizer. Defaults to 1e-6. batch_size : int Number of data samples to process in each step. discriminator_steps : int Number of discriminator updates to do for each generator update. From the WGAN paper: https://arxiv.org/abs/1701.07875. WGAN paper default is 5. Default used is 1 to match original CTGAN implementation. log_frequency : bool Whether to use log frequency of categorical levels in conditional sampling. Defaults to ``True``. epochs : int Number of training epochs. Defaults to 300. pac : int Number of samples to group together when applying the discriminator. Defaults to 10. Attributes ---------- generator_loss : list of torch.Tensor Generator loss at each epoch. reconsutrction_loss : list of torch.Tensor Reconstruction loss at each epoch. discriminator_loss : list of torch.Tensor Discriminator loss at each epoch. """ super().__init__( embedding_dim=embedding_dim, generator_dim=generator_dim, discriminator_dim=discriminator_dim, generator_lr=generator_lr, generator_decay=generator_decay, discriminator_lr=discriminator_lr, discriminator_decay=discriminator_decay, batch_size=batch_size, discriminator_steps=discriminator_steps, log_frequency=log_frequency, verbose=False, epochs=epochs, pac=pac, cuda=False, ) self.generator_loss: List[torch.Tensor] = [] self.reconstruction_loss: List[torch.Tensor] = [] self.discriminator_loss: List[torch.Tensor] = [] @random_state def fit( self, train_data: pd.DataFrame, discrete_columns: Optional[List[str]] = None, conditional: Optional[List[str]] = None, ): """Fit LUCID-GAN to the training data. Parameters ---------- train_data : pandas.DataFrame Training Data. It must be a 2-dimensional pandas.DataFrame where each column is a feature and each row is a sample. discrete_columns : list of str, optional List of discrete columns to be used to create the conditional vector. This list should contain the column names. Note that if ``None``, we select all columns which dtype is not ``number``. See pd.DataFrame.select_dtypes for more information. conditional : list of str, optional List of columns with the conditional features which should not be part of the generated data. Note that the columns in this list should be ``numeric``, and not be included in the ``discrete_columns``. """ self._n_conditions: Optional[int] = None self._conditions: Optional[np.ndarray] = None self._conditions_columns: Optional[List[str]] = None if conditional is not None: self._n_conditions = len(conditional) self._conditions = train_data[conditional].to_numpy() self._conditions_columns = conditional train_data = train_data.drop(conditional, axis=1) if discrete_columns is None: discrete_columns = train_data.select_dtypes( exclude=["number"] ).columns.tolist() self._validate_discrete_columns(train_data, discrete_columns) self._transformer = DataTransformer() self._transformer.fit(train_data, discrete_columns) train_data = self._transformer.transform(train_data) data_dim = self._transformer.output_dimensions self._data_sampler = _Sampler( train_data, self._transformer.output_info_list, self._log_frequency, self._conditions, ) if self._conditions is None: self._generator = Generator( self._embedding_dim + self._data_sampler.dim_cond_vec(), self._generator_dim, data_dim, ) discriminator = Discriminator( data_dim + self._data_sampler.dim_cond_vec(), self._discriminator_dim, pac=self.pac, ) else: self._generator = Generator( self._embedding_dim + self._data_sampler.dim_cond_vec() + self._n_conditions, self._generator_dim, data_dim, ) discriminator = Discriminator( data_dim + self._data_sampler.dim_cond_vec() + self._n_conditions, self._discriminator_dim, pac=self.pac, ) optimizerG = optim.Adam( self._generator.parameters(), lr=self._generator_lr, betas=(0.5, 0.9), weight_decay=self._generator_decay, ) optimizerD = optim.Adam( discriminator.parameters(), lr=self._discriminator_lr, betas=(0.5, 0.9), weight_decay=self._discriminator_decay, ) mean = torch.zeros(self._batch_size, self._embedding_dim) std = mean + 1 steps_per_epoch = max(len(train_data) // self._batch_size, 1) with tqdm.trange(self._epochs) as epochs_bar: for i in epochs_bar: for id_ in range(steps_per_epoch): # discriminator training loop for n in range(self._discriminator_steps): # generate noise for fake data fakez = torch.normal(mean=mean, std=std) # if no discrete columns, all are None condvec = self._data_sampler.sample_condvec( self._batch_size ) # Dict with c1, m1, col, opt # if condvec is None, set all of its elements to None if condvec is None: # if conditions are not provided, sample both # real data and conditions and create a random # permutation of the conditions for the fake # data generation and add it to the noise if self._conditions is not None: ( real, conditions_real, ) = self._data_sampler.sample_data( self._batch_size ) perm = np.arange(self._batch_size) np.random.shuffle(perm) conditions_fake = conditions_real[perm, :] fakez = torch.cat( [fakez, conditions_fake], dim=1 ) # else, only sample real data unconditionally, # i.e. no conditional vector and conditions else: real = self._data_sampler.sample_data( self._batch_size ) # if discrete columns are provided, sample # conditional vectors and create a random # permutation of the conditional vectors for # the fake data generation else: c1, m1, col, opt = condvec c1 = torch.from_numpy(c1) m1 = torch.from_numpy(m1) fakez = torch.cat([fakez, c1], dim=1) perm = np.arange(self._batch_size) np.random.shuffle(perm) c2 = c1[perm] # if there are no conditions, sample only real # data, conditional on the permutated conditional # vector if self._conditions is None: real = self._data_sampler.sample_data( self._batch_size, col[perm], opt[perm] ) # else sample both real data and conditions, # conditional on the conditional vector and # create the same permutation of the conditions # (as for the conditional vector) for the fake # data generation else: ( real, conditions, ) = self._data_sampler.sample_data( self._batch_size, col[perm], opt[perm] ) conditions_c2 = conditions conditions_c1 = conditions_c2[ np.argsort(perm), : ] fakez = torch.cat( [fakez, conditions_c1], dim=1 ) fake = self._generator(fakez) fakeact = self._apply_activate(fake) if condvec is None: if self._conditions is None: real_cat = real fake_cat = fakeact else: real_cat = torch.cat( [real, conditions_real], dim=1 ) fake_cat = torch.cat( [fakeact, conditions_fake], dim=1 ) else: if self._conditions is None: real_cat = torch.cat([real, c2], dim=1) fake_cat = torch.cat([fakeact, c1], dim=1) else: real_cat = torch.cat( [real, c2, conditions_c2], dim=1 ) fake_cat = torch.cat( [fakeact, c1, conditions_c1], dim=1 ) y_fake = discriminator(fake_cat) y_real = discriminator(real_cat) pen = discriminator.calc_gradient_penalty( real_cat, fake_cat, pac=self.pac ) loss_d = -(torch.mean(y_real) - torch.mean(y_fake)) optimizerD.zero_grad() pen.backward(retain_graph=True) loss_d.backward() optimizerD.step() fakez = torch.normal(mean=mean, std=std) condvec = self._data_sampler.sample_condvec( self._batch_size ) if condvec is None: c1, m1, col, opt = None, None, None, None else: c1, m1, col, opt = condvec c1 = torch.from_numpy(c1) m1 = torch.from_numpy(m1) fakez = torch.cat([fakez, c1], dim=1) if self._conditions is not None: real, conditions = self._data_sampler.sample_data( self._batch_size, col, opt ) fakez = torch.cat([fakez, conditions], dim=1) fake = self._generator(fakez) fakeact = self._apply_activate(fake) if condvec is None: if self._conditions is None: y_fake = discriminator(fakeact) else: y_fake = discriminator( torch.cat([fakeact, conditions], dim=1) ) else: if self._conditions is None: y_fake = discriminator( torch.cat([fakeact, c1], dim=1) ) else: y_fake = discriminator( torch.cat([fakeact, c1, conditions], dim=1) ) if condvec is None: cross_entropy = torch.zeros(1) else: cross_entropy = self._cond_loss(fake, c1, m1) loss_g = -torch.mean(y_fake) + cross_entropy optimizerG.zero_grad() loss_g.backward() optimizerG.step() epochs_bar.set_description( f"Epoch {i}, " f"Loss G: {-torch.mean(y_fake).detach().cpu(): .4f}, " f"Loss R: {cross_entropy.detach().cpu().item(): .4f}, " f"Loss D: {loss_d.detach().cpu(): .4f}" ) self.generator_loss.append(loss_g.detach().cpu()) self.reconstruction_loss.append(cross_entropy.detach().cpu()) self.discriminator_loss.append(loss_d.detach().cpu()) @random_state def sample( self, n: int, condition_column: Optional[Union[List[str], str]] = None, condition_value: Optional[Union[List[str], str]] = None, conditional: Optional[pd.DataFrame] = None, empirical: bool = False, ): """Sample data from LUCIDGAN. Choosing a condition_column and condition_value will increase the probability of the discrete condition_value happening in the condition_column. Parameters ---------- n : int Number of samples to generate. condition_column : str or list of str, optional Name(s) of the discrete column(s) to condition on. Both condition_column and condition_value must be specified to condition on a discrete column. Otherwise this argument is ignored. condition_value : str or list of str, optional Name(s) of the category in the condition_column(s) which we wish to increase the probability of happening. Both condition_value and condition_column must be specified to condition on a discrete column. Otherwise this argument is ignored. conditional : pandas.DataFrame, optional A 2-dimensional ``pandas.DataFrame`` with ``numeric`` values for the conditional features. The number of rows of the 2d array should be the same as ``n`` or be equal to one (in which case it will be repeated ``n`` times). The number of columns should be equal to the number of conditional features. empirical : bool Whether to use the empirical distribution of the data to sample from. If False, the generator will be used to sample from the learned distribution. Default is False. Returns ------- pandas.DataFrame A ``pandas.DataFrame`` with ``n`` samples. """ self._generator.eval() if condition_column is not None and condition_value is not None: if isinstance(condition_column, str) and isinstance( condition_value, str ): condition_column = [condition_column] condition_value = [condition_value] condition_info = self._convert_column_name_value_to_id( condition_column, condition_value # type: ignore ) global_condition_vec = ( self._data_sampler.generate_cond_from_condition_column_info( condition_info, self._batch_size ) ) else: global_condition_vec = None if conditional is not None and self._conditions is not None: if not all( conditional.apply( lambda s: pd.to_numeric(s, errors="coerce").notnull().all() ) ): raise ValueError("The conditional data must be numeric.") conditional = torch.from_numpy( conditional.to_numpy().astype("float32") ) if conditional.shape[0] == 1: conditional = conditional.repeat(self._batch_size, 1) elif conditional.shape[0] != n: raise ValueError( "conditional must have either length one or n." ) elif conditional is None and self._conditions is not None: cond_len = self._conditions.shape[0] elif conditional is not None and self._conditions is None: warnings.warn( "self._conditions is None so the specified conditional " "argument will be ignored." ) steps = n // self._batch_size + 1 data = [] if self._conditions is not None: data_cond = [] for i in range(steps): mean = torch.zeros(self._batch_size, self._embedding_dim) std = mean + 1 fakez = torch.normal(mean=mean, std=std) if global_condition_vec is not None: condvec = global_condition_vec.copy() if conditional is not None and self._conditions is not None: conditions = conditional elif conditional is None and self._conditions is not None: conditions = self._conditions[ np.random.randint(cond_len, size=self._batch_size) ] else: if conditional is not None and self._conditions is not None: if empirical: ( condvec, _, ) = self._data_sampler.sample_original_condvec( self._batch_size ) else: condvec = np.zeros( ( self._batch_size, self._data_sampler.dim_cond_vec(), ), dtype=np.float32, ) conditions = conditional elif conditional is None and self._conditions is not None: if empirical: ( condvec, conditions, ) = self._data_sampler.sample_original_condvec( self._batch_size ) else: ( _, conditions, ) = self._data_sampler.sample_original_condvec( self._batch_size ) condvec = np.zeros( ( self._batch_size, self._data_sampler.dim_cond_vec(), ), dtype=np.float32, ) else: if empirical: condvec = self._data_sampler.sample_original_condvec( self._batch_size ) # type: ignore else: condvec = np.zeros( ( self._batch_size, self._data_sampler.dim_cond_vec(), ), dtype=np.float32, ) if condvec is not None: c1 = torch.from_numpy(condvec) fakez = torch.cat([fakez, c1], dim=1) if self._conditions is not None: fakez = torch.cat([fakez, conditions], dim=1) with torch.no_grad(): fake = self._generator(fakez) fakeact = self._apply_activate(fake) data.append(fakeact.detach().cpu().numpy()) if self._conditions is not None: data_cond.append(conditions.detach().cpu().numpy()) data = np.concatenate(data, axis=0) data = data[:n] if self._conditions is not None: data_cond = np.concatenate(data_cond, axis=0) data_cond = data_cond[:n] if self._conditions is not None: transformed_data = self._transformer.inverse_transform(data) transformed_cond = pd.DataFrame( data_cond, columns=self._conditions_columns ) return pd.concat([transformed_data, transformed_cond], axis=1) else: return self._transformer.inverse_transform(data) def _convert_column_name_value_to_id( self, column_names: List[str], values: List[str] ) -> List[Dict[str, int]]: """Get the ids of the given `column_name`. Parameters ---------- column_names : List[str] The column names. values : List[str] The values of the columns. Returns ------- List[Dict[str, int]] The ids of the given `column_name`. """ results = [] for column_name, value in zip(column_names, values): discrete_counter = 0 column_id = 0 for ( column_transform_info ) in self._transformer._column_transform_info_list: if column_transform_info.column_name == column_name: break if column_transform_info.column_type == "discrete": discrete_counter += 1 column_id += 1 else: raise ValueError( f"The column_name `{column_name}` " f"doesn't exist in the data." ) ohe = column_transform_info.transform data = pd.DataFrame( [value], columns=[column_transform_info.column_name] ) one_hot = ohe.transform(data).to_numpy()[0] if sum(one_hot) == 0: raise ValueError( f"The value `{value}` doesn't exist in the" f"column `{column_name}`." ) results.append( { "discrete_column_id": discrete_counter, "column_id": column_id, "value_id": np.argmax(one_hot), } ) return results # type: ignore