← All topics

JAX

Interactive notebooks for learning JAX — high-performance numerical computing and machine learning.


2 - UvA Training Models At Scale

2.1 - Single-GPU Training Single-GPU Techniques Protein localization classifier combining mixed precision, gradient checkpointing & accumulation $\nabla_\theta \mathcal{L} = \frac{1}{K}\sum_{k=1}^{K} \nabla_\theta \mathcal{L}_k$
Single-GPU Transformer DNA sequence transformer with mixed precision, remat, gradient accumulation & layer scanning $\text{Attention}(Q,K,V) = \text{softmax}_{\text{f32}}!\left(\frac{QK^\top}{\sqrt{d_k}}\right)!V$
2.2 - Multi-Device Training Single-Device Training Drug response classifier baseline: model, loss, TrainState, jit-compiled step, gradient accumulation $\theta \leftarrow \theta - \alpha \nabla_\theta \mathcal{L}$
FSDP Step by Step Building FSDP from first principles: sharding, all_gather, psum_scatter, and transparent module wrapping $W = \text{all_gather}(W_{\text{shard}})$
Pipeline Parallelism Step by Step Splitting a model across devices by layers: micro-batching, ppermute ring communication, stage masking $\text{bubble} = \frac{S-1}{M+S-1}$

3 - UvA Deep Learning 1

3.3 - Activation Functions Activation Functions Comparing six activation functions on FashionMNIST: gradient flow, dead neurons, and training stability $\text{Swish}(x) = x \cdot \sigma(x)$

7 - Recap Project

7.1 - Course Recap Drug Response Prediction Predicting cancer drug sensitivity from gene expression using a neural network built from scratch $\hat{y} = W_2 \cdot \text{relu}(W_1 x + b_1) + b_2$
7.2 - Full Project Protein Language Model Decoder-only transformer (GPT) trained to model protein sequence grammar $\text{Attention}(Q,K,V) = \text{softmax}!\left(\frac{QK^\top}{\sqrt{d_k}}\right)!V$