diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index a7969494..982d5326 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -41,22 +41,29 @@ from text_generation_server.layers.layernorm import ( def load_attention(config, prefix, weights): + bias = config.attention_bias if config.num_attention_heads != config.num_key_value_heads: - return _load_gqa(config, prefix, weights) + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=bias, + ) else: if config.model_type == "baichuan": return TensorParallelColumnLinear.load_qkv( config, prefix=f"{prefix}.W_pack", weights=weights, - bias=False, + bias=bias, ) elif config.model_type == "phi3": return TensorParallelColumnLinear.load_qkv( config, prefix=f"{prefix}.qkv_proj", weights=weights, - bias=False, + bias=bias, ) else: return TensorParallelColumnLinear.load_multi( @@ -64,36 +71,10 @@ def load_attention(config, prefix, weights): prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], dim=0, weights=weights, - bias=False, + bias=bias, ) -def _load_gqa(config, prefix: str, weights): - assert config.hidden_size % config.num_attention_heads == 0 - assert config.num_attention_heads % weights.process_group.size() == 0 - - weight = weights.get_multi_weights_col( - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, - dim=0, - ) - - if config.quantize not in ["gptq", "awq"]: - weight = weight.to(dtype=weights.dtype).to(device=weights.device) - - head_size = config.hidden_size // config.num_attention_heads - num_heads = config.num_attention_heads // weights.process_group.size() - num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - assert list(weight.shape) == [ - (num_heads + 2 * num_key_value_heads) * head_size, - config.hidden_size, - ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" - - return TensorParallelColumnLinear( - get_linear(weight, bias=None, quantize=config.quantize) - ) - - class FlashLlamaAttention(torch.nn.Module): def __init__( self, @@ -214,12 +195,13 @@ class LlamaMLP(nn.Module): ) ) # Fuse gate and up proj + bias = config.mlp_bias if config.model_type == "phi3": self.gate_up_proj = TensorParallelColumnLinear.load_gate_up( config, prefix=f"{prefix}.gate_up_proj", weights=weights, - bias=False, + bias=bias, ) else: self.gate_up_proj = TensorParallelColumnLinear.load_multi( @@ -227,13 +209,13 @@ class LlamaMLP(nn.Module): prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], weights=weights, dim=0, - bias=False, + bias=bias, ) self.down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj", weights=weights, - bias=False, + bias=bias, ) self.intermediate_size = ( config.intermediate_size // weights.process_group.size() @@ -385,9 +367,14 @@ class FlashLlamaForCausalLM(torch.nn.Module): weights=weights, ) self.model = FlashLlamaModel(prefix, config, weights) + if config.tie_word_embeddings: + suffix = "model.embed_tokens" + else: + suffix = "lm_head" + self.lm_head = SpeculativeHead.load( config, - prefix="lm_head" if not prefix else f"{prefix}.lm_head", + prefix=suffix if not prefix else f"{prefix}.suffix", weights=weights, )