Tokenising Text Data for Pretraining

Author

Junaid Butt

Published

September 20, 2025

After reading Chapter 3 of Building Large Language Models from Scratch, I was curious about how text data turns into tokens and gets batched for pretraining. This post contains my notes on how this happens, it acts as a companion to the Python code in the book.

When training a decoder Large Language Model (LLM), our input is unstructured text. In order to transform the data into a format that can be used for model training, it must be transformed into batches of input and targets in accordance with the PyTorch convention. Our goal is to decompose a sequence of unstructured text (consisting of words, punctuation etc)

\[D = \{w_0,...,w_{N}\}\]

into batches of tensors \(\{(\mathbf{x}_{i}, y_{i})\}_{i=1}^{M}\) which can be used to update model weights after computing loss on each observation/ batch/ epoch.

Below are the steps taken to transform the data and compute loss on this data type:

  1. Tokenize the original unstructured text data
  2. Partition data into input and target sequences
  3. Construct data batches
  4. Compute Observation Loss via Next Token Probability

Tokenize the original unstructured text data

Tokenize the text into a sequence of token IDs - IDs are usually integer values. Denote the tokenized text by

\[D_{T} = \{t_{0}, ..., t_{P}\}\]

Partition Data into Input and Target Sequences

A decoder transformer is a sequence to sequence model. A sequence of tokens is taken as input, called an input sequence, and should produce an expected sequence of output tokens - known as a target sequence. An instance of a matching (input, target) sequence is called an observation. The target sequence is an offset of the input sequence. Let \(L\) be the maximum sequence length and \(S\) be the stride length. We construct the 1st (input, target) sequence as

[(Input: \(D_{T}[0:L]\), Target: \(D_{T}[1: L+1]\))]

where square brackets indicate subsetting. For each observation, the token sequences are

\[i=1: [(t_{0}, t_{1},..., t_{L}), (t_{1}, t_{2},..., t_{L+1})]\] \[i=2: [(t_{S}, t_{1+S},..., t_{L+S}), (t_{1+S}, t_{2+S},..., t_{L+1+S})]\]\[i=j: [(t_{jS}, t_{jS+1},..., t_{jS+L}), (t_{jS+1}, t_{jS+2},..., t_{jS+L+1})]\]

The above sequences are paired as tensors.

Remarks:

  • Each (input, target) sequence has \(L\) (maximum length) token ids

  • Between an input and target sequence, there is an offset of 1 token to the right/ forward. In general \([(t_{0}, t_{1},..., t_{L}), (t_{0+1}, t_{1+1},..., t_{L+1})]\)

  • Across input sequences/ observations there is an overlap of S tokens. In other words, input sequences can be generated by applying a sliding window of length \(S\) over \(D_{T}\). Target sequences for each input sequence are generated by applying an offset to the right to the input sequences.

As a table, the dataset of token sequences is

i \(x_{i}\) \(y_{i}\)
0 \((t_{0}, t_{1},..., t_{L})\) \((t_{1}, t_{2},..., t_{L+1})\)
1 \((t_{S}, t_{1+S},..., t_{L+S})\) \((t_{1+S}, t_{2+S},..., t_{L+1+S})\)
j \((t_{jS}, t_{jS+1},..., t_{jS+L})\) \((t_{jS+1}, t_{jS+2},..., t_{jS+L+1})\)

Construct Batches

Take multiple observations and randomly stack them (horizontally) as rows. Denote the sequences in matrix form

\[X = \begin{bmatrix} x_{11} & x_{12} & ... & x_{1L} \\ x_{21} & x_{22} & ... & x_{2L} \\ ... \\ x_{P1} & x_{P2} & ... & x_{PL} \end{bmatrix}\]

\[Y = \begin{bmatrix} y_{11} & y_{12} & ... & y_{1L} \\ y_{21} & y_{22} & ... & y_{2L} \\ ... \\ y_{P1} & y_{P2} & ... & y_{PL} \end{bmatrix}\]

\[X, Y \in \mathbb{Z}^{+(P \times L)}\]

where

\(\mathbb{Z}^{+}\) are the positive integers (because \(x_{ij}, y_{ij}\)) are token IDs

\(P\) is the batch size

\(L\) is the maximum sequence length

X and Y represent an input and target batch. Despite the change in notation:

\(x_{ij}\): token ID for token associated with the input sequence

\(y_{ij}\): token ID for token associated with the target sequence

It follows that

\[y_{ij}=x_{i,(j+1)}\]

Compute Observation Loss via Next Token Probability

We compute loss on the next token. A model auto regressively generates tokens one at a time given a sequence of tokens. Given a sequence of tokens, we look at the probability distribution over all tokens in the vocabulary - giving us the probability of what the next token should be. Given tokens \(x_1, x_2,...,x_k\) we want to compute the probability of generating the next token \(x_{k+1}\). During training, we know the actual value of \(x_{k+1}\) and we compare it with the model prediction for this token - \(\hat{x}_{k+1}\). Consider a single observation within a batch, this would correspond to a row in \(X\). We feed this into the transformer model \(f\) to yield a matrix of logits. Note: we assume that \(f\) takes tokens as input, embeds them, adds positional encodings and applies the attention mechanism. The model \(f\) yields:

\[f(\mathbf{x}_{j}) = Z \in \mathbb{R}^{L \times |V|}\]

\(Z\) is a matrix of logits where each row are unnormalised probabilities over the vocabulary \(V\). Normalise \(Z\) with the softmax function (applied row-wise) to obtain

\[\hat{Y} = softmax(Z) \in \mathbb{R}^{L \times |V|}\]

Each row in \(\hat{Y}\) corresponds to model probabilities over all tokens in the vocabulary given a sequence of tokens up until that row as input. More precisely, the k-th row is the probability distribution given by the model \(f\) when the sequence \(x_{i1},...,x_{ik}\) is input

\(f(x_{i1},...,x_{ik}) = [\hat{y}_{k1}, \hat{y}_{k2},...,\hat{y}_{k|V|}] = \mathbf{\hat{y}}_{k}\)

The target vector corresponding to \(\mathbf{\hat{y}}_{k}\) is the target sequence, taken from \(Y\). If \(\mathbf{x}_{j}\) is the j-th row of \(X\) then \(\mathbf{y}_{j}\) is the j-th row of \(Y\) and is the target vector.

\[\mathbf{y}_{j} = [y_{j1}, y_{j2},...,y_{jL}]\]

The k-th element of \(\mathbf{y}_{j}\) is the actual target token ID given \(x_{1},...,x_{k}\) as input. Align the logit matrix and the target vector

\[softmax(f(\mathbf{x}_{j}))\mathbf{y}_{j} = \hat{Y}\mathbf{y}_{j} = \begin{bmatrix} \hat{y}_{11} & \hat{y}_{12} & ... & \hat{y}_{1|V|} \\ \hat{y}_{21} & \hat{y}_{22} & ... & \hat{y}_{2|V|} \\ ... \\ \hat{y}_{L1} & \hat{y}_{L2} & ... & \hat{y}_{L|V|} \end{bmatrix} \begin{bmatrix} y_{j1} \\ y_{j2} \\ \vdots \\y_{jL} \end{bmatrix}\]

As mentioned before, we match a row in \(\hat{Y}\) with the same row in \(\mathbf{y}_{j}\). Consider the k-th row, the k-th element in \(\mathbf{y}_{j}\), \(y_{jk}\), is the target token ID that follows the input sequence \(x_{j1}, x_{j2},...,x_{j(k-1)}\). Compute cross entropy for every row in \(\hat{Y}\) and \(\mathbf{y}_{j}\) and take the mean. Cross entropy loss for row j is

\[ce(\hat{\mathbf{y}}, y) = - \sum_{v=1}^{|V|} \hat{y}_{v}log[y_{v}]\]

Note: With cross entropy, the target label is expanded/ broadcasted to a OHE label vector of length \(|V|\) to make the cross entropy. We take the mean cross entropy over all rows in \(\hat{Y}\) to get the loss obtained for a single observation

\[loss = \frac{1}{L} \sum_{l=1}^{L} ce(\hat{\mathbf{y}}_{l}, y_{l})\]

The loss would be 0 if the model predicted the target/ next token exactly at every decoding step. This approach can be generalised for batches with more than one observation. The loss is summed over observations in a batch to give a batch loss. Epoch loss is the mean loss per observation across all observations processed in the epoch (and thereby all batches).

References

  1. Building Large Language Models from Scratch: Sebastian Rashcka