mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Fixing santacoder (num_kv_heads hardcoded).
This commit is contained in:
parent
43ef5268fd
commit
dbf9292afc
@ -87,6 +87,9 @@ try:
|
||||
from text_generation_server.models.pali_gemma import (
|
||||
PaliGemmaBatch,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
|
||||
PaliGemmaForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
||||
FlashPhiForCausalLM,
|
||||
)
|
||||
@ -489,6 +492,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
aliases={"transformer.wte.weight": ["lm_head.weight"]},
|
||||
num_kv_heads=1,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(
|
||||
|
@ -464,7 +464,7 @@ class FlashSantacoderModel(nn.Module):
|
||||
|
||||
|
||||
class FlashSantacoderForCausalLM(nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
config.transpose = config.architectures[0].startswith("GPT2")
|
||||
self.transformer = FlashSantacoderModel(config, weights)
|
||||
|
@ -825,6 +825,8 @@ class FlashCausalLM(Model):
|
||||
config_class: PreTrainedTokenizerBase = AutoConfig,
|
||||
default_dtype=torch.float16,
|
||||
aliases=None,
|
||||
# Used for Santacoder override of config
|
||||
num_kv_heads=None,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
@ -886,7 +888,10 @@ class FlashCausalLM(Model):
|
||||
config = text_config
|
||||
self.num_layers = config.num_hidden_layers
|
||||
# Validation is done in the model itself
|
||||
self.num_kv_heads = config.num_key_value_heads // self.process_group.size()
|
||||
num_heads = getattr(config, "num_key_value_heads", config.n_head)
|
||||
if num_kv_heads is None:
|
||||
num_kv_heads = config.num_key_value_heads
|
||||
self.num_kv_heads = num_kv_heads // self.process_group.size()
|
||||
self.head_size = config.hidden_size // config.num_attention_heads
|
||||
|
||||
self.cuda_graphs = {}
|
||||
|
Loading…
Reference in New Issue
Block a user