LUCID Tensorflow
In this example, we train a simple tensorflow/keras classifier on the UCI Adult income data set, and generate canonical sets via inverse design.
[1]:
import pandas as pd
import tensorflow as tf
from canonical_sets.data import Adult
from canonical_sets.models import ClassifierTF
from canonical_sets import LUCID
We train the classifier with the adam optimizer and a binary cross-entropy loss function. We monitor the loss and accuracy, and perform early stopping (to prevent overfitting) based on the validation accuracy (with a patience of 3 epochs). Finally, we assess the model’s performance via the test set.
[2]:
tf.keras.utils.set_random_seed(42)
data = Adult()
model = ClassifierTF(2)
model.compile(optimizer="adam", loss="binary_crossentropy", metrics = ["accuracy"])
callback = tf.keras.callbacks.EarlyStopping(monitor="val_accuracy", patience=5)
model.fit(data.train_data.to_numpy(), data.train_labels.to_numpy(), epochs=200,
validation_data=(data.val_data.to_numpy(), data.val_labels.to_numpy()), callbacks=[callback])
model.evaluate(data.test_data.to_numpy(), data.test_labels.to_numpy())
Epoch 1/200
755/755 [==============================] - 2s 2ms/step - loss: 0.4249 - accuracy: 0.7994 - val_loss: 0.3708 - val_accuracy: 0.8294
Epoch 2/200
755/755 [==============================] - 1s 1ms/step - loss: 0.3649 - accuracy: 0.8265 - val_loss: 0.3563 - val_accuracy: 0.8356
Epoch 3/200
755/755 [==============================] - 1s 1ms/step - loss: 0.3572 - accuracy: 0.8312 - val_loss: 0.3523 - val_accuracy: 0.8376
Epoch 4/200
755/755 [==============================] - 1s 1ms/step - loss: 0.3541 - accuracy: 0.8338 - val_loss: 0.3491 - val_accuracy: 0.8351
Epoch 5/200
755/755 [==============================] - 1s 1ms/step - loss: 0.3523 - accuracy: 0.8345 - val_loss: 0.3471 - val_accuracy: 0.8371
Epoch 6/200
755/755 [==============================] - 1s 1ms/step - loss: 0.3511 - accuracy: 0.8344 - val_loss: 0.3465 - val_accuracy: 0.8369
Epoch 7/200
755/755 [==============================] - 1s 1ms/step - loss: 0.3501 - accuracy: 0.8360 - val_loss: 0.3456 - val_accuracy: 0.8357
Epoch 8/200
755/755 [==============================] - 1s 1ms/step - loss: 0.3494 - accuracy: 0.8360 - val_loss: 0.3441 - val_accuracy: 0.8384
Epoch 9/200
755/755 [==============================] - 1s 1ms/step - loss: 0.3488 - accuracy: 0.8361 - val_loss: 0.3433 - val_accuracy: 0.8392
Epoch 10/200
755/755 [==============================] - 1s 1ms/step - loss: 0.3483 - accuracy: 0.8364 - val_loss: 0.3428 - val_accuracy: 0.8405
Epoch 11/200
755/755 [==============================] - 1s 1ms/step - loss: 0.3477 - accuracy: 0.8368 - val_loss: 0.3419 - val_accuracy: 0.8404
Epoch 12/200
755/755 [==============================] - 1s 1ms/step - loss: 0.3472 - accuracy: 0.8363 - val_loss: 0.3440 - val_accuracy: 0.8374
Epoch 13/200
755/755 [==============================] - 1s 1ms/step - loss: 0.3468 - accuracy: 0.8370 - val_loss: 0.3412 - val_accuracy: 0.8419
Epoch 14/200
755/755 [==============================] - 1s 1ms/step - loss: 0.3463 - accuracy: 0.8365 - val_loss: 0.3407 - val_accuracy: 0.8414
Epoch 15/200
755/755 [==============================] - 1s 1ms/step - loss: 0.3458 - accuracy: 0.8375 - val_loss: 0.3405 - val_accuracy: 0.8410
Epoch 16/200
755/755 [==============================] - 1s 1ms/step - loss: 0.3457 - accuracy: 0.8370 - val_loss: 0.3400 - val_accuracy: 0.8392
Epoch 17/200
755/755 [==============================] - 1s 1ms/step - loss: 0.3453 - accuracy: 0.8371 - val_loss: 0.3398 - val_accuracy: 0.8417
Epoch 18/200
755/755 [==============================] - 1s 1ms/step - loss: 0.3450 - accuracy: 0.8377 - val_loss: 0.3387 - val_accuracy: 0.8409
471/471 [==============================] - 1s 1ms/step - loss: 0.3456 - accuracy: 0.8366
[2]:
[0.34558895230293274, 0.8365870118141174]
We use the training data as the example (note that this is the training data which has already been pre-processed), and set the outputs to be a probability of zero for “<=50K” and a probability for one for “>50K”. This means that we want to maximize the positive outcome in this case.
[3]:
example_data = data.train_data
outputs = pd.DataFrame([[0, 1]], columns=["<=50K", ">50K"])
example_data.head()
[3]:
Age | fnlwgt | Education-Num | Capital Gain | Capital Loss | Hours per week | Workclass+Federal-gov | Workclass+Local-gov | Workclass+Private | Workclass+Self-emp-inc | ... | Country+Portugal | Country+Puerto-Rico | Country+Scotland | Country+South | Country+Taiwan | Country+Thailand | Country+Trinadad&Tobago | Country+United-States | Country+Vietnam | Country+Yugoslavia | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.123288 | -0.950895 | 0.066667 | -1.0 | -1.000000 | 0.000000 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
1 | -0.726027 | -0.621532 | 0.066667 | -1.0 | -1.000000 | -0.102041 | 0 | 0 | 1 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
2 | -0.150685 | -0.874857 | -0.466667 | -1.0 | -1.000000 | -0.204082 | 0 | 0 | 1 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
3 | -0.561644 | -0.787375 | 0.066667 | -1.0 | -1.000000 | -0.102041 | 0 | 0 | 1 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
4 | -0.013699 | -0.694464 | 0.333333 | -1.0 | -0.318182 | -0.204082 | 0 | 0 | 1 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
5 rows × 104 columns
We run the gradient-based inverse design with the default settings.
[4]:
lucid = LUCID(model, outputs, example_data)
lucid.results.head(12)
100%|██████████| 100/100 [01:47<00:00, 1.07s/it]
[4]:
<=50K | >50K | Age | fnlwgt | Education-Num | Capital Gain | Capital Loss | Hours per week | Workclass+Federal-gov | Workclass+Local-gov | ... | Country+Portugal | Country+Puerto-Rico | Country+Scotland | Country+South | Country+Taiwan | Country+Thailand | Country+Trinadad&Tobago | Country+United-States | Country+Vietnam | Country+Yugoslavia | ||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
sample | epoch | |||||||||||||||||||||
1 | 1 | 0.992438 | 0.007562 | 0.953400 | -0.239609 | 0.846492 | -0.476615 | -0.361806 | -0.763818 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
200 | 0.000913 | 0.999087 | 1.162248 | -0.180637 | 1.129408 | -0.054342 | -0.172412 | -0.466596 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | |
201 | 0.013462 | 0.986537 | 1.162248 | -0.180637 | 1.129408 | -0.054342 | -0.172412 | -0.466596 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | |
2 | 1 | 0.257312 | 0.742688 | -0.008990 | 0.521141 | 0.374523 | -0.503690 | 0.244093 | -0.368689 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
200 | 0.000909 | 0.999091 | 0.095517 | 0.550650 | 0.516092 | -0.292388 | 0.338864 | -0.219961 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | |
201 | 0.507224 | 0.492776 | 0.095517 | 0.550650 | 0.516092 | -0.292388 | 0.338864 | -0.219961 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | |
3 | 1 | 0.017506 | 0.982494 | 0.461862 | 0.364942 | 0.525183 | 0.198825 | -0.529045 | -0.943877 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
200 | 0.000872 | 0.999128 | 0.514914 | 0.379922 | 0.597049 | 0.306091 | -0.480936 | -0.868377 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | |
201 | 0.454597 | 0.545403 | 0.514914 | 0.379922 | 0.597049 | 0.306091 | -0.480936 | -0.868377 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | |
4 | 1 | 0.579365 | 0.420635 | -0.290496 | -0.508702 | -0.064275 | 0.552951 | -0.922651 | 0.564081 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
200 | 0.000905 | 0.999095 | -0.161642 | -0.472319 | 0.110275 | 0.813480 | -0.805800 | 0.747458 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | |
201 | 0.003508 | 0.996492 | -0.161642 | -0.472319 | 0.110275 | 0.813480 | -0.805800 | 0.747458 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
12 rows × 106 columns
Using the pandas multi-index you can select by epoch or sample.
[5]:
lucid.results.query("sample == 1")
[5]:
<=50K | >50K | Age | fnlwgt | Education-Num | Capital Gain | Capital Loss | Hours per week | Workclass+Federal-gov | Workclass+Local-gov | ... | Country+Portugal | Country+Puerto-Rico | Country+Scotland | Country+South | Country+Taiwan | Country+Thailand | Country+Trinadad&Tobago | Country+United-States | Country+Vietnam | Country+Yugoslavia | ||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
sample | epoch | |||||||||||||||||||||
1 | 1 | 0.992438 | 0.007562 | 0.953400 | -0.239609 | 0.846492 | -0.476615 | -0.361806 | -0.763818 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
200 | 0.000913 | 0.999087 | 1.162248 | -0.180637 | 1.129408 | -0.054342 | -0.172412 | -0.466596 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | |
201 | 0.013462 | 0.986537 | 1.162248 | -0.180637 | 1.129408 | -0.054342 | -0.172412 | -0.466596 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
3 rows × 106 columns
[6]:
lucid.results.query("epoch == 1")
[6]:
<=50K | >50K | Age | fnlwgt | Education-Num | Capital Gain | Capital Loss | Hours per week | Workclass+Federal-gov | Workclass+Local-gov | ... | Country+Portugal | Country+Puerto-Rico | Country+Scotland | Country+South | Country+Taiwan | Country+Thailand | Country+Trinadad&Tobago | Country+United-States | Country+Vietnam | Country+Yugoslavia | ||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
sample | epoch | |||||||||||||||||||||
1 | 1 | 0.992438 | 0.007562 | 0.953400 | -0.239609 | 0.846492 | -0.476615 | -0.361806 | -0.763818 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
2 | 1 | 0.257312 | 0.742688 | -0.008990 | 0.521141 | 0.374523 | -0.503690 | 0.244093 | -0.368689 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
3 | 1 | 0.017506 | 0.982494 | 0.461862 | 0.364942 | 0.525183 | 0.198825 | -0.529045 | -0.943877 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
4 | 1 | 0.579365 | 0.420635 | -0.290496 | -0.508702 | -0.064275 | 0.552951 | -0.922651 | 0.564081 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
5 | 1 | 0.966358 | 0.033642 | 0.631150 | -0.486036 | -0.960426 | -0.393935 | -0.271197 | 0.217928 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
96 | 1 | 0.923635 | 0.076365 | -0.243961 | -0.747192 | 0.047886 | 0.942131 | -0.474525 | 0.557954 | 1 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
97 | 1 | 0.870405 | 0.129595 | 0.719518 | -0.812895 | -0.828304 | 0.721479 | 0.955049 | 0.709456 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
98 | 1 | 0.006257 | 0.993743 | 0.901062 | 0.068491 | -0.799558 | 0.374923 | 0.670941 | 0.993641 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
99 | 1 | 0.997700 | 0.002300 | -0.045285 | 0.590058 | -0.718306 | 0.808829 | -0.296529 | -0.509131 | 0 | 1 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
100 | 1 | 0.662636 | 0.337364 | -0.381372 | -0.254751 | 0.917487 | 0.670918 | -0.212994 | -0.935526 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 |
100 rows × 106 columns
We can also select certain categorical features by using pandas.
[7]:
lucid.results.query("sample == 1").loc[:, lucid.results.columns.str.startswith("Workclass")]
[7]:
Workclass+Federal-gov | Workclass+Local-gov | Workclass+Private | Workclass+Self-emp-inc | Workclass+Self-emp-not-inc | Workclass+State-gov | Workclass+Without-pay | ||
---|---|---|---|---|---|---|---|---|
sample | epoch | |||||||
1 | 1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 |
200 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | |
201 | 0 | 0 | 1 | 0 | 0 | 0 | 0 |
The results are not yet transformed back to their original range and can therefore be difficult to interpret. To make this easier, we can apply the transform_results method and provide the original sklearn scaler from the data object.
[8]:
lucid.process_results(data.scaler)
[9]:
lucid.results_processed.head(12)
[9]:
<=50K | >50K | Age | fnlwgt | Education-Num | Capital Gain | Capital Loss | Hours per week | Workclass | Education | Martial Status | Occupation | Relationship | Race | Sex | Country | ||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
sample | epoch | ||||||||||||||||
1 | 1 | 0.992438 | 0.007562 | 88.299083 | 5.730126e+05 | 14.848694 | 26168.980694 | 1389.986786 | 12.572941 | Private | Masters | Widowed | Other-service | Unmarried | Other | Male | El-Salvador |
200 | 0.000913 | 0.999087 | 95.922060 | 6.163840e+05 | 16.970563 | 47282.406212 | 1802.485577 | 27.136805 | Private | Masters | Widowed | Other-service | Wife | Other | Male | Italy | |
201 | 0.013462 | 0.986537 | 95.922060 | 6.163840e+05 | 16.970563 | 47282.406212 | 1802.485577 | 27.136805 | Private | Masters | Widowed | Other-service | Wife | Other | Male | Italy | |
2 | 1 | 0.257312 | 0.742688 | 53.171879 | 1.132519e+06 | 11.308924 | 24815.229793 | 2709.634965 | 31.934258 | Self-emp-inc | 12th | Married-spouse-absent | Armed-Forces | Other-relative | White | Female | Hong |
200 | 0.000909 | 0.999091 | 56.986358 | 1.154222e+06 | 12.370691 | 35380.225207 | 2916.046168 | 39.221903 | Self-emp-inc | 12th | Married-spouse-absent | Armed-Forces | Husband | White | Female | Hong | |
201 | 0.507224 | 0.492776 | 56.986358 | 1.154222e+06 | 12.370691 | 35380.225207 | 2916.046168 | 39.221903 | Self-emp-inc | 12th | Married-spouse-absent | Armed-Forces | Husband | White | Female | Hong | |
3 | 1 | 0.017506 | 0.982494 | 70.357980 | 1.017640e+06 | 12.438870 | 59940.653184 | 1025.739436 | 3.750007 | Self-emp-not-inc | HS-grad | Married-spouse-absent | Handlers-cleaners | Husband | Other | Male | Cuba |
200 | 0.000872 | 0.999128 | 72.294370 | 1.028657e+06 | 12.977867 | 65303.889562 | 1130.522453 | 7.449522 | Self-emp-not-inc | HS-grad | Married-spouse-absent | Handlers-cleaners | Husband | Other | Male | Cuba | |
201 | 0.454597 | 0.545403 | 72.294370 | 1.028657e+06 | 12.977867 | 65303.889562 | 1130.522453 | 7.449522 | Self-emp-not-inc | HS-grad | Married-spouse-absent | Handlers-cleaners | Husband | Other | Male | Cuba | |
4 | 1 | 0.579365 | 0.420635 | 42.896910 | 3.751027e+05 | 8.017934 | 77646.750421 | 168.467151 | 77.639946 | Without-pay | Some-college | Separated | Tech-support | Own-child | Asian-Pac-Islander | Male | France |
200 | 0.000905 | 0.999095 | 47.600063 | 4.018616e+05 | 9.327066 | 90673.115326 | 422.966812 | 86.625423 | Self-emp-inc | Some-college | Separated | Tech-support | Wife | Asian-Pac-Islander | Male | France | |
201 | 0.003508 | 0.996492 | 47.600063 | 4.018616e+05 | 9.327066 | 90673.115326 | 422.966812 | 86.625423 | Self-emp-inc | Some-college | Separated | Tech-support | Wife | Asian-Pac-Islander | Male | France |
We can also plot some of the results, such as the predictions of the first and the last epoch. And also after the last epoch is categorically formatted.
[10]:
lucid.plot(">50K")
We can also check the distributions of the features in the first and last epochs.
[11]:
lucid.hist(["Age", "Relationship"])