mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
nit
This commit is contained in:
parent
5a1512c025
commit
5882768682
@ -148,23 +148,27 @@ class LlamaRMSNorm(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def _load_gqa(config, prefix: str, weights):
|
def _load_gqa(config, prefix: str, weights):
|
||||||
w = [
|
|
||||||
weights.get_sharded(f"{prefix}.q_proj.weight", dim=0),
|
|
||||||
weights.get_sharded(f"{prefix}.k_proj.weight", dim=0),
|
|
||||||
weights.get_sharded(f"{prefix}.v_proj.weight", dim=0),
|
|
||||||
]
|
|
||||||
weight = torch.cat(w, dim=0)
|
|
||||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
|
||||||
bias = None
|
|
||||||
assert config.hidden_size % config.num_attention_heads == 0
|
assert config.hidden_size % config.num_attention_heads == 0
|
||||||
head_size = config.hidden_size // config.num_attention_heads
|
|
||||||
assert config.num_attention_heads % weights.process_group.size() == 0
|
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||||
|
|
||||||
|
weight = weights.get_multi_weights_col(
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
quantize=config.quantize,
|
||||||
|
dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.quantize != "gptq":
|
||||||
|
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||||
|
|
||||||
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
num_heads = config.num_attention_heads // weights.process_group.size()
|
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||||
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||||
assert list(weight.shape) == [
|
assert list(weight.shape) == [
|
||||||
(num_heads + 2 * num_key_value_heads) * head_size,
|
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||||
|
|
||||||
|
bias = None
|
||||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user