Surviving CUDA OOMs by halving the batch

This is a pattern from the inference engine I built at Superlinked, where embedding throughput matters a lot and GPU memory is the cliff.

Every GPU inference service I have ever written has eventually hit a CUDA OOM in production. The input that caused it was always something completely benign. A slightly longer text, a slightly larger image, a batch where one item was a 2MB PNG instead of a 200KB one. The service tipped over for an input that on its own would have been fine.

The lazy fix is to lower the batch size everywhere and over-provision the GPU. That works, and costs you 2-3x in throughput. The better fix is to treat the OOM as a signal, not a failure.

The pattern

Permalink to “The pattern”
import torch


async def encode_with_backoff(model, inputs):
    batch_size = len(inputs)
    while batch_size >= 1:
        try:
            return await _encode_batches(model, inputs, batch_size)
        except torch.cuda.OutOfMemoryError:
            torch.cuda.empty_cache()
            batch_size = max(1, batch_size // 2)
            log.warning(
                "cuda OOM at batch_size=%d, retrying at %d",
                batch_size * 2, batch_size,
            )
    raise RuntimeError("OOM even at batch_size=1; input is too large for this GPU")

That is the whole idea. Catch torch.cuda.OutOfMemoryError, empty the cache, halve the batch, try again. Continue until you reach 1, then fail honestly.

The serialisation trick

Permalink to “The serialisation trick”

When the batch hits 1 you have a different problem. If the service is concurrent, two single-item calls can still OOM together because they fight each other for VRAM. So when the batch size collapses to 1, route the call through a semaphore so only one runs at a time:

self._gpu_sem = asyncio.Semaphore(1)

async def _encode_one(model, item):
    async with self._gpu_sem:
        return await _encode_batches(model, [item], 1)

CPU-side work (decoding, tokenisation) keeps running concurrently. Only the GPU pass is serialised. Throughput drops on the offending input, but the service stays up, which is the point.

Logging is half the value

Permalink to “Logging is half the value”

The retry alone keeps you alive. The logs are what let you actually fix the root cause. Every OOM I have ever hit told me one of two things. Either a specific input was unusually large (find it, document it, decide if it is real traffic or a probe), or the configured batch size was just wrong for the GPU (the service repeatedly halves to the same number under load; lower the default to that number and stop re-discovering it every cold start).

I keep the warning at WARN, not ERROR. ERROR pages people for something the service has already handled, and that is rude.

The reason this pattern is not already standard, I think, is that most tutorials run on a single example, not under concurrent load on a shared GPU. The retry is obvious in retrospect and absent from every “deploy a model with X” guide I have read.

If you ship inference, write the retry loop on day one. The first time it fires in production you will be glad you did, and you will not learn anything new from the page that wakes you up at 3am.