MatX logo

High throughput chips for LLMs

Introducing seqax: A Simple and Efficient LLM Research Codebase

May 6, 2024. By MatX ML team.

We’re excited to announce seqax, a research-focused LLM codebase that is simple, efficient, and performs well on up to 100 GPUs or TPUs. Everything you need to edit, from the math, to parallelism, to memory footprint, is all there in 500 lines of JAX code.

Introduction

Mature LLM codebases are often weighed down by excessive configurability, making it hard to understand which code path is taken. At MatX, we’ve developed seqax—a minimalist codebase designed for small-to-medium scale LLM pretraining research. In seqax, we avoid most configuration: no if statements, no inheritance. We run experiments by directly editing the code.

Key Features

Explicit Parallelism

In seqax, parallelism is made explicit, allowing you to fully understand and control how computations are distributed across devices. Sharding notation is expressed using Python type annotations, making the code both readable and concise.

For example, here’s how multihost Fully Sharded Data Parallel (FSDP) and Tensor Parallelism (TP) are implemented for a feedforward network. There are two sharding axes: d for FSDP and t for TP. In our annotation style, F/t indicates that dimension F is sharded over t chips, and F/t is the size per chip.

# Pre-FFN RMSNorm
ln2 = shardops.all_gather('M/t/d -> M', layer_weights.ln2)
gx = shardops.all_gather('B/d L M/t -> B/d L M', x)
nx = rms_norm(gx) * ln2

# FFN, using SwiGLU
w_gate = shardops.all_gather('M/d F/t -> M F/t', layer_weights.w_gate)
gate_proj = shardops.einsum_unreduced('B/d L M, M F/t -> B/d L F/t', nx, w_gate)
w_up = shardops.all_gather('M/d F/t -> M F/t', layer_weights.w_up)
up_proj = shardops.einsum_unreduced('B/d L M, M F/t -> B/d L F/t', nx, w_up)
y = jax.nn.swish(gate_proj) * up_proj
w_down = shardops.all_gather('M/d F/t -> M F/t', layer_weights.w_down)
ffn_out = shardops.einsum_unreduced('B/d L F/t, M F/t -> B/d L M', y, w_down)
ffn_out = shardops.psum_scatter('B/d L M -> B/d L M/t', ffn_out)

Scalability Features

Seqax isn’t just a prototyping tool; it includes several features essential for experiments of any size:

Get Started

Seqax is open source and available on GitHub.

MatX Logo

Conclusion

We’ve been using seqax internally at MatX for several months and are excited to share it with the community. Stay tuned for our upcoming research publications.

If you’re passionate about working on one of the cleanest LLM codebases and want to be part of a team that values open research, consider joining us at MatX: https://matx.com/jobs.