How to adapt the Transformer architecture to generate Sets
Generation is one of the most complex tasks in Machine Learning. There is a huge literature about images and text. But how do you generate sets?
The problem with set generation?
Set generation is the ability to generate unordered subsets of a much bigger set following an implicit probability distribution.
The keyword here is unordered
An image is not a set of pixels, because there is a notion of geometry between them.
A text is not a set of words, it’s a sequence, because there is a notion of order between words.
An example of a set is the list of ingredients in a dish. This is actually the use case I had to make a generative model for.
This absence of natural order or geometry makes it difficult to directly use the models used for texts and images
Potential architectures: VAE, GAN and AR
There are 3 broad classes of generative models commonly used:
The first two are mainly used for image generation while the transformer architecture is mainly used for text generation.
In theory, all 3 could be used for sets generation:
- VAEs generate usual feature vectors. A subset can be represented as a vector of zeros and ones where each index represents an element of the set.
- GANs is more of a strategy to create a generator model than a specific model. As such, it can be adapted to any use case.
- ARs can be used to generate sets by just ignoring the order of the generated sequences.
In practice,
- VAEs are sensitive to a phenomenon called posterior collapse that makes them unsuitable for sets generation. For example, when generating a set of ingredients for a pie, it will select a bit of puff, shortcrust, and shortbread pastry instead of choosing one.
- GANs demand a lot of tuning and a big dataset to work. It makes them impractical if you don’t have infinite resources.
This makes Autoregressive Models the most promising ones for sets generation. Among them, the Transformer architecture is by far the most efficient one and completely replaced previous models like RNNs.
The Transformer architecture
If you need an introduction to the Transformer architecture in the context of sequence generation, the best explanation I know is Here, so I won’t detail it here. Let’s just recapitulate the different layers to see what we need to change.
Baseline model: take the transformer as is
To generate sets, nothing prevents us to reuse the previous architecture as is, using our set elements as the vocabulary and generating sequences. We can then just ignore the order to get a set.
Below is the Pytorch source code of the transformer generator split in 3 parts:
- The heart of the Transformer model
- The Positional encoding Layer, specific to the Transformer architecture
- The Cross-entropy loss function adapted to Autoregressive Models
Transformer model
import torch
from torch import nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer
# This is loosely inspired by the code available here:
# https://pytorch.org/tutorials/beginner/transformer_tutorial.html
class TransformerModel(nn.Module):
def __init__(self,
idx_to_token: [str],
embedding_dim_between_layers: int,
nb_attention_heads: int,
hidden_dim_feed_forward_layers: int,
nb_encoder_layers: int,
padding_idx: int,
dropout: float = 0.5,
max_sentence_size: int=50):
super(TransformerModel, self).__init__()
self.src_mask_by_sequence_size = [self._generate_square_subsequent_mask(len_sequence) for len_sequence in range(max_sentence_size + 1)]
self.vocab_to_embedding = nn.Embedding(len(idx_to_token), embedding_dim_between_layers, padding_idx=padding_idx)
self.pos_encoder = PositionalEncoding(embedding_dim_between_layers)
encoder_layers = TransformerEncoderLayer(
embedding_dim_between_layers,
nb_attention_heads,
hidden_dim_feed_forward_layers,
dropout
)
self.transformer_encoder = TransformerEncoder(encoder_layers, nb_encoder_layers)
self.embedding_to_vocab = nn.Linear(embedding_dim_between_layers, len(idx_to_token))
self.embedding_dim_between_layers = embedding_dim_between_layers
self.idx_to_token = idx_to_token # so it's serialized with model
self._init_weights()
# Generate a mask for the TransformerEncoderLayer
# it allows the generator to use previous but not following tokens to generate the current one.
def _generate_square_subsequent_mask(self, len_sequence):
mask = (torch.triu(torch.ones(len_sequence, len_sequence)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def _init_weights(self):
initrange = 0.1
self.vocab_to_embedding.weight.data.uniform_(-initrange, initrange)
self.embedding_to_vocab.bias.data.zero_()
self.embedding_to_vocab.weight.data.uniform_(-initrange, initrange)
# take a tensor of dimensions [sequence_size, batch_size]
# sequences is a [sequence_size, batch_size] tensor. Each entry is a token index
def forward(self, sequences):
# 1. We transform each token to its corresponding embedding.
# We get a tensor of dimensions [sequence_size, batch_size, embedding_size]
# See http://nlp.seas.harvard.edu/2018/04/03/attention.html#encoder for the normalization factor
sequences = self.vocab_to_embedding(sequences) * math.sqrt(self.embedding_dim_between_layers)
# 2. We add positional encoding to the embeddings
sequences = self.pos_encoder(sequences)
# 3. We pass the sequence through the encoder layers
# We get a tensor of dimension [sequence_size, batch_size, embedding_dim_between_layers]
output = self.transformer_encoder(sequences, self.src_mask_by_sequence_size[len(sequences)])
# 4. We convert output embeddings to vocabulary size vectors that can be processed by the softmax layer
# We get a tensor of dimension [sequence_size, batch_size, vocabulary_size]
# the softmax layer is inside the loss function (as it does not contain learnable parameters)
vocab = self.embedding_to_vocab(output)
return vocab
Positional encoding
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0), :]
return x
Loss function
class CrossEntropyTransformerLoss(TransformerLoss):
def __init__(self, nb_classes: int, pad_idx: int, eos_idx: int, label_smoothing_coeff=0.0):
super(TransformerLoss, self).__init__()
self.nb_classes = nb_classes
self.loss = nn.CrossEntropyLoss(label_smoothing=label_smoothing_coeff, ignore_index=pad_idx)
# input: Tensor [sequence_size, batch_size, vocab_size] of linear output to give to softmax
# target: Tensor [sequence_size, batch_size] of token indexes
def forward(self, result, target) -> TransformerLossResult:
flatten_result = result.view(-1, self.nb_classes)
flatten_target = target.view(-1)
loss_result = self.loss(flatten_result, flatten_target)
return loss_result
Adaptation 1: Get rid of the positional encoding
The positional encoding is here to give information about the relative position of the previous tokens to generate the next one.
Here all we care about is the set of previous tokens, irrelative to their order. Because positional encoding is just an addition to the embedding vectors, let’s just remove the following line in the TransformerModel
forward function.
sequences = self.pos_encoder(sequences)
To be honest, it didn’t improve significantly the loss on the test set in my case, but it’s always good to make the code simpler if complexity doesn’t bring more performance.
Adaptation 2: Use the Soft cross-entropy loss function
This one is a bit trickier. At training time, the standard loss function is the cross-entropy between the output probability distribution of the model and the one-hot vector of the correct next token.
If we remove any notion of order, the next correct token is any token belonging to the set and not yet generated.
Thankfully, Pytorch CrossEntropyLoss
class can calculate the cross-entropy between two distributions. We can then set as the target a vector where the probability is shared between all not yet generated elements of the set.
Soft cross-entropy
class SoftCrossEntropyLossTransformerLoss(TransformerLoss):
def __init__(self, nb_classes: int, pad_idx: int, eos_idx: int, label_smoothing_coeff=0.0):
super(TransformerLoss, self).__init__()
self.nb_classes = nb_classes
self.pad_idx = pad_idx
self.eos_idx = eos_idx
self.loss = nn.CrossEntropyLoss(label_smoothing=label_smoothing_coeff, ignore_index=pad_idx)
# result: Tensor [sequence_size, batch_size, vocab_size] of linear output to give to softmax
# target: Tensor [sequence_size, batch_size] of token indexes
def forward(self, result, target) -> TransformerLossResult:
# I) ---- PREPARE TARGET ----
target_proba = self._compute_unordered_probabilistic_target(target)
# II) ---- COMPUTE LOSS ----
# we flatten tensors on softmax dimension
flatten_result = result.view(-1, self.nb_classes)
flatten_target = target_proba.view(-1, self.nb_classes)
loss_result = self.loss(flatten_result, flatten_target)
return loss_result
Probabilistic target vectors
With the following function preparing target vectors
def _compute_unordered_probabilistic_target(self, target):
# 1) [sequence_size, batch_size] of indexes to [sequence_size, batch_size, vocabulary_ntokens] one_hot
target_one_hot = torch.nn.functional.one_hot(target, num_classes=self.nb_classes)
# 2) valid next token is any token in the subsequent sequence except pad and eos
seq_length = len(target)
for i_seq_idx in range(seq_length - 2, -1, -1):
# elt[i] = elt[i] + elt[i+1]
to_paste = target_one_hot[i_seq_idx + 1, :, :].clone()
to_paste[:, self.pad_idx] = 0
to_paste[:, self.eos_idx] = 0
target_one_hot[i_seq_idx, :, :].add_(to_paste)
# 3) if there is several possible next token, give each one equals probability
nb_targets = target_one_hot.sum(dim=2, dtype=torch.float, keepdims=True) # keepdims to be broadcastable
target_proba = target_one_hot.float() # cast int to float
target_proba.div_(nb_targets)
return target_proba
Pytorch limitation
This should work…except we are met with the following error:
RuntimeError: ignore_index is not supported for floating point target
Pytorch is telling us that we either use a probabilistic target or the ignore_index
parameter of the loss function to not compute loss on padded tokens. This is explained here
Adaptation 2 (Bis): Recode the Soft cross-entropy loss function
Since Pytorch cannot combine ignore_index
and a probabilistic target, let’s do it ourselves.
Soft cross-entropy from scratch
class SoftCrossEntropyTransformerLossFromScratch(TransformerLoss):
def __init__(self, nb_classes: int, pad_idx: int, eos_idx: int, label_smoothing_coeff=0.0):
super(SoftCrossEntropyTransformerLossFromScratch, self).__init__()
self.nb_classes = nb_classes
self.label_smoothing_coeff = label_smoothing_coeff
self.pad_idx = pad_idx
self.eos_idx = eos_idx
# input: Tensor [sequence_size, batch_size, vocab_size] of linear output to give to softmax
# target: Tensor [sequence_size, batch_size] of token indexes
def forward(self, output: torch.Tensor, target: torch.Tensor) -> TransformerLossResult:
target_proba_smoothed = self._compute_unordered_probabilistic_target(target)
# II) ---- COMPUTE LOSS ----
# we flatten tensors on softmax dimension
target_proba_smoothed_flattened = target_proba_smoothed.view(-1, self.nb_classes)
target_flattened = target.view(-1)
output_flattened = output.view(-1, self.nb_classes)
# compute log_softmax
log_prb_output = F.log_softmax(output_flattened, dim=1)
# compute cross entropy by softmax
loss_by_softmax = -(target_proba_smoothed_flattened * log_prb_output).sum(dim=1)
# restrict loss computation to non padded elements
non_pad_mask = target_flattened.ne(self.pad_idx) # vector of True where target is not pad_idx
loss_reduced_to_not_pad_index = loss_by_softmax.masked_select(non_pad_mask)
return loss_reduced_to_not_pad_index.mean()
Add label smoothing back into our new loss
We are almost done, but our new loss implementation doesn’t manage label smoothing, a regularization technique used in generator models to improve generation.
Let’s call this function at the end of the _compute_unordered_probabilistic_target
:
def _apply_label_smoothing_to_target(self, nb_targets, target_proba):
# https://arxiv.org/pdf/1512.00567.pdf, https://arxiv.org/abs/1906.02629
nb_targets_smoothing = self.nb_classes - nb_targets
smoothing_value = self.label_smoothing_coeff / nb_targets_smoothing
target_proba.mul_(1.0 - self.label_smoothing_coeff) # apply smoothing
target_proba_smoothed = torch.max(target_proba, smoothing_value) # use broadcasting
return target_proba_smoothed
Comparison to the standard loss function
Here we should check if the generator trained with this loss function is better than the previous one. This is not obvious because the loss function being modified, the resulting loss value on the test set is not directly comparable.
In my case, It seems better (it generates more variety), but you should check in your use case.
Let’s sum it up
To adapt the transformer to the generation of sets, we made two modifications:
- On the Network architecture, we removed the positional encoding
- On the loss function, we replaced the standard cross-entropy with a custom soft cross-entropy.
Source code
The source code for this project is available here