home

Transformer

Minimalist Transformer

Many Transformer implementations feature extreme amounts of optimization and abstraction, at the cost of modularity and readability. This project was an attempt to fix this by implementing a transformer as minimally as possible in PyTorch: no tricks, exactly as Vaswani describes it, written for readability even at the cost of some speed.

Built in an object-based fashion, it’s easy to modify any singular part of the model and have it seamlessly tie into the rest. As an example, see the forward function of a given transformer encoder layer:

    def forward(self, x, source_mask=None):
        r = x
        x = self.self_attention_layer_norm(x)
        x, _ = self.self_attention(
            query=x, key=x, value=x, key_padding_mask=source_mask
        )
        x = self.dropout(x)
        x = x + r

        r = x
        x = self.final_layer_norm(x)
        x = self.fc2(self.dropout(F.relu(self.fc1(x.transpose(0, 1))))).transpose(0, 1)
        x = self.dropout(x)
        x = x + r
        return x

The implementation was tested against Huggingfaces’ transformer implementation, with minimal differences in performance.