Source code for canonical_sets.lucid

"""LUCID."""

import warnings
from typing import List, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
import torch
import torch.nn as nn
import torch.optim as optim
from numpy.random import default_rng
from sklearn.base import TransformerMixin
from tqdm import tqdm

from canonical_sets.utils import safe_isinstance


[docs]class LUCID: """Gradient-based inverse design to generate canonical sets. This class generates a canonical set via inverse design and attributes the ``pd.DataFrame`` to ``results``. Attributes ---------- results : pd.DataFrame A dataframe with the canonical inputs. results_processed: pd.DataFrame A dataframe with the processed canonical inputs. Examples -------- >>> model = tf.keras.Model() >>> outputs = pd.DataFrame([[0, 1]], columns=["No", "Yes"]) >>> example_data = train_data >>> lucid = LUCID(model, outputs, example_data) >>> lucid.results """ results: pd.DataFrame results_processed: pd.DataFrame def __init__( self, model: Union[torch.nn.Module, tf.keras.Model], outputs: pd.DataFrame, example_data: pd.DataFrame, numb_of_samples: int = 100, numb_of_epochs: int = 200, lr: float = 0.1, low: float = -1, high: float = 1, seed: int = 1234, index: bool = True, extra_epoch: bool = True, one_hot_pre: bool = False, one_hot_post: bool = True, log_every_n: int = 0, prefix_sep: str = "+", ): """Initialize the inverse design. Parameters ---------- model : torch.nn.Module or tf.keras.Model The trained model to use for inverse design. outputs : pd.DataFrame The outputs to use for inverse design. These are the targets/labels that have been used during training. For example, ``pd.DataFrame([[0, 1]], columns=["<=50K", ">50K"])`` in the ``Adult`` data set. example_data : pd.DataFrame The example data to infer columns, dtypes, ... This is often (a part of) the training data itself, but can also be an artificial example. numb_of_samples : int The number of samples to generate. The default is 100. numb_of_epochs : int The number of epochs to train the model. The default is 200. lr : float The learning rate for the optimizer. The default is 0.1. low : float The lower bound for the random uniform distribution. The default is -1. high : float The upper bound for the random uniform distribution. The default is 1. seed : int The seed for the random number generator. The default is 1234. index : bool If True the sample and epoch numbers are used as indices in the results ``pd.DataFrame``. Otherwise they are just columns. The default is True. extra_epoch : bool If True an additional forward pass is run after the categorical features have been one-hot encoded (post-processed). The results are saved for the last sample as the ``numb_of_epochs`` + 1 epoch. If there are no categorical features the argument is ignored. The default is True. one_hot_pre : bool If True, the initial values for the categorical features are pre-processed to be one-hot. If there are no categorical features the argument is ignored. Note that the inverse design will start from this one-hot sample, hence the pre- process. If False, the inverse design will start from the randomly drawn initial vectors. The default is False. one_hot_post : bool If True, the values for the categorical features are post-processed to be one-hot. Note that the predictions during the inverse design are made with the original values of the categorical features and not with the post-processed values. To run an additional forward pass with the post-processed values check the ``extra_epoch`` argument. If there are no categorical features the argument is ignored. The default is True. log_every_n : int The number of epochs to log results. If 0, this argument is set equal to the ``numb_of_epochs`` argument which makes it a static analysis with only the start and end samples. The default is 0. prefix_sep : str The separator for the prefix of the column names. The one-hot encoded features are grouped via the prefix. To be safe, make sure that the prefix only appears as a prefix in the column names (i.e., avoid Categorical-category-name, and opt for Categorical+category-name instead). The default is "+". Raises ------ ValueError If any columns are neither integer (one-hot encoded) or float (numerical). ValueError If the model is neither a torch.nn.Module or (tf.)keras.Model. """ self.model = model self.outputs = outputs self.example_data = example_data self.numb_of_samples = numb_of_samples self.numb_of_epochs = numb_of_epochs self.lr = lr self.low = low self.high = high self.seed = seed self.index = index self.extra_epoch = extra_epoch self.one_hot_pre = one_hot_pre self.one_hot_post = one_hot_post self.log_every_n = log_every_n self.prefix_sep = prefix_sep if self.log_every_n == 0: self.log_every_n = self.numb_of_epochs self._processed = False # check whether all columns are either float or integer if len(self.example_data.columns) != len( self.example_data.select_dtypes( include=["float", "integer"] ).columns ): raise ValueError( ( "Some columns are neither integer (one-hot encoded) " "or float (numerical)" ) ) # store column information self._input_columns = self.example_data.columns.to_list() self._numb_of_inputs = len(self._input_columns) self._output_columns = self.outputs.columns.to_list() self.numerical_cols = self.example_data.select_dtypes( include="float" ).columns.to_list() self.categories_cols = self.example_data.select_dtypes( include="integer" ).columns.to_list() self.categories = list( dict.fromkeys( [ item.split(self.prefix_sep)[0] for item in self.categories_cols ] ) ) # generate random inputs self._rng = default_rng(seed=self.seed) self._inputs = self._rng.uniform( low=self.low, high=self.high, size=(self.numb_of_samples, self._numb_of_inputs), ) # one-hot-encode categorical features (pre-processing, optional) if self.one_hot_pre and self.categories_cols: self._inputs = self._one_hot_encode(self._inputs) # initialize results self._columns = ( ["sample", "epoch"] + self._output_columns + self._input_columns ) self.results = pd.DataFrame(columns=self._columns) # run inverse design loop if safe_isinstance(self.model, "torch.nn.Module"): self._inverse_loop_pt() elif safe_isinstance(self.model, "keras.Model"): self._inverse_loop_tf() else: raise ValueError( "model must be a torch.nn.Module or (tf.)keras.Model." ) # set types in results self.results["sample"] = self.results["sample"].astype(int) self.results["epoch"] = self.results["epoch"].astype(int) # save additional info self._min_epoch = 1 self._max_epoch = self.numb_of_epochs - ( self.numb_of_epochs % self.log_every_n ) # set index if self.index: self.results.set_index(["sample", "epoch"], inplace=True) # one-hot-encode categorical features (post-processing, optional) if self.one_hot_post and self.categories_cols: self.results = self._one_hot_encode(self.results)
[docs] def process_results(self, scaler: TransformerMixin = None) -> None: """Process the results by applying inverse scaler and one-hot encoding to categories. Parameters ---------- scaler: sklearn.base.TransformerMixin, optional Any of the ``sklearn`` preprocessing modules. The default is None which means there is no transformation on numerical features. """ if self._processed: warnings.warn("Results have already been processed before.") numbers = pd.DataFrame(index=self.results.index) dummies = pd.DataFrame(index=self.results.index) if self.numerical_cols and scaler: numbers = pd.DataFrame( scaler.inverse_transform(self.results[self.numerical_cols]), index=self.results.index, columns=self.numerical_cols, ) elif self.numerical_cols: numbers = self.results[self.numerical_cols] if self.categories_cols: cols2collapse = { item.split(self.prefix_sep)[0]: (self.prefix_sep in item) for item in self.categories_cols } series_list = [] for col, needs_to_collapse in cols2collapse.items(): if needs_to_collapse: undummified = ( self.results[self.categories_cols] .filter(like=col) .astype(float) .idxmax(axis=1) .apply( lambda x: x.split(self.prefix_sep, maxsplit=1)[1] ) .rename(col) ) series_list.append(undummified) else: series_list.append(self.results[col]) dummies = pd.concat(series_list, axis=1) results_processed = numbers.join(dummies) self.results_processed = self.results[self._output_columns].join( results_processed ) if self.index is False: self.results_processed = self.results[["sample", "epoch"]].join( self.results_processed ) self._processed = True
[docs] def plot(self, output: str) -> None: """Plot the outputs. Parameters ---------- output: str The name of the output to plot. """ if self.index: results = self.results else: results = self.results.set_index(["sample", "epoch"]) plt.scatter( results.query(f"epoch == {self._min_epoch}")[ output ].index.get_level_values("sample"), results.query(f"epoch == {self._min_epoch}")[output].values, color="blue", ) plt.scatter( results.query(f"epoch == {self._max_epoch}")[ output ].index.get_level_values("sample"), results.query(f"epoch == {self._max_epoch}")[output].values, color="red", ) plt.scatter( results.query(f"epoch == {self._max_epoch + 1}")[ output ].index.get_level_values("sample"), results.query(f"epoch == {self._max_epoch + 1}")[output].values, color="gray", ) plt.title(f"{output}") plt.legend( [ f"epoch {self._min_epoch}", f"epoch {self._max_epoch}", f"epoch {self._max_epoch + 1}", ] ) plt.show()
[docs] def hist(self, features: Union[str, List[str]]) -> None: """Plot the results for a given feature. Parameters ---------- features: str or list of str The feature(s) to plot (either 1, 2, 3, 4, 6 or 8). Raises ------ ValueError If the ``features`` are neither a string or a list of strings of size 2, 3, 4, 6 or 8. Note ---- If the ``results`` are not yet processed, they will be with ``process_results``. """ if not self._processed: self.process_results() if self.index: results = self.results_processed else: results = self.results_processed.set_index(["sample", "epoch"]) if isinstance(features, str): plt.hist( results.query(f"epoch == {self._min_epoch}")[[features]], alpha=0.45, color="blue", ) plt.hist( results.query(f"epoch == {self._max_epoch}")[[features]], alpha=0.45, color="red", ) plt.title(f"{features}") plt.legend( [f"epoch {self._min_epoch}", f"epoch {self._max_epoch}"] ) plt.show() return None if len(features) == 2: ncol = 2 nrow = 1 elif len(features) == 3: ncol = 3 nrow = 1 elif len(features) == 4: ncol = 2 nrow = 2 elif len(features) == 6: ncol = 3 nrow = 2 elif len(features) == 8: ncol = 4 nrow = 2 else: raise ValueError( "features must be either a string or list of " "size 2, 3, 4, 6 or 8." ) fig, axs = plt.subplots(nrows=nrow, ncols=ncol) for feature, ax in zip(features, axs.ravel()): ax.hist( results.query(f"epoch == {self._min_epoch}")[[feature]], alpha=0.45, color="blue", ) ax.hist( results.query(f"epoch == {self._max_epoch}")[[feature]], alpha=0.45, color="red", ) ax.set_title(f"{feature}") ax.legend([f"epoch {self._min_epoch}", f"epoch {self._max_epoch}"])
def _one_hot_encode( self, data: Union[pd.DataFrame, np.ndarray] ) -> Union[pd.DataFrame, np.ndarray]: """One-hot encode the data. This method takes a dataframe or a numpy array and one-hot encodes the categorical features. The return type is the same as the input type. Parameters ---------- data: pd.DataFrame or np.ndarray The data where the categorical features should be one-hot encoded. Returns ------- results: pd.DataFrame or np.ndarray The data where the categorical features are one-hot encoded. """ if isinstance(data, np.ndarray): is_numpy_array = True data = pd.DataFrame(data, columns=self._input_columns) else: is_numpy_array = False results = data.drop(self.numerical_cols + self.categories_cols, axis=1) _results = pd.DataFrame(index=results.index) categories = self.categories.copy() for col in self._input_columns: if col in self.numerical_cols: _results[col] = data[col] elif col.split(self.prefix_sep)[0] in categories: categories.remove(col.split(self.prefix_sep)[0]) filtered = ( data[self.categories_cols] .filter(like=col.split(self.prefix_sep)[0]) .astype(float) ) df = pd.DataFrame( np.transpose( np.where(filtered.T == filtered.T.max(), 1, 0) ), index=data.index, columns=filtered.columns, ) _results = pd.concat([_results, df], axis=1) _results = _results[self._input_columns] results = results.join(_results) if is_numpy_array: return results.to_numpy() return results def _inverse_loop_pt(self) -> None: """Inverse design loop for PyTorch.""" # transform output y = torch.tensor(self.outputs.to_numpy(), dtype=torch.float32) # inverse design loop for i, x in enumerate(tqdm(self._inputs)): x = torch.tensor(x, dtype=torch.float32).unsqueeze(0) x.requires_grad_() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD([x], lr=self.lr) for epoch in range(self.numb_of_epochs): sample = x.detach().numpy().flatten() optimizer.zero_grad() outputs = self.model(x) loss = criterion(outputs, y) loss.backward() optimizer.step() if (epoch + 1) % self.log_every_n == 0 or epoch == 0: pred = outputs.detach().numpy().flatten() row = np.concatenate( [ np.array([(i + 1), (epoch + 1)]), pred, sample, ] ) self.results = pd.concat( [ self.results, pd.DataFrame([row], columns=self._columns), ], ignore_index=True, ) if self.categories_cols and self.extra_epoch: sample = self._one_hot_encode(np.expand_dims(sample, 0)) output = self.model(torch.Tensor(sample)) pred = output.detach().numpy().flatten() row = np.concatenate( [np.array([(i + 1), (epoch + 2)]), pred, sample.squeeze()] ) self.results = pd.concat( [ self.results, pd.DataFrame([row], columns=self._columns), ], ignore_index=True, ) def _inverse_loop_tf(self) -> None: """Inverse design loop for Tensorflow.""" # transform output y = self.outputs.to_numpy() # inverse design loop for i, x in enumerate(tqdm(self._inputs)): x = tf.Variable(np.expand_dims(x, axis=0)) criterion = tf.keras.losses.BinaryCrossentropy() optimizer = tf.keras.optimizers.SGD(learning_rate=self.lr) for epoch in range(self.numb_of_epochs): sample = x.numpy() with tf.GradientTape() as tape: tape.watch(x) outputs = self.model(x, training=False) loss = criterion(y, outputs) grads = tape.gradient(loss, x) optimizer.apply_gradients(zip([grads], [x])) if (epoch + 1) % self.log_every_n == 0 or epoch == 0: pred = outputs.numpy() row = np.concatenate( [ np.array([[(i + 1), (epoch + 1)]]), pred, sample, ], axis=1, ) self.results = pd.concat( [ self.results, pd.DataFrame( [row.squeeze()], columns=self._columns ), ], ignore_index=True, ) if self.categories_cols and self.extra_epoch: sample = self._one_hot_encode(sample) output = self.model(sample, training=False) pred = output.numpy() self.sample = sample self.output = output self.pred = pred row = np.concatenate( [ np.array([(i + 1), (epoch + 2)]), pred.squeeze(), sample.squeeze(), ] ) self.results = pd.concat( [ self.results, pd.DataFrame([row], columns=self._columns), ], ignore_index=True, )