From 540e710c3f837745b48bf25b2854130c5a526554 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 8 Jul 2024 13:22:38 +0200 Subject: [PATCH] Falcon/DBRX: get correct number of key-value heads (#2205) --- server/text_generation_server/models/__init__.py | 4 ++++ .../models/custom_modeling/flash_dbrx_modeling.py | 12 ++++++++++++ .../models/custom_modeling/flash_rw_modeling.py | 1 + .../text_generation_server/models/flash_causal_lm.py | 11 +++++------ 4 files changed, 22 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 58131a3a..ba980195 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -797,6 +797,10 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + aliases={ + "lm_head.weight": ["transformer.word_embeddings.weight"], + "transformer.word_embeddings.weight": ["lm_head.weight"], + }, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, config_class=RWConfig, diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 41aa5859..44411687 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -105,6 +105,12 @@ class DbrxFFNConfig(PretrainedConfig): class DbrxConfig(PretrainedConfig): + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "n_heads", + "num_hidden_layers": "n_layers", + } + def __init__( self, d_model: int = 2048, @@ -157,6 +163,12 @@ class DbrxConfig(PretrainedConfig): **kwargs, ) + @property + def num_key_value_heads(self): + # We can't use the attribute map, since this the number of KV + # heads is not top-level. + return self.attn_config.kv_n_heads + def promote_scalar(x: torch.Tensor) -> torch.Tensor: return x.view(1) if len(x.size()) == 0 else x diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index d12ed567..4813e2df 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -42,6 +42,7 @@ class RWConfig(PretrainedConfig): attribute_map = { "num_hidden_layers": "n_layer", "num_attention_heads": "n_head", + "num_key_value_heads": "n_head_kv", } def __init__( diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5c086a73..bf1fda4a 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -905,13 +905,12 @@ class FlashCausalLM(Model): self.num_layers = config.num_hidden_layers # Validation is done in the model itself if num_kv_heads is None: - # Order is important here. - for attr in ["num_key_value_heads", "num_attention_heads", "n_head"]: - num_kv_heads = getattr(config, attr, None) - if num_kv_heads is not None: - break + num_kv_heads = getattr(config, "num_key_value_heads", None) + # GPT-2 workaround if num_kv_heads is None: - raise ValueError("Cannot get the number of key/value heads") + num_kv_heads = getattr(config, "n_head", None) + if num_kv_heads is None: + raise ValueError("Cannot get the number of key/value heads") self.num_kv_heads = ( num_kv_heads // self.process_group.size() if num_kv_heads > 1