Whisper in ONNX with key value caching pt. 2

As promised in the last post, I will cover the ONNX conversion of the full Whisper model, including key-value caching, in this post. The full model resists a straightforward ONNX conversion due to its reliance on hooks and branching control flow. In this post I will discuss how to patch the model to make it exportable.

One helpful idea in any ONNX conversion is the fact that there is no need to export the model that was previously trained. It is possible, and often necessary, to initialize a different model with the trained parameters and then export this model. One common example is to wrap one model in another to customize arguments and return values.

In the case of Whisper, though, a simple wrapper is not sufficient. There is an asymmetry between the first call and subsequent calls of the model. Such a change in behavior is not directly supported by the tracing exporter of PyTorch. One option would be to re-write the complete model from scratch. A quite realistic task, as the model comes in at around 300 lines of Python. This approach is taken, for example, by the Whipser conversion of sherpa-onnx.

Here, I will use a different approach and patch the forward method of the MultiHeadAttention module. The new forward method will result in the following key computation (as in the last post I will focus on the keys, the values are addressed equivalently):

def forward(self, x, xa=None, mask=None, *, kv_cache):
    # compute the new key
    k = self.key(x if xa is None else xa).detach()

    #  append it to the cached key
    k = torch.cat([kv_cache[self.key], k], dim=1)

    # limit the key to the maximum context length
    k = k[:, -self.n_ctx :, :]

    # store the key for the next iteration
    kv_cache[self.key] = k

    # [..] compute values the same way

    # default implementation of MultiHeadAttention.forward
    q = self.query(x)
    wv, qk = self.qkv_attention(q, k, v, mask)
    return self.out(wv), qk

This new forward method is combined with the following calling convention: In the first call pass a cache dictionary with empty tensors, the full audio input, and the complete initial tokens. In subsequent calls pass the updated caches, a zero time-step audio input, and only the newly generated token. This calling convention in combination with the new forward method results in the same behavior as the original model. In essence the asymmetry between initial and subsequent calls has been moved outside the model into the outer code calling the model.

The conversion script includes the new forward method as part of a new module, FunctionalMultiHeadAttention. It inherits from OpenAI's MultiHeadAttention module and overwrites the forward method. To use it, we iterate over all attention blocks and patch the objects to use the new class:

for block in model.decoder.blocks:
    block.attn.__class__ = FunctionalMultiHeadAttention
    block.attn.n_ctx = model.dims.n_text_ctx

    block.cross_attn.__class__ = FunctionalMultiHeadAttention
    block.cross_attn.n_ctx = model.dims.n_audio_ctx

To fully export the model, we use an additional wrapper to only use tensor inputs and outputs. It converts the cache dictionary into tensor inputs and outputs as in

# unpack the cache tensor into a dictionary
kv_cache = dict(zip(keys_self_attn, cache_self_attn))

# [..] call the model

# pack the dictionary into a single cache tensor
cache_self_attn = torch.stack([kv_cache[key] for key in keys_self_attn])

The complete model, after the ONNX export, can be used as follows

cache_self_attn = np.zeros(shape_empty_cache_self_attn, dtype="float32")
cache_cross_attn = np.zeros(shape_empty_cache_cross_attn, dtype="float32")

logits, cache_self_attn, cache_cross_attn = sess_decoder.run(
    ["logits", "new_cache_self_attn", "new_cache_cross_attn"],
        "x": x_tokens,
        "xa": x_audio,
        "cache_self_attn": cache_self_attn,
        "cache_cross_attn": cache_cross_attn,

generated = [x_tokens]
while running:
    x_tokens = logits[:, -1:, :].argmax(axis=2)

    logits, cache_self_attn, cache_cross_attn = sess_decoder.run(,
        ["logits", "new_cache_self_attn", "new_cache_cross_attn"],
            "x": x_tokens,
            "xa": x_audio[:, :0, :],
            "cache_self_attn": cache_self_attn,
            "cache_cross_attn": cache_cross_attn,

generated = np.concatenate(generated, axis=1)

See the complete conversion script for details. Continuing this series, the next posts will describe how to leverage the ONNX runtime to execute the converted model in Rust.