TLDR: Forked pytorch and triton internals . Changed attention so its linear first layer , middle quadratic layer, last linear layer Inference got much faster with a low perplexity hit in tests .
Full attention O(n²): 17.96s / 5.6 tok/s
HybridAttention O(n·W + n·D): 0.35s / 286.6 tok/s
I have been building a small Rust focused language model from scratch in PyTorch. This is not a finetune. It is byte level, trained from random initialization on a Rust heavy corpus assembled here: https://codeberg.org/JohannaJuntos/Sisyphus
Model and training setup
The model has 25.6M parameters with a 512 context length. It uses a byte level vocabulary of 256, with 8 layers, 8 heads, and 512 dimensional embeddings. Positional embeddings are learned and the embedding and LM head weights are tied.
Training ran for 30k steps on a 173.5M byte Rust corpus using a single RTX 4060 Ti 8GB.
Final metrics were a train loss of 0.5834, validation loss of 0.8217, and perplexity of 2.15. The best validation loss occurred around step 18.5k, which suggests some late overfitting or plateau.
Architecture
The model is a GPT style decoder, but replaces standard full attention with a HybridAttention block in each layer. This combines local windowed causal attention with a GRU like recurrent state path, along with a learned gate that mixes the two.
... continue reading