Fixing santacoder (num_kv_heads hardcoded).

This commit is contained in:
Nicolas Patry 2024-07-02 16:35:08 +00:00
parent 43ef5268fd
commit dbf9292afc
No known key found for this signature in database
GPG Key ID: E939E8CC91A1C674
3 changed files with 11 additions and 2 deletions

View File

@ -87,6 +87,9 @@ try:
from text_generation_server.models.pali_gemma import ( from text_generation_server.models.pali_gemma import (
PaliGemmaBatch, PaliGemmaBatch,
) )
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
PaliGemmaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.flash_phi_modeling import ( from text_generation_server.models.custom_modeling.flash_phi_modeling import (
FlashPhiForCausalLM, FlashPhiForCausalLM,
) )
@ -489,6 +492,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
aliases={"transformer.wte.weight": ["lm_head.weight"]}, aliases={"transformer.wte.weight": ["lm_head.weight"]},
num_kv_heads=1,
) )
elif sharded: elif sharded:
raise NotImplementedError( raise NotImplementedError(

View File

@ -464,7 +464,7 @@ class FlashSantacoderModel(nn.Module):
class FlashSantacoderForCausalLM(nn.Module): class FlashSantacoderForCausalLM(nn.Module):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
config.transpose = config.architectures[0].startswith("GPT2") config.transpose = config.architectures[0].startswith("GPT2")
self.transformer = FlashSantacoderModel(config, weights) self.transformer = FlashSantacoderModel(config, weights)

View File

@ -825,6 +825,8 @@ class FlashCausalLM(Model):
config_class: PreTrainedTokenizerBase = AutoConfig, config_class: PreTrainedTokenizerBase = AutoConfig,
default_dtype=torch.float16, default_dtype=torch.float16,
aliases=None, aliases=None,
# Used for Santacoder override of config
num_kv_heads=None,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -886,7 +888,10 @@ class FlashCausalLM(Model):
config = text_config config = text_config
self.num_layers = config.num_hidden_layers self.num_layers = config.num_hidden_layers
# Validation is done in the model itself # Validation is done in the model itself
self.num_kv_heads = config.num_key_value_heads // self.process_group.size() num_heads = getattr(config, "num_key_value_heads", config.n_head)
if num_kv_heads is None:
num_kv_heads = config.num_key_value_heads
self.num_kv_heads = num_kv_heads // self.process_group.size()
self.head_size = config.hidden_size // config.num_attention_heads self.head_size = config.hidden_size // config.num_attention_heads
self.cuda_graphs = {} self.cuda_graphs = {}