From 252a086e9b60213926725b8ef65f95d74aa70da5 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 17 Apr 2023 16:29:14 -0700 Subject: [PATCH] feat(server): have FlashGPTNeoXModel support HF accelerate It seems to work fine and loads 4-10x faster for me depending on the storage/page cache (non-sharded 20B parameter model). However when loaded this way inference appears to be 10-15% slower for some reason. --- .../models/custom_modeling/flash_neox_modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 16fd4091..3179aebe 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -505,7 +505,7 @@ class FlashGPTNeoXPreTrainedModel(PreTrainedModel): config_class = GPTNeoXConfig base_model_prefix = "gpt_neox" supports_gradient_checkpointing = False - _no_split_modules = None + _no_split_modules = ["FlashNeoXLayer"] class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):