Whisper in ONNX

Recently, I have been looking into using the Whisper text to speech model from Rust. While packages like burn or candle do look quite promising and do offer Whisper models, I decided to stick with ONNX for now, as it served me quite well in the past. In this post, I would like to discuss the necessary steps to convert the Whisper model to ONNX.

To convert a model to ONNX, I usually follow three major steps:

  1. Execute the model in Python to generate reference inputs and outputs
  2. Convert the model to ONNX
  3. Execute the model in ONNX and compare the results

The Whisper model is somewhat complex. It uses an encoder-decoder architecture with cross attention (architecture diagram). First the input audio is converted into a mel-frequency cepstrum and then passed through a convolution stack followed by a transformer encoder. The encoded audio signal is then used as a conditioning input in a auto-regressive transformer decoder model, similar to GPT.

Much of the implementation by OpenAI is concerned with getting the maximum performance and accuracy from the model. The repository offers a lot of decoding options and also uses key-value caching as a performance optimization. In the following I will focus on the simplest decoding approach, greedy decoding without timestamp prediction, and ignore key-value caching.

To illustrate how the Whisper model works, it's instructive to look at the simples possible inference in plain Python (see also the full conversion script). The whisper model requires a 30 second audio input. I used a small Rust program to capture a 30 second sample from my microphone and convert it into a mel-frequency cepstrum. This input is then passed through the encoder model

# x_mel shape: [batch, coeff=80, time=3000] 
# x_audio shape: [batch, time=1500, feature=512]
x_audio = model.encoder(x_mel)

The text is decoded by autoregressively applying the decoder model. We initialize the predicted tokens with a fixed sequence that instructs the model of the task

# shape: [batch, seq<=448]
x_tokens = torch.tensor(
    [tokenizer.sot_sequence_including_notimestamps],
    dtype=torch.long,
)

And then use the decoder to predict the next token in a loop until the end-of-text token or the maximum number of tokens is predicted

# run the decoding loop using greedy decoding
next_token = tokenizer.sot
while x_tokens.shape[1] <= model.dims.n_text_ctx and next_token != tokenizer.eot:
    y_tokens = model.decoder(x_tokens, x_audio)

    next_token = y_tokens[0, -1].argmax()        
    x_tokens = torch.concat(
        [x_tokens, next_token.reshape(1, 1)], 
        axis=1,
    )

Finally, we can use the tokenizer to map back the generated tokens to text

print(tokenizer.decode(x_tokens[0]))

With the running inference example at hand, we can go about the ONNX conversion. PyTorch offers two main APIs for converting models to ONNX: torch.onnx.export and the newer torch.onnx.dynamo_export. Here I will use the former, as the latter failed in my experiment. For torch.onnx.export, we need to specify the dynamic axes explicitly. For the encoder, only the batch dimension is dynamic as whisper uses a fixed 30 second audio window. For the decoder, both the batch and the sequence dimensions, i.e., the number of generated tokens, are dynamic. The full export looks like

torch.onnx.export(
    model.encoder, 
    (x_mel,), 
    "./tmp/encoder.onnx", 
    input_names=["x"], 
    output_names=["out"],
    dynamic_axes={
        "x": {0: "batch"},
        "out": {0: "batch"},
    },
)

torch.onnx.export(
    model.decoder, 
    (x_tokens, x_audio), 
    "./tmp/decoder.onnx", 
    input_names=["tokens", "audio"], 
    output_names=["out"], 
    dynamic_axes={
        "tokens": {0: "batch", 1: "seq"},
        "audio": {0: "batch"},
        "out": {0: "batch", 1: "seq"},
    },
)

To execute the model, we first build the inference sessions

import onnxruntime

sess_encoder = onnxruntime.InferenceSession("./tmp/encoder.onnx")
sess_decoder = onnxruntime.InferenceSession("./tmp/decoder.onnx")

and can then execute the models just as we did in PyTorch

out_encoder, = sess_encoder.run(["out"], {"x": x_mel.numpy()})

# initialize the tokens
tokens = list(tokenizer.sot_sequence_including_notimestamps)

next_token = tokenizer.sot
while x_tokens.shape[1] <= max_tokens and next_token != tokenizer.eot:
    out_decoder, = sess_decoder.run(
        ["out"], 
        {
            "tokens": np.asarray([tokens], dtype="int64"), 
            "audio": out_encoder,
        },
    )
    next_token = out_decoder[0, -1].argmax()
    tokens.append(next_token)

print(tokenizer.decode(x_tokens[0]))

On my CPU I get a moderate speed-up with the ONNX conversion. Note however, that the ONNX version is still 4x slower than the PyTorch model using kv-caching. The next steps will be to include kv-caching in the ONNX model, as well, to look into distill-whisper, and to use the ort Rust library to execute the model from Rust. You can find the complete conversion script here.