mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Fixing santacoder (num_kv_heads hardcoded).
This commit is contained in:
parent
43ef5268fd
commit
dbf9292afc
@ -87,6 +87,9 @@ try:
|
|||||||
from text_generation_server.models.pali_gemma import (
|
from text_generation_server.models.pali_gemma import (
|
||||||
PaliGemmaBatch,
|
PaliGemmaBatch,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
|
||||||
|
PaliGemmaForConditionalGeneration,
|
||||||
|
)
|
||||||
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
||||||
FlashPhiForCausalLM,
|
FlashPhiForCausalLM,
|
||||||
)
|
)
|
||||||
@ -489,6 +492,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
aliases={"transformer.wte.weight": ["lm_head.weight"]},
|
aliases={"transformer.wte.weight": ["lm_head.weight"]},
|
||||||
|
num_kv_heads=1,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
@ -464,7 +464,7 @@ class FlashSantacoderModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashSantacoderForCausalLM(nn.Module):
|
class FlashSantacoderForCausalLM(nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config.transpose = config.architectures[0].startswith("GPT2")
|
config.transpose = config.architectures[0].startswith("GPT2")
|
||||||
self.transformer = FlashSantacoderModel(config, weights)
|
self.transformer = FlashSantacoderModel(config, weights)
|
||||||
|
@ -825,6 +825,8 @@ class FlashCausalLM(Model):
|
|||||||
config_class: PreTrainedTokenizerBase = AutoConfig,
|
config_class: PreTrainedTokenizerBase = AutoConfig,
|
||||||
default_dtype=torch.float16,
|
default_dtype=torch.float16,
|
||||||
aliases=None,
|
aliases=None,
|
||||||
|
# Used for Santacoder override of config
|
||||||
|
num_kv_heads=None,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -886,7 +888,10 @@ class FlashCausalLM(Model):
|
|||||||
config = text_config
|
config = text_config
|
||||||
self.num_layers = config.num_hidden_layers
|
self.num_layers = config.num_hidden_layers
|
||||||
# Validation is done in the model itself
|
# Validation is done in the model itself
|
||||||
self.num_kv_heads = config.num_key_value_heads // self.process_group.size()
|
num_heads = getattr(config, "num_key_value_heads", config.n_head)
|
||||||
|
if num_kv_heads is None:
|
||||||
|
num_kv_heads = config.num_key_value_heads
|
||||||
|
self.num_kv_heads = num_kv_heads // self.process_group.size()
|
||||||
self.head_size = config.hidden_size // config.num_attention_heads
|
self.head_size = config.hidden_size // config.num_attention_heads
|
||||||
|
|
||||||
self.cuda_graphs = {}
|
self.cuda_graphs = {}
|
||||||
|
Loading…
Reference in New Issue
Block a user