From a26e57f9f329c6cbbaa914f987e23591d7a9f369 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 1 Jul 2024 11:54:34 +0000 Subject: [PATCH] Fixing non flash tests/imports. --- server/text_generation_server/models/__init__.py | 3 ++- .../models/custom_modeling/flash_cohere_modeling.py | 6 +++--- server/text_generation_server/models/flash_causal_lm.py | 4 ++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index f2f0f457..5ea43290 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -12,7 +12,6 @@ from pathlib import Path from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model from text_generation_server.models.causal_lm import CausalLM -from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.bloom import BLOOMSharded from text_generation_server.models.mpt import MPTSharded from text_generation_server.models.seq2seq_lm import Seq2SeqLM @@ -53,6 +52,7 @@ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." FLASH_ATTENTION = True try: + from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_rw import FlashRWSharded from text_generation_server.models.flash_gpt2 import FlashGPT2 from text_generation_server.models.flash_neox import FlashNeoXSharded @@ -92,6 +92,7 @@ except ImportError as e: FLASH_ATTENTION = False if FLASH_ATTENTION: + __all__.append(FlashCausalLM) __all__.append(FlashGPT2) __all__.append(FlashNeoXSharded) __all__.append(FlashRWSharded) diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index c51cce3b..e088f9aa 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -387,8 +387,8 @@ class FlashCohereLayer(nn.Module): cu_seqlen_prefill, kv_cache, block_tables, - input_lengths, slots, + input_lengths, max_s, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -401,8 +401,8 @@ class FlashCohereLayer(nn.Module): cu_seqlen_prefill, kv_cache, block_tables, - input_lengths, slots, + input_lengths, max_s, ) @@ -475,8 +475,8 @@ class FlashCohereModel(torch.nn.Module): cu_seqlen_prefill, kv_cache[i], block_tables, - input_lengths, slots, + input_lengths, max_s, ) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 7ad1c8c5..49a088a1 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1086,7 +1086,7 @@ class FlashCausalLM(Model): # Dummy value, some models (starcoder2) don't accept `None`. input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) - seqlen = Seqlen(input_lengths=input_lengths) + input_lengths = Seqlen(input_lengths=input_lengths) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( @@ -1097,7 +1097,7 @@ class FlashCausalLM(Model): ), kv_cache=self.kv_cache, block_tables=None, - seqlen=seqlen, + input_lengths=input_lengths, slots=slots, max_s=seqlen, lm_head_indices=None,