Fixing santacoder (num_kv_heads hardcoded).

This commit is contained in:
Nicolas Patry 2024-07-02 16:35:08 +00:00
parent 43ef5268fd
commit dbf9292afc
No known key found for this signature in database
GPG Key ID: E939E8CC91A1C674
3 changed files with 11 additions and 2 deletions

View File

@ -87,6 +87,9 @@ try:
from text_generation_server.models.pali_gemma import (
PaliGemmaBatch,
)
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
PaliGemmaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
FlashPhiForCausalLM,
)
@ -489,6 +492,7 @@ def get_model(
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
aliases={"transformer.wte.weight": ["lm_head.weight"]},
num_kv_heads=1,
)
elif sharded:
raise NotImplementedError(

View File

@ -464,7 +464,7 @@ class FlashSantacoderModel(nn.Module):
class FlashSantacoderForCausalLM(nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix, config, weights):
super().__init__()
config.transpose = config.architectures[0].startswith("GPT2")
self.transformer = FlashSantacoderModel(config, weights)

View File

@ -825,6 +825,8 @@ class FlashCausalLM(Model):
config_class: PreTrainedTokenizerBase = AutoConfig,
default_dtype=torch.float16,
aliases=None,
# Used for Santacoder override of config
num_kv_heads=None,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
@ -886,7 +888,10 @@ class FlashCausalLM(Model):
config = text_config
self.num_layers = config.num_hidden_layers
# Validation is done in the model itself
self.num_kv_heads = config.num_key_value_heads // self.process_group.size()
num_heads = getattr(config, "num_key_value_heads", config.n_head)
if num_kv_heads is None:
num_kv_heads = config.num_key_value_heads
self.num_kv_heads = num_kv_heads // self.process_group.size()
self.head_size = config.hidden_size // config.num_attention_heads
self.cuda_graphs = {}