From 6defe57d7aa1ee7a183603f396f4fb86926d9027 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 6 Jan 2025 16:07:58 +0000 Subject: [PATCH] flashinfer: remove `contiguous` calls --- server/text_generation_server/layers/attention/cuda.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 3038602e..7b5af3c4 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -60,8 +60,7 @@ def paged_attention( from text_generation_server.layers.attention.flashinfer import decode_state return decode_state.get().forward( - # TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged. - query.contiguous(), + query, paged_kv_cache=(kv_cache.key, kv_cache.value), logits_soft_cap=softcap, sm_scale=softmax_scale, @@ -231,8 +230,7 @@ def attention( softcap = 0.0 return prefill_with_paged_kv_state.get().forward( - # TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged. - query.contiguous(), + query, causal=causal, paged_kv_cache=(kv_cache.key, kv_cache.value), logits_soft_cap=softcap,