diff --git a/server/pyproject.toml b/server/pyproject.toml index b5e9c3fd..7595752a 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -26,7 +26,7 @@ hf-transfer = "^0.1.2" sentencepiece = "^0.1.97" tokenizers = "^0.15.0" huggingface-hub = "^0.19.3" -transformers = "^4.38" +transformers = { git = "https://github.com/huggingface/transformers", rev = "517a3e670d8fc11374895e870dd0dd041467c7fe" } einops = "^0.6.1" texttable = { version = "^1.6.7", optional = true } datasets = { version = "^2.14.0", optional = true } diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 8df7e075..468615a7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -53,8 +53,7 @@ class CohereLayerNorm(nn.Module): self.eps = eps def forward(self, hidden_states): - # if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM: - if True: + if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM: hidden_states = hidden_states.reshape( -1, self.weight.shape[0], self.weight.shape[1] ) @@ -147,6 +146,93 @@ def _load_gqa(config, prefix: str, weights): ) +class CohereRotaryEmbedding(nn.Module): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base + ** ( + torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) + / self.dim + ) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, device_type, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = ( + self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + ) + position_ids_expanded = position_ids[None, None, :].float() + + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = ( + device_type + if isinstance(device_type, str) and device_type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): + freqs = ( + inv_freq_expanded.float().to(position_ids.device) + @ position_ids_expanded.float() + ).transpose(1, 2) + emb = torch.repeat_interleave(freqs, 2, dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos[0], sin[0] + + +def rotate_half(x): + # Split and rotate + x1 = x[..., ::2] + x2 = x[..., 1::2] + rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) + return rot_x + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + dtype = q.dtype + q = q.float() + k = k.float() + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype) + + class FlashCohereAttention(torch.nn.Module): def __init__( self, @@ -232,14 +318,16 @@ class FlashCohereAttention(torch.nn.Module): if self.use_qk_norm: query = query.reshape(-1, self.head_size) key = key.reshape(-1, self.head_size) - query = self.q_norm(query) - key = self.k_norm(key) + query = self.q_norm(query.contiguous()) + key = self.k_norm(key.contiguous()) query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_key_value_heads, self.head_size) value = value.view(-1, self.num_key_value_heads, self.head_size) - self.rotary_emb(query, key, cos, sin) + query, key = apply_rotary_pos_emb(query, key, cos, sin) + + # self.rotary_emb(query, key, true_cos.reshape(*cos.shape), true_sin.reshape(*sin.shape)) paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) @@ -399,6 +487,11 @@ class FlashCohereModel(torch.nn.Module): self.head_size = self.layers[0].self_attn.head_size self.num_heads = self.layers[0].self_attn.num_heads self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + self.rotary_true = CohereRotaryEmbedding( + self.head_size, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) def forward( self, @@ -415,9 +508,10 @@ class FlashCohereModel(torch.nn.Module): # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + # cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + # position_ids, max_s, hidden_states.dtype + # ) + cos, sin = self.rotary_true(hidden_states.device.type, position_ids) residual = None for i, layer in enumerate(self.layers):