add gpt neox

This commit is contained in:
Cyril Vallez 2025-02-18 11:41:51 +01:00
parent 188b150b57
commit aeb6429bd1
No known key found for this signature in database

View File

@ -858,6 +858,15 @@ def get_model(
lora_adapter_ids=lora_adapter_ids,
config_class=GPTNeoXConfig,
)
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
return CausalLM(
model_id=model_id,