LUCID-GAN

[1]:
import matplotlib.pyplot as plt
import pandas as pd
import tensorflow as tf

from canonical_sets import LUCIDGAN
from canonical_sets.data import Adult
from canonical_sets.models import ClassifierTF
[2]:
# Potential Direct Discrimination (protected attributes included)
tf.keras.utils.set_random_seed(42)

callback = tf.keras.callbacks.EarlyStopping(monitor="val_accuracy", patience=5)

adult = Adult()
dir_adult_model = ClassifierTF(2)
dir_adult_model.compile(optimizer="adam", loss="binary_crossentropy", metrics = ["accuracy"])
dir_adult_model.fit(adult.train_data.to_numpy(), adult.train_labels.to_numpy(), epochs=200,
 validation_data=(adult.val_data.to_numpy(), adult.val_labels.to_numpy()),
 callbacks=[callback], verbose=0)
print(f"Test accuracy: {dir_adult_model.evaluate(adult.test_data.to_numpy(), adult.test_labels.to_numpy())[1]}")
2023-02-01 10:54:54.783998: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
471/471 [==============================] - 0s 364us/step - loss: 0.3442 - accuracy: 0.8386
Test accuracy: 0.838578999042511
[5]:
# LUCID-GAN
adult_preds = dir_adult_model.predict(adult.test_data.to_numpy())[:, 1]
adult_test_data = adult.inverse_preprocess(adult.test_data)
adult_data = pd.concat([adult_test_data, pd.DataFrame(adult_preds, columns=["preds"])], axis=1)

dir_adult_lucidgan = LUCIDGAN(epochs=5)

dir_adult_lucidgan.set_random_state(42)
dir_adult_lucidgan.fit(adult_data, conditional=["preds"])

dir_adult_pos_samples = dir_adult_lucidgan.sample(1000, conditional=pd.DataFrame({"preds": [1]}))
dir_adult_neg_samples = dir_adult_lucidgan.sample(1000, conditional=pd.DataFrame({"preds": [0]}))

plt.plot(dir_adult_lucidgan.generator_loss, label = "generator loss")
plt.plot(dir_adult_lucidgan.reconstruction_loss, label = "reconstruction loss")
plt.plot(dir_adult_lucidgan.discriminator_loss, label = "discriminator loss")
plt.legend()
plt.show()
471/471 [==============================] - 0s 278us/step
Epoch 4, Loss G: -0.9575, Loss R:  2.1966, Loss D: -0.0807: 100%|██████████| 5/5 [00:06<00:00,  1.38s/it]
_images/example_gan_3_2.png

For more information on LUCID-GAN, check the notebooks in the examples folder on GitHub. There is also the replication of the LUCID-GAN paper’s results.