Source code for canonical_sets.models.classifiers

"""Classifiers."""

import tensorflow as tf
import torch.nn as nn
import torch.nn.functional as F


[docs]class ClassifierPT(nn.Module): """Classifier for PyTorch. Parameters ---------- input_dim : int Input dimension. output_dim : int Output dimension. """ def __init__(self, input_dim: int, output_dim: int): super().__init__() self.linear = nn.Linear(input_dim, output_dim)
[docs] def forward(self, x): outputs = F.softmax(self.linear(x), dim=1) return outputs
[docs]class ClassifierTF(tf.keras.Model): """Classifier for Keras. Parameters ---------- output_dim : int Output dimension. """ def __init__(self, output_dim: int): super().__init__() self.linear = tf.keras.layers.Dense( output_dim, activation=tf.nn.softmax )
[docs] def call(self, x): outputs = self.linear(x) return outputs