mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Granite support?
This commit is contained in:
parent
fd89d9dfae
commit
fff4899e57
@ -41,22 +41,29 @@ from text_generation_server.layers.layernorm import (
|
|||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
|
bias = config.attention_bias
|
||||||
if config.num_attention_heads != config.num_key_value_heads:
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
return _load_gqa(config, prefix, weights)
|
return TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if config.model_type == "baichuan":
|
if config.model_type == "baichuan":
|
||||||
return TensorParallelColumnLinear.load_qkv(
|
return TensorParallelColumnLinear.load_qkv(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.W_pack",
|
prefix=f"{prefix}.W_pack",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=bias,
|
||||||
)
|
)
|
||||||
elif config.model_type == "phi3":
|
elif config.model_type == "phi3":
|
||||||
return TensorParallelColumnLinear.load_qkv(
|
return TensorParallelColumnLinear.load_qkv(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.qkv_proj",
|
prefix=f"{prefix}.qkv_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=bias,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return TensorParallelColumnLinear.load_multi(
|
return TensorParallelColumnLinear.load_multi(
|
||||||
@ -64,33 +71,7 @@ def load_attention(config, prefix, weights):
|
|||||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
dim=0,
|
dim=0,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=bias,
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_gqa(config, prefix: str, weights):
|
|
||||||
assert config.hidden_size % config.num_attention_heads == 0
|
|
||||||
assert config.num_attention_heads % weights.process_group.size() == 0
|
|
||||||
|
|
||||||
weight = weights.get_multi_weights_col(
|
|
||||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
|
||||||
quantize=config.quantize,
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.quantize not in ["gptq", "awq"]:
|
|
||||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
|
||||||
|
|
||||||
head_size = config.hidden_size // config.num_attention_heads
|
|
||||||
num_heads = config.num_attention_heads // weights.process_group.size()
|
|
||||||
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
|
||||||
assert list(weight.shape) == [
|
|
||||||
(num_heads + 2 * num_key_value_heads) * head_size,
|
|
||||||
config.hidden_size,
|
|
||||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
|
||||||
|
|
||||||
return TensorParallelColumnLinear(
|
|
||||||
get_linear(weight, bias=None, quantize=config.quantize)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -214,12 +195,13 @@ class LlamaMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Fuse gate and up proj
|
# Fuse gate and up proj
|
||||||
|
bias = config.mlp_bias
|
||||||
if config.model_type == "phi3":
|
if config.model_type == "phi3":
|
||||||
self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.gate_up_proj",
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=bias,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
@ -227,13 +209,13 @@ class LlamaMLP(nn.Module):
|
|||||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
weights=weights,
|
weights=weights,
|
||||||
dim=0,
|
dim=0,
|
||||||
bias=False,
|
bias=bias,
|
||||||
)
|
)
|
||||||
self.down_proj = TensorParallelRowLinear.load(
|
self.down_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.down_proj",
|
prefix=f"{prefix}.down_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=bias,
|
||||||
)
|
)
|
||||||
self.intermediate_size = (
|
self.intermediate_size = (
|
||||||
config.intermediate_size // weights.process_group.size()
|
config.intermediate_size // weights.process_group.size()
|
||||||
@ -385,9 +367,14 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.model = FlashLlamaModel(prefix, config, weights)
|
self.model = FlashLlamaModel(prefix, config, weights)
|
||||||
|
if config.tie_word_embeddings:
|
||||||
|
suffix = "model.embed_tokens"
|
||||||
|
else:
|
||||||
|
suffix = "lm_head"
|
||||||
|
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
|
prefix=suffix if not prefix else f"{prefix}.suffix",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user