OLMo — Points of Interest

Ranko Mosic
2 min readFeb 2, 2024

OLMo was trained on both NVIDIA A100 and AMD MI250X hardware. This is of huge importance, given the current shortages of NVIDIA GPU and uncertainty around practicalities of AMD ROCm software stack. It seems it took some direct involvement from AMD people to get PyTorch code to run on MI250X GPUs.

Training scripts use Mosaic ( recently acquired by Databricks ) Command Line Interface ( MCLI ). The team passed on Mosaic Composer for distributed training as PyTorch2 is not yet supported and decided to write their own trainer.

Most of our architecture choices follow the PaLM paper.

The team tried using Triton but ran into compatibility issues¹ with AMD GPUs¹.

We train our models using the ZeRO optimizer strategy (Rajbhandari et al., 2019) via PyTorch’s FSDP framework (Zhao et al., 2023), which reduces memory consumption by sharding the model weights and their corresponding optimizer state across GPUs.

To improve throughput, we employ mixed-precision training (Micikevicius et al., 2017) through FSDP’s built-in settings and PyTorch’s amp module. The latter ensures that certain operations like the softmax always run in full precision to improve stability, while all other operations run in half precision with the bfloat16 format.

Under our specific settings, the sharded model weights and
optimizer state local to each GPU are kept in full precision. The weights within each transformer block are only cast to bfloat16 when the full-sized parameters are materialized on each GPU during the forward and backward passes. Gradients are reduced across GPUs in full precision.

The team used slurm for cluster management.²

¹We’ve been experimenting with a triton implementation of FlashAttention that supports using an arbitrary attention bias, which would allow us to use ALiBi. Unfortunately it doesn’t look like this is going to be a viable option at the moment. This particular implementation only works on an older version of triton that uses a CUDA-specific backend. Therefore it won’t run on AMD GPUs.

² slurm, which stands for Simple Linux Utility for Resource Management, is an open-source, fault-tolerant, and highly scalable cluster management and job scheduling system for Linux clusters. It is used by many of the world’s supercomputers and computer clusters.

--

--

Ranko Mosic

Applied AI Consultant Full Stack. GLG Network Expert https://glginsights.com/ . AI tech advisor for VCs, investors, startups.