* launcher: ensure correct detection of Gemma 3 head size
* Support flashinfer for Gemma3 prefill
Gemma3 uses bidirectional attention for images. Flashinfer
supports custom masks. Hook up the mask with flashinfer, so that we do
not have to use the slower SDPA implementation for prefills with images.
* Update Gemma3 test outputs
* Fixed unused import
* Basic flashinfer 0.2 support
This change does not use any of the new features yet, but makes
some small compatibility changes.
* Update to flashinfer 0.2.0.post1
* flashinfer: remove `contiguous` calls
* Fix flashinfer install
* flashinfer: fixup kv cache dtype
* Fix some annoying perturbations
* More output changes
* Add support for FP8 KV cache scales
Since FP8 only has limited dynamic range, we can scale keys/values
before storing them into the cache (and unscale them in attention). To
avoid rescaling the cache as the absmax values change, good scales are
usually determined per layer using calibration calibration data and stored
in the checkpoint.
This change adds support for for using key-value scales and loading them
from checkpoints in the two most common formats:
- Separate per-layer `k_scale` and `v_scale` scalars.
- Per-layer `kv_scale` scalar (older format).
Currently, scales are only used with an `float8_e4m3fn` cache.
Besides adding support for key/value scales, the `fp8_quantize` function
is also extended to support quantization with a kernel vendored from
vLLM. This is slightly faster than the PyTorch implementation, but also
scales in FP32, potentially improving accuracy.
* Update FP8 KV cache test to use checkpoint with scales
* `can_scale`: check that the attention is flashinfer