diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index 30ea4c8f..17f47e5e 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -243,10 +243,8 @@ class TransformersFlashCausalLM(FlashCausalLM): adapter_data=None, # not supported, but passed to match original signature ): hidden_states = self.model.model.forward( - input_ids=input_ids.unsqueeze(0), # expand dim to easily fit transformers - position_ids=position_ids.unsqueeze( - 0 - ), # expand dim to easily fit transformers + input_ids=input_ids.unsqueeze(0), # expand dim to fit Transformers + position_ids=position_ids.unsqueeze(0), # expand dim to fit Transformers past_key_values=None, # we use self.kv_cache instead of transformers cache object use_cache=False, # we use self.kv_cache instead of transformers cache object return_dict=True,