diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index fd88d43a..4adf1381 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -201,7 +201,10 @@ def get_model( if model_type in ["RefinedWeb", "RefinedWebModel"]: if sharded: if FLASH_ATTENTION: - if config.alibi: + if config.alibi or ( + config.model_type == "RefinedWebModel" + and config.n_head_kv != config.n_head + ): raise NotImplementedError("sharded is not supported for this model") return FlashRWSharded( model_id, diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index f617ec20..93de9648 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -107,11 +107,11 @@ class FlashRWAttention(torch.nn.Module): ) self.dense = FastLinear(hidden_size, hidden_size, bias=bias) else: - self.num_heads = self.num_heads // process_group.size() - self.query_key_value = FastLinear( + self.query_key_value = TensorParallelColumnLinear( hidden_size, self.head_size * (self.num_heads + 2 * self.num_heads_kv), bias=bias, + process_group=process_group, ) self.dense = TensorParallelRowLinear( hidden_size, @@ -120,6 +120,7 @@ class FlashRWAttention(torch.nn.Module): process_group=process_group, reduce=reduce, ) + self.num_heads = self.num_heads // process_group.size() def forward( self, @@ -231,13 +232,18 @@ class FlashRWLargeAttention(torch.nn.Module): if process_group is None: self.query_key_value = FastLinear( hidden_size, - self.num_groups * - self.head_size + self.num_groups + * self.head_size * (self.num_heads + 2 * self.num_heads_kv), bias=bias, ) self.dense = FastLinear(hidden_size, hidden_size, bias=bias) else: + if process_group.size() > self.num_groups: + raise NotImplementedError( + f"Tensor Parallelism is not implemented for world_size > n groups" + ) + self.query_key_value = TensorParallelColumnLinear( hidden_size, self.num_groups @@ -269,10 +275,13 @@ class FlashRWLargeAttention(torch.nn.Module): ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size) + + # Split on group dimension query, kv = qkv.split( [self.num_heads, 2], dim=2, ) + # Merge groups and heads query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size) # Inplace rotary @@ -285,8 +294,12 @@ class FlashRWLargeAttention(torch.nn.Module): layer_past[...] = kv k, v = kv.split(1, dim=2) # Expand to query shape - k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) - v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) + k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape( + -1, self.num_groups * self.num_heads, self.head_size + ) + v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape( + -1, self.num_groups * self.num_heads, self.head_size + ) # output attn_output = torch.empty_like(query) @@ -314,8 +327,12 @@ class FlashRWLargeAttention(torch.nn.Module): layer_past[layer_past_present_indices] = kv k, v = layer_past.split(1, dim=2) # Expand to query shape - k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) - v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) + k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape( + -1, self.num_groups * self.num_heads, self.head_size + ) + v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape( + -1, self.num_groups * self.num_heads, self.head_size + ) # output attn_output = torch.empty_like(query) @@ -338,7 +355,9 @@ class FlashRWLargeAttention(torch.nn.Module): None, ) - return self.dense(attn_output.view(-1, self.num_groups * self.num_heads * self.head_size)) + return self.dense( + attn_output.view(-1, self.num_groups * self.num_heads * self.head_size) + ) class FlashMLP(nn.Module): @@ -389,7 +408,12 @@ class FlashRWLayer(nn.Module): self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps) self.self_attention = FlashRWAttention( - num_heads, num_heads_kv, hidden_size, bias, process_group=process_group, reduce=False + num_heads, + num_heads_kv, + hidden_size, + bias, + process_group=process_group, + reduce=False, ) self.post_attention_layernorm = ( FastLayerNorm(hidden_size, eps=layer_norm_eps) @@ -397,7 +421,9 @@ class FlashRWLayer(nn.Module): else None ) - self.mlp = FlashMLP(hidden_size, bias, process_group=process_group, reduce=False) + self.mlp = FlashMLP( + hidden_size, bias, process_group=process_group, reduce=False + ) self.process_group = process_group @@ -473,10 +499,17 @@ class FlashRWLargeLayer(nn.Module): self.ln_mlp = FastLayerNorm(hidden_size, eps=layer_norm_eps) self.self_attention = FlashRWLargeAttention( - num_heads, num_heads_kv, hidden_size, bias, process_group=process_group, reduce=False + num_heads, + num_heads_kv, + hidden_size, + bias, + process_group=process_group, + reduce=False, ) - self.mlp = FlashMLP(hidden_size, bias, process_group=process_group, reduce=False) + self.mlp = FlashMLP( + hidden_size, bias, process_group=process_group, reduce=False + ) self.process_group = process_group @@ -492,8 +525,8 @@ class FlashRWLargeLayer(nn.Module): layer_past_present_indices, cu_seqlens_q, ): - ln_attn, _ = self.ln_attn(hidden_states) - ln_mlp, _ = self.ln_mlp(hidden_states) + ln_attn, residual = self.ln_attn(hidden_states, residual) + ln_mlp, _ = self.ln_mlp(residual) # Self attention. attn_output = self.self_attention( @@ -516,13 +549,11 @@ class FlashRWLargeLayer(nn.Module): if self.process_group is not None: torch.distributed.all_reduce(intermediate, group=self.process_group) - return intermediate + hidden_states, None + return intermediate, residual class FlashRWPreTrainedModel(PreTrainedModel): config_class = RWConfig - supports_gradient_checkpointing = False - _no_split_modules = None class FlashRWModel(FlashRWPreTrainedModel): @@ -559,7 +590,11 @@ class FlashRWModel(FlashRWPreTrainedModel): for _ in range(config.num_hidden_layers) ] ) - self.cache_size = (2, self.h[0].self_attention.num_heads_kv, self.h[0].self_attention.head_size) + self.cache_size = ( + 2, + self.h[0].self_attention.num_heads_kv, + self.h[0].self_attention.head_size, + ) elif config.model_type == "RefinedWeb": self.h = nn.ModuleList( [ @@ -574,7 +609,11 @@ class FlashRWModel(FlashRWPreTrainedModel): for _ in range(config.num_hidden_layers) ] ) - self.cache_size = (self.h[0].self_attention.num_groups, 2, self.h[0].self_attention.head_size) + self.cache_size = ( + self.h[0].self_attention.num_groups, + 2, + self.h[0].self_attention.head_size, + ) else: raise NotImplementedError( f"model_type {config.model_type} is not supported." @@ -582,8 +621,6 @@ class FlashRWModel(FlashRWPreTrainedModel): self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) - self.gradient_checkpointing = False - self.head_size = self.h[0].self_attention.head_size def post_load_weights(self, quantize: Optional[str] = None): @@ -629,7 +666,7 @@ class FlashRWModel(FlashRWPreTrainedModel): len(hidden_states) if pre_allocate_past_size is None else pre_allocate_past_size, - *self.cache_size + *self.cache_size, ) ) layer_past_present_indices = None diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 94efc833..8219ac86 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -113,7 +113,6 @@ class FlashRW(FlashCausalLM): model.post_load_weights(quantize) - class FlashRWSharded(FlashRW): def __init__( self,