This post continues my quest to convert the Whisper model to ONNX. Last time, I discussed how to convert the model without the optional performance optimization of key value caching. After some trial and error, I managed to convert the full model. In this post, I will discuss how the Whisper model implements key value caching, before I cover the actual conversion in the next post.
The decoder generates text in an iterative fashion. In each iteration it is passed all the tokens generated so far. The idea behind key value caching is to memorize computations that are repeated in the different iterations. The attention mechanism includes three components: keys, values, and queries. While keys and values can be cached, the queries need to be recomputed for each iteration. In the following, I will only discuss the keys, as the values are handled in an equivalent manner in the model.
The model includes two different attention mechanisms with two different ways of
computing keys: cross attention computes keys x_audio @ W_k
that are
independent of the generated token. Self attention computes keys x_tokens @ W_k
. This expression can be rewritten as a concatenation of the previously
generated keys and the current key
keys = concat([keys_prev_tokens, x_tokens[:, -1:, :] @ W_k], dim=1)
To avoid these recomputations, the model introduces key value caching in two
places: the forward
method of the MultiHeadAttention
module
and the forward hook attached to the key
and value
modules.
Forward hooks are called after the module's forward
method
and allow to modify the return value. Due to the hook, any call to key(x)
in
the MultiHeadAttention
module is equivalent to
k = hook(key, x, key.forward(x))
The relevant parts of the forward
method and the
hook read, slightly simplified,
# forward
if xa is None or self.key not in kv_cache:
k = self.key(x if xa is None else xa)
else:
k = kv_cache[self.key]
# hook
if module not in cache:
cache[module] = output
else:
cache[module] = torch.cat([cache[module], output], dim=1).detach()
return cache[module]
For cross attention, the audio input xa
is passed to the module and the terms
of the conditions evaluate to
xa is None | self.key not in kv_cache | module not in cache | |
---|---|---|---|
first iteration | False | True | True |
subsequent iterations | False | False | False |
Hence, the module itself is only evaluated in the first iteration. As is the
hook. Subsequent calls to forward
simply retrieve keys computed in the first
call to forward
. The following interactions with the cache result
# First call to forward()
k = self.key(xa)
# Due to the hook, this call is equivalent to
kv_cache[self.key] = self.key.forward(xa)
k = kv_cache[self.key]
# Subsequent calls to forward()
k = kv_cache[self.key]
For self attention, the audio input xa
is not not passed to the module and the
table reads
xa is None | self.key not in kv_cache | module not in cache | |
---|---|---|---|
first iteration | True | True | True |
subsequent iterations | True | False | False |
Therefore the key
module is always evaluated. The first call to forward
evaluates the key
and caches it. Subsequent calls to forward
append the key
of the current token. The following interactions with the cache result
# First call to forward
k = self.key(x)
# Due to the hook, this call is equivalent to
kv_cache[self.key] = self.key.forward(xa)
k = kv_cache[self.key]
# Subsequent calls to forward
k = self.key(x)
# Due to the hook, this call is equivalent to
kv_cache[self.key] = torch.concat(
[kv_cache[self.key], self.key.forward(xa)],
dim=1,
).detach()
k = kv_cache[self.key]
The values modules are treated equivalently. Combined, these mechanisms ensure that keys and values are only computed once. In the next post I will discuss how to translate this behavior into ONNX.