diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 47ec7072..b3982171 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -18,9 +18,10 @@ from text_generation_server.utils.layers import ( TensorParallelHead, FastLayerNorm, PositionRotaryEmbedding, - get_linear + get_linear, ) + def load_row(config, prefix: str, weights, bias: bool): weight = weights.get_sharded(f"{prefix}.weight", dim=1) if bias and weights.process_group.rank() == 0: @@ -100,7 +101,9 @@ class RWConfig(PretrainedConfig): class FlashRWAttention(torch.nn.Module): def __init__( self, - config, prefix, weights, + config, + prefix, + weights, reduce=True, ): super().__init__() @@ -109,12 +112,21 @@ class FlashRWAttention(torch.nn.Module): self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads - self.rotary_emb = PositionRotaryEmbedding.static(dim=self.head_size, base=10000.0, device=weights.device) + self.rotary_emb = PositionRotaryEmbedding.static( + dim=self.head_size, base=10000.0, device=weights.device + ) self.softmax_scale = self.head_size ** (-0.5) - self.num_heads = self.num_heads //weights.process_group.size() + self.num_heads = self.num_heads // weights.process_group.size() - self.query_key_value = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.query_key_value", weights=weights, bias=config.bias) - self.dense = load_row(config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias) + self.query_key_value = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.query_key_value", + weights=weights, + bias=config.bias, + ) + self.dense = load_row( + config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias + ) def forward( self, @@ -204,26 +216,29 @@ class FlashRWAttention(torch.nn.Module): class FlashRWLargeAttention(torch.nn.Module): def __init__( self, - config, prefix, weights, - # num_heads, - # num_heads_kv, - # hidden_size, - # bias, - # process_group=None, - reduce=True, + config, + prefix, + weights, ): super().__init__() + hidden_size = config.hidden_size + num_heads = config.n_head + num_heads_kv = config.n_head_kv + self.hidden_size = hidden_size self.head_size = hidden_size // num_heads - self.rotary_emb = PositionRotaryEmbedding.static(self.head_size, base=10000.0, device=weights.device) + self.rotary_emb = PositionRotaryEmbedding.static( + self.head_size, base=10000.0, device=weights.device + ) self.softmax_scale = self.head_size ** (-0.5) self.num_groups = num_heads // (num_heads_kv * 2) self.num_heads = num_heads // self.num_groups self.num_heads_kv = num_heads_kv // self.num_groups process_group = weights.process_group + if process_group.size() > self.num_groups: raise NotImplementedError( f"Tensor Parallelism is not implemented for world_size > n groups" @@ -232,9 +247,17 @@ class FlashRWLargeAttention(torch.nn.Module): raise NotImplementedError( f"Tensor Parallelism is not implemented for {self.num_groups} not divisible by {process_group.size()}" ) + self.num_groups = self.num_groups // process_group.size() - self.query_key_value = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.query_key_value", weights=weights, bias=config.bias) - self.dense = load_row(config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias) + self.query_key_value = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.query_key_value", + weights=weights, + bias=config.bias, + ) + self.dense = load_row( + config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias + ) def forward( self, @@ -331,12 +354,16 @@ class FlashRWLargeAttention(torch.nn.Module): class FlashMLP(nn.Module): - def __init__(self, config, prefix, weights, reduce=True): + def __init__(self, config, prefix, weights): super().__init__() self.act = torch.nn.functional.gelu - self.dense_h_to_4h = TensorParallelColumnLinear.load(config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=config.bias) - self.dense_4h_to_h = load_row(config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=config.bias) + self.dense_h_to_4h = TensorParallelColumnLinear.load( + config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=config.bias + ) + self.dense_4h_to_h = load_row( + config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=config.bias + ) def forward(self, hidden_states): hidden_states = self.dense_h_to_4h(hidden_states) @@ -351,20 +378,9 @@ class FlashRWLayer(nn.Module): layer_id, config, weights, - # num_heads, - # num_heads_kv, - # hidden_size, - # bias, - # layer_norm_eps, - # parallel_attn, - # process_group=None, ): super().__init__() - n_head = config.n_head - n_head_kv = config.n_head_kv - hidden_size = config.hidden_size - bias = config.bias parallel_attn = config.parallel_attn self.parallel_attn = parallel_attn @@ -376,31 +392,26 @@ class FlashRWLayer(nn.Module): eps=config.layer_norm_epsilon, ) self.self_attention = FlashRWAttention( - # num_heads, - # num_heads_kv, - # hidden_size, - # bias, - # process_group=process_group, - config, + config, prefix=f"{prefix}.self_attention", weights=weights, reduce=False, ) self.post_attention_layernorm = ( - FastLayerNorm.load( - prefix=f"{prefix}.post_attention_layernorm", - weights=weights, - eps=config.layer_norm_epsilon, - ) if not parallel_attn + FastLayerNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.layer_norm_epsilon, + ) + if not parallel_attn else None ) self.mlp = FlashMLP( - # hidden_size, bias, process_group=process_group, reduce=False config, prefix=f"{prefix}.mlp", weights=weights, - reduce=False + reduce=False, ) self.process_group = weights.process_group @@ -461,11 +472,9 @@ class FlashRWLayer(nn.Module): class FlashRWLargeLayer(nn.Module): - def __init__( - self, - config, prefix, weights - ): + def __init__(self, layer_id, config, weights): super().__init__() + prefix = f"transformer.h.{layer_id}" self.ln_attn = FastLayerNorm.load( prefix=f"{prefix}.ln_attn", weights=weights, @@ -478,13 +487,13 @@ class FlashRWLargeLayer(nn.Module): ) self.self_attention = FlashRWLargeAttention( - config, prefix=f"{prefix}.self_attention", weights=weights, - reduce=False, + config, + prefix=f"{prefix}.self_attention", + weights=weights, ) + assert config.parallel_attn, "This version doesn't support non parallel_attn" - self.mlp = FlashMLP( - config, prefix=f"{prefix}.mlp", weights=weights, reduce=False - ) + self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights) self.process_group = weights.process_group @@ -541,7 +550,9 @@ class FlashRWModel(FlashRWPreTrainedModel): self.h = nn.ModuleList( [ FlashRWLayer( - layer_id, config, weights + layer_id, + config, + weights # config.n_head, # config.n_head_kv, # config.hidden_size, @@ -561,15 +572,7 @@ class FlashRWModel(FlashRWPreTrainedModel): elif config.model_type == "RefinedWeb": self.h = nn.ModuleList( [ - FlashRWLargeLayer( - layer_id, config, weights - # config.n_head, - # config.n_head_kv, - # config.hidden_size, - # config.bias, - # config.layer_norm_epsilon, - # process_group, - ) + FlashRWLargeLayer(layer_id, config, weights) for layer_id in range(config.num_hidden_layers) ] ) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 9fd31c76..5945f210 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -310,11 +310,12 @@ try: @staticmethod def static(dim, base, device): - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, - dtype=torch.float32) / dim)) + inv_freq = 1.0 / ( + base + ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) + ) return PositionRotaryEmbedding(inv_freq) - @staticmethod def load(prefix, weights): # XXX: Always load this in float32 ! @@ -324,7 +325,6 @@ try: weights.dtype = dtype return PositionRotaryEmbedding(inv_freq) - def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance)