diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index fb7c8cbe..32c2168f 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -87,6 +87,9 @@ try: from text_generation_server.models.pali_gemma import ( PaliGemmaBatch, ) + from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import ( + PaliGemmaForConditionalGeneration, + ) from text_generation_server.models.custom_modeling.flash_phi_modeling import ( FlashPhiForCausalLM, ) @@ -489,6 +492,7 @@ def get_model( trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, aliases={"transformer.wte.weight": ["lm_head.weight"]}, + num_kv_heads=1, ) elif sharded: raise NotImplementedError( diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index a77a7655..2bc305fe 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -464,7 +464,7 @@ class FlashSantacoderModel(nn.Module): class FlashSantacoderForCausalLM(nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights): super().__init__() config.transpose = config.architectures[0].startswith("GPT2") self.transformer = FlashSantacoderModel(config, weights) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 07e9f97f..5f558caa 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -825,6 +825,8 @@ class FlashCausalLM(Model): config_class: PreTrainedTokenizerBase = AutoConfig, default_dtype=torch.float16, aliases=None, + # Used for Santacoder override of config + num_kv_heads=None, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -886,7 +888,10 @@ class FlashCausalLM(Model): config = text_config self.num_layers = config.num_hidden_layers # Validation is done in the model itself - self.num_kv_heads = config.num_key_value_heads // self.process_group.size() + num_heads = getattr(config, "num_key_value_heads", config.n_head) + if num_kv_heads is None: + num_kv_heads = config.num_key_value_heads + self.num_kv_heads = num_kv_heads // self.process_group.size() self.head_size = config.hidden_size // config.num_attention_heads self.cuda_graphs = {}