diff --git a/backends/gaudi/server/text_generation_server/layers/rotary.py b/backends/gaudi/server/text_generation_server/layers/rotary.py index 1f8a6bd1d..b25f9fab1 100644 --- a/backends/gaudi/server/text_generation_server/layers/rotary.py +++ b/backends/gaudi/server/text_generation_server/layers/rotary.py @@ -77,6 +77,11 @@ class PositionRotaryEmbedding(nn.Module): inv_freq = _create_inv_freq(dim, base, device) scaling_factor = None rope_scaling = _get_rope_config(config) + if not hasattr(config, "max_position_embeddings") and hasattr( + config, "max_seq_len" + ): + # handling for dbrx + config.max_position_embeddings = config.max_seq_len if rope_scaling is not None: # `rope_type` is now standard in transformers, but some existing models # have `type` instead. diff --git a/backends/gaudi/server/text_generation_server/models/__init__.py b/backends/gaudi/server/text_generation_server/models/__init__.py index 9229bcf2b..dfdec9dce 100644 --- a/backends/gaudi/server/text_generation_server/models/__init__.py +++ b/backends/gaudi/server/text_generation_server/models/__init__.py @@ -286,16 +286,6 @@ class ModelType(enum.Enum): "name": "Qwen 2.5 VL", "url": "https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e", } - OPT = { - "type": "opt", - "name": "Opt", - "url": "https://huggingface.co/facebook/opt-6.7b", - } - T5 = { - "type": "t5", - "name": "T5", - "url": "https://huggingface.co/google/flan-t5-xxl", - } GALACTICA = { "type": "galactica", "name": "Galactica", @@ -306,16 +296,6 @@ class ModelType(enum.Enum): "name": "SantaCoder", "url": "https://huggingface.co/bigcode/santacoder", } - BLOOM = { - "type": "bloom", - "name": "Bloom", - "url": "https://huggingface.co/bigscience/bloom-560m", - } - MPT = { - "type": "mpt", - "name": "Mpt", - "url": "https://huggingface.co/mosaicml/mpt-7b-instruct", - } GPT2 = { "type": "gpt2", "name": "Gpt2", diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 0f1338caa..b335a81f0 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -43,9 +43,7 @@ from text_generation_server.layers.rotary import ( from text_generation_server.layers.layernorm import ( FastLayerNorm, ) - - -moe_kernels = None +from vllm_hpu_extension.ops import DynamicFusedMOE class DbrxAttentionConfig(PretrainedConfig): @@ -497,19 +495,15 @@ class BlockSparseMoE(nn.Module): self.process_group = weights.process_group + self.hpu_fused_moe = DynamicFusedMOE(self.num_experts) + for i in range(self.num_experts): + self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.wv1[i]) + self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.w2[i]) + def forward(self, x: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits = self.gate(x) - - out = moe_kernels.fused_moe( - x, - self.wv1, - self.w2, - router_logits, - self.top_k, - renormalize=self.moe_normalize_expert_weights, - inplace=True, - ) + out = self.hpu_fused_moe(x, router_logits, self.top_k) # Reduce sum if self.process_group.size() > 1: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/mpt_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/mpt_modeling.py deleted file mode 100644 index 988a74a39..000000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/mpt_modeling.py +++ /dev/null @@ -1,1215 +0,0 @@ -"""A simple, flexible implementation of a GPT model. - -Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py -""" - -import math -import warnings -from typing import List, Optional, Tuple, Union -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from einops import rearrange -from packaging import version -from text_generation_server.layers import ( - TensorParallelEmbedding, - TensorParallelColumnLinear, - TensorParallelRowLinear, - SpeculativeHead, - get_linear, -) - -EPS = 1e-5 - - -def load_col(config, prefix, weights, bias): - assert config.quantize != "gptq", NotImplementedError - slice_ = weights._get_slice(f"{prefix}.weight") - rank = weights.process_group.rank() - size = weights.process_group.size() - - h3, h = slice_.get_shape() - block_size = h // size - - q_part = slice_[rank * block_size : (rank + 1) * block_size] - k_part = slice_[h + rank * block_size : h + (rank + 1) * block_size] - v_part = slice_[2 * h + rank * block_size : 2 * h + (rank + 1) * block_size] - - weight = torch.cat([q_part, k_part, v_part], dim=0) - if weight.dtype != torch.int32: - weight = weight.to(dtype=weights.dtype) - weight = weight.to(device=weights.device) - - if bias: - bias_slice_ = weights._get_slice(f"{prefix}.bias") - bias_rank = weights.process_group.rank() - bias_size = weights.process_group.size() - - bias_h = bias_slice_.get_shape() - bias_h = bias_h[0] - bias_block_size = bias_h // bias_size - - bias_q_part = bias_slice_[ - bias_rank * bias_block_size : (bias_rank + 1) * bias_block_size - ] - bias_k_part = bias_slice_[ - bias_h - + bias_rank * bias_block_size : bias_h - + (bias_rank + 1) * bias_block_size - ] - bias_v_part = bias_slice_[ - 2 * bias_h - + bias_rank * bias_block_size : 2 * bias_h - + (bias_rank + 1) * bias_block_size - ] - - bias = torch.cat([bias_q_part, bias_k_part, bias_v_part], dim=0) - if bias.dtype != torch.int32: - bias = bias.to(dtype=weights.dtype) - bias = bias.to(device=weights.device) - else: - bias = None - linear = get_linear(weight, bias) - return TensorParallelColumnLinear(linear) - - -def _reset_is_causal( - num_query_tokens: int, num_key_tokens: int, original_is_causal: bool -): - if original_is_causal and num_query_tokens != num_key_tokens: - if num_query_tokens != 1: - raise NotImplementedError( - "MPT does not support query and key with different number of tokens, unless number of query tokens is 1." - ) - else: - return False - return original_is_causal - - -def scaled_multihead_dot_product_attention( - query, - key, - value, - n_heads, - past_key_value=None, - softmax_scale=None, - attn_bias=None, - key_padding_mask=None, - is_causal=False, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, -): - q = rearrange(query, "b s (h d) -> b h s d", h=n_heads) - kv_n_heads = 1 if multiquery else n_heads - k = rearrange(key, "b s (h d) -> b h d s", h=kv_n_heads) - v = rearrange(value, "b s (h d) -> b h s d", h=kv_n_heads) - if past_key_value is not None: - if len(past_key_value) != 0: - k = torch.cat([past_key_value[0], k], dim=3) - v = torch.cat([past_key_value[1], v], dim=2) - past_key_value = (k, v) - (b, _, s_q, d) = q.shape - s_k = k.size(-1) - attn_weight = q.matmul(k) * softmax_scale - if attn_bias is not None: - _s_q = max(0, attn_bias.size(2) - s_q) - _s_k = max(0, attn_bias.size(3) - s_k) - attn_bias = attn_bias[:, :, _s_q:, _s_k:] - if ( - attn_bias.size(-1) != 1 - and attn_bias.size(-1) != s_k - or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q) - ): - raise RuntimeError( - f"attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}." - ) - attn_weight = attn_weight + attn_bias - min_val = torch.finfo(q.dtype).min - if key_padding_mask is not None: - if attn_bias is not None: - warnings.warn( - "Propogating key_padding_mask to the attention module " - + "and applying it within the attention module can cause " - + "unneccessary computation/memory usage. Consider integrating " - + "into attn_bias once and passing that to each attention " - + "module instead." - ) - attn_weight = attn_weight.masked_fill( - ~key_padding_mask.view((b, 1, 1, s_k)), min_val - ) - if is_causal and (not q.size(2) == 1): - s = max(s_q, s_k) - causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16) - causal_mask = causal_mask.tril() - causal_mask = causal_mask.to(torch.bool) - causal_mask = ~causal_mask - causal_mask = causal_mask[-s_q:, -s_k:] - attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val) - attn_weight = torch.softmax(attn_weight, dim=-1) - if dropout_p: - attn_weight = torch.nn.functional.dropout( - attn_weight, p=dropout_p, training=training, inplace=True - ) - out = attn_weight.to(v.dtype).matmul(v) - out = rearrange(out, "b h s d -> b s (h d)") - if needs_weights: - return (out, attn_weight, past_key_value) - return (out, None, past_key_value) - - -def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]): - for tensor in tensors: - if tensor.dtype not in valid_dtypes: - raise TypeError( - f"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}." - ) - if not tensor.is_cuda: - raise TypeError( - f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r})." - ) - - -def flash_attn_fn( - query, - key, - value, - n_heads, - past_key_value=None, - softmax_scale=None, - attn_bias=None, - key_padding_mask=None, - is_causal=False, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, -): - try: - from flash_attn import bert_padding, flash_attn_interface - except Exception: - raise RuntimeError("Please install flash-attn==1.0.3.post0") - check_valid_inputs(query, key, value) - if past_key_value is not None: - if len(past_key_value) != 0: - key = torch.cat([past_key_value[0], key], dim=1) - value = torch.cat([past_key_value[1], value], dim=1) - past_key_value = (key, value) - if attn_bias is not None: - _s_q = max(0, attn_bias.size(2) - query.size(1)) - _s_k = max(0, attn_bias.size(3) - key.size(1)) - attn_bias = attn_bias[:, :, _s_q:, _s_k:] - if attn_bias is not None: - raise NotImplementedError("attn_bias not implemented for flash attn.") - (batch_size, seqlen) = query.shape[:2] - if key_padding_mask is None: - key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) - query_padding_mask = key_padding_mask[:, -query.size(1) :] - (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input( - query, query_padding_mask - ) - query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=n_heads) - (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input( - key, key_padding_mask - ) - key_unpad = rearrange( - key_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads - ) - (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask) - value_unpad = rearrange( - value_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads - ) - if multiquery: - key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1)) - value_unpad = value_unpad.expand( - value_unpad.size(0), n_heads, value_unpad.size(-1) - ) - dropout_p = dropout_p if training else 0.0 - reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) - output_unpad = flash_attn_interface.flash_attn_unpadded_func( - query_unpad, - key_unpad, - value_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale=softmax_scale, - causal=reset_is_causal, - return_attn_probs=needs_weights, - ) - output = bert_padding.pad_input( - rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices_q, batch_size, seqlen - ) - return (output, None, past_key_value) - - -def triton_flash_attn_fn( - query, - key, - value, - n_heads, - past_key_value=None, - softmax_scale=None, - attn_bias=None, - key_padding_mask=None, - is_causal=False, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, -): - try: - from .flash_attn_triton import flash_attn_func - except Exception: - _installed = False - if version.parse(torch.__version__) < version.parse("2.0.0"): - _installed = True - try: - from flash_attn.flash_attn_triton import flash_attn_func - except Exception: - _installed = False - if not _installed: - raise RuntimeError( - "Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed." - ) - check_valid_inputs(query, key, value) - if past_key_value is not None: - if len(past_key_value) != 0: - key = torch.cat([past_key_value[0], key], dim=1) - value = torch.cat([past_key_value[1], value], dim=1) - past_key_value = (key, value) - if attn_bias is not None: - _s_q = max(0, attn_bias.size(2) - query.size(1)) - _s_k = max(0, attn_bias.size(3) - key.size(1)) - attn_bias = attn_bias[:, :, _s_q:, _s_k:] - if dropout_p: - raise NotImplementedError("Dropout not implemented for attn_impl: triton.") - if needs_weights: - raise NotImplementedError("attn_impl: triton cannot return attn weights.") - if key_padding_mask is not None: - warnings.warn( - "Propagating key_padding_mask to the attention module " - + "and applying it within the attention module can cause " - + "unnecessary computation/memory usage. Consider integrating " - + "into attn_bias once and passing that to each attention " - + "module instead." - ) - (b_size, s_k) = key_padding_mask.shape[:2] - if attn_bias is None: - attn_bias = query.new_zeros(b_size, 1, 1, s_k) - attn_bias = attn_bias.masked_fill( - ~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min - ) - query = rearrange(query, "b s (h d) -> b s h d", h=n_heads) - key = rearrange(key, "b s (h d) -> b s h d", h=1 if multiquery else n_heads) - value = rearrange(value, "b s (h d) -> b s h d", h=1 if multiquery else n_heads) - if multiquery: - key = key.expand(*key.shape[:2], n_heads, key.size(-1)) - value = value.expand(*value.shape[:2], n_heads, value.size(-1)) - reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) - attn_output = flash_attn_func( - query, key, value, attn_bias, reset_is_causal, softmax_scale - ) - output = attn_output.view(*attn_output.shape[:2], -1) - return (output, None, past_key_value) - - -class MultiheadAttention(nn.Module): - """Multi-head self attention. - - Using torch or triton attention implementation enables user to also use - additive bias. - """ - - def __init__( - self, - config, - prefix, - weights, - ): - super().__init__() - attn_impl = config.attn_config.attn_impl - self.attn_impl = config.attn_config.attn_impl - self.clip_qkv = config.attn_config.clip_qkv - self.qk_ln = config.attn_config.qk_ln - self.d_model = config.d_model - d_model = config.d_model - self.n_heads = config.n_heads - self.softmax_scale = config.attn_config.softmax_scale - if self.softmax_scale is None: - self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) - self.attn_dropout_p = config.attn_config.attn_pdrop - - if self.n_heads % weights.process_group.size() != 0: - raise ValueError( - f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} " - f"and `num_shards`: {weights.process_group.size()}" - ) - self.n_heads = self.n_heads // weights.process_group.size() - self.Wqkv = load_col( - config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias - ) - if self.qk_ln: - bias = not config.no_bias - hidden_size = config.d_model - head_dim = hidden_size // self.n_heads - - self.q_ln = LPLayerNorm( - d_model, bias=bias, prefix=f"{prefix}.q_ln", weights=weights - ) - self.k_ln = LPLayerNorm( - self.n_heads * head_dim, prefix=f"{prefix}.k_ln", weights=weights - ) - if self.attn_impl == "flash": - self.attn_fn = flash_attn_fn - elif self.attn_impl == "triton": - self.attn_fn = triton_flash_attn_fn - elif self.attn_impl == "torch": - self.attn_fn = scaled_multihead_dot_product_attention - else: - raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") - self.out_proj = TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.out_proj", - weights=weights, - bias=not config.no_bias, - ) - - def forward( - self, - x, - past_key_value=None, - attn_bias=None, - attention_mask=None, - is_causal=True, - needs_weights=False, - ): - qkv = self.Wqkv(x) - if self.clip_qkv: - qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) - (query, key, value) = qkv.chunk(3, dim=2) - - key_padding_mask = attention_mask - if self.qk_ln: - dtype = query.dtype - query = self.q_ln(query).to(dtype) - key = self.k_ln(key).to(dtype) - (context, attn_weights, past_key_value) = self.attn_fn( - query, - key, - value, - self.n_heads, - past_key_value=past_key_value, - softmax_scale=self.softmax_scale, - attn_bias=attn_bias, - key_padding_mask=key_padding_mask, - is_causal=is_causal, - dropout_p=self.attn_dropout_p, - training=self.training, - needs_weights=needs_weights, - ) - out = self.out_proj(context) - return (out, attn_weights, past_key_value) - - -class MultiQueryAttention(nn.Module): - """Multi-Query self attention. - - Using torch or triton attention implementation enables user to also use - additive bias. - """ - - def __init__(self, config, prefix, weights, verbose=False): - super().__init__() - attn_impl = config.attn_config.attn_impl - self.attn_impl = config.attn_config.attn_impl - self.clip_qkv = config.attn_config.clip_qkv - self.qk_ln = config.attn_config.qk_ln - self.d_model = config.d_model - d_model = config.d_model - self.n_heads = config.n_heads - self.softmax_scale = config.attn_config.softmax_scale - if self.softmax_scale is None: - self.softmax_scale = 1 / math.sqrt(self.head_dim) - self.attn_dropout_p = config.attn_config.attn_pdrop - # self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device) - self.Wqkv = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias - ) - (d_model, d_model + self.head_dim) - if self.qk_ln: - raise NotImplementedError("qk_ln not supported") - if self.attn_impl == "flash": - self.attn_fn = flash_attn_fn - elif self.attn_impl == "triton": - self.attn_fn = triton_flash_attn_fn - if verbose: - warnings.warn( - "While `attn_impl: triton` can be faster than `attn_impl: flash` " - + "it uses more memory. When training larger models this can trigger " - + "alloc retries which hurts performance. If encountered, we recommend " - + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`." - ) - elif self.attn_impl == "torch": - self.attn_fn = scaled_multihead_dot_product_attention - if torch.cuda.is_available() and verbose: - warnings.warn( - "Using `attn_impl: torch`. If your model does not use `alibi` or " - + "`prefix_lm` we recommend using `attn_impl: flash` otherwise " - + "we recommend using `attn_impl: triton`." - ) - else: - raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") - self.out_proj = TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.out_proj", - weights=weights, - bias=not config.no_bias, - ) - # self.out_proj._is_residual = True - - def forward( - self, - x, - past_key_value=None, - attn_bias=None, - attention_mask=None, - is_causal=True, - needs_weights=False, - ): - qkv = self.Wqkv(x) - if self.clip_qkv: - qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) - (query, key, value) = qkv.split( - [self.d_model, self.head_dim, self.head_dim], dim=2 - ) - key_padding_mask = attention_mask - if self.qk_ln: - dtype = query.dtype - query = self.q_ln(query).to(dtype) - key = self.k_ln(key).to(dtype) - (context, attn_weights, past_key_value) = self.attn_fn( - query, - key, - value, - self.n_heads, - past_key_value=past_key_value, - softmax_scale=self.softmax_scale, - attn_bias=attn_bias, - key_padding_mask=key_padding_mask, - is_causal=is_causal, - dropout_p=self.attn_dropout_p, - training=self.training, - needs_weights=needs_weights, - multiquery=True, - ) - return (self.out_proj(context), attn_weights, past_key_value) - - -def attn_bias_shape( - attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id -): - if attn_impl == "flash": - return None - elif attn_impl in ["torch", "triton"]: - if alibi: - if (prefix_lm or not causal) or use_sequence_id: - return (1, n_heads, seq_len, seq_len) - return (1, n_heads, 1, seq_len) - elif prefix_lm or use_sequence_id: - return (1, 1, seq_len, seq_len) - return None - else: - raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") - - -def build_attn_bias( - attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8 -): - if attn_impl == "flash": - return None - elif attn_impl in ["torch", "triton"]: - if alibi: - (device, dtype) = (attn_bias.device, attn_bias.dtype) - attn_bias = attn_bias.add( - build_alibi_bias( - n_heads, - seq_len, - full=not causal, - alibi_bias_max=alibi_bias_max, - device=device, - dtype=dtype, - ) - ) - return attn_bias - else: - raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") - - -def gen_slopes(n_heads, alibi_bias_max=8, device=None): - _n_heads = 2 ** math.ceil(math.log2(n_heads)) - m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device) - m = m.mul(alibi_bias_max / _n_heads) - slopes = 1.0 / torch.pow(2, m) - if _n_heads != n_heads: - slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads] - return slopes.view(1, n_heads, 1, 1) - - -def build_alibi_bias( - n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None -): - alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view( - 1, 1, 1, seq_len - ) - if full: - alibi_bias = alibi_bias - torch.arange( - 1 - seq_len, 1, dtype=torch.int32, device=device - ).view(1, 1, seq_len, 1) - alibi_bias = alibi_bias.abs().mul(-1) - slopes = gen_slopes(n_heads, alibi_bias_max, device=device) - alibi_bias = alibi_bias * slopes - return alibi_bias.to(dtype=dtype) - - -ATTN_CLASS_REGISTRY = { - "multihead_attention": MultiheadAttention, - "multiquery_attention": MultiQueryAttention, -} - -"""GPT Blocks used for the GPT Model.""" - - -class MPTMLP(nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - # self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) - self.up_proj = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.up_proj", weights=weights, bias=not config.no_bias - ) - self.act = nn.GELU(approximate="none") - # self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) - self.down_proj = TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.down_proj", - weights=weights, - bias=not config.no_bias, - ) - # self.down_proj._is_residual = True - - def forward(self, x): - return self.down_proj(self.act(self.up_proj(x))) - - -class MPTBlock(nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - self.prefix = prefix - if config.attn_config.attn_type != "multihead_attention": - raise NotImplementedError( - f"""Not implemented attn {config.attn_config.attn_type}""" - ) - resid_pdrop = config.resid_pdrop - if config.no_bias: - self.norm_1 = nn.LayerNorm.load_no_bias( - prefix=f"{prefix}.norm_1", weights=weights, eps=EPS - ) - self.norm_2 = nn.LayerNorm.load_no_bias( - prefix=f"{prefix}.norm_2", weights=weights, eps=EPS - ) - else: - self.norm_1 = nn.LayerNorm.load( - prefix=f"{prefix}.norm_1", weights=weights, eps=EPS - ) - self.norm_2 = nn.LayerNorm.load( - prefix=f"{prefix}.norm_2", weights=weights, eps=EPS - ) - self.attn = MultiheadAttention(config, prefix=f"{prefix}.attn", weights=weights) - self.ffn = MPTMLP(config, prefix=f"{prefix}.ffn", weights=weights) - self.resid_attn_dropout = nn.Dropout(resid_pdrop) - self.resid_ffn_dropout = nn.Dropout(resid_pdrop) - - def forward( - self, - x: torch.Tensor, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attn_bias: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.ByteTensor] = None, - is_causal: bool = True, - ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: - a = self.norm_1(x) - (b, attn_weights, past_key_value) = self.attn( - a, - past_key_value=past_key_value, - attn_bias=attn_bias, - attention_mask=attention_mask, - is_causal=is_causal, - ) - x = x + self.resid_attn_dropout(b) - m = self.norm_2(x) - n = self.ffn(m) - x = x + self.resid_ffn_dropout(n) - return (x, attn_weights, past_key_value) - - -def _cast_if_autocast_enabled(tensor): - if torch.is_autocast_enabled(): - if tensor.device.type == "cuda": - dtype = torch.get_autocast_gpu_dtype() - elif tensor.device.type == "cpu": - dtype = torch.get_autocast_cpu_dtype() - else: - raise NotImplementedError() - return tensor.to(dtype=dtype) - return tensor - - -class LPLayerNorm(torch.nn.LayerNorm): - def __init__( - self, - normalized_shape, - eps=1e-05, - elementwise_affine=True, - device=None, - dtype=None, - bias: Optional[bool] = True, - prefix=None, - weights=None, - ): - super().__init__( - normalized_shape=normalized_shape, - eps=eps, - elementwise_affine=elementwise_affine, - device=device, - dtype=dtype, - bias=bias, - ) - if weights is not None: - self.weight = nn.Parameter(weights.get_sharded(f"{prefix}.weight", dim=0)) - if bias: - self.bias = nn.Parameter(weights.get_sharded(f"{prefix}.bias", dim=0)) - self.normalized_shape = self.weight.shape - - def forward(self, x): - module_device = x.device - downcast_x = _cast_if_autocast_enabled(x) - downcast_weight = ( - _cast_if_autocast_enabled(self.weight) - if self.weight is not None - else self.weight - ) - downcast_bias = ( - _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias - ) - with torch.autocast(enabled=False, device_type=module_device.type): - return torch.nn.functional.layer_norm( - downcast_x, - self.normalized_shape, - downcast_weight, - downcast_bias, - self.eps, - ) - - -def rms_norm(x, weight=None, eps=1e-05): - output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) - if weight is not None: - return output * weight - return output - - -class RMSNorm(torch.nn.Module): - def __init__( - self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None - ): - super().__init__() - self.eps = eps - if weight: - self.weight = torch.nn.Parameter( - torch.ones(normalized_shape, dtype=dtype, device=device) - ) - else: - self.register_parameter("weight", None) - - def forward(self, x): - return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) - - -class LPRMSNorm(RMSNorm): - def __init__( - self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None - ): - super().__init__( - normalized_shape=normalized_shape, - eps=eps, - weight=weight, - dtype=dtype, - device=device, - ) - - def forward(self, x): - downcast_x = _cast_if_autocast_enabled(x) - downcast_weight = ( - _cast_if_autocast_enabled(self.weight) - if self.weight is not None - else self.weight - ) - with torch.autocast(enabled=False, device_type=x.device.type): - return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) - - -NORM_CLASS_REGISTRY = { - "layernorm": torch.nn.LayerNorm, - "low_precision_layernorm": LPLayerNorm, - "rmsnorm": RMSNorm, - "low_precision_rmsnorm": LPRMSNorm, -} - -Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] - - -class MPTPreTrainedModel(PreTrainedModel): - base_model_prefix = "model" - _no_split_modules = ["MPTBlock"] - - -class MPTModel(MPTPreTrainedModel): - def __init__(self, prefix: str, config, weights): - # config._validate_config() - super().__init__(config) - self.world_size = weights.process_group.size() - self.rank = weights.process_group.rank() - self.n_heads = config.n_heads - self.attn_impl = config.attn_config.attn_impl - self.prefix_lm = config.attn_config.prefix_lm - self.attn_uses_sequence_id = config.attn_config.attn_uses_sequence_id - self.alibi = config.attn_config.alibi - self.alibi_bias_max = config.attn_config.alibi_bias_max - if config.init_device == "mixed": - # TODO: reimplement mixed device initialization - # dist.get_local_rank() == 0: - if True: - config.init_device = "cpu" - else: - config.init_device = "meta" - if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys(): - norm_options = " | ".join(NORM_CLASS_REGISTRY.keys()) - raise NotImplementedError( - f"Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options})." - ) - if config.norm_type.lower() != "low_precision_layernorm": - raise NotImplementedError( - f"Requested norm type ({config.norm_type}) is not implemented within this repo." - ) - - self.wte = TensorParallelEmbedding(f"{prefix}.wte", weights) - - if not self.alibi: - self.wpe = TensorParallelEmbedding(f"{prefix}.wpe", weights) - self.blocks = nn.ModuleList( - [ - MPTBlock(config, prefix=f"{prefix}.blocks.{i}", weights=weights) - for i in range(config.n_layers) - ] - ) - if config.no_bias: - self.norm_f = nn.LayerNorm.load_no_bias( - prefix="transformer.norm_f", weights=weights, eps=EPS - ) - else: - self.norm_f = nn.LayerNorm.load( - prefix="transformer.norm_f", weights=weights, eps=EPS - ) - self.is_causal = not self.prefix_lm - self._attn_bias_initialized = False - self.attn_bias = None - self.attn_bias_shape = attn_bias_shape( - self.attn_impl, - config.n_heads, - config.max_seq_len, - self.alibi, - prefix_lm=self.prefix_lm, - causal=self.is_causal, - use_sequence_id=self.attn_uses_sequence_id, - ) - if config.no_bias: - for module in self.modules(): - if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter): - if config.verbose: - warnings.warn(f"Removing bias ({module.bias}) from {module}.") - module.register_parameter("bias", None) - if hasattr(self.config, "verbose"): - if config.verbose and config.verbose > 2: - print(self) - if "verbose" not in self.config.init_config: - self.config.init_config["verbose"] = self.config.verbose - if self.config.init_config["verbose"] > 1: - init_fn_name = self.config.init_config["name"] - warnings.warn(f"Using {init_fn_name} initialization.") - - @torch.no_grad() - def _attn_bias( - self, - device, - dtype, - attention_mask: Optional[torch.ByteTensor] = None, - prefix_mask: Optional[torch.ByteTensor] = None, - sequence_id: Optional[torch.LongTensor] = None, - ): - if not self._attn_bias_initialized: - if self.attn_bias_shape: - self.attn_bias = torch.zeros( - self.attn_bias_shape, device=device, dtype=dtype - ) - self.attn_bias = build_attn_bias( - self.attn_impl, - self.attn_bias, - self.config.n_heads, - self.config.max_seq_len, - causal=self.is_causal, - alibi=self.alibi, - alibi_bias_max=self.alibi_bias_max, - ) - assert self.n_heads % self.world_size == 0 - block_size = self.n_heads // self.world_size - self.attn_bias = self.attn_bias[ - :, self.rank * block_size : (self.rank + 1) * block_size - ] - self._attn_bias_initialized = True - if self.attn_impl == "flash": - return (self.attn_bias, attention_mask) - if self.attn_bias is not None: - self.attn_bias = self.attn_bias.to(dtype=dtype, device=device) - attn_bias = self.attn_bias - if self.prefix_lm: - assert isinstance(attn_bias, torch.Tensor) - assert isinstance(prefix_mask, torch.Tensor) - attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask) - if self.attn_uses_sequence_id and sequence_id is not None: - assert isinstance(attn_bias, torch.Tensor) - attn_bias = self._apply_sequence_id(attn_bias, sequence_id) - if attention_mask is not None: - s_k = attention_mask.shape[-1] - if attn_bias is None: - attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype) - else: - _s_k = max(0, attn_bias.size(-1) - s_k) - attn_bias = attn_bias[:, :, :, _s_k:] - if prefix_mask is not None and attention_mask.shape != prefix_mask.shape: - raise ValueError( - f"attention_mask shape={attention_mask.shape} " - + f"and prefix_mask shape={prefix_mask.shape} are not equal." - ) - min_val = torch.finfo(attn_bias.dtype).min - attn_bias = attn_bias.masked_fill( - ~attention_mask.view(-1, 1, 1, s_k), min_val - ) - return (attn_bias, None) - - def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor): - (s_k, s_q) = attn_bias.shape[-2:] - if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len: - raise ValueError( - "attn_bias does not match the expected shape. " - + f"The last two dimensions should both be {self.config.max_length} " - + f"but are {s_k} and {s_q}." - ) - seq_len = prefix_mask.shape[-1] - if seq_len > self.config.max_seq_len: - raise ValueError( - f"prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}" - ) - attn_bias = attn_bias[..., :seq_len, :seq_len] - causal = torch.tril( - torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device) - ).view(1, 1, seq_len, seq_len) - prefix = prefix_mask.view(-1, 1, 1, seq_len) - cannot_attend = ~torch.logical_or(causal, prefix.bool()) - min_val = torch.finfo(attn_bias.dtype).min - attn_bias = attn_bias.masked_fill(cannot_attend, min_val) - return attn_bias - - def _apply_sequence_id( - self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor - ): - seq_len = sequence_id.shape[-1] - if seq_len > self.config.max_seq_len: - raise ValueError( - f"sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}" - ) - attn_bias = attn_bias[..., :seq_len, :seq_len] - cannot_attend = torch.logical_not( - torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len)) - ).unsqueeze(1) - min_val = torch.finfo(attn_bias.dtype).min - attn_bias = attn_bias.masked_fill(cannot_attend, min_val) - return attn_bias - - def forward( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, - attention_mask: Optional[torch.ByteTensor] = None, - prefix_mask: Optional[torch.ByteTensor] = None, - sequence_id: Optional[torch.LongTensor] = None, - return_dict: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - use_cache: Optional[bool] = None, - ): - return_dict = ( - return_dict if return_dict is not None else self.config.return_dict - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - if attention_mask is not None: - attention_mask = attention_mask.bool() - if prefix_mask is not None: - prefix_mask = prefix_mask.bool() - if not return_dict: - raise NotImplementedError( - "return_dict False is not implemented yet for MPT" - ) - if output_attentions: - if self.attn_impl != "torch": - raise NotImplementedError( - "output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`." - ) - if ( - attention_mask is not None - and attention_mask[:, 0].sum() != attention_mask.shape[0] - and self.training - ): - raise NotImplementedError( - "MPT does not support training with left padding." - ) - if self.prefix_lm and prefix_mask is None: - raise ValueError( - "prefix_mask is a required argument when MPT is configured with prefix_lm=True." - ) - if self.training: - if self.attn_uses_sequence_id and sequence_id is None: - raise ValueError( - "sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True " - + "and the model is in train mode." - ) - elif self.attn_uses_sequence_id is False and sequence_id is not None: - warnings.warn( - "MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. " - + "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True." - ) - S = input_ids.size(1) - assert ( - S <= self.config.max_seq_len - ), f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}" - tok_emb = self.wte(input_ids) - if self.alibi: - x = tok_emb - else: - past_position = 0 - if past_key_values is not None: - if len(past_key_values) != self.config.n_layers: - raise ValueError( - "past_key_values must provide a past_key_value for each attention " - + f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})." - ) - past_position = past_key_values[0][0].size(1) - if self.attn_impl == "torch": - past_position = past_key_values[0][0].size(3) - if S + past_position > self.config.max_seq_len: - raise ValueError( - f"Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}." - ) - pos = torch.arange( - past_position, - S + past_position, - dtype=torch.long, - device=input_ids.device, - ).unsqueeze(0) - if attention_mask is not None: - pos = torch.clamp( - pos - - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[ - :, past_position: - ], - min=0, - ) - pos_emb = self.wpe(pos) - x = tok_emb + pos_emb - (attn_bias, attention_mask) = self._attn_bias( - device=x.device, - dtype=torch.float32, - attention_mask=attention_mask, - prefix_mask=prefix_mask, - sequence_id=sequence_id, - ) - if use_cache and past_key_values is None: - past_key_values = [() for _ in range(self.config.n_layers)] - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - for b_idx, block in enumerate(self.blocks): - if output_hidden_states: - assert all_hidden_states is not None - all_hidden_states = all_hidden_states + (x,) - past_key_value = ( - past_key_values[b_idx] if past_key_values is not None else None - ) - (x, attn_weights, past_key_value) = block( - x, - past_key_value=past_key_value, - attn_bias=attn_bias, - attention_mask=attention_mask, - is_causal=self.is_causal, - ) - if past_key_values is not None: - past_key_values[b_idx] = past_key_value - if output_attentions: - assert all_self_attns is not None - all_self_attns = all_self_attns + (attn_weights,) - x = self.norm_f(x) - if output_hidden_states: - assert all_hidden_states is not None - all_hidden_states = all_hidden_states + (x,) - return BaseModelOutputWithPast( - last_hidden_state=x, - past_key_values=past_key_values, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class MPTForCausalLM(MPTPreTrainedModel): - def __init__(self, prefix: str, config, weights): - super().__init__(config) - - if not prefix: - prefix = "transformer" - else: - prefix = f"{prefix}.transformer" - - if not config.tie_word_embeddings: - raise ValueError("MPTForCausalLM only supports tied word embeddings") - self.transformer = MPTModel(prefix, config, weights) - self.lm_head = SpeculativeHead.load( - config, prefix=f"{prefix}.wte", weights=weights - ) - self.logit_scale = None - if config.logit_scale is not None: - logit_scale = config.logit_scale - if isinstance(logit_scale, str): - if logit_scale == "inv_sqrt_d_model": - logit_scale = 1 / math.sqrt(config.d_model) - else: - raise ValueError( - f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'." - ) - self.logit_scale = logit_scale - - def forward( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, - attention_mask: Optional[torch.ByteTensor] = None, - prefix_mask: Optional[torch.ByteTensor] = None, - sequence_id: Optional[torch.LongTensor] = None, - labels: Optional[torch.LongTensor] = None, - return_dict: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - use_cache: Optional[bool] = None, - ): - return_dict = ( - return_dict if return_dict is not None else self.config.return_dict - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - outputs = self.transformer( - input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - prefix_mask=prefix_mask, - sequence_id=sequence_id, - return_dict=return_dict, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - use_cache=use_cache, - ) - logits, speculative_logits = self.lm_head(outputs.last_hidden_state) - if self.logit_scale is not None: - if self.logit_scale == 0: - warnings.warn( - f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs." - ) - logits *= self.logit_scale - loss = None - if labels is not None: - labels = torch.roll(labels, shifts=-1) - labels[:, -1] = -100 - loss = F.cross_entropy( - logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1) - ) - return ( - CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ), - speculative_logits, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs - ): - if inputs_embeds is not None: - raise NotImplementedError("inputs_embeds is not implemented for MPT yet") - attention_mask = kwargs["attention_mask"].bool() - if attention_mask[:, -1].sum() != attention_mask.shape[0]: - raise NotImplementedError( - "MPT does not support generation with right padding." - ) - if self.transformer.attn_uses_sequence_id and self.training: - sequence_id = torch.zeros_like(input_ids[:1]) - else: - sequence_id = None - if past_key_values is not None: - input_ids = input_ids[:, -1].unsqueeze(-1) - if self.transformer.prefix_lm: - prefix_mask = torch.ones_like(attention_mask) - if kwargs.get("use_cache") is False: - raise NotImplementedError( - "MPT with prefix_lm=True does not support use_cache=False." - ) - else: - prefix_mask = None - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "prefix_mask": prefix_mask, - "sequence_id": sequence_id, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache", True), - } - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - """Used by HuggingFace generate when using beam search with kv-caching. - - See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133 - for an example in transformers. - """ - reordered_past = [] - for layer_past in past_key_values: - reordered_past += [ - tuple( - (past_state.index_select(0, beam_idx) for past_state in layer_past) - ) - ] - return reordered_past diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/neox_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/neox_modeling.py deleted file mode 100644 index 06731a6f9..000000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/neox_modeling.py +++ /dev/null @@ -1,796 +0,0 @@ -# coding=utf-8 -# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch GPTNeoX model.""" - -from typing import Optional, Tuple, Union - -import os -import torch -import torch.distributed -import torch.utils.checkpoint -from torch import nn -from torch.nn import CrossEntropyLoss - -from transformers.activations import ACT2FN -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from transformers.modeling_utils import PreTrainedModel -from text_generation_server.layers import ( - TensorParallelColumnLinear, - TensorParallelEmbedding, - TensorParallelRowLinear, - SpeculativeHead, -) - - -CUSTOM_KERNELS_ENABLED = False -if ( - torch.cuda.is_available() - and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True" -): - try: - from custom_kernels import fused_attention_cuda - - CUSTOM_KERNELS_ENABLED = True - except ImportError: - pass - - -def make_causal_mask( - input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int -) -> torch.BoolTensor: - """ - Make causal mask used for self-attention. - """ - batch_size, target_length = input_ids_shape - mask = torch.ones( - (target_length, target_length + past_key_values_length), - dtype=torch.bool, - device=device, - ) - mask = mask.triu(1 + past_key_values_length) - - expanded_mask = mask.unsqueeze(0).expand( - batch_size, target_length, target_length + past_key_values_length - ) - return expanded_mask - - -def expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor: - """ - Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`. - """ - batch_size, src_length = mask.shape - tgt_length = tgt_length if tgt_length is not None else src_length - - expanded_mask = ~(mask[:, None, :].to(torch.bool)) - return expanded_mask.expand(batch_size, tgt_length, src_length) - - -def prepare_attn_mask( - attention_mask: torch.Tensor, - input_shape: Tuple[int, int], - past_key_values_length: int, -) -> torch.BoolTensor: - # create causal mask - # [batch_size, seq_length] -> [batch_size, tgt_length, src_length] - combined_attention_mask = None - device = attention_mask.device - _, src_length = input_shape - - if src_length > 1: - combined_attention_mask = make_causal_mask( - input_shape, device=device, past_key_values_length=past_key_values_length - ) - - # [batch_size, seq_length] -> [batch_size, tgt_length, src_length] - expanded_attn_mask = expand_mask(attention_mask, tgt_length=src_length) - combined_attention_mask = ( - expanded_attn_mask - if combined_attention_mask is None - else expanded_attn_mask | combined_attention_mask - ) - - return combined_attention_mask - - -class GPTNeoXPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - -class GPTNeoXAttention(nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - self.num_attention_heads = config.num_attention_heads - self.hidden_size = config.hidden_size - self.head_size = self.hidden_size // self.num_attention_heads - self.rotary_ndims = int(self.head_size * config.rotary_pct) - # ??? TODO - # self.register_buffer( - # "bias", - # torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( - # 1, 1, max_positions, max_positions - # ), - # ) - # self.register_buffer("masked_bias", torch.tensor(-1e9)) - self.rotary_emb = RotaryEmbedding( - self.rotary_ndims, - config.max_position_embeddings, - base=config.rotary_emb_base, - ) - self.rotary_emb.inv_freq = nn.Parameter( - weights.get_tensor(f"{prefix}.rotary_emb.inv_freq") - ) - self.inv_norm_factor = 1.0 / torch.sqrt( - torch.tensor(self.head_size, dtype=torch.float32) - ).to(torch.get_default_dtype()) - - if self.num_attention_heads % weights.process_group.size() != 0: - raise ValueError( - f"`num_attention_heads` must be divisible by `num_shards` " - f"(got `num_attention_heads`: {self.num_attention_heads} " - f"and `num_shards`: {weights.process_group.size()}" - ) - self.num_attention_heads = ( - self.num_attention_heads // weights.process_group.size() - ) - self.query_key_value = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.query_key_value", weights=weights, bias=True - ) - self.dense = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.dense", weights=weights, bias=True - ) - - def forward( - self, - hidden_states, - position_ids, - attention_mask, - head_mask=None, - layer_past=None, - use_cache=False, - output_attentions=False, - ): - has_layer_past = layer_past is not None - - # Compute QKV - # Attention heads [batch, seq_len, hidden_size] - # --> [batch, seq_len, (np * 3 * head_size)] - qkv = self.query_key_value(hidden_states) - - # [batch, seq_len, (num_heads * 3 * head_size)] - # --> [batch, seq_len, num_heads, 3 * head_size] - new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size) - qkv = qkv.view(*new_qkv_shape).permute(0, 2, 1, 3) - # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size] - query, key, value = qkv.split(self.head_size, -1) - - # Compute token offset for rotary embeddings (when decoding) - seq_len = key.shape[-2] - if has_layer_past: - seq_len += layer_past[0].shape[-2] - - # Compute rotary embeddings on rotary_ndims - query_rot = query[..., : self.rotary_ndims] - key_rot = key[..., : self.rotary_ndims] - - query_rot, key_rot = self.rotary_emb(query_rot, key_rot, position_ids, seq_len) - - query[..., : self.rotary_ndims] = query_rot - key[..., : self.rotary_ndims] = key_rot - - if CUSTOM_KERNELS_ENABLED: - attn_output, present, attn_weights = fused_attention_cuda.forward( - query, - key, - value, - layer_past, - attention_mask, - head_mask, - self.inv_norm_factor, - self.num_attention_heads, - use_cache, - ) - else: - # Cache QKV values - if has_layer_past: - past_key = layer_past[0] - past_value = layer_past[1] - key = torch.cat((past_key, key), dim=-2) - value = torch.cat((past_value, value), dim=-2) - present = (key, value) if use_cache else None - - # Compute attention - attn_output, attn_weights = self._attn( - query, key, value, attention_mask, head_mask - ) - - # Reshape outputs - attn_output = self._merge_heads( - attn_output, self.num_attention_heads, self.head_size - ) - - attn_output = self.dense(attn_output) - - outputs = (attn_output, present) - if output_attentions: - outputs += (attn_weights,) - - return outputs - - @classmethod - def _split_heads(cls, tensor, num_attention_heads, attn_head_size): - """ - Splits hidden dim into attn_head_size and num_attention_heads - """ - # tensor: [bs, seq_len, hidden_size] - new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) - # -> [bs, seq_len, num_attention_heads, attn_head_size] - tensor = tensor.view(new_shape) - # -> [bs, num_attention_heads, seq_len, attn_head_size] - tensor = tensor.permute(0, 2, 1, 3) - return tensor - - @classmethod - def _merge_heads(cls, tensor, num_attention_heads, attn_head_size): - """ - Merges attn_head_size dim and num_attn_heads dim into hidden dim - """ - # tensor [bs, num_attention_heads, seq_len, attn_head_size] - tensor = tensor.permute(0, 2, 1, 3).contiguous() - # -> [bs, seq_len, num_attention_heads, attn_head_size] - tensor = tensor.view( - tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size - ) - # -> [bs, seq_len, hidden_size] - return tensor - - def _attn(self, query, key, value, attention_mask=None, head_mask=None): - # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size] - # compute causal mask from causal mask buffer - batch_size, num_attention_heads, query_length, attn_head_size = query.size() - key_length = key.size(-2) - - query = query.reshape( - batch_size * num_attention_heads, query_length, attn_head_size - ) - key = key.reshape(batch_size * num_attention_heads, key_length, attn_head_size) - attn_scores = torch.zeros( - 1, - dtype=query.dtype, - device=key.device, - ).expand(batch_size * num_attention_heads, query_length, key_length) - attn_scores = torch.baddbmm( - attn_scores, - query, - key.transpose(1, 2), - beta=1.0, - alpha=self.inv_norm_factor, - ) - - # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] - input_dtype = attn_scores.dtype - if input_dtype in [torch.float16, torch.bfloat16]: - attn_scores = attn_scores.to(torch.float) - attn_scores = torch.where( - attention_mask, torch.finfo(attn_scores.dtype).min, attn_scores - ) - attn_scores = attn_scores.view( - batch_size, num_attention_heads, query_length, key_length - ) - - attn_weights = nn.functional.softmax(attn_scores, dim=-1) - attn_weights = attn_weights.to(value.dtype) - - # Mask heads if we want to - if head_mask is not None: - attn_weights = attn_weights * head_mask - - attn_output = torch.matmul(attn_weights, value) - return attn_output, attn_weights - - -class RotaryEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings, base=10000, device=None): - super().__init__() - self.true_inv_freq = 1.0 / ( - base ** (torch.arange(0, dim, 2).float().to(device) / dim) - ) - self.register_buffer("inv_freq", self.true_inv_freq) - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - self.cos_cached = None - self.sin_cached = None - - @staticmethod - def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - @staticmethod - def _create_cos_sin(inv_freq, max_position_embeddings, dtype, device): - t = torch.arange( - max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype - ) - freqs = torch.einsum("i,j->ij", t, inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - return emb.cos().to(device).to(dtype), emb.sin().to(device).to(dtype) - - def forward(self, q, k, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if ( - seq_len > self.max_seq_len_cached - or self.cos_cached is None - or self.sin_cached is None - ): - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - self.cos_cached, self.sin_cached = self._create_cos_sin( - self.true_inv_freq, self.max_seq_len_cached, q.dtype, q.device - ) - return rotary_forward(q, k, self.cos_cached, self.sin_cached, position_ids) - - -@torch.jit.script -def rotary_forward(q, k, cos, sin, position_ids): - cos = cos[position_ids].unsqueeze(1) - sin = sin[position_ids].unsqueeze(1) - - chunk_size = q.shape[-1] // 2 - q1, q2 = q.split(chunk_size, -1) - q_rotated = torch.cat((-q2, q1), dim=-1) - k1, k2 = k.split(chunk_size, -1) - k_rotated = torch.cat((-k2, k1), dim=-1) - - q_embed = (q * cos) + (q_rotated * sin) - k_embed = (k * cos) + (k_rotated * sin) - return q_embed, k_embed - - -class GPTNeoXMLP(nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - self.act = ( - ACT2FN[config.hidden_act] - if "gelu_fast" not in config.hidden_act - else lambda x: torch.nn.functional.gelu(x, approximate="tanh") - ) - - self.dense_h_to_4h = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True - ) - self.dense_4h_to_h = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True - ) - - def forward(self, hidden_states): - hidden_states = self.dense_h_to_4h(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.dense_4h_to_h(hidden_states) - return hidden_states - - -class GPTNeoXLayer(nn.Module): - def __init__(self, layer_id, prefix: str, config, weights): - super().__init__() - self.use_parallel_residual = config.use_parallel_residual - self.input_layernorm = nn.LayerNorm.load( - prefix=f"{prefix}.layers.{layer_id}.input_layernorm", - weights=weights, - eps=config.layer_norm_eps, - ) - self.post_attention_layernorm = nn.LayerNorm.load( - prefix=f"{prefix}.layers.{layer_id}.post_attention_layernorm", - weights=weights, - eps=config.layer_norm_eps, - ) - self.attention = GPTNeoXAttention( - config, prefix=f"{prefix}.layers.{layer_id}.attention", weights=weights - ) - self.mlp = GPTNeoXMLP( - config, prefix=f"{prefix}.layers.{layer_id}.mlp", weights=weights - ) - - def forward( - self, - hidden_states, - position_ids, - attention_mask=None, - head_mask=None, - use_cache=False, - layer_past=None, - output_attentions=False, - ): - attention_layer_outputs = self.attention( - self.input_layernorm(hidden_states), - attention_mask=attention_mask, - position_ids=position_ids, - layer_past=layer_past, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - attn_output = attention_layer_outputs[ - 0 - ] # output_attn: attn_output, present, (attn_weights) - outputs = attention_layer_outputs[1:] - - if self.use_parallel_residual: - # pseudocode: - # x = x + attn(ln1(x)) + mlp(ln2(x)) - mlp_output = self.mlp(self.post_attention_layernorm(hidden_states)) - hidden_states = mlp_output + attn_output + hidden_states - else: - # pseudocode: - # x = x + attn(ln1(x)) - # x = x + mlp(ln2(x)) - attn_output = attn_output + hidden_states - mlp_output = self.mlp(self.post_attention_layernorm(attn_output)) - hidden_states = mlp_output + attn_output - - if use_cache: - outputs = ( - hidden_states, - ) + outputs # hidden_states, present, (attn_weights) - else: - outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights) - - return outputs - - -class GPTNeoXModel(GPTNeoXPreTrainedModel): - def __init__(self, prefix: str, config, weights): - super().__init__(config) - self.config = config - - self.num_attention_heads = config.num_attention_heads - - self.embed_in = TensorParallelEmbedding( - prefix=f"{prefix}.embed_in", weights=weights - ) - self.layers = nn.ModuleList( - [ - GPTNeoXLayer(layer_id, prefix, config, weights) - for layer_id in range(config.num_hidden_layers) - ] - ) - self.final_layer_norm = nn.LayerNorm.load( - prefix=f"{prefix}.final_layer_norm", - weights=weights, - eps=config.layer_norm_eps, - ) - self.tp_world_size = weights.process_group.size() - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - position_ids=None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - r""" - past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - """ - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time" - ) - elif input_ids is not None: - input_shape = input_ids.size() - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - batch_size, seq_length = input_shape - - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * self.config.num_hidden_layers) - else: - past_length = past_key_values[0][0].size(-2) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_length, seq_length + past_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_in(input_ids) - - hidden_states = inputs_embeds - - # Attention mask. - seq_length_with_past = seq_length - past_key_values_length = 0 - if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[-1] - seq_length_with_past = seq_length_with_past + past_key_values_length - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), device=hidden_states.device - ) - else: - attention_mask = attention_mask.to(hidden_states.device) - - causal_mask = prepare_attn_mask( - attention_mask, - input_shape=(batch_size, seq_length), - past_key_values_length=past_key_values_length, - ) - - assert self.num_attention_heads % self.tp_world_size == 0 - block_size = self.num_attention_heads // self.tp_world_size - causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - - presents = () if use_cache else None - all_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - outputs = layer( - hidden_states, - position_ids=position_ids, - attention_mask=causal_mask, - head_mask=head_mask[i], - layer_past=layer_past, - use_cache=use_cache, - output_attentions=output_attentions, - ) - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - if output_attentions: - all_attentions = all_attentions + (outputs[2 if use_cache else 1],) - - hidden_states = self.final_layer_norm(hidden_states) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [hidden_states, presents, all_hidden_states, all_attentions] - if v is not None - ) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_attentions, - ) - - -class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel): - _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] - - def __init__(self, prefix: str, config, weights): - super().__init__(config) - - if not prefix: - prefix = "gpt_neox" - else: - prefix = f"{prefix}.gpt_neox" - - self.gpt_neox = GPTNeoXModel(prefix, config, weights) - self.embed_out = SpeculativeHead.load( - config, prefix="embed_out", weights=weights - ) - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are - only required when the model is used as a decoder in a Sequence to Sequence model. - - Contains pre-computed hidden-states (key and values in the self-attention blocks that can be used (see - `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in - `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are - ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, GPTNeoXForCausalLM, GPTNeoXConfig - >>> import torch - - >>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") - >>> config = GPTNeoXConfig.from_pretrained("EleutherAI/gpt-neox-20b") - >>> config.is_decoder = True - >>> model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", config=config) - - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") - >>> outputs = model(**inputs) - - >>> prediction_logits = outputs.logits - ```""" - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - outputs = self.gpt_neox( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - lm_logits, speculative_logits = self.embed_out(hidden_states) - - lm_loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # we are doing next-token prediction; shift prediction scores and input ids by one - shift_logits = lm_logits[:, :-1, :].contiguous() - labels = labels[:, 1:].contiguous() - loss_fct = CrossEntropyLoss() - lm_loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1) - ) - - if not return_dict: - output = (lm_logits,) + outputs[1:] - return ((lm_loss,) + output) if lm_loss is not None else output - - return ( - CausalLMOutputWithPast( - loss=lm_loss, - logits=lm_logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ), - speculative_logits, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - **kwargs, - ): - input_shape = input_ids.shape - - # cut decoder_input_ids if past is used - if past_key_values and past_key_values[0] is not None: - input_ids = input_ids[:, -1:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly - if attention_mask is None: - attention_mask = input_ids.new_ones(input_shape) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "attention_mask": attention_mask, - "past_key_values": past_key_values, - "position_ids": position_ids, - } - ) - - return model_inputs - - def _reorder_cache(self, past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple( - past_state.index_select(0, beam_idx) - for past_state in layer_past[:2] - ) - + layer_past[2:], - ) - return reordered_past diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/opt_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/opt_modeling.py deleted file mode 100644 index db73ae84e..000000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/opt_modeling.py +++ /dev/null @@ -1,864 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch OPT model.""" - -import random -from typing import List, Optional, Tuple, Union - -import torch -import torch.utils.checkpoint -from torch import nn - -from transformers.activations import ACT2FN -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from transformers.modeling_utils import PreTrainedModel -from transformers import OPTConfig -from text_generation_server.layers import ( - FastLinear, - TensorParallelColumnLinear, - TensorParallelEmbedding, - TensorParallelRowLinear, - SpeculativeHead, -) - -EPS = 1e-5 - - -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, - dtype: torch.dtype, - device: torch.device, - past_key_values_length: int = 0, -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full( - (tgt_len, tgt_len), - torch.tensor(torch.finfo(dtype).min, device=device), - device=device, - ) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat( - [ - torch.zeros( - tgt_len, past_key_values_length, dtype=dtype, device=device - ), - mask, - ], - dim=-1, - ) - return mask[None, None, :, :].expand( - bsz, 1, tgt_len, tgt_len + past_key_values_length - ) - - -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill( - inverted_mask.to(torch.bool), torch.finfo(dtype).min - ) - - -class OPTLearnedPositionalEmbedding(nn.Module): - """ - This module learns positional embeddings up to a fixed maximum size. - """ - - def __init__(self, prefix: str, weights): - super().__init__() - self.offset = 2 - self.weight = nn.Parameter( - weights.get_tensor( - f"{prefix if prefix else ''}decoder.embed_positions.weight" - ) - ) - - def forward( - self, attention_mask: torch.LongTensor, past_key_values_length: int = 0 - ): - """`input_ids_shape` is expected to be [bsz x seqlen].""" - attention_mask = attention_mask.long() - - # create positions depending on attention_mask - positions = ( - torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask - ).long() - 1 - - # cut positions if `past_key_values_length` is > 0 - positions = positions[:, past_key_values_length:] - - return torch.nn.functional.embedding(positions + self.offset, self.weight) - - -class OPTAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__( - self, - config, - prefix, - weights, - is_decoder: bool = False, - bias: bool = True, - process_group=None, - ): - super().__init__() - hidden_size = config.hidden_size - num_heads = config.num_attention_heads - - self.hidden_size = hidden_size - self.num_heads = num_heads - self.dropout = config.dropout - self.head_dim = hidden_size // num_heads - - if (self.head_dim * num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {num_heads})." - ) - self.scaling = self.head_dim**-0.5 - self.is_decoder = is_decoder - - process_group = weights.process_group - if self.num_heads % weights.process_group.size() != 0: - raise ValueError( - f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " - f"and `num_shards`: {weights.process_group.size()}" - ) - self.num_heads = self.num_heads // process_group.size() - self.hidden_size = self.hidden_size // process_group.size() - - self.q_proj = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.q_proj", weights=weights, bias=bias - ) - self.k_proj = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.k_proj", weights=weights, bias=bias - ) - self.v_proj = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.v_proj", weights=weights, bias=bias - ) - self.out_proj = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.out_proj", weights=weights, bias=bias - ) - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return ( - tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - .transpose(1, 2) - .contiguous() - ) - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = ( - attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - + attention_mask - ) - attn_weights = torch.max( - attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) - ) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 - if attn_weights.dtype == torch.float16: - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ).to(torch.float16) - else: - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view( - bsz, self.num_heads, tgt_len, src_len - ) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view( - bsz, self.num_heads, tgt_len, src_len - ) - attn_weights = attn_weights_reshaped.view( - bsz * self.num_heads, tgt_len, src_len - ) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `hidden_size` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned aross GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.hidden_size) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -class OPTDecoderLayer(nn.Module): - def __init__(self, layer_id: int, prefix: str, config: OPTConfig, weights): - super().__init__() - self.process_group = weights.process_group - self.hidden_size = config.hidden_size - self.self_attn = OPTAttention( - config, - prefix=f"{prefix}.self_attn", - weights=weights, - is_decoder=True, - bias=config.enable_bias, - ) - self.do_layer_norm_before = config.do_layer_norm_before - self.dropout = config.dropout - self.activation_fn = ACT2FN[config.activation_function] - - self.self_attn_layer_norm = nn.LayerNorm.load( - prefix=f"{prefix}.self_attn_layer_norm", weights=weights, eps=EPS - ) - self.fc1 = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.fc1", weights=weights, bias=config.enable_bias - ) - self.fc2 = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.fc2", weights=weights, bias=config.enable_bias - ) - self.final_layer_norm = nn.LayerNorm.load( - prefix=f"{prefix}.final_layer_norm", weights=weights, eps=EPS - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - ) -> Tuple[ - torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] - ]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size - `(encoder_attention_heads,)`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention - if self.do_layer_norm_before: - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - ) - hidden_states = nn.functional.dropout( - hidden_states, p=self.dropout, training=self.training - ) - hidden_states = residual + hidden_states - - # 350m applies layer norm AFTER attention - if not self.do_layer_norm_before: - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Fully Connected - hidden_states_shape = hidden_states.shape - hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) - residual = hidden_states - - # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention - if self.do_layer_norm_before: - hidden_states = self.final_layer_norm(hidden_states) - - hidden_states = self.fc1(hidden_states) - hidden_states = self.activation_fn(hidden_states) - - hidden_states = self.fc2(hidden_states) - hidden_states = nn.functional.dropout( - hidden_states, p=self.dropout, training=self.training - ) - - hidden_states = (residual + hidden_states).view(hidden_states_shape) - - # 350m applies layer norm AFTER attention - if not self.do_layer_norm_before: - hidden_states = self.final_layer_norm(hidden_states) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -class OPTPreTrainedModel(PreTrainedModel): - config_class = OPTConfig - - -class OPTDecoder(OPTPreTrainedModel): - def __init__(self, prefix: str, config: OPTConfig, weights): - super().__init__(config) - self.dropout = config.dropout - self.layerdrop = config.layerdrop - self.padding_idx = config.pad_token_id - self.max_target_positions = config.max_position_embeddings - self.vocab_size = config.vocab_size - - prefix = prefix + "." if prefix else "" - - self.embed_tokens = TensorParallelEmbedding( - prefix=f"{prefix}decoder.embed_tokens", weights=weights - ) - self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights) - - if config.word_embed_proj_dim != config.hidden_size: - self.project_out = FastLinear.load( - config, - prefix=f"{prefix}decoder.project_out", - weights=weights, - bias=False, - ) - else: - self.project_out = None - - if config.word_embed_proj_dim != config.hidden_size: - self.project_in = FastLinear.load( - config, - prefix=f"{prefix}decoder.project_in", - weights=weights, - bias=False, - ) - else: - self.project_in = None - - # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility - # with checkpoints that have been fine-tuned before transformers v4.20.1 - # see https://github.com/facebookresearch/metaseq/pull/164 - if config.do_layer_norm_before and not config._remove_final_layer_norm: - self.final_layer_norm = nn.LayerNorm.load( - prefix=f"{prefix}decoder.final_layer_norm", weights=weights, eps=EPS - ) - else: - self.final_layer_norm = None - - self.layers = nn.ModuleList( - [ - OPTDecoderLayer( - layer_id, - prefix=f"{prefix}decoder.layers.{layer_id}", - config=config, - weights=weights, - ) - for layer_id in range(config.num_hidden_layers) - ] - ) - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask( - self, attention_mask, input_shape, inputs_embeds, past_key_values_length - ): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ).to(inputs_embeds.device) - combined_attention_mask = ( - expanded_attn_mask - if combined_attention_mask is None - else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the - cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those - that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of - all `decoder_input_ids` of shape `(batch_size, sequence_length)`. - - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - batch_size, seq_length = input_shape - past_key_values_length = ( - past_key_values[0][0].shape[2] if past_key_values is not None else 0 - ) - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values_length + seq_length - - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - batch_size, mask_seq_length, device=inputs_embeds.device - ) - causal_attention_mask = self._prepare_decoder_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - pos_embeds = self.embed_positions(attention_mask, past_key_values_length) - - if self.project_in is not None: - inputs_embeds = self.project_in(inputs_embeds) - - hidden_states = inputs_embeds + pos_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - # check if head_mask has a correct number of layers specified if desired - for attn_mask, mask_name in zip([head_mask], ["head_mask"]): - if attn_mask is not None: - if attn_mask.size()[0] != (len(self.layers)): - raise ValueError( - f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" - f" {head_mask.size()[0]}." - ) - - for idx, decoder_layer in enumerate(self.layers): - # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) - if output_hidden_states: - all_hidden_states += (hidden_states,) - - dropout_probability = random.uniform(0, 1) - if self.training and (dropout_probability < self.layerdrop): - continue - - past_key_value = ( - past_key_values[idx] if past_key_values is not None else None - ) - - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if self.final_layer_norm is not None: - hidden_states = self.final_layer_norm(hidden_states) - - if self.project_out is not None: - hidden_states = self.project_out(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None - ) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class OPTModel(OPTPreTrainedModel): - def __init__(self, prefix: str, config: OPTConfig, weights): - super().__init__(config) - self.decoder = OPTDecoder(prefix, config, weights) - # Initialize weights and apply final processing - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) - decoder_outputs = self.decoder( - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - if not return_dict: - return decoder_outputs - - return BaseModelOutputWithPast( - last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=decoder_outputs.past_key_values, - hidden_states=decoder_outputs.hidden_states, - attentions=decoder_outputs.attentions, - ) - - -class OPTForCausalLM(OPTPreTrainedModel): - def __init__(self, prefix, config, weights): - super().__init__(config) - if not prefix and any(s.startswith("model") for s in weights.routing.keys()): - prefix = "model" - - self.model = OPTModel(prefix, config, weights) - - self.lm_head = SpeculativeHead.load( - config, - prefix=f"{prefix + '.' if prefix else ''}decoder.embed_tokens", - weights=weights, - ) - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model.decoder( - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - logits, speculative_logits = self.lm_head(outputs.last_hidden_state) - - loss = None - - return ( - CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ), - speculative_logits, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - **kwargs, - ): - if past_key_values: - input_ids = input_ids[:, -1:] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple( - past_state.index_select(0, beam_idx) for past_state in layer_past - ), - ) - return reordered_past diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/phi_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/phi_modeling.py deleted file mode 100644 index 3f2ed010f..000000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/phi_modeling.py +++ /dev/null @@ -1,336 +0,0 @@ -# imlementation of the PhiModel and PhiForCausalLM classes - -import torch -import torch.distributed - -import math -from torch import nn -from typing import Optional, List, Tuple -from transformers.configuration_utils import PretrainedConfig -from transformers.modeling_outputs import CausalLMOutputWithPast - -from text_generation_server.layers import ( - TensorParallelRowLinear, - TensorParallelColumnLinear, - TensorParallelEmbedding, - SpeculativeHead, - FastLinear, -) - - -# PhiConfig is the configuration class for the PhiModel. -class PhiConfig(PretrainedConfig): - def __init__( - self, - vocab_size=51200, - n_positions=2048, - n_embd=2560, - n_layer=32, - n_inner=None, - n_head=32, - rotary_dim=32, - layer_norm_epsilon=1e-5, - tie_word_embeddings=False, - pad_vocab_size_multiple=64, - pad_token_id=0, - bos_token_id=1, - eos_token_id=2, - no_bias=False, - **kwargs, - ): - self.vocab_size = vocab_size - self.n_positions = n_positions - self.n_embd = n_embd - self.n_layer = n_layer - self.n_inner = n_inner - self.n_head = n_head - self.rotary_dim = rotary_dim - - self.layer_norm_epsilon = layer_norm_epsilon - self.tie_word_embeddings = tie_word_embeddings - self.pad_vocab_size_multiple = pad_vocab_size_multiple - self.pad_token_id = pad_token_id - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.no_bias = no_bias - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - -# RotaryEmbedding is a class that implements the rotary embedding. -class RotaryEmbedding(nn.Module): - def __init__(self, dim, max_seq_len): - super().__init__() - inv_freq = [1.0 / 10000.0 ** (i / dim) for i in range(0, dim, 2)] - inv_freq_len = len(inv_freq) - inv_freq = torch.tensor(inv_freq).view(1, inv_freq_len) - t = torch.arange(0, max_seq_len, dtype=torch.float).view(max_seq_len, 1) - freqs = t.matmul(inv_freq) - self.sin = freqs.sin() - self.cos = freqs.cos() - - def apply_rotary_emb_qkv(self, qkv, seqlen_offset): - b_size, seqlen, three, _, _headdim = qkv.shape - if three != 3: - raise Exception("unexpected shape for qkv") - _, rotary_dim = self.cos.shape - rotary_dim = rotary_dim * 2 - q_rot = qkv[:, :, 0, :, :rotary_dim] - q_pass = qkv[:, :, 0, :, rotary_dim:] - k_rot = qkv[:, :, 1, :, :rotary_dim] - k_pass = qkv[:, :, 1, :, rotary_dim:] - q12 = torch.chunk(q_rot, 2, dim=-1) - k12 = torch.chunk(k_rot, 2, dim=-1) - q1, q2 = q12[0], q12[1] - k1, k2 = k12[0], k12[1] - c = self.cos.narrow(0, seqlen_offset, seqlen).unsqueeze(1) - s = self.sin.narrow(0, seqlen_offset, seqlen).unsqueeze(1) - q_rot = torch.cat( - [ - q1 * c - q2 * s, - q1 * s + q2 * c, - ], - dim=-1, - ) - k_rot = torch.cat( - [ - k1 * c - k2 * s, - k1 * s + k2 * c, - ], - dim=-1, - ) - q = torch.cat([q_rot, q_pass], dim=-1) - k = torch.cat([k_rot, k_pass], dim=-1) - v = qkv[:, :, 2] - return q, k, v - - -# PhiCausalLMHead is the head of the PhiModel. It is a linear layer with a layer norm. -class PhiCausalLMHead(nn.Module): - def __init__(self, config, weights): - super().__init__() - self.ln = nn.LayerNorm.load( - prefix="lm_head.ln", - weights=weights, - eps=config.layer_norm_epsilon, - ) - self.linear = SpeculativeHead.load( - config=config, prefix="lm_head.linear", weights=weights - ) - - def forward(self, hidden_states): - hidden_states = self.ln(hidden_states) - hidden_states = self.linear(hidden_states) - return hidden_states - - -# PhiMHA is a multi-head attention layer. This layer uses an attention mask to prevent tokens from attending to subsequent tokens. -class PhiMHA(nn.Module): - def __init__(self, prefix, config, weights): - super().__init__() - self.Wqkv = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias - ) - self.out_proj = TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.out_proj", - weights=weights, - bias=not config.no_bias, - ) - self.op_size = config.n_embd - self.head_dim = int(config.n_embd / config.n_head) - self.num_heads = config.n_head - self.rotary_emb = RotaryEmbedding( - config.rotary_dim, - config.n_positions, - ) - self.softmax_scale = 1.0 / math.sqrt(self.head_dim) - - def forward( - self, - hidden_states, - past_kv_cache, - attention_mask=None, - ): - b_size, seq_len, _n_embd = hidden_states.shape - qkv = self.Wqkv(hidden_states) - qkv = qkv.view(b_size, seq_len, 3, self.num_heads, self.head_dim) - seqlen_offset = 0 if past_kv_cache is None else past_kv_cache[0].shape[1] - q, k, v = self.rotary_emb.apply_rotary_emb_qkv(qkv, seqlen_offset) - - # if there is a kv_cache, then we need to concatenate - if past_kv_cache is not None: - prev_k, prev_v = past_kv_cache - k = torch.cat([prev_k, k], dim=1) - v = torch.cat([prev_v, v], dim=1) - - past_kv_cache = [k, v] - attn_weights = torch.einsum("bthd,bshd->bhts", q, k * self.softmax_scale) - - if attention_mask is not None: - seqlen_k = k.shape[1] - seqlen_q = q.shape[1] - causal_mask = torch.triu( - torch.full((seqlen_q, seqlen_k), -10000.0, device=attn_weights.device), - 1, - ) - attn_weights = attn_weights + causal_mask.to(dtype=attn_weights.dtype) - - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) - attn_output = attn_weights.matmul(v.transpose(1, 2)).squeeze(0) - attn_output = ( - attn_output.view((b_size, self.num_heads, seq_len, self.head_dim)) - .transpose(1, 2) - .flatten(-2) - ) - return self.out_proj(attn_output), past_kv_cache - - -# PhiMLP is a multi-layer perceptron. It contains two linear layers with a gelu activation function. -class PhiMLP(nn.Module): - def __init__(self, prefix, config, weights): - super().__init__() - - self.n_inner = config.n_inner - self.fc1 = FastLinear.load( - config=config, - prefix=f"{prefix}.fc1", - weights=weights, - bias=False, - ) - self.fc2 = FastLinear.load( - config=config, - prefix=f"{prefix}.fc2", - weights=weights, - bias=False, - ) - self.activation = torch.nn.functional.gelu - - def forward(self, hidden_states): - hidden_states = self.fc1(hidden_states) - hidden_states = self.activation(hidden_states) - hidden_states = self.fc2(hidden_states) - return hidden_states - - -# PhiBlock is a single transformer block. It contains a layer norm, a multi-head attention layer and an multi-layer perceptron. -class PhiBlock(nn.Module): - def __init__(self, layer_id, config, weights): - super().__init__() - self.layer_id = layer_id - self.layer_norm = nn.LayerNorm.load( - prefix=f"{layer_id}.ln", weights=weights, eps=config.layer_norm_epsilon - ) - self.mixer = PhiMHA(prefix=f"{layer_id}.mixer", config=config, weights=weights) - self.mlp = PhiMLP(prefix=f"{layer_id}.mlp", config=config, weights=weights) - - def forward( - self, - hidden_states, - kv_cache, - attention_mask, - ): - residual = hidden_states - hidden_states = self.layer_norm(hidden_states) - attn_outputs, past_kv_cache = self.mixer( - hidden_states, kv_cache, attention_mask - ) - feed_forward_hidden_states = self.mlp(hidden_states) - out = attn_outputs + feed_forward_hidden_states + residual - return out, past_kv_cache - - -# PhiModel implements the embedding layer and the transformer blocks. -class PhiModel(nn.Module): - def __init__(self, prefix: str, config, weights): - super().__init__() - self.tp_rank = weights.process_group.rank() - self.tp_world_size = weights.process_group.size() - self.embed_tokens = TensorParallelEmbedding( - prefix=f"{prefix}.embd.wte", weights=weights - ) - self.blocks = nn.ModuleList( - [ - PhiBlock(f"{prefix}.h.{layer_id}", config, weights) - for layer_id in range(config.n_layer) - ] - ) - - def forward( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, - attention_mask: Optional[torch.ByteTensor] = None, - return_dict: Optional[bool] = None, - use_cache: Optional[bool] = None, - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: - hidden_states = self.embed_tokens(input_ids) - seq_len = hidden_states.shape[1] - mask = None if seq_len <= 1 else attention_mask - - past_key_values = ( - [None] * len(self.blocks) if past_key_values is None else past_key_values - ) - - for index, block in enumerate(self.blocks): - hidden_states, new_key_values = block( - hidden_states, past_key_values[index], mask - ) - past_key_values[index] = new_key_values - - return hidden_states, past_key_values - - -# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object. -class PhiForCausalLM(torch.nn.Module): - def __init__(self, prefix: str, config, weights): - super().__init__() - - if not prefix: - prefix = "transformer" - else: - prefix = f"{prefix}.transformer" - - self.model = PhiModel(prefix, config, weights) - self.lm_head = PhiCausalLMHead(config, weights) - - def forward( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, - attention_mask: Optional[torch.ByteTensor] = None, - return_dict: Optional[bool] = None, - use_cache: Optional[bool] = None, - labels: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: - model_output = self.model( - input_ids, past_key_values, attention_mask, return_dict, use_cache - ) - logits = self.lm_head(model_output[0]) - - loss = None - if labels is not None: - loss = nn.CrossEntropyLoss()( - logits[:, :-1].view(-1, logits.size(-1)), labels[:, 1:].view(-1) - ) - - if not return_dict: - return ( - ((loss,) + (logits,) + model_output[1:]) - if loss is not None - else (logits,) + model_output[1:] - ) - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=model_output[1], - hidden_states=None, - attentions=None, - ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/t5_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/t5_modeling.py deleted file mode 100644 index e6666acd3..000000000 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ /dev/null @@ -1,1227 +0,0 @@ -# coding=utf-8 -# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch T5 model.""" - -import copy -import math -import warnings -from typing import Optional, Tuple, Union - -from loguru import logger - -import torch -import torch.distributed -from torch import nn -from torch.nn import CrossEntropyLoss - -from transformers.activations import ACT2FN -from transformers.modeling_outputs import ( - BaseModelOutput, - BaseModelOutputWithPastAndCrossAttentions, - Seq2SeqLMOutput, -) -from transformers.modeling_utils import PreTrainedModel -from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS -from transformers.utils import ( - is_torch_fx_proxy, -) -from transformers import T5Config -from text_generation_server.layers import ( - TensorParallelColumnLinear, - TensorParallelEmbedding, - TensorParallelRowLinear, - SpeculativeHead, -) - -# copied from https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/t5/modeling_t5.py#L1316 -# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask -__HEAD_MASK_WARNING_MSG = """ -The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, -`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. -If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, -num_heads)`. -""" - - -class PartialTPEmbedding(nn.Module): - def __init__(self, prefix: str, weights): - super().__init__() - weight = weights.get_sharded(f"{prefix}.weight", dim=1) - self.weight = nn.Parameter(weight) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return torch.nn.functional.embedding(input, self.weight) - - -@torch.jit.script -def layer_norm(hidden_states, weight, epsilon): - # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean - # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated - # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for - # half-precision inputs is done in fp32 - - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + epsilon) - - # convert into half-precision if necessary - if weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(weight.dtype) - - return weight * hidden_states - - -class T5LayerNorm(nn.Module): - def __init__(self, prefix, weights, eps=1e-6): - """ - Construct a layernorm module in the T5 style. No bias and no subtraction of mean. - """ - super().__init__() - weight = weights.get_tensor(f"{prefix}.weight") - self.weight = nn.Parameter(weight) - self.variance_epsilon = torch.tensor(eps) - - def forward(self, hidden_states): - return layer_norm(hidden_states, self.weight, self.variance_epsilon) - - -try: - from apex.normalization import FusedRMSNorm - - T5LayerNorm = FusedRMSNorm # noqa - - logger.info( - "Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm" - ) -except ImportError: - # using the normal T5LayerNorm - pass -except Exception: - logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm") - pass - -ALL_LAYERNORM_LAYERS.append(T5LayerNorm) - - -class T5DenseActDense(nn.Module): - def __init__(self, config: T5Config, prefix, weights): - super().__init__() - self.wi = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.wi", weights=weights, bias=False - ) - - ### XXX: T5 models do not handle well both f16 and quantization. - ### Overidding specifically this layer for that reason. - ### https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L316 - ### https://github.com/huggingface/transformers/issues/20287 - _q = config.quantize - _dtype = weights.dtype - weights.dtype = torch.float32 - config.quantize = None - self.wo_cast = (torch.float32, _dtype) - self.wo = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.wo", weights=weights, bias=False - ) - weights.dtype = _dtype - config.quantize = _q - - self.dropout = nn.Dropout(config.dropout_rate) - self.act = ( - ACT2FN[config.dense_act_fn] - if "gelu" not in config.dense_act_fn - else lambda x: torch.nn.functional.gelu(x, approximate="tanh") - ) - - def forward(self, hidden_states): - hidden_states = self.wi(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.dropout(hidden_states) - - hidden_states = hidden_states.to(dtype=self.wo_cast[0]) - hidden_states = self.wo(hidden_states) - # XXX: Recasting is already done within the layer norm. - # Casting back to float16 here modifies results - # hidden_states = hidden_states.to(dtype=self.wo_cast[1]) - return hidden_states - - -class T5DenseGatedActDense(nn.Module): - def __init__(self, config: T5Config, prefix, weights): - super().__init__() - self.wi_0 = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.wi_0", weights=weights, bias=False - ) - self.wi_1 = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.wi_1", weights=weights, bias=False - ) - ### XXX: T5 models do not handle well both f16 and quantization. - ### Overidding specifically this layer for that reason. - ### https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L316 - ### https://github.com/huggingface/transformers/issues/20287 - _q = config.quantize - _dtype = weights.dtype - weights.dtype = torch.float32 - config.quantize = None - self.wo_cast = (torch.float32, _dtype) - self.wo = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.wo", weights=weights, bias=False - ) - weights.dtype = _dtype - config.quantize = _q - - self.dropout = nn.Dropout(config.dropout_rate) - self.act = ( - ACT2FN[config.dense_act_fn] - if "gelu" not in config.dense_act_fn - else lambda x: torch.nn.functional.gelu(x, approximate="tanh") - ) - - def forward(self, hidden_states): - hidden_gelu = self.act(self.wi_0(hidden_states)) - hidden_linear = self.wi_1(hidden_states) - hidden_states = hidden_gelu * hidden_linear - hidden_states = self.dropout(hidden_states) - - hidden_states = hidden_states.to(dtype=self.wo_cast[0]) - hidden_states = self.wo(hidden_states) - # XXX: Recasting is already done within the layer norm. - # Casting back to float16 here modifies results - # hidden_states = hidden_states.to(dtype=self.wo_cast[1]) - return hidden_states - - -class T5LayerFF(nn.Module): - def __init__(self, config: T5Config, prefix, weights): - super().__init__() - if config.is_gated_act: - self.DenseReluDense = T5DenseGatedActDense( - config, prefix=f"{prefix}.DenseReluDense", weights=weights - ) - else: - self.DenseReluDense = T5DenseActDense( - config, prefix=f"{prefix}.DenseReluDense", weights=weights - ) - - self.layer_norm = T5LayerNorm( - prefix=f"{prefix}.layer_norm", - weights=weights, - eps=config.layer_norm_epsilon, - ) - self.dropout = nn.Dropout(config.dropout_rate) - - def forward(self, hidden_states): - forwarded_states = self.layer_norm(hidden_states) - forwarded_states = self.DenseReluDense(forwarded_states) - hidden_states = hidden_states + self.dropout(forwarded_states) - return hidden_states - - -class T5Attention(nn.Module): - def __init__( - self, config: T5Config, prefix, weights, has_relative_attention_bias=False - ): - super().__init__() - self.is_decoder = config.is_decoder - self.has_relative_attention_bias = has_relative_attention_bias - self.relative_attention_num_buckets = config.relative_attention_num_buckets - self.relative_attention_max_distance = config.relative_attention_max_distance - self.d_model = config.d_model - self.key_value_proj_dim = config.d_kv - self.n_heads = config.num_heads - self.dropout = config.dropout_rate - self.inner_dim = self.n_heads * self.key_value_proj_dim - - process_group = weights.process_group - # Mesh TensorFlow initialization to avoid scaling before softmax - assert self.n_heads % process_group.size() == 0 - self.q = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.q", weights=weights, bias=False - ) - self.k = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.k", weights=weights, bias=False - ) - self.v = TensorParallelColumnLinear.load( - config, prefix=f"{prefix}.v", weights=weights, bias=False - ) - self.o = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.o", weights=weights, bias=False - ) - if self.n_heads % weights.process_group.size() != 0: - raise ValueError( - f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} " - f"and `num_shards`: {weights.process_group.size()}" - ) - self.n_heads = self.n_heads // process_group.size() - self.inner_dim = self.inner_dim // process_group.size() - - if self.has_relative_attention_bias: - self.relative_attention_bias = PartialTPEmbedding( - prefix=f"{prefix}.relative_attention_bias", weights=weights - ) - - @staticmethod - def _relative_position_bucket( - relative_position, bidirectional=True, num_buckets=32, max_distance=128 - ): - """ - Adapted from Mesh Tensorflow: - https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 - - Translate relative position to a bucket number for relative attention. The relative position is defined as - memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to - position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for - small absolute relative_position and larger buckets for larger absolute relative_positions. All relative - positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. - This should allow for more graceful generalization to longer sequences than the model has been trained on - - Args: - relative_position: an int32 Tensor - bidirectional: a boolean - whether the attention is bidirectional - num_buckets: an integer - max_distance: an integer - - Returns: - a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) - """ - relative_buckets = 0 - if bidirectional: - num_buckets //= 2 - relative_buckets += (relative_position > 0).to(torch.long) * num_buckets - relative_position = torch.abs(relative_position) - else: - relative_position = -torch.min( - relative_position, torch.zeros_like(relative_position) - ) - # now relative_position is in the range [0, inf) - - # half of the buckets are for exact increments in positions - max_exact = num_buckets // 2 - is_small = relative_position < max_exact - - # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance - relative_position_if_large = max_exact + ( - torch.log(relative_position.float() / max_exact) - / math.log(max_distance / max_exact) - * (num_buckets - max_exact) - ).to(torch.long) - relative_position_if_large = torch.min( - relative_position_if_large, - torch.full_like(relative_position_if_large, num_buckets - 1), - ) - - relative_buckets += torch.where( - is_small, relative_position, relative_position_if_large - ) - return relative_buckets - - def compute_bias(self, query_length, key_length, device=None): - """Compute binned relative position bias""" - if device is None: - device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[ - :, None - ] - memory_position = torch.arange(key_length, dtype=torch.long, device=device)[ - None, : - ] - relative_position = ( - memory_position - context_position - ) # shape (query_length, key_length) - relative_position_bucket = self._relative_position_bucket( - relative_position, # shape (query_length, key_length) - bidirectional=(not self.is_decoder), - num_buckets=self.relative_attention_num_buckets, - max_distance=self.relative_attention_max_distance, - ) - values = self.relative_attention_bias( - relative_position_bucket - ) # shape (query_length, key_length, num_heads) - values = values.permute([2, 0, 1]).unsqueeze( - 0 - ) # shape (1, num_heads, query_length, key_length) - return values - - def forward( - self, - hidden_states, - mask=None, - key_value_states=None, - position_bias=None, - past_key_value=None, - layer_head_mask=None, - query_length=None, - use_cache=False, - output_attentions=False, - ): - """ - Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). - """ - # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) - - batch_size, seq_length = hidden_states.shape[:2] - - real_seq_length = seq_length - - if past_key_value is not None: - assert ( - len(past_key_value) == 2 - ), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states" - real_seq_length += ( - past_key_value[0].shape[2] if query_length is None else query_length - ) - - key_length = ( - real_seq_length if key_value_states is None else key_value_states.shape[1] - ) - - def shape(states): - """projection""" - return states.view( - batch_size, -1, self.n_heads, self.key_value_proj_dim - ).transpose(1, 2) - - def unshape(states): - """reshape""" - return ( - states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - ) - - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - - if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - - # get query states - query_states = shape( - self.q(hidden_states) - ) # (batch_size, n_heads, seq_length, dim_per_head) - - # get key/value states - key_states = project( - hidden_states, - self.k, - key_value_states, - past_key_value[0] if past_key_value is not None else None, - ) - value_states = project( - hidden_states, - self.v, - key_value_states, - past_key_value[1] if past_key_value is not None else None, - ) - - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 - - if position_bias is None: - if not self.has_relative_attention_bias: - position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), - device=scores.device, - dtype=scores.dtype, - ) - else: - position_bias = self.compute_bias( - real_seq_length, key_length, device=scores.device - ) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] - - if mask is not None: - position_bias = ( - position_bias + mask - ) # (batch_size, n_heads, seq_length, key_length) - - position_bias_masked = position_bias - - scores += position_bias_masked - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) # (batch_size, n_heads, seq_length, key_length) - - # Mask heads if we want to - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask - - attn_output = unshape( - torch.matmul(attn_weights, value_states) - ) # (batch_size, seq_length, dim) - attn_output = self.o(attn_output) - - present_key_value_state = ( - (key_states, value_states) if (self.is_decoder and use_cache) else None - ) - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) - - if output_attentions: - outputs = outputs + (attn_weights,) - return outputs - - -class T5LayerSelfAttention(nn.Module): - def __init__(self, config, prefix, weights, has_relative_attention_bias=False): - super().__init__() - self.SelfAttention = T5Attention( - config, - prefix=f"{prefix}.SelfAttention", - weights=weights, - has_relative_attention_bias=has_relative_attention_bias, - ) - self.layer_norm = T5LayerNorm( - prefix=f"{prefix}.layer_norm", - weights=weights, - eps=config.layer_norm_epsilon, - ) - self.dropout = nn.Dropout(config.dropout_rate) - - def forward( - self, - hidden_states, - attention_mask=None, - position_bias=None, - layer_head_mask=None, - past_key_value=None, - use_cache=False, - output_attentions=False, - ): - normed_hidden_states = self.layer_norm(hidden_states) - attention_output = self.SelfAttention( - normed_hidden_states, - mask=attention_mask, - position_bias=position_bias, - layer_head_mask=layer_head_mask, - past_key_value=past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - ) - hidden_states = hidden_states + self.dropout(attention_output[0]) - outputs = (hidden_states,) + attention_output[ - 1: - ] # add attentions if we output them - return outputs - - -class T5LayerCrossAttention(nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - self.EncDecAttention = T5Attention( - config, - prefix=f"{prefix}.EncDecAttention", - weights=weights, - has_relative_attention_bias=False, - ) - self.layer_norm = T5LayerNorm( - prefix=f"{prefix}.layer_norm", - weights=weights, - eps=config.layer_norm_epsilon, - ) - self.dropout = nn.Dropout(config.dropout_rate) - - def forward( - self, - hidden_states, - key_value_states, - attention_mask=None, - position_bias=None, - layer_head_mask=None, - past_key_value=None, - use_cache=False, - query_length=None, - output_attentions=False, - ): - normed_hidden_states = self.layer_norm(hidden_states) - attention_output = self.EncDecAttention( - normed_hidden_states, - mask=attention_mask, - key_value_states=key_value_states, - position_bias=position_bias, - layer_head_mask=layer_head_mask, - past_key_value=past_key_value, - use_cache=use_cache, - query_length=query_length, - output_attentions=output_attentions, - ) - layer_output = hidden_states + self.dropout(attention_output[0]) - outputs = (layer_output,) + attention_output[ - 1: - ] # add attentions if we output them - return outputs - - -class T5Block(nn.Module): - def __init__(self, config, prefix, weights, has_relative_attention_bias: bool): - super().__init__() - self.is_decoder = config.is_decoder - self.layer = nn.ModuleList() - self.layer.append( - T5LayerSelfAttention( - config, - prefix=f"{prefix}.layer.0", - weights=weights, - has_relative_attention_bias=has_relative_attention_bias, - ) - ) - if self.is_decoder: - i = 2 - self.layer.append( - T5LayerCrossAttention( - config, prefix=f"{prefix}.layer.1", weights=weights - ) - ) - else: - i = 1 - - self.layer.append( - T5LayerFF(config, prefix=f"{prefix}.layer.{i}", weights=weights) - ) - - def forward( - self, - hidden_states, - attention_mask=None, - position_bias=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - encoder_decoder_position_bias=None, - layer_head_mask=None, - cross_attn_layer_head_mask=None, - past_key_value=None, - use_cache=False, - output_attentions=False, - return_dict=True, - ): - if past_key_value is not None: - if not self.is_decoder: - logger.warning( - "`past_key_values` is passed to the encoder. Please make sure this is intended." - ) - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 - - if len(past_key_value) != expected_num_past_key_values: - raise ValueError( - f"There should be {expected_num_past_key_values} past states. " - f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" - f"Got {len(past_key_value)} past key / value states" - ) - - self_attn_past_key_value = past_key_value[:2] - cross_attn_past_key_value = past_key_value[2:] - else: - self_attn_past_key_value, cross_attn_past_key_value = None, None - - self_attention_outputs = self.layer[0]( - hidden_states, - attention_mask=attention_mask, - position_bias=position_bias, - layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - ) - hidden_states, present_key_value_state = self_attention_outputs[:2] - attention_outputs = self_attention_outputs[ - 2: - ] # Keep self-attention outputs and relative position weights - - # clamp inf values to enable fp16 training - if hidden_states.dtype == torch.float16: - clamp_value = torch.where( - torch.isinf(hidden_states).any(), - torch.finfo(hidden_states.dtype).max - 1000, - torch.finfo(hidden_states.dtype).max, - ) - hidden_states = torch.clamp( - hidden_states, min=-clamp_value, max=clamp_value - ) - - do_cross_attention = self.is_decoder and encoder_hidden_states is not None - if do_cross_attention: - # the actual query length is unknown for cross attention - # if using past key value states. Need to inject it here - if present_key_value_state is not None: - query_length = present_key_value_state[0].shape[2] - else: - query_length = None - - cross_attention_outputs = self.layer[1]( - hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - position_bias=encoder_decoder_position_bias, - layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - query_length=query_length, - use_cache=use_cache, - output_attentions=output_attentions, - ) - hidden_states = cross_attention_outputs[0] - - # clamp inf values to enable fp16 training - if hidden_states.dtype == torch.float16: - clamp_value = torch.where( - torch.isinf(hidden_states).any(), - torch.finfo(hidden_states.dtype).max - 1000, - torch.finfo(hidden_states.dtype).max, - ) - hidden_states = torch.clamp( - hidden_states, min=-clamp_value, max=clamp_value - ) - - # Combine self attn and cross attn key value states - if present_key_value_state is not None: - present_key_value_state = ( - present_key_value_state + cross_attention_outputs[1] - ) - - # Keep cross-attention outputs and relative position weights - attention_outputs = attention_outputs + cross_attention_outputs[2:] - - # Apply Feed Forward layer - hidden_states = self.layer[-1](hidden_states) - - # clamp inf values to enable fp16 training - if hidden_states.dtype == torch.float16: - clamp_value = torch.where( - torch.isinf(hidden_states).any(), - torch.finfo(hidden_states.dtype).max - 1000, - torch.finfo(hidden_states.dtype).max, - ) - hidden_states = torch.clamp( - hidden_states, min=-clamp_value, max=clamp_value - ) - - outputs = (hidden_states,) - - if use_cache: - outputs = outputs + (present_key_value_state,) + attention_outputs - else: - outputs = outputs + attention_outputs - - return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) - - -class T5PreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = T5Config - - def _shift_right(self, input_ids): - decoder_start_token_id = self.config.decoder_start_token_id - pad_token_id = self.config.pad_token_id - - assert decoder_start_token_id is not None, ( - "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id." - " See T5 docs for more information" - ) - - # shift inputs to the right - if is_torch_fx_proxy(input_ids): - # Item assignment is not supported natively for proxies. - shifted_input_ids = torch.full( - input_ids.shape[:-1] + (1,), decoder_start_token_id - ) - shifted_input_ids = torch.cat( - [shifted_input_ids, input_ids[..., :-1]], dim=-1 - ) - else: - shifted_input_ids = input_ids.new_zeros(input_ids.shape) - shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() - shifted_input_ids[..., 0] = decoder_start_token_id - - assert ( - pad_token_id is not None - ), "self.model.config.pad_token_id has to be defined." - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) - - return shifted_input_ids - - -class T5Stack(T5PreTrainedModel): - def __init__(self, config, prefix, weights, embed_tokens): - super().__init__(config) - - self.is_decoder = config.is_decoder - - self.embed_tokens = embed_tokens - self.block = nn.ModuleList( - [ - T5Block( - config, - prefix=f"{prefix}.block.{layer_id}", - weights=weights, - has_relative_attention_bias=(layer_id == 0), - ) - for layer_id in range(config.num_layers) - ] - ) - self.final_layer_norm = T5LayerNorm( - prefix=f"{prefix}.final_layer_norm", - weights=weights, - eps=config.layer_norm_epsilon, - ) - self.dropout = nn.Dropout(config.dropout_rate) - - def forward( - self, - input_ids=None, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - inputs_embeds=None, - head_mask=None, - cross_attn_head_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - # Model parallel - use_cache = use_cache if use_cache is not None else self.config.use_cache - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - if input_ids is not None and inputs_embeds is not None: - err_msg_prefix = "decoder_" if self.is_decoder else "" - raise ValueError( - f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" - ) - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - err_msg_prefix = "decoder_" if self.is_decoder else "" - raise ValueError( - f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds" - ) - - if inputs_embeds is None: - assert ( - self.embed_tokens is not None - ), "You have to initialize the model with valid token embeddings" - inputs_embeds = self.embed_tokens(input_ids) - - batch_size, seq_length = input_shape - - # required mask seq length can be calculated via length of past - mask_seq_length = ( - past_key_values[0][0].shape[2] + seq_length - if past_key_values is not None - else seq_length - ) - - if use_cache is True: - assert ( - self.is_decoder - ), f"`use_cache` can only be set to `True` if {self} is used as a decoder" - - if attention_mask is None: - attention_mask = torch.ones( - batch_size, mask_seq_length, device=inputs_embeds.device - ) - if ( - self.is_decoder - and encoder_attention_mask is None - and encoder_hidden_states is not None - ): - encoder_seq_length = encoder_hidden_states.shape[1] - encoder_attention_mask = torch.ones( - batch_size, - encoder_seq_length, - device=inputs_embeds.device, - dtype=torch.long, - ) - - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) - - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask( - attention_mask, input_shape - ) - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.is_decoder and encoder_hidden_states is not None: - ( - encoder_batch_size, - encoder_sequence_length, - _, - ) = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones( - encoder_hidden_shape, device=inputs_embeds.device - ) - encoder_extended_attention_mask = self.invert_attention_mask( - encoder_attention_mask - ) - else: - encoder_extended_attention_mask = None - - # Prepare head mask if needed - head_mask = self.get_head_mask(head_mask, self.config.num_layers) - cross_attn_head_mask = self.get_head_mask( - cross_attn_head_mask, self.config.num_layers - ) - present_key_value_states = () if use_cache else None - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - all_cross_attentions = () if (output_attentions and self.is_decoder) else None - position_bias = None - encoder_decoder_position_bias = None - - hidden_states = self.dropout(inputs_embeds) - - for i, (layer_module, past_key_value) in enumerate( - zip(self.block, past_key_values) - ): - layer_head_mask = head_mask[i] - cross_attn_layer_head_mask = cross_attn_head_mask[i] - # Model parallel - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_outputs = layer_module( - hidden_states, - attention_mask=extended_attention_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - layer_head_mask=layer_head_mask, - cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - # layer_outputs is a tuple with: - # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) - if use_cache is False: - layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - - hidden_states, present_key_value_state = layer_outputs[:2] - - # We share the position biases between the layers - the first layer store them - # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), - # (cross-attention position bias), (cross-attention weights) - position_bias = layer_outputs[2] - if self.is_decoder and encoder_hidden_states is not None: - encoder_decoder_position_bias = layer_outputs[ - 4 if output_attentions else 3 - ] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + ( - present_key_value_state, - ) - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[3],) - if self.is_decoder: - all_cross_attentions = all_cross_attentions + (layer_outputs[5],) - - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.dropout(hidden_states) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - present_key_value_states, - all_hidden_states, - all_attentions, - all_cross_attentions, - ] - if v is not None - ) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=present_key_value_states, - hidden_states=all_hidden_states, - attentions=all_attentions, - cross_attentions=all_cross_attentions, - ) - - -class T5ForConditionalGeneration(T5PreTrainedModel): - def __init__(self, config: T5Config, weights): - super().__init__(config) - self.model_dim = config.d_model - - self.shared = TensorParallelEmbedding(prefix="shared", weights=weights) - - encoder_config = copy.deepcopy(config) - encoder_config.is_decoder = False - encoder_config.use_cache = False - encoder_config.is_encoder_decoder = False - self.encoder = T5Stack( - config=encoder_config, - prefix="encoder", - weights=weights, - embed_tokens=self.shared, - ) - - decoder_config = copy.deepcopy(config) - decoder_config.is_decoder = True - decoder_config.is_encoder_decoder = False - decoder_config.num_layers = config.num_decoder_layers - self.decoder = T5Stack( - config=decoder_config, - prefix="decoder", - weights=weights, - embed_tokens=self.shared, - ) - - try: - self.lm_head = SpeculativeHead.load( - config, prefix="lm_head", weights=weights - ) - except RuntimeError: - # Some models like t5-small were saved with shared weights unlike flan - # Since they are declared as the same arch we have no choice but hope - # that this is OK instead of using a proper flag. - self.lm_head = SpeculativeHead.load( - config, prefix="shared", weights=weights - ) - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - decoder_input_ids: Optional[torch.LongTensor] = None, - decoder_attention_mask: Optional[torch.BoolTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - decoder_head_mask: Optional[torch.FloatTensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - decoder_inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask - if head_mask is not None and decoder_head_mask is None: - if self.config.num_layers == self.config.num_decoder_layers: - warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) - decoder_head_mask = head_mask - - # Encode if needed (training, first prediction pass) - if encoder_outputs is None: - # Convert encoder inputs in embeddings if needed - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): - encoder_outputs = BaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, - ) - - hidden_states = encoder_outputs[0] - - if ( - labels is not None - and decoder_input_ids is None - and decoder_inputs_embeds is None - ): - # get decoder inputs from shifting lm labels to the right - decoder_input_ids = self._shift_right(labels) - - # Decode - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - inputs_embeds=decoder_inputs_embeds, - past_key_values=past_key_values, - encoder_hidden_states=hidden_states, - encoder_attention_mask=attention_mask, - head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = decoder_outputs[0] - - if self.config.tie_word_embeddings: - # Rescale output before projecting on vocab - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 - sequence_output = sequence_output * (self.model_dim**-0.5) - - logits, speculative_logits = self.lm_head(sequence_output) - - loss = None - if labels is not None: - loss_fct = CrossEntropyLoss(ignore_index=-100) - # move labels to correct device to enable PP - labels = labels.to(logits.device) - loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) - # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 - - if not return_dict: - output = (logits,) + decoder_outputs[1:] + encoder_outputs - return ((loss,) + output) if loss is not None else output - - return ( - Seq2SeqLMOutput( - loss=loss, - logits=logits, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ), - speculative_logits, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - head_mask=None, - decoder_head_mask=None, - decoder_attention_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - # cut decoder_input_ids if past is used - if past_key_values is not None: - input_ids = input_ids[:, -1:] - - return { - "decoder_input_ids": input_ids, - "past_key_values": past_key_values, - "encoder_outputs": encoder_outputs, - "attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "decoder_attention_mask": decoder_attention_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, - } - - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): - return self._shift_right(labels) - - def _reorder_cache(self, past_key_values, beam_idx): - # if decoder past is not included in output - # speedy decoding is disabled and no need to reorder - if past_key_values is None: - logger.warning( - "You might want to consider setting `use_cache=True` to speed up decoding" - ) - return past_key_values - - reordered_decoder_past = () - for layer_past_states in past_key_values: - # get the correct batch idx from layer past batch dim - # batch dim of `past` is at 2nd position - reordered_layer_past_states = () - for layer_past_state in layer_past_states: - # need to set correct `past` for each of the four key / value states - reordered_layer_past_states = reordered_layer_past_states + ( - layer_past_state.index_select( - 0, beam_idx.to(layer_past_state.device) - ), - ) - - assert reordered_layer_past_states[0].shape == layer_past_states[0].shape - assert len(reordered_layer_past_states) == len(layer_past_states) - - reordered_decoder_past = reordered_decoder_past + ( - reordered_layer_past_states, - ) - return reordered_decoder_past