text-generation-inference/launcher/src/gpu.rs
Daniël de Kok 5b6b74e21d
Improve support for GPUs with capability < 8 (#2575)
* Improve support for GPUs with capability < 8

- For models that cannot use flashinfer, use flash-attn v1 + paged
  attention for models with a compute capability older than 8.
- Disable prefix caching when using paged attention.
- When using flash-attn v1, pass the key/value, rather than the
  cache, since v1 cannot use block tables.

* nix: add flash-attn-v1 to the server environment

* Move disabling prefix caching into the block of exceptions

* Capability as `usize`s
2024-09-27 16:19:42 +02:00

27 lines
907 B
Rust

use std::sync::LazyLock;
pub static COMPUTE_CAPABILITY: LazyLock<Option<(usize, usize)>> =
LazyLock::new(get_cuda_capability);
fn get_cuda_capability() -> Option<(usize, usize)> {
use pyo3::prelude::*;
let py_get_capability = |py: Python| -> PyResult<(isize, isize)> {
let torch = py.import_bound("torch.cuda")?;
let get_device_capability = torch.getattr("get_device_capability")?;
get_device_capability.call0()?.extract()
};
match pyo3::Python::with_gil(py_get_capability) {
Ok((major, minor)) if major < 0 || minor < 0 => {
tracing::warn!("Ignoring negative GPU compute capabilities: {major}.{minor}");
None
}
Ok((major, minor)) => Some((major as usize, minor as usize)),
Err(err) => {
tracing::warn!("Cannot determine GPU compute capability: {}", err);
None
}
}
}