Improve Your RAG Accuracy by 30% with Lettria's GraphRAG. Download our free white paper.

Speed up your BERT inference with ONNX-Torchscript

In recent years models based on the Transformer architecture have been the driving force behind NLP breakthroughs in research and industry.

Increase your rag accuracy by 30% with Lettria

In recent years models based on the Transformer architecture have been the driving force behind NLP breakthroughs in research and industry. BERT, XLNET, GPT or XLM are some of the models that improved the state of the art and reached the top of popular benchmarks like GLUE.

These advances come with a steep computational cost, most transformer based models are massive and both the number of parameters and the data used for training are constantly increasing. While the original BERT model had already 110 million parameters, the last GPT-3 has 175 billion, a staggering ~1700x increase in two years worth of research.

These massive models usually need hundreds of GPU for several days of training to be effective, fortunately thanks to transfer learning we can download pretrained models and quickly fine-tune them on our own much smaller datasets for a low cost.

That being said, once training is done you still have a massive model on your hands that you may want to deploy into production. Inference takes a relatively long time compared to more modest models and it may be too slow to achieve the throughput you need.

While you could invest into faster hardware or use more servers to do the job, there are different ways to reduce the inference time of your model:

  • Model pruning : Reduce the number of layers, the dimension of the embeddings or the number of units in hidden layers.
  • Quantization : Instead of using 32-bit float (FP32) for weights, use half-precision (FP16) or even 8-bit integer.
  • Exporting a model from native Pytorch/Tensorflow to an approriate format or inference engine (Torchscript/ONNX/TensorRT...)
  • Batching: Predict on batch of samples instead of individual samples

The first and second approach usually imply retraining of your model while the last two approaches are done post-training and are essentially agnostic to your particular task.

If inference speed is extremely important for your use case then you will most likely need to experiment with all of these methods to produce a reliable and blazingly fast model. In most cases however exporting your model to an appropriate format/framework and predicting on batches will give you much faster results for a minimal amount of work, we will focus on this approach here to see the impact it can have on the throughput of our model.

We will explore the effects of changing model format and batching with a few experiments:

  • Baseline with vanilla Pytorch CPU/GPU
  • Export Pytorch model to Torchscript CPU/GPU
  • Pytorch model to ONNX CPU/GPU
  • All experiments run on 1/2/4/8/16/32/64 samples batches

As of this post it is not yet possible to directly export a transformer model from Pytorch to TensorRT due to the lack of support of int64 used by Pytorch embeddings so we will skip it for now.

We will perform sentence classification on camemBERT (~100M parameters), a french variant of Roberta. Since the vast majority of the computation is done inside the transformer model, you should have similar results regardless of your task.

First we’ll take a quick look at how to export a Pytorch model to the relevant format/framework, if you don’t want to read code you can skip to the results section further down.

Want to see how easy it is to implement GraphRAG?

How to export your model

Vanilla Pytorch

Saving and loading a model in Pytorch is quite straightforward, though there are different ways to proceed. For inference the official documentation recommends to save the ‘state_dict’ of your model, which is a python dictionary containing the learnable parameters of your model. This is more lightweight and robust than pickling your whole model.

#savingmodel = SequenceClassifier()train_model(model)torch.save(model.state_dict(), 'pytorch_model.pt')#loadingmodel = SequenceClassifier()model.load_state_dict(torch.load(PATH))model.eval() #Set dropout and batch normalization layers to evaluation modewith torch.go_grad():   logits = model(**batch_x)

Follow this link for additional information on saving/loading on Pytorch.

Torchscript JIT

TorchScript is a way to create serializable and optimizable models from your Pytorch code. Once exported to Torchscript your model will be runnable from Python and C++.

  • Trace: An input is sent through the model and all operations are recorded in a graph that will define your torchscript model.
  • Script: If your model is more complex and has control flow such as conditional statements, scripting will inspect the source code of the model and compile it as TorchScript code.

Note that since your model will be serialized you won’t be able to modify it after it has been saved, therefore you should put it in evaluation mode and export it on the appropriate device before saving.

If you want to do inference both on CPU and GPU you need to save 2 different models.

#savingjit_sample = (batch_x['input_ids'].int().to(device), batch_x['attention_mask'].int().to(device))model.eval()model.to(device)module = torch.jit.trace(model, jit_sample)torch.jit.save('model_jit.pt')#loadingmodel = torch.jit.load('model_jit.pt',         map_location=torch.device(device))logits = model(**batch_x)

For a more comprehensive introduction you can follow the official tutorial.

ONNX

ONNX provides an open source format for AI models, most frameworks can export their model to the ONNX format. In addition to interoperability between frameworks ONNX comes with some optimization that should accelerate inference.

Exporting to ONNX is slightly more complicated but Pytorch does provide a direct export function, you only need to provide some key information.

  • opset_version, for each version there is a set of operators that are supported, some models with more exotic architectures may not be exportable yet.
  • input_names and output_names are the names to assign to the input and output nodes of the graph.
  • dynamic_axes argument is a dictionary which indicates which dimension of your input and output variables may change, for example the batch_size or the length of the sequence.


#savinginput_x = jit_sample ## taking sample from previous exampletorch.onnx.export(model, input_x,'model_onnx.pt',export_params=True,  opset_version=11, do_constant_folding=True, input_names = ['input_ids', 'attention_mask'], output_names = ['output'],dynamic_axes= {'input_ids' : {0 : 'batch_size', 1:'length'},'attention_mask' : {0 : 'batch_size', 1:'length'},'output' : {0 : 'batch_size'}})#loadingmodel = onnxruntime.InferenceSession(model_onnx)batch_x = {'input_ids':sample['input_ids'].cpu().numpy(),"attention_mask":sample['attention_mask'].cpu().numpy()}logits = model.run(None, batch_x)

ONNX runtime can be used with a GPU though it does require specific versions of CUDA, cuDNN and OS making the installation process challenging at first.

For a more comprehensive tutorial you can follow the official documentation.

Experimental results

Each configuration has been run 5x times on a dataset of 1k sentences of various lengths. We tested 2 different popular GPU: T4 and V100 with torch 1.7.1 and ONNX 1.6.0. Keep in mind that the results will vary with your specific hardware, packages versions and dataset.

Mean inference time in ms per sequence

Inference time ranges from around 50 ms per sample on average to 0.6 ms on our dataset, depending on the hardware setup.

On CPU the ONNX format is a clear winner for batch_size <32, at which point the format seems to not really matter anymore. If we predict sample by sample we see that ONNX manages to be as fast as inference on our baseline on GPU for a fraction of the cost.

As expected inference is much quicker on a GPU especially with higher batch size. We can also see that the ideal batch size depends on the GPU used:

  • For the T4 the best setup is to run ONNX with batches of 8 samples, this gives a ~12x speedup compared to batch size 1 on pytorch
  • For the V100 with batches of 32 or 64 we can achieve up to a ~28x speedup compared to the baseline for GPU and ~90x for baseline on CPU.

Overall we find that choosing an appropriate format has a significant impact for smaller batch sizes but that impact narrows down as batches get larger, for batches of 64 samples the 3 setups are within ~10% of each other.

Impact of sequence length and batching strategy

Another thing to take into account is sequence length. Transformers are usually restricted to sequences of 512 tokens but there is a massive difference in speed and memory requirement for different sequences lengths in that range.

Mean inference time in ms per sequence on T4 GPU
Mean inference time in ms per sequence on V100 GPU

Inference time scales up roughly linearly with sequence length for larger batches but not for individual samples. This means that if your data is made of long sequences of text (news articles for example) then you won’t get as big speedup by batching. As always this depends on your hardware, a V100 is faster than a T4 and won’t suffer as much when predicting long sequences whereas on the other hand our CPU does get completely overwhelmed:

Mean inference time in ms per sequence on CPU

If your data is heterogeneous length-wise and you work with batches these discrepancies will cause problem due to the necessity to pad your samples to the longest one in your batch, which adds a lot of computation. Therefore it is usually better to batch samples of similar length together as it is most likely quicker to predict multiple batches of similar length than one big batch that will be mostly padding tokens.

As a quick check let’s look what happens when we sort our dataset prior to running inference:

Mean inference time in ms per sequence

As we expected for larger batch sizes there is a significant incentive to group samples of similar length. For unsorted data, as batches get larger there is an increasing probability to end up with some longer samples that will significantly increase the inference time of the whole batch, we can see that going from 16 to 64 batch_size slow down inference by 20% while it gets 10% faster with sorted data.

This strategy can also be used to significantly reduce your training time, however this should be done with caution since it may negatively impact the performance of your model, especially if there is some correlation between your labels and the length of your samples.

Next steps

While these experiments have been run directly in Python, both Torchscript and ONNX models can be loaded directly in C++, this could provide an additional boost in inference speed.

If your model is still too slow for your use-case, Pytorch does provide different options for quantization . ‘dynamic quantization’ can be done post-training but it will most likely have an impact on the accuracy of your model, while ‘quantization aware training’ requires retraining but it should have less impact on your model performance.

Conclusion

As we have seen there is no straightforward answer for optimizing your inference time since it depends mostly on your specific hardware and the problem you are trying to solve. Consequently you should perform your experiments with your own target hardware and data to obtain reliable results.

Nonetheless there are some guidelines that should hold true and are easy to implement:

  • Predicting on batches can provide significant speedup up to a certain size (depending on your specific hardware), especially if you can batch samples of similar length together
  • Using Torchscript or ONNX does provide significant speedup for lower batch size and sequence length, the effect is particularly strong when running inference on individual samples.
  • ONNX seems to be the best performing of the three configuration we have tested, though it is also the most difficult to install for inference on GPU
  • Torchscript does provide reliable speedup for smaller batch sizes and is very easy to setup.
Ready to revolutionize your RAG?

Callout

Get started with GraphRAG in 2 minutes
Talk to an expert ->