diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 47758d30e..6e23aa2bd 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -230,11 +230,15 @@ class LlamaMLP(nn.Module): config.intermediate_size // weights.process_group.size() ) + # TODO: This is a hotfix to be removed & properly refactored. + self.quantize = config.quantize + def forward(self, hidden_states): if ( SYSTEM == "rocm" and self.hidden_act == "silu" and hidden_states.shape[0] == 1 + and not self.quantize ): out = torch.empty( hidden_states.shape[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 21edc79ef..ef3777dad 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -290,11 +290,15 @@ class MistralMLP(nn.Module): config.intermediate_size // weights.process_group.size() ) + # TODO: This is a hotfix to be removed & properly refactored. + self.quantize = config.quantize + def forward(self, hidden_states): if ( SYSTEM == "rocm" and self.hidden_act == "silu" and hidden_states.shape[0] == 1 + and not self.quantize ): out = torch.empty( hidden_states.shape[0], diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 333efe337..45ddd8569 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -890,6 +890,9 @@ class FlashCausalLM(Model): slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) kv_cache = get_cache_manager().kv_cache + # Dummy value, some models (starcoder2) don't accept `None`. + input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) + # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( input_ids=input_ids, @@ -899,7 +902,7 @@ class FlashCausalLM(Model): ), kv_cache=get_cache_manager().kv_cache, block_tables=None, - input_lengths=None, + input_lengths=input_lengths, slots=slots, max_s=seqlen, lm_head_indices=None, diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 30ae95c91..e6125e294 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -397,6 +397,9 @@ class BaseFlashMistral(FlashCausalLM): slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) kv_cache = get_cache_manager().kv_cache + # Dummy value, some models (starcoder2) don't accept `None`. + input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) + # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( input_ids=input_ids, @@ -406,7 +409,7 @@ class BaseFlashMistral(FlashCausalLM): ), kv_cache=get_cache_manager().kv_cache, block_tables=None, - input_lengths=None, + input_lengths=input_lengths, slots=slots, max_s=seqlen, lm_head_indices=None,