Whisper in ONNX with key value caching pt 1

This post continues my quest to convert the Whisper model to ONNX. Last time, I discussed how to convert the model without the optional performance optimization of key value caching. After some trial and error, I managed to convert the full model. In this post, I will discuss how the Whisper model implements key value caching, before I cover the actual conversion in the next post.

The decoder generates text in an iterative fashion. In each iteration it is passed all the tokens generated so far. The idea behind key value caching is to memorize computations that are repeated in the different iterations. The attention mechanism includes three components: keys, values, and queries. While keys and values can be cached, the queries need to be recomputed for each iteration. In the following, I will only discuss the keys, as the values are handled in an equivalent manner in the model.

The model includes two different attention mechanisms with two different ways of computing keys: cross attention computes keys x_audio @ W_k that are independent of the generated token. Self attention computes keys x_tokens @ W_k. This expression can be rewritten as a concatenation of the previously generated keys and the current key

keys =  concat([keys_prev_tokens, x_tokens[:, -1:, :] @ W_k], dim=1)

To avoid these recomputations, the model introduces key value caching in two places: the forward method of the MultiHeadAttention module and the forward hook attached to the key and value modules. Forward hooks are called after the module's forward method and allow to modify the return value. Due to the hook, any call to key(x) in the MultiHeadAttention module is equivalent to

k = hook(key, x, key.forward(x))

The relevant parts of the forward method and the hook read, slightly simplified,

# forward
if xa is None or self.key not in kv_cache:
    k = self.key(x if xa is None else xa)
else:
    k = kv_cache[self.key]

# hook
if module not in cache:
    cache[module] = output
else:
    cache[module] = torch.cat([cache[module], output], dim=1).detach()

return cache[module]

For cross attention, the audio input xa is passed to the module and the terms of the conditions evaluate to

xa is Noneself.key not in kv_cachemodule not in cache
first iterationFalseTrueTrue
subsequent iterationsFalseFalseFalse

Hence, the module itself is only evaluated in the first iteration. As is the hook. Subsequent calls to forward simply retrieve keys computed in the first call to forward. The following interactions with the cache result

# First call to forward()
k = self.key(xa)

# Due to the hook, this call is equivalent to
kv_cache[self.key] = self.key.forward(xa)
k = kv_cache[self.key]

# Subsequent calls to forward()
k = kv_cache[self.key]

For self attention, the audio input xa is not not passed to the module and the table reads

xa is Noneself.key not in kv_cachemodule not in cache
first iterationTrueTrueTrue
subsequent iterationsTrueFalseFalse

Therefore the key module is always evaluated. The first call to forward evaluates the key and caches it. Subsequent calls to forward append the key of the current token. The following interactions with the cache result

# First call to forward
k = self.key(x)

# Due to the hook, this call is equivalent to
kv_cache[self.key] = self.key.forward(xa)
k = kv_cache[self.key]

# Subsequent calls to forward
k = self.key(x)

# Due to the hook, this call is equivalent to
kv_cache[self.key] = torch.concat(
    [kv_cache[self.key], self.key.forward(xa)],
    dim=1,
).detach()

k = kv_cache[self.key]

The values modules are treated equivalently. Combined, these mechanisms ensure that keys and values are only computed once. In the next post I will discuss how to translate this behavior into ONNX.