diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 06f92016..828d7d14 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -91,11 +91,28 @@ def load_row(config, transpose: bool, prefix: str, weights, bias: bool): bias = weights.get_tensor(f"{prefix}.bias") else: bias = None + return TensorParallelRowLinear( get_linear(weight, bias, config.quantize), process_group=weights.process_group ) +def load_col(config, transpose: bool, prefix: str, weights, bias: bool): + if transpose: + weight = weights.get_sharded(f"{prefix}.weight", dim=1).T + else: + weight = weights.get_multi_weights_col( + [prefix], quantize=config.quantize, dim=0 + ) + + if bias: + bias = weights.get_sharded(f"{prefix}.bias", dim=0) + else: + bias = None + + return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + + class FlashGPT2Attention(torch.nn.Module): def __init__( self, @@ -106,8 +123,8 @@ class FlashGPT2Attention(torch.nn.Module): super().__init__() self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size - self.head_size = self.hidden_size // self.num_heads + self.head_size = self.hidden_size // self.num_heads self.softmax_scale = self.head_size**-0.5 if self.num_heads % weights.process_group.size() != 0: @@ -116,7 +133,6 @@ class FlashGPT2Attention(torch.nn.Module): f"and `num_shards`: {weights.process_group.size()}" ) self.num_heads = self.num_heads // weights.process_group.size() - self.num_key_value_heads = self.num_heads // weights.process_group.size() self.query_key_value = _load_qkv( config, @@ -133,10 +149,10 @@ class FlashGPT2Attention(torch.nn.Module): weights=weights, bias=True, ) - self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( - 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device - ).repeat_interleave(self.num_groups) + 0, self.num_heads, dtype=torch.int32, device=weights.device + ) def forward( self, @@ -204,7 +220,7 @@ class GPT2MLP(nn.Module): ) ) - self.c_fc = load_row( + self.c_fc = load_col( config, prefix=f"{prefix}.c_fc", weights=weights, transpose=True, bias=True ) self.c_proj = load_row( @@ -309,7 +325,6 @@ class FlashGPT2Model(torch.nn.Module): self.head_size = self.layers[0].self_attn.head_size self.num_heads = self.layers[0].self_attn.num_heads - self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads def forward( self, diff --git a/server/text_generation_server/models/flash_gpt2.py b/server/text_generation_server/models/flash_gpt2.py index ebcbf14e..42019546 100644 --- a/server/text_generation_server/models/flash_gpt2.py +++ b/server/text_generation_server/models/flash_gpt2.py @@ -87,7 +87,7 @@ class FlashGPT2(FlashCausalLM): model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, + num_kv_heads=model.model.num_heads, head_size=model.model.head_size, dtype=dtype, device=device,