"""A simple, flexible implementation of a GPT model. Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py """ import math import os import warnings from typing import List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, ) from einops import rearrange from packaging import version from text_generation_server.layers import ( TensorParallelEmbedding, TensorParallelColumnLinear, TensorParallelRowLinear, SpeculativeHead, get_linear, ) EPS = 1e-5 def load_col(config, prefix, weights, bias): assert config.quantize != "gptq", NotImplementedError slice_ = weights._get_slice(f"{prefix}.weight") rank = weights.process_group.rank() size = weights.process_group.size() h3, h = slice_.get_shape() block_size = h // size q_part = slice_[rank * block_size : (rank + 1) * block_size] k_part = slice_[h + rank * block_size : h + (rank + 1) * block_size] v_part = slice_[2 * h + rank * block_size : 2 * h + (rank + 1) * block_size] weight = torch.cat([q_part, k_part, v_part], dim=0) if weight.dtype != torch.int32: weight = weight.to(dtype=weights.dtype) weight = weight.to(device=weights.device) if bias: bias_slice_ = weights._get_slice(f"{prefix}.bias") bias_rank = weights.process_group.rank() bias_size = weights.process_group.size() bias_h = bias_slice_.get_shape() bias_h = bias_h[0] bias_block_size = bias_h // bias_size bias_q_part = bias_slice_[ bias_rank * bias_block_size : (bias_rank + 1) * bias_block_size ] bias_k_part = bias_slice_[ bias_h + bias_rank * bias_block_size : bias_h + (bias_rank + 1) * bias_block_size ] bias_v_part = bias_slice_[ 2 * bias_h + bias_rank * bias_block_size : 2 * bias_h + (bias_rank + 1) * bias_block_size ] bias = torch.cat([bias_q_part, bias_k_part, bias_v_part], dim=0) if bias.dtype != torch.int32: bias = bias.to(dtype=weights.dtype) bias = bias.to(device=weights.device) else: bias = None linear = get_linear(weight, bias, config.quantize) return TensorParallelColumnLinear(linear) def _reset_is_causal( num_query_tokens: int, num_key_tokens: int, original_is_causal: bool ): if original_is_causal and num_query_tokens != num_key_tokens: if num_query_tokens != 1: raise NotImplementedError( "MPT does not support query and key with different number of tokens, unless number of query tokens is 1." ) else: return False return original_is_causal def scaled_multihead_dot_product_attention( query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False, ): q = rearrange(query, "b s (h d) -> b h s d", h=n_heads) kv_n_heads = 1 if multiquery else n_heads k = rearrange(key, "b s (h d) -> b h d s", h=kv_n_heads) v = rearrange(value, "b s (h d) -> b h s d", h=kv_n_heads) if past_key_value is not None: if len(past_key_value) != 0: k = torch.cat([past_key_value[0], k], dim=3) v = torch.cat([past_key_value[1], v], dim=2) past_key_value = (k, v) (b, _, s_q, d) = q.shape s_k = k.size(-1) attn_weight = q.matmul(k) * softmax_scale if attn_bias is not None: _s_q = max(0, attn_bias.size(2) - s_q) _s_k = max(0, attn_bias.size(3) - s_k) attn_bias = attn_bias[:, :, _s_q:, _s_k:] if ( attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q) ): raise RuntimeError( f"attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}." ) attn_weight = attn_weight + attn_bias min_val = torch.finfo(q.dtype).min if key_padding_mask is not None: if attn_bias is not None: warnings.warn( "Propogating key_padding_mask to the attention module " + "and applying it within the attention module can cause " + "unneccessary computation/memory usage. Consider integrating " + "into attn_bias once and passing that to each attention " + "module instead." ) attn_weight = attn_weight.masked_fill( ~key_padding_mask.view((b, 1, 1, s_k)), min_val ) if is_causal and (not q.size(2) == 1): s = max(s_q, s_k) causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16) causal_mask = causal_mask.tril() causal_mask = causal_mask.to(torch.bool) causal_mask = ~causal_mask causal_mask = causal_mask[-s_q:, -s_k:] attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val) attn_weight = torch.softmax(attn_weight, dim=-1) if dropout_p: attn_weight = torch.nn.functional.dropout( attn_weight, p=dropout_p, training=training, inplace=True ) out = attn_weight.to(v.dtype).matmul(v) out = rearrange(out, "b h s d -> b s (h d)") if needs_weights: return (out, attn_weight, past_key_value) return (out, None, past_key_value) def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]): for tensor in tensors: if tensor.dtype not in valid_dtypes: raise TypeError( f"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}." ) if not tensor.is_cuda: raise TypeError( f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r})." ) def flash_attn_fn( query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False, ): try: from flash_attn import bert_padding, flash_attn_interface except: raise RuntimeError("Please install flash-attn==1.0.3.post0") check_valid_inputs(query, key, value) if past_key_value is not None: if len(past_key_value) != 0: key = torch.cat([past_key_value[0], key], dim=1) value = torch.cat([past_key_value[1], value], dim=1) past_key_value = (key, value) if attn_bias is not None: _s_q = max(0, attn_bias.size(2) - query.size(1)) _s_k = max(0, attn_bias.size(3) - key.size(1)) attn_bias = attn_bias[:, :, _s_q:, _s_k:] if attn_bias is not None: raise NotImplementedError(f"attn_bias not implemented for flash attn.") (batch_size, seqlen) = query.shape[:2] if key_padding_mask is None: key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) query_padding_mask = key_padding_mask[:, -query.size(1) :] (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input( query, query_padding_mask ) query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=n_heads) (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input( key, key_padding_mask ) key_unpad = rearrange( key_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads ) (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask) value_unpad = rearrange( value_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads ) if multiquery: key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1)) value_unpad = value_unpad.expand( value_unpad.size(0), n_heads, value_unpad.size(-1) ) dropout_p = dropout_p if training else 0.0 reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) output_unpad = flash_attn_interface.flash_attn_unpadded_func( query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights, ) output = bert_padding.pad_input( rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices_q, batch_size, seqlen ) return (output, None, past_key_value) def triton_flash_attn_fn( query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False, ): try: from .flash_attn_triton import flash_attn_func except: _installed = False if version.parse(torch.__version__) < version.parse("2.0.0"): _installed = True try: from flash_attn.flash_attn_triton import flash_attn_func except: _installed = False if not _installed: raise RuntimeError( "Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed." ) check_valid_inputs(query, key, value) if past_key_value is not None: if len(past_key_value) != 0: key = torch.cat([past_key_value[0], key], dim=1) value = torch.cat([past_key_value[1], value], dim=1) past_key_value = (key, value) if attn_bias is not None: _s_q = max(0, attn_bias.size(2) - query.size(1)) _s_k = max(0, attn_bias.size(3) - key.size(1)) attn_bias = attn_bias[:, :, _s_q:, _s_k:] if dropout_p: raise NotImplementedError(f"Dropout not implemented for attn_impl: triton.") if needs_weights: raise NotImplementedError(f"attn_impl: triton cannot return attn weights.") if key_padding_mask is not None: warnings.warn( "Propagating key_padding_mask to the attention module " + "and applying it within the attention module can cause " + "unnecessary computation/memory usage. Consider integrating " + "into attn_bias once and passing that to each attention " + "module instead." ) (b_size, s_k) = key_padding_mask.shape[:2] if attn_bias is None: attn_bias = query.new_zeros(b_size, 1, 1, s_k) attn_bias = attn_bias.masked_fill( ~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min ) query = rearrange(query, "b s (h d) -> b s h d", h=n_heads) key = rearrange(key, "b s (h d) -> b s h d", h=1 if multiquery else n_heads) value = rearrange(value, "b s (h d) -> b s h d", h=1 if multiquery else n_heads) if multiquery: key = key.expand(*key.shape[:2], n_heads, key.size(-1)) value = value.expand(*value.shape[:2], n_heads, value.size(-1)) reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) attn_output = flash_attn_func( query, key, value, attn_bias, reset_is_causal, softmax_scale ) output = attn_output.view(*attn_output.shape[:2], -1) return (output, None, past_key_value) class MultiheadAttention(nn.Module): """Multi-head self attention. Using torch or triton attention implementation enables user to also use additive bias. """ def __init__( self, config, prefix, weights, ): super().__init__() attn_impl = config.attn_config["attn_impl"] self.attn_impl = config.attn_config["attn_impl"] self.clip_qkv = config.attn_config["clip_qkv"] self.qk_ln = config.attn_config["qk_ln"] self.d_model = config.d_model d_model = config.d_model self.n_heads = config.n_heads self.softmax_scale = config.attn_config["softmax_scale"] if self.softmax_scale is None: self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) self.attn_dropout_p = config.attn_config["attn_pdrop"] if self.n_heads % weights.process_group.size() != 0: raise ValueError( f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} " f"and `num_shards`: {weights.process_group.size()}" ) self.n_heads = self.n_heads // weights.process_group.size() self.Wqkv = load_col( config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias ) if self.qk_ln: bias = not config.no_bias hidden_size = config.d_model head_dim = hidden_size // self.n_heads self.q_ln = LPLayerNorm( d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights ) self.k_ln = LPLayerNorm( self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights ) if self.attn_impl == "flash": self.attn_fn = flash_attn_fn elif self.attn_impl == "triton": self.attn_fn = triton_flash_attn_fn elif self.attn_impl == "torch": self.attn_fn = scaled_multihead_dot_product_attention else: raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") self.out_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.out_proj", weights=weights, bias=not config.no_bias, ) def forward( self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False, ): qkv = self.Wqkv(x) if self.clip_qkv: qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) (query, key, value) = qkv.chunk(3, dim=2) key_padding_mask = attention_mask if self.qk_ln: dtype = query.dtype query = self.q_ln(query).to(dtype) key = self.k_ln(key).to(dtype) (context, attn_weights, past_key_value) = self.attn_fn( query, key, value, self.n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, ) out = self.out_proj(context) return (out, attn_weights, past_key_value) class MultiQueryAttention(nn.Module): """Multi-Query self attention. Using torch or triton attention implementation enables user to also use additive bias. """ def __init__(self, config, prefix, weights): super().__init__() attn_impl = config.attn_config["attn_impl"] self.attn_impl = config.attn_config["attn_impl"] self.clip_qkv = config.attn_config["clip_qkv"] self.qk_ln = config.attn_config["qk_ln"] self.d_model = config.d_model d_model = config.d_model self.n_heads = config.n_heads self.softmax_scale = config.attn_config["softmax_scale"] if self.softmax_scale is None: self.softmax_scale = 1 / math.sqrt(self.head_dim) self.attn_dropout_p = config.attn_config["attn_pdrop"] # self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device) self.Wqkv = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias ) fuse_splits = (d_model, d_model + self.head_dim) if self.qk_ln: raise NotImplementedError("qk_ln not supported") if self.attn_impl == "flash": self.attn_fn = flash_attn_fn elif self.attn_impl == "triton": self.attn_fn = triton_flash_attn_fn if verbose: warnings.warn( "While `attn_impl: triton` can be faster than `attn_impl: flash` " + "it uses more memory. When training larger models this can trigger " + "alloc retries which hurts performance. If encountered, we recommend " + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`." ) elif self.attn_impl == "torch": self.attn_fn = scaled_multihead_dot_product_attention if torch.cuda.is_available() and verbose: warnings.warn( "Using `attn_impl: torch`. If your model does not use `alibi` or " + "`prefix_lm` we recommend using `attn_impl: flash` otherwise " + "we recommend using `attn_impl: triton`." ) else: raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") self.out_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.out_proj", weights=weights, bias=not config.no_bias, ) # self.out_proj._is_residual = True def forward( self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False, ): qkv = self.Wqkv(x) if self.clip_qkv: qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) (query, key, value) = qkv.split( [self.d_model, self.head_dim, self.head_dim], dim=2 ) key_padding_mask = attention_mask if self.qk_ln: dtype = query.dtype query = self.q_ln(query).to(dtype) key = self.k_ln(key).to(dtype) (context, attn_weights, past_key_value) = self.attn_fn( query, key, value, self.n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True, ) return (self.out_proj(context), attn_weights, past_key_value) def attn_bias_shape( attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id ): if attn_impl == "flash": return None elif attn_impl in ["torch", "triton"]: if alibi: if (prefix_lm or not causal) or use_sequence_id: return (1, n_heads, seq_len, seq_len) return (1, n_heads, 1, seq_len) elif prefix_lm or use_sequence_id: return (1, 1, seq_len, seq_len) return None else: raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") def build_attn_bias( attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8 ): if attn_impl == "flash": return None elif attn_impl in ["torch", "triton"]: if alibi: (device, dtype) = (attn_bias.device, attn_bias.dtype) attn_bias = attn_bias.add( build_alibi_bias( n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype, ) ) return attn_bias else: raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") def gen_slopes(n_heads, alibi_bias_max=8, device=None): _n_heads = 2 ** math.ceil(math.log2(n_heads)) m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device) m = m.mul(alibi_bias_max / _n_heads) slopes = 1.0 / torch.pow(2, m) if _n_heads != n_heads: slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads] return slopes.view(1, n_heads, 1, 1) def build_alibi_bias( n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None ): alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view( 1, 1, 1, seq_len ) if full: alibi_bias = alibi_bias - torch.arange( 1 - seq_len, 1, dtype=torch.int32, device=device ).view(1, 1, seq_len, 1) alibi_bias = alibi_bias.abs().mul(-1) slopes = gen_slopes(n_heads, alibi_bias_max, device=device) alibi_bias = alibi_bias * slopes return alibi_bias.to(dtype=dtype) ATTN_CLASS_REGISTRY = { "multihead_attention": MultiheadAttention, "multiquery_attention": MultiQueryAttention, } """GPT Blocks used for the GPT Model.""" class MPTMLP(nn.Module): def __init__(self, config, prefix, weights): super().__init__() # self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) self.up_proj = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.up_proj", weights=weights, bias=not config.no_bias ) self.act = nn.GELU(approximate="none") # self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) self.down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj", weights=weights, bias=not config.no_bias, ) # self.down_proj._is_residual = True def forward(self, x): return self.down_proj(self.act(self.up_proj(x))) class MPTBlock(nn.Module): def __init__(self, config, prefix, weights): super().__init__() self.prefix = prefix if config.attn_config["attn_type"] != "multihead_attention": raise NotImplementedError( f"""Not implemented attn {config.attn_config["attn_type"]}""" ) resid_pdrop = config.resid_pdrop if config.no_bias: self.norm_1 = nn.LayerNorm.load_no_bias( prefix=f"{prefix}.norm_1", weights=weights, eps=EPS ) self.norm_2 = nn.LayerNorm.load_no_bias( prefix=f"{prefix}.norm_2", weights=weights, eps=EPS ) else: self.norm_1 = nn.LayerNorm.load( prefix=f"{prefix}.norm_1", weights=weights, eps=EPS ) self.norm_2 = nn.LayerNorm.load( prefix=f"{prefix}.norm_2", weights=weights, eps=EPS ) self.attn = MultiheadAttention(config, prefix=f"{prefix}.attn", weights=weights) self.ffn = MPTMLP(config, prefix=f"{prefix}.ffn", weights=weights) self.resid_attn_dropout = nn.Dropout(resid_pdrop) self.resid_ffn_dropout = nn.Dropout(resid_pdrop) def forward( self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]] = None, attn_bias: Optional[torch.Tensor] = None, attention_mask: Optional[torch.ByteTensor] = None, is_causal: bool = True, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: a = self.norm_1(x) (b, attn_weights, past_key_value) = self.attn( a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal, ) x = x + self.resid_attn_dropout(b) m = self.norm_2(x) n = self.ffn(m) x = x + self.resid_ffn_dropout(n) return (x, attn_weights, past_key_value) def _cast_if_autocast_enabled(tensor): if torch.is_autocast_enabled(): if tensor.device.type == "cuda": dtype = torch.get_autocast_gpu_dtype() elif tensor.device.type == "cpu": dtype = torch.get_autocast_cpu_dtype() else: raise NotImplementedError() return tensor.to(dtype=dtype) return tensor class LPLayerNorm(torch.nn.LayerNorm): def __init__( self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None, bias: Optional[bool] = True, prefix=None, weights=None, ): super().__init__( normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype, bias=bias, ) if weights is not None: self.weight = nn.Parameter(weights.get_sharded(f"{prefix}.weight", dim=0)) if bias: self.bias = nn.Parameter(weights.get_sharded(f"{prefix}.bias", dim=0)) self.normalized_shape = self.weight.shape def forward(self, x): module_device = x.device downcast_x = _cast_if_autocast_enabled(x) downcast_weight = ( _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight ) downcast_bias = ( _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias ) with torch.autocast(enabled=False, device_type=module_device.type): return torch.nn.functional.layer_norm( downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps, ) def rms_norm(x, weight=None, eps=1e-05): output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) if weight is not None: return output * weight return output class RMSNorm(torch.nn.Module): def __init__( self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None ): super().__init__() self.eps = eps if weight: self.weight = torch.nn.Parameter( torch.ones(normalized_shape, dtype=dtype, device=device) ) else: self.register_parameter("weight", None) def forward(self, x): return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) class LPRMSNorm(RMSNorm): def __init__( self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None ): super().__init__( normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device, ) def forward(self, x): downcast_x = _cast_if_autocast_enabled(x) downcast_weight = ( _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight ) with torch.autocast(enabled=False, device_type=x.device.type): return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) NORM_CLASS_REGISTRY = { "layernorm": torch.nn.LayerNorm, "low_precision_layernorm": LPLayerNorm, "rmsnorm": RMSNorm, "low_precision_rmsnorm": LPRMSNorm, } Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] class MPTPreTrainedModel(PreTrainedModel): base_model_prefix = "model" _no_split_modules = ["MPTBlock"] class MPTModel(MPTPreTrainedModel): def __init__(self, config, weights): # config._validate_config() super().__init__(config) self.world_size = weights.process_group.size() self.rank = weights.process_group.rank() self.n_heads = config.n_heads self.attn_impl = config.attn_config["attn_impl"] self.prefix_lm = config.attn_config["prefix_lm"] self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"] self.alibi = config.attn_config["alibi"] self.alibi_bias_max = config.attn_config["alibi_bias_max"] if config.init_device == "mixed": if dist.get_local_rank() == 0: config.init_device = "cpu" else: config.init_device = "meta" if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys(): norm_options = " | ".join(NORM_CLASS_REGISTRY.keys()) raise NotImplementedError( f"Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options})." ) if config.norm_type.lower() != "low_precision_layernorm": raise NotImplementedError( f"Requested norm type ({config.norm_type}) is not implemented within this repo." ) self.wte = TensorParallelEmbedding("transformer.wte", weights) if not self.alibi: self.wpe = TensorParallelEmbedding("transformer.wpe", weights) self.blocks = nn.ModuleList( [ MPTBlock(config, prefix=f"transformer.blocks.{i}", weights=weights) for i in range(config.n_layers) ] ) if config.no_bias: self.norm_f = nn.LayerNorm.load_no_bias( prefix="transformer.norm_f", weights=weights, eps=EPS ) else: self.norm_f = nn.LayerNorm.load( prefix="transformer.norm_f", weights=weights, eps=EPS ) self.is_causal = not self.prefix_lm self._attn_bias_initialized = False self.attn_bias = None self.attn_bias_shape = attn_bias_shape( self.attn_impl, config.n_heads, config.max_seq_len, self.alibi, prefix_lm=self.prefix_lm, causal=self.is_causal, use_sequence_id=self.attn_uses_sequence_id, ) if config.no_bias: for module in self.modules(): if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter): if config.verbose: warnings.warn(f"Removing bias ({module.bias}) from {module}.") module.register_parameter("bias", None) if hasattr(self.config, "verbose"): if config.verbose and config.verbose > 2: print(self) if "verbose" not in self.config.init_config: self.config.init_config["verbose"] = self.config.verbose if self.config.init_config["verbose"] > 1: init_fn_name = self.config.init_config["name"] warnings.warn(f"Using {init_fn_name} initialization.") @torch.no_grad() def _attn_bias( self, device, dtype, attention_mask: Optional[torch.ByteTensor] = None, prefix_mask: Optional[torch.ByteTensor] = None, sequence_id: Optional[torch.LongTensor] = None, ): if not self._attn_bias_initialized: if self.attn_bias_shape: self.attn_bias = torch.zeros( self.attn_bias_shape, device=device, dtype=dtype ) self.attn_bias = build_attn_bias( self.attn_impl, self.attn_bias, self.config.n_heads, self.config.max_seq_len, causal=self.is_causal, alibi=self.alibi, alibi_bias_max=self.alibi_bias_max, ) assert self.n_heads % self.world_size == 0 block_size = self.n_heads // self.world_size self.attn_bias = self.attn_bias[ :, self.rank * block_size : (self.rank + 1) * block_size ] self._attn_bias_initialized = True if self.attn_impl == "flash": return (self.attn_bias, attention_mask) if self.attn_bias is not None: self.attn_bias = self.attn_bias.to(dtype=dtype, device=device) attn_bias = self.attn_bias if self.prefix_lm: assert isinstance(attn_bias, torch.Tensor) assert isinstance(prefix_mask, torch.Tensor) attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask) if self.attn_uses_sequence_id and sequence_id is not None: assert isinstance(attn_bias, torch.Tensor) attn_bias = self._apply_sequence_id(attn_bias, sequence_id) if attention_mask is not None: s_k = attention_mask.shape[-1] if attn_bias is None: attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype) else: _s_k = max(0, attn_bias.size(-1) - s_k) attn_bias = attn_bias[:, :, :, _s_k:] if prefix_mask is not None and attention_mask.shape != prefix_mask.shape: raise ValueError( f"attention_mask shape={attention_mask.shape} " + f"and prefix_mask shape={prefix_mask.shape} are not equal." ) min_val = torch.finfo(attn_bias.dtype).min attn_bias = attn_bias.masked_fill( ~attention_mask.view(-1, 1, 1, s_k), min_val ) return (attn_bias, None) def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor): (s_k, s_q) = attn_bias.shape[-2:] if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len: raise ValueError( "attn_bias does not match the expected shape. " + f"The last two dimensions should both be {self.config.max_length} " + f"but are {s_k} and {s_q}." ) seq_len = prefix_mask.shape[-1] if seq_len > self.config.max_seq_len: raise ValueError( f"prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}" ) attn_bias = attn_bias[..., :seq_len, :seq_len] causal = torch.tril( torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device) ).view(1, 1, seq_len, seq_len) prefix = prefix_mask.view(-1, 1, 1, seq_len) cannot_attend = ~torch.logical_or(causal, prefix.bool()) min_val = torch.finfo(attn_bias.dtype).min attn_bias = attn_bias.masked_fill(cannot_attend, min_val) return attn_bias def _apply_sequence_id( self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor ): seq_len = sequence_id.shape[-1] if seq_len > self.config.max_seq_len: raise ValueError( f"sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}" ) attn_bias = attn_bias[..., :seq_len, :seq_len] cannot_attend = torch.logical_not( torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len)) ).unsqueeze(1) min_val = torch.finfo(attn_bias.dtype).min attn_bias = attn_bias.masked_fill(cannot_attend, min_val) return attn_bias def forward( self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, attention_mask: Optional[torch.ByteTensor] = None, prefix_mask: Optional[torch.ByteTensor] = None, sequence_id: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, use_cache: Optional[bool] = None, ): return_dict = ( return_dict if return_dict is not None else self.config.return_dict ) use_cache = use_cache if use_cache is not None else self.config.use_cache if attention_mask is not None: attention_mask = attention_mask.bool() if prefix_mask is not None: prefix_mask = prefix_mask.bool() if not return_dict: raise NotImplementedError( "return_dict False is not implemented yet for MPT" ) if output_attentions: if self.attn_impl != "torch": raise NotImplementedError( "output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`." ) if ( attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training ): raise NotImplementedError( "MPT does not support training with left padding." ) if self.prefix_lm and prefix_mask is None: raise ValueError( "prefix_mask is a required argument when MPT is configured with prefix_lm=True." ) if self.training: if self.attn_uses_sequence_id and sequence_id is None: raise ValueError( "sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True " + "and the model is in train mode." ) elif self.attn_uses_sequence_id is False and sequence_id is not None: warnings.warn( "MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. " + "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True." ) S = input_ids.size(1) assert ( S <= self.config.max_seq_len ), f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}" tok_emb = self.wte(input_ids) if self.alibi: x = tok_emb else: past_position = 0 if past_key_values is not None: if len(past_key_values) != self.config.n_layers: raise ValueError( f"past_key_values must provide a past_key_value for each attention " + f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})." ) past_position = past_key_values[0][0].size(1) if self.attn_impl == "torch": past_position = past_key_values[0][0].size(3) if S + past_position > self.config.max_seq_len: raise ValueError( f"Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}." ) pos = torch.arange( past_position, S + past_position, dtype=torch.long, device=input_ids.device, ).unsqueeze(0) if attention_mask is not None: pos = torch.clamp( pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[ :, past_position: ], min=0, ) pos_emb = self.wpe(pos) x = tok_emb + pos_emb (attn_bias, attention_mask) = self._attn_bias( device=x.device, dtype=torch.float32, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, ) if use_cache and past_key_values is None: past_key_values = [() for _ in range(self.config.n_layers)] all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None for b_idx, block in enumerate(self.blocks): if output_hidden_states: assert all_hidden_states is not None all_hidden_states = all_hidden_states + (x,) past_key_value = ( past_key_values[b_idx] if past_key_values is not None else None ) (x, attn_weights, past_key_value) = block( x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, ) if past_key_values is not None: past_key_values[b_idx] = past_key_value if output_attentions: assert all_self_attns is not None all_self_attns = all_self_attns + (attn_weights,) x = self.norm_f(x) if output_hidden_states: assert all_hidden_states is not None all_hidden_states = all_hidden_states + (x,) return BaseModelOutputWithPast( last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) class MPTForCausalLM(MPTPreTrainedModel): def __init__(self, config, weights): super().__init__(config) if not config.tie_word_embeddings: raise ValueError("MPTForCausalLM only supports tied word embeddings") self.transformer = MPTModel(config, weights) self.lm_head = SpeculativeHead.load( config, prefix="transformer.wte", weights=weights ) self.logit_scale = None if config.logit_scale is not None: logit_scale = config.logit_scale if isinstance(logit_scale, str): if logit_scale == "inv_sqrt_d_model": logit_scale = 1 / math.sqrt(config.d_model) else: raise ValueError( f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'." ) self.logit_scale = logit_scale def forward( self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, attention_mask: Optional[torch.ByteTensor] = None, prefix_mask: Optional[torch.ByteTensor] = None, sequence_id: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, use_cache: Optional[bool] = None, ): return_dict = ( return_dict if return_dict is not None else self.config.return_dict ) use_cache = use_cache if use_cache is not None else self.config.use_cache outputs = self.transformer( input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, ) logits, speculative_logits = self.lm_head(outputs.last_hidden_state) if self.logit_scale is not None: if self.logit_scale == 0: warnings.warn( f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs." ) logits *= self.logit_scale loss = None if labels is not None: labels = torch.roll(labels, shifts=-1) labels[:, -1] = -100 loss = F.cross_entropy( logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1) ) return ( CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ), speculative_logits, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs ): if inputs_embeds is not None: raise NotImplementedError("inputs_embeds is not implemented for MPT yet") attention_mask = kwargs["attention_mask"].bool() if attention_mask[:, -1].sum() != attention_mask.shape[0]: raise NotImplementedError( "MPT does not support generation with right padding." ) if self.transformer.attn_uses_sequence_id and self.training: sequence_id = torch.zeros_like(input_ids[:1]) else: sequence_id = None if past_key_values is not None: input_ids = input_ids[:, -1].unsqueeze(-1) if self.transformer.prefix_lm: prefix_mask = torch.ones_like(attention_mask) if kwargs.get("use_cache") == False: raise NotImplementedError( "MPT with prefix_lm=True does not support use_cache=False." ) else: prefix_mask = None return { "input_ids": input_ids, "attention_mask": attention_mask, "prefix_mask": prefix_mask, "sequence_id": sequence_id, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache", True), } @staticmethod def _reorder_cache(past_key_values, beam_idx): """Used by HuggingFace generate when using beam search with kv-caching. See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133 for an example in transformers. """ reordered_past = [] for layer_past in past_key_values: reordered_past += [ tuple( (past_state.index_select(0, beam_idx) for past_state in layer_past) ) ] return reordered_past