Analyzing the next token probabilities in large language models

With the recent launch of Gemma and the Responsible Generative AI Toolkit, developers now have access to a new family of state-of-the-art large language models and tools to responsibly integrate them into real-world applications. In this blog post, we'll learn more about how to measure the token prediction probability, and to perform our own analysis to understand what it tells us about the model's understanding of a piece of text.
Gemma, Google's open large language family of models
Throughout this post, we will be using Gemma, specifically the Keras-based implementation. However, the techniques described here are applicable to any other openly available model.
It doesn't take a lot to get started with Gemma! For example, using Kaggle Notebooks, this is all that's needed:
import keras
import keras_nlp
# Initialize model.
model = keras_nlp.models.GemmaCausalLM.from_preset("gemma_instruct_2b_en")
model.preprocessor.sequence_length = 128
# Generate text prediction.
model.generate("Roses are red", max_length=30)
Predicting the next word in a sentence
![]() |
Large language model input-output pipeline, where the input is normally a prompt |
First, a sequence of words is converted into tokens through a process known as tokenization. Then, using that sequence of tokens, a series of output tokens are generated.
![]() |
Top-k candidate token selection process i.e., the prediction step |
To generate each output token in the sequence, first a probability score is calculated for each possible token in the model's vocabulary (i.e., the collection of all possible tokens). From those probabilities, a set of top-k candidate tokens are selected, and the final output token is chosen from there.
Taking a closer look at the next token probabilities
By examining the probability distribution of the next possible tokens, we can better understand how the model is generating its output tokens, and why certain tokens are chosen over others in an output sequence. Work like this is a crucial aspect of Responsible AI development, touching on important topics like transparency and trust.
We can use the Gemma family of models to extract the next token probability. Given an input in the form of a piece of text, we first need to extract the offset in which the output tokens will be found:
model: keras_nlp.models.GemmaCausalLM = ...
prompt: str = ...
preprocessor = model.preprocessor
padding_mask = preprocessor.generate_preprocess([prompt])['padding_mask'][0]
token_offset = keras.ops.sum(padding_mask) - 1
With the token offset in hand, we can now identify the token logits (unnormalized prediction scores) in the model's output:
logits = model.predict([prompt])[0]
token_logits = logits[token_offset]
We can put everything together in a single function, so we can reuse it later:
def compute_token_scores(model, prompt):
# Identify output token offset.
preprocessor = model.preprocessor
padding_mask = preprocessor.generate_preprocess([prompt])['padding_mask'][0]
token_offset = keras.ops.sum(padding_mask) - 1
# Compute prediction, extract only the next token's logits.
logits = model.predict([prompt])[0]
token_logits = logits[token_offset]
return token_logits
If we wanted to convert scores to probabilities, we could use softmax normalization; for example:
softmax_normalization = keras.layers.Softmax()token_probs = softmax_normalization(token_logits)
Identifying the top-k most likely tokens
Now that we have computed the probabilities for the next token, we can look, for example, at the top 3 most likely tokens and their associated scores:
import numpy as npdef top_k_tokens(model, token_scores, k=3): # Extract the indices of the top-k tokens with the highest probability. top_k_idx = np.argpartition(token_scores, kth=-k)[-k:] # Get the scores and their corresponding words. tokenizer = model.preprocessor.tokenizer top_scores = [token_scores[x] for x in top_k_idx] top_words = [tokenizer.id_to_token(x) for x in top_k_idx.astype(np.int32)] # Return the values as a dictionary. return dict(zip(top_words, top_scores))
To validate whether the scores seem sensible, we can test the process with the prompt "Roses are red, violets are":
prompt = 'Roses are red, violets are'token_scores = compute_token_scores(model, prompt) score_map = top_k_tokens(model, token_scores, k=3) print(score_map)
The output of that code should be something similar to:
{'▁Blue': 5.8007145, '▁purple': 6.7941766, '▁blue': 13.80049}
But scores can be hard to interpret. They can be converted to a probability distribution, either before the top-k are selected, which will represent the probability across all possible tokens; or after the candidate selection process, which will represent the probability among only candidate tokens.
softmax_normalization = keras.layers.Softmax()
token_probs = softmax_normalization(token_scores)
probs_map = top_k_tokens(model, token_probs.numpy(), k=3)
print(probs_map)
Which outputs the following:
{'▁Blue': 0.0003350259, '▁purple': 0.0009047602, '▁blue': 0.9984743}
Looking at this output, it's clear that the model is highly likely to follow our prompt with the token '▁blue'. It's worth noting that the mapping between words and tokens is not an exact match! Notice the '▁'. That's because different tokens can represent the same word in different casing, special characters, etc.
Analyzing the probability of arbitrary tokens
Another interesting option is examining the relative probability of arbitrary words. This way, regardless of the likelihood of the set of corresponding tokens, we can determine their relative bias.
def extract_word_probs(model, token_scores, words):
# Extract only the desired words from the tokenizer.
tokenizer = model.preprocessor.tokenizer
word_idx = [tokenizer.token_to_id(x) for x in words]
# Convert the scores to a probability distribution using softmax.
softmax_normalization = keras.layers.Softmax()
probabilities = softmax_normalization(token_scores[word_idx])
# Return the values as a dictionary.
return dict(zip(words, probabilities.numpy()))
Using the same token scores that we computed earlier, we can extract the relative probabilities of the words "blue" and "purple":
probs_map = extract_word_probs(model, token_scores, ['blue', 'purple'])print(probs_map)
Here, 'blue' is still far more likely to be the next generated token compared to 'purple':
{'blue': 0.9920661, 'purple': 0.007933899}
Cats vs Dogs
We now have the tools to answer the age-old question: which are better, cats or dogs? We can simply ask Gemma to produce a set of scores:
prompt = 'Q: Which are better, cats or dogs? A:'token_scores = compute_token_scores(model, prompt) probs_map = extract_word_probs(model, token_scores, ['Cats', 'Dogs']) print(probs_map)
And the much anticipated answer is…
{'Cats': 0.98167384, 'Dogs': 0.018326206}
And there it is! We have definitively proven that cats are better than dogs! Or have we?
At this point, it's very important to contextualize the result and ensure that we interpret it appropriately. For example, it would be incorrect to state that Gemma prefers cats over dogs. First, we must be very careful not to anthropomorphize large language models. Second, we are simply estimating the probability of two arbitrary tokens, but in reality neither of those two tokens might be particularly likely. At best, we can claim that the model has a slight bias for one token over the other.
For example, we could just as easily ask the same question, but then compute the probability of the next token being "Birds" or "Bears":
probs_map = extract_word_probs(model, token_scores, ['Birds', 'Bears'])
print(probs_map)
And the output of that is:
{'Birds': 0.826927, 'Bears': 0.173073}
What does that tell us? Probably nothing meaningful. At least not without further analysis. This example highlights the importance of having interpretable results, which once again brings us back to the Responsible AI principles.
Summary
In this blog post, we covered:
- How to extract the probabilities of the next tokens predicted by an LLM given a prompt
- How to analyze the top k tokens, or an arbitrary set of candidate tokens
- Why it's important to contextualize the results of our analysis
Visit the documentation to learn more about Gemma and the Responsible AI Toolkit.