Llama3 — What We Know So Far
Llama3 is #1 open source¹ and one of the top 10 models.
It is autoregressive decoder model ( 8k context, 70B), relying on massive compute ( trained for 7.7M hours i.e. 15 days on 24k H100 GPUs) and data ( 15T tokens ). Yet another bitter lesson.
Fairscale ( PyTorch extension library ) is used for large scale distributed training.
The model was trained with 2K context size, but uses context caching for 8k inference context.
In the Attention
class, you can see self.cache_k
and self.cache_v
which are used to store the key and value tensors across multiple transformer blocks.
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
¹ Llama3 paper is not out yet. 400B dense ( no MoE ) parameter Llama3 is due some time this year.
PyTorch Stack
At the very most abstract level, when you call torch.mm, two dispatches happen: