As promised in the last post, I will cover the ONNX conversion of the full Whisper model, including key-value caching, in this post. The full model resists a straightforward ONNX conversion due to its reliance on hooks and branching control flow. In this post I will discuss how to patch the model to make it exportable.
One helpful idea in any ONNX conversion is the fact that there is no need to export the model that was previously trained. It is possible, and often necessary, to initialize a different model with the trained parameters and then export this model. One common example is to wrap one model in another to customize arguments and return values.
In the case of Whisper, though, a simple wrapper is not sufficient. There is an asymmetry between the first call and subsequent calls of the model. Such a change in behavior is not directly supported by the tracing exporter of PyTorch. One option would be to re-write the complete model from scratch. A quite realistic task, as the model comes in at around 300 lines of Python. This approach is taken, for example, by the Whipser conversion of sherpa-onnx.
Here, I will use a different approach and patch the forward
method of the
MultiHeadAttention
module. The new forward
method will result in the
following key computation (as in the last post I will focus on the keys,
the values are addressed equivalently):
def forward(self, x, xa=None, mask=None, *, kv_cache):
# compute the new key
k = self.key(x if xa is None else xa).detach()
# append it to the cached key
k = torch.cat([kv_cache[self.key], k], dim=1)
# limit the key to the maximum context length
k = k[:, -self.n_ctx :, :]
# store the key for the next iteration
kv_cache[self.key] = k
# [..] compute values the same way
# default implementation of MultiHeadAttention.forward
q = self.query(x)
wv, qk = self.qkv_attention(q, k, v, mask)
return self.out(wv), qk
This new forward
method is combined with the following calling convention: In
the first call pass a cache dictionary with empty tensors, the full audio input,
and the complete initial tokens. In subsequent calls pass the updated caches, a
zero time-step audio input, and only the newly generated token. This calling
convention in combination with the new forward
method results in the same
behavior as the original model. In essence the asymmetry between initial and
subsequent calls has been moved outside the model into the outer code calling
the model.
The conversion script includes the new forward
method as part of a
new module, FunctionalMultiHeadAttention
. It inherits from OpenAI's
MultiHeadAttention
module and overwrites the forward
method. To use it, we
iterate over all attention blocks and patch the objects to use the new class:
for block in model.decoder.blocks:
block.attn.__class__ = FunctionalMultiHeadAttention
block.attn.n_ctx = model.dims.n_text_ctx
block.cross_attn.__class__ = FunctionalMultiHeadAttention
block.cross_attn.n_ctx = model.dims.n_audio_ctx
To fully export the model, we use an additional wrapper to only use tensor inputs and outputs. It converts the cache dictionary into tensor inputs and outputs as in
# unpack the cache tensor into a dictionary
kv_cache = dict(zip(keys_self_attn, cache_self_attn))
# [..] call the model
# pack the dictionary into a single cache tensor
cache_self_attn = torch.stack([kv_cache[key] for key in keys_self_attn])
The complete model, after the ONNX export, can be used as follows
cache_self_attn = np.zeros(shape_empty_cache_self_attn, dtype="float32")
cache_cross_attn = np.zeros(shape_empty_cache_cross_attn, dtype="float32")
logits, cache_self_attn, cache_cross_attn = sess_decoder.run(
["logits", "new_cache_self_attn", "new_cache_cross_attn"],
{
"x": x_tokens,
"xa": x_audio,
"cache_self_attn": cache_self_attn,
"cache_cross_attn": cache_cross_attn,
},
)
generated = [x_tokens]
while running:
x_tokens = logits[:, -1:, :].argmax(axis=2)
generated.append(x_tokens)
logits, cache_self_attn, cache_cross_attn = sess_decoder.run(,
["logits", "new_cache_self_attn", "new_cache_cross_attn"],
{
"x": x_tokens,
"xa": x_audio[:, :0, :],
"cache_self_attn": cache_self_attn,
"cache_cross_attn": cache_cross_attn,
},
)
generated = np.concatenate(generated, axis=1)
See the complete conversion script for details. Continuing this series, the next posts will describe how to leverage the ONNX runtime to execute the converted model in Rust.