feat(server): have FlashGPTNeoXModel support HF accelerate

It seems to work fine and loads 4-10x faster for me depending on the storage/page cache (non-sharded 20B parameter model).

However when loaded this way inference appears to be 10-15% slower for some reason.
This commit is contained in:
Nick Hill 2023-04-17 16:29:14 -07:00
parent b927244eb5
commit 252a086e9b

View File

@ -505,7 +505,7 @@ class FlashGPTNeoXPreTrainedModel(PreTrainedModel):
config_class = GPTNeoXConfig
base_model_prefix = "gpt_neox"
supports_gradient_checkpointing = False
_no_split_modules = None
_no_split_modules = ["FlashNeoXLayer"]
class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):