diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 16fd4091..3179aebe 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -505,7 +505,7 @@ class FlashGPTNeoXPreTrainedModel(PreTrainedModel): config_class = GPTNeoXConfig base_model_prefix = "gpt_neox" supports_gradient_checkpointing = False - _no_split_modules = None + _no_split_modules = ["FlashNeoXLayer"] class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):