Run Flux fast on H100s with torch.compile
Update: To speed up inference by another >2x, check out the additional optimization techniques we tried in this blog post!
In this guide, we’ll run Flux as fast as possible on Modal using open source tools.
We’ll use torch.compile and NVIDIA H100 GPUs.
Setting up the image and dependencies
We’ll make use of the full CUDA toolkit in this example, so we’ll build our container image off of the nvidia/cuda base.
Now we install most of our dependencies with apt and pip.
For Hugging Face’s Diffusers library
we install from GitHub source and so pin to a specific commit.
PyTorch added faster attention kernels for Hopper GPUs in version 2.5.
Later, we’ll also use torch.compile to increase the speed further.
Torch compilation needs to be re-executed when each new container starts,
so we turn on some extra caching to reduce compile times for later containers.
Finally, we construct our Modal App,
set its default image to the one we just constructed,
and import FluxPipeline for downloading and running Flux.1.
Defining a parameterized Model inference class
Next, we map the model’s setup and inference code onto Modal.
We run the model setup in the method decorated with
@modal.enter(). This includes loading the weights and moving them to the GPU, along with an optionaltorch.compilestep (see details below). The@modal.enter()decorator ensures that this method runs only once, when a new container starts, instead of in the path of every call.We run the actual inference in methods decorated with
@modal.method().
Note: Access to the Flux.1-schnell model on Hugging Face is gated by a license agreement which you must agree to here.
After you have accepted the license, create a Modal Secret with the name huggingface-secret following the instructions in the template.
Calling our inference function
To generate an image we just need to call the Model’s generate method
with .remote appended to it.
You can call .generate.remote from any Python environment that has access to your Modal credentials.
The local environment will get back the image as bytes.
Here, we wrap the call in a Modal local_entrypoint so that it can be run with modal run:
By default, we call generate twice to demonstrate how much faster
the inference is after cold start. In our tests, clients received images in about 1.2 seconds.
We save the output bytes to a temporary file.
Speeding up Flux with torch.compile
By default, we do some basic optimizations, like adjusting memory layout and re-expressing the attention head projections as a single matrix multiplication. But there are additional speedups to be had!
PyTorch 2 added a compiler that optimizes the compute graphs created dynamically during PyTorch execution. This feature helps close the gap with the performance of static graph frameworks like TensorRT and TensorFlow.
Here, we follow the suggestions from Hugging Face’s guide to fast diffusion inference, which we verified with our own internal benchmarks. Review that guide for detailed explanations of the choices made below.
The resulting compiled Flux schnell deployment returns images to the client in under a second (~700 ms), according to our testing. Super schnell!
Compilation takes up to twenty minutes on first iteration.
As of time of writing in late 2024,
the compilation artifacts cannot be fully serialized,
so some compilation work must be re-executed every time a new container is started.
That includes when scaling up an existing deployment or the first time a Function is invoked with modal run.
We cache compilation outputs from nvcc, triton, and inductor,
which can reduce compilation time by up to an order of magnitude.
For details see this tutorial.
You can turn on compilation with the --compile flag.
Try it out with:
The compile option is passed by a modal.parameter on our class.
Each different choice for a parameter creates a separate auto-scaling deployment.
That means your client can use arbitrary logic to decide whether to hit a compiled or eager endpoint.