Granite support?

This commit is contained in:
Nicolas Patry 2024-05-13 08:04:14 +00:00
parent fd89d9dfae
commit fff4899e57

View File

@ -41,22 +41,29 @@ from text_generation_server.layers.layernorm import (
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
bias = config.attention_bias
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights) return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=bias,
)
else: else:
if config.model_type == "baichuan": if config.model_type == "baichuan":
return TensorParallelColumnLinear.load_qkv( return TensorParallelColumnLinear.load_qkv(
config, config,
prefix=f"{prefix}.W_pack", prefix=f"{prefix}.W_pack",
weights=weights, weights=weights,
bias=False, bias=bias,
) )
elif config.model_type == "phi3": elif config.model_type == "phi3":
return TensorParallelColumnLinear.load_qkv( return TensorParallelColumnLinear.load_qkv(
config, config,
prefix=f"{prefix}.qkv_proj", prefix=f"{prefix}.qkv_proj",
weights=weights, weights=weights,
bias=False, bias=bias,
) )
else: else:
return TensorParallelColumnLinear.load_multi( return TensorParallelColumnLinear.load_multi(
@ -64,33 +71,7 @@ def load_attention(config, prefix, weights):
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0, dim=0,
weights=weights, weights=weights,
bias=False, bias=bias,
)
def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0
assert config.num_attention_heads % weights.process_group.size() == 0
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0,
)
if config.quantize not in ["gptq", "awq"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(
get_linear(weight, bias=None, quantize=config.quantize)
) )
@ -214,12 +195,13 @@ class LlamaMLP(nn.Module):
) )
) )
# Fuse gate and up proj # Fuse gate and up proj
bias = config.mlp_bias
if config.model_type == "phi3": if config.model_type == "phi3":
self.gate_up_proj = TensorParallelColumnLinear.load_gate_up( self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
config, config,
prefix=f"{prefix}.gate_up_proj", prefix=f"{prefix}.gate_up_proj",
weights=weights, weights=weights,
bias=False, bias=bias,
) )
else: else:
self.gate_up_proj = TensorParallelColumnLinear.load_multi( self.gate_up_proj = TensorParallelColumnLinear.load_multi(
@ -227,13 +209,13 @@ class LlamaMLP(nn.Module):
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights, weights=weights,
dim=0, dim=0,
bias=False, bias=bias,
) )
self.down_proj = TensorParallelRowLinear.load( self.down_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.down_proj", prefix=f"{prefix}.down_proj",
weights=weights, weights=weights,
bias=False, bias=bias,
) )
self.intermediate_size = ( self.intermediate_size = (
config.intermediate_size // weights.process_group.size() config.intermediate_size // weights.process_group.size()
@ -385,9 +367,14 @@ class FlashLlamaForCausalLM(torch.nn.Module):
weights=weights, weights=weights,
) )
self.model = FlashLlamaModel(prefix, config, weights) self.model = FlashLlamaModel(prefix, config, weights)
if config.tie_word_embeddings:
suffix = "model.embed_tokens"
else:
suffix = "lm_head"
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head" if not prefix else f"{prefix}.lm_head", prefix=suffix if not prefix else f"{prefix}.suffix",
weights=weights, weights=weights,
) )