diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 96dd1ed1..dca3612f 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -67,6 +67,16 @@ if FLASH_ATTENTION: __all__.append(FlashLlama) __all__.append(IDEFICSSharded) +MISTRAL = True +try: + from text_generation_server.models.flash_mistral import FlashMistral +except ImportError as e: + logger.warning(f"Could not import Mistral model: {e}") + MISTRAL = False + +if MISTRAL: + __all__.append(FlashMistral) + def get_model( model_id: str, @@ -237,7 +247,18 @@ def get_model( trust_remote_code=trust_remote_code, ) - elif model_type == "opt": + if model_type == "mistral": + if MISTRAL: + return FlashMistral( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + raise NotImplementedError("Mistral model requires flash attention v2") + + if model_type == "opt": return OPTSharded( model_id, revision, @@ -246,7 +267,7 @@ def get_model( trust_remote_code=trust_remote_code, ) - elif model_type == "t5": + if model_type == "t5": return T5Sharded( model_id, revision, @@ -254,7 +275,7 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - elif model_type == "idefics": + if model_type == "idefics": if FLASH_ATTENTION: return IDEFICSSharded( model_id, diff --git a/server/text_generation_server/models/cache_manager.py b/server/text_generation_server/models/cache_manager.py new file mode 100644 index 00000000..2e6ae086 --- /dev/null +++ b/server/text_generation_server/models/cache_manager.py @@ -0,0 +1,135 @@ +import math +import torch + +from typing import Optional, List, Tuple + +BLOCK_SIZE: int = 16 +# Will be set in warmup +CACHE_MANAGER: Optional["CacheManager"] = None + + +class CacheManager: + def __init__( + self, + num_blocks: int, + num_layers: int, + num_heads: int, + head_size: int, + repeat_slots: bool, + dtype: torch.dtype, + device: torch.device, + ): + self.block_size = BLOCK_SIZE + self.num_blocks = num_blocks + self.repeat_slots = repeat_slots + + element_size = torch.tensor([], dtype=dtype).element_size() + x = self.block_size // element_size + + self.kv_cache = [ + ( + torch.empty( + (num_blocks, num_heads, head_size // x, self.block_size, x), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, num_heads, head_size, self.block_size), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] + self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu") + self.slots = torch.arange( + 0, num_blocks * self.block_size, dtype=torch.int32 + ).view(num_blocks, self.block_size) + + def allocate( + self, + needed_blocks_slots: List[Tuple[int, int]], + blocks: int, + max_blocks: int, + device: torch.device, + ): + # Get free blocks indices by finding values in mask that are not set to 0 + free_block_indices = self.free_block_mask.nonzero() + assert ( + len(free_block_indices) >= blocks + ), f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks" + + # Slice by the number of required blocks + block_indices = free_block_indices[:blocks] + block_indices = block_indices.flatten() + + # Padded block tables + block_tables_tensor = torch.zeros( + (len(needed_blocks_slots), max_blocks), dtype=torch.int32 + ) + + # Allocate paged attention blocks + cumulative_blocks = 0 + slots = [] + block_tables = [] + for i, (needed_blocks, needed_slots) in enumerate(needed_blocks_slots): + # Get allocated blocks for this sequence + allocated_blocks = block_indices[ + cumulative_blocks : cumulative_blocks + needed_blocks + ] + # Get slots for the allocated blocks + all_slots = self.slots[allocated_blocks].flatten() + + # Repeat slots in the case of context sliding window + if needed_slots > len(all_slots) and self.repeat_slots: + repeats = math.ceil(needed_slots / len(all_slots)) + all_slots = all_slots.repeat(repeats) + + allocated_slots = all_slots[:needed_slots] + + slots.append(allocated_slots) + block_tables.append(allocated_blocks.tolist()) + block_tables_tensor[i, :needed_blocks] = allocated_blocks + cumulative_blocks += needed_blocks + + block_tables = block_tables + block_tables_tensor = block_tables_tensor.to(device) + slots = torch.concat(slots).to(device) + + # Allocate the required number of blocks by setting the mask to 0 + self.free_block_mask[block_indices] = 0 + + return block_tables, block_tables_tensor, slots + + def free(self, block_indices: Optional[List[int]]): + if block_indices is not None and block_indices: + # Reset mask + self.free_block_mask[block_indices] = 1 + + +def set_cache_manager( + num_blocks: int, + num_layers: int, + num_heads: int, + head_size: int, + repeat_slots: bool, + dtype: torch.dtype, + device: torch.device, +) -> CacheManager: + global CACHE_MANAGER + if CACHE_MANAGER is not None: + del CACHE_MANAGER + torch.cuda.empty_cache() + + CACHE_MANAGER = CacheManager( + num_blocks, num_layers, num_heads, head_size, repeat_slots, dtype, device + ) + return CACHE_MANAGER + + +def get_cache_manager() -> CacheManager: + global CACHE_MANAGER + if CACHE_MANAGER is None: + raise RuntimeError("cache manager was not initialized") + + return CACHE_MANAGER diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py new file mode 100644 index 00000000..f721d51f --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -0,0 +1,532 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed + +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from typing import Optional, List, Tuple + +# Flash attention imports +import dropout_layer_norm + +# vllm imports +import vllm_cache_ops +import vllm_attention_ops + +from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2 +from text_generation_server.utils.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + PositionRotaryEmbedding, + TensorParallelHead, + get_linear, +) + +if not HAS_FLASH_ATTN_V2: + raise ImportError("Mistral model requires flash attn v2") + + +class MistralConfig(PretrainedConfig): + model_type = "mistral" + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + sliding_window=4096, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class MistralRMSNorm(nn.Module): + def __init__(self, prefix, weights, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + + weight = weights.get_tensor(f"{prefix}.weight") + self.weight = nn.Parameter(weight) + self.variance_epsilon = eps + + def forward(self, hidden_states, residual=None): + if hidden_states.shape[-1] > 8192: + if residual is not None: + hidden_states += residual + residual = hidden_states + + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt( + variance + self.variance_epsilon + ) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states, residual + else: + # faster post attention rms norm + normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + None, + None, + None, + None, + None, + 0.0, + self.variance_epsilon, + 1.0, + 0, + None, + False, + True, # Activate RMSNorm + ) + if res is None: + res = hidden_states + + return normed_hidden_states, res + + +def load_attention(config, prefix, weights): + if config.num_attention_heads != config.num_key_value_heads: + return _load_gqa(config, prefix, weights) + else: + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + + +def _load_gqa(config, prefix: str, weights): + assert config.hidden_size % config.num_attention_heads == 0 + assert config.num_attention_heads % weights.process_group.size() == 0 + + weight = weights.get_multi_weights_col( + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + quantize=config.quantize, + dim=0, + ) + + if config.quantize not in ["gptq", "awq"]: + weight = weight.to(dtype=weights.dtype).to(device=weights.device) + + head_size = config.hidden_size // config.num_attention_heads + num_heads = config.num_attention_heads // weights.process_group.size() + num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + assert list(weight.shape) == [ + (num_heads + 2 * num_key_value_heads) * head_size, + config.hidden_size, + ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + + return TensorParallelColumnLinear( + get_linear(weight, bias=None, quantize=config.quantize) + ) + + +class MistralAttention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights, + ): + super().__init__() + self.max_past = ( + config.sliding_window if config.sliding_window is not None else 0 + ) + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_heads + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, + ) + + self.softmax_scale = self.head_size**-0.5 + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + self.query_key_value = load_attention(config, prefix, weights) + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, + ): + qkv = self.query_key_value(hidden_states) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + + self.rotary_emb(query, cos, sin) + self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) + + if prefill_cache_indices is not None: + kv_to_cache = kv[prefill_cache_indices] + else: + kv_to_cache = kv + + vllm_cache_ops.reshape_and_cache( + kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots + ) + + # output tensor + attn_output = torch.empty_like(query) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attention( + query, + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), + attn_output, + cu_seqlen_prefill, + max_s, + self.softmax_scale, + max_past=self.max_past, + ) + # Decode + else: + # kv_cache[1] => [num_blocks, num_heads, head_size, block_size] + block_size = kv_cache[1].shape[3] + vllm_attention_ops.single_query_cached_kv_attention( + attn_output, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + ) + + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + + +class MistralMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + act = config.hidden_act + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate="tanh" + if act in ["gelu_fast", "gelu_pytorch_tanh"] + else "none", + ) + ) + # Fuse gate and up proj + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + + def forward(self, hidden_states): + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + + +class MistralLayer(nn.Module): + def __init__(self, layer_id, config, weights): + super().__init__() + prefix = f"model.layers.{layer_id}" + self.self_attn = MistralAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.mlp = MistralMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + + self.input_layernorm = MistralRMSNorm( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = MistralRMSNorm( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, + ): + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, + ) + + # faster post attention rms norm + normed_attn_res_output, attn_res = self.post_attention_layernorm( + attn_output, res + ) + + mlp_output = self.mlp(normed_attn_res_output) + + return mlp_output, attn_res + + +class MistralModel(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + self.embed_tokens = TensorParallelEmbedding( + prefix="model.embed_tokens", weights=weights + ) + self.layers = nn.ModuleList( + [ + MistralLayer( + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = MistralRMSNorm( + prefix="model.norm", weights=weights, eps=config.rms_norm_eps + ) + + self.gradient_checkpointing = False + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashMistralForCausalLM(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + + self.model = MistralModel(config, weights) + self.lm_head = TensorParallelHead.load( + config, + prefix="lm_head", + weights=weights, + ) + self.max_past = config.sliding_window + if self.max_past is None: + raise ValueError("max_past cannot be None") + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if prefill_cache_indices is not None: + # Slots also need to be sliced as it has the same size as the whole kv tensor + slots = slots[prefill_cache_indices] + else: + # Clamp in decode mode as paged attention requires clamped values whereas the flash attention + # kernel requires the true values + max_s = min(self.max_past, max_s) + input_lengths = torch.clamp(input_lengths, max=self.max_past) + + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + prefill_cache_indices, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.lm_head(hidden_states) + return logits diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 34c7f633..cefa32d8 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -19,99 +19,17 @@ from text_generation_server.models.types import ( GeneratedText, TopTokens, ) +from text_generation_server.models.cache_manager import ( + get_cache_manager, + set_cache_manager, + BLOCK_SIZE, +) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION tracer = trace.get_tracer(__name__) -BLOCK_SIZE = 16 -# Will be set in warmup -CACHE_MANAGER: Optional["CacheManager"] = None - - -class CacheManager: - def __init__( - self, - num_blocks: int, - num_layers: int, - num_heads: int, - head_size: int, - dtype: torch.dtype, - device: torch.device, - ): - self.block_size = BLOCK_SIZE - self.num_blocks = num_blocks - - element_size = torch.tensor([], dtype=dtype).element_size() - x = self.block_size // element_size - - self.kv_cache = [ - ( - torch.empty( - (num_blocks, num_heads, head_size // x, self.block_size, x), - dtype=dtype, - device=device, - ), - torch.empty( - (num_blocks, num_heads, head_size, self.block_size), - dtype=dtype, - device=device, - ), - ) - for _ in range(num_layers) - ] - self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu") - self.slots = torch.arange( - 0, num_blocks * self.block_size, dtype=torch.int32 - ).view(num_blocks, self.block_size) - - def allocate(self, batch: "FlashCausalLMBatch"): - # Get free blocks indices by finding values in mask that are not set to 0 - free_block_indices = self.free_block_mask.nonzero() - assert ( - len(free_block_indices) >= batch.blocks - ), f"Out of available cache blocks: asked {batch.blocks}, only {len(free_block_indices)} free blocks" - - # Slice by the number of required blocks - block_indices = free_block_indices[: batch.blocks] - block_indices = block_indices.flatten() - - # Padded block tables - block_tables_tensor = torch.zeros( - (len(batch), batch.max_blocks), dtype=torch.int32 - ) - - # Allocate paged attention blocks - cumulative_blocks = 0 - slots = [] - block_tables = [] - for i, (needed_blocks, needed_slots) in enumerate(batch.needed_blocks_slots): - # Get allocated blocks for this sequence - allocated_blocks = block_indices[ - cumulative_blocks : cumulative_blocks + needed_blocks - ] - # Get slots for the allocated blocks - allocated_slots = self.slots[allocated_blocks].flatten()[:needed_slots] - - slots.append(allocated_slots) - block_tables.append(allocated_blocks.tolist()) - block_tables_tensor[i, :needed_blocks] = allocated_blocks - cumulative_blocks += needed_blocks - - batch.needed_blocks_slots = None - batch.block_tables = block_tables - batch.block_tables_tensor = block_tables_tensor.to(batch.input_ids.device) - batch.slots = torch.concat(slots).to(batch.input_ids.device) - - # Allocate the required number of blocks by setting the mask to 0 - self.free_block_mask[block_indices] = 0 - - def free(self, block_indices: Optional[List[int]]): - if block_indices is not None and block_indices: - # Reset mask - self.free_block_mask[block_indices] = 1 - @dataclass class FlashCausalLMBatch(Batch): @@ -481,7 +399,6 @@ class FlashCausalLMBatch(Batch): max_blocks = max(max_blocks, len(request_block_table)) - global CACHE_MANAGER block_indices_to_free = [] # Iterate on all requests for i, r in enumerate(self.requests): @@ -489,7 +406,7 @@ class FlashCausalLMBatch(Batch): if r.id not in requests_idx_mapping.keys(): block_indices_to_free.extend(self.block_tables[i]) # Free blocks - CACHE_MANAGER.free(block_indices_to_free) + get_cache_manager().free(block_indices_to_free) # Needed to avoid dropping blocks when the batches will go out of scope self.block_tables = None @@ -508,7 +425,7 @@ class FlashCausalLMBatch(Batch): # Move to GPU now that we have the whole tensor slot_indices = slot_indices.to(device) - return FlashCausalLMBatch( + return type(self)( batch_id=self.batch_id, requests=requests, requests_idx_mapping=requests_idx_mapping, @@ -665,7 +582,7 @@ class FlashCausalLMBatch(Batch): b.block_tables = None del b - return FlashCausalLMBatch( + return cls( batch_id=batches[0].batch_id, requests=requests, requests_idx_mapping=requests_idx_mapping, @@ -698,9 +615,10 @@ class FlashCausalLMBatch(Batch): def __del__(self): if self.block_tables is not None and self.block_tables: - global CACHE_MANAGER # Free blocks - CACHE_MANAGER.free(list(itertools.chain.from_iterable(self.block_tables))) + get_cache_manager().free( + list(itertools.chain.from_iterable(self.block_tables)) + ) def __len__(self): return len(self.requests) @@ -718,10 +636,12 @@ class FlashCausalLM(Model): device: torch.device, rank: int = 0, world_size: int = 1, + repeat_slots: bool = False, ): self.num_layers = num_layers self.num_kv_heads = num_kv_heads self.head_size = head_size + self.repeat_slots = repeat_slots super(FlashCausalLM, self).__init__( model=model, @@ -738,15 +658,14 @@ class FlashCausalLM(Model): return FlashCausalLMBatch def warmup(self, batch: FlashCausalLMBatch): - global CACHE_MANAGER - torch.cuda.empty_cache() try: - CACHE_MANAGER = CacheManager( + cache_manager = set_cache_manager( batch.blocks, self.num_layers, self.num_kv_heads, self.head_size, + self.repeat_slots, self.dtype, self.device, ) @@ -775,48 +694,36 @@ class FlashCausalLM(Model): num_blocks = ( int(free_memory // total_cache_size) # Add batch.blocks as we allocated it above, so it is included in the peak memory. - + CACHE_MANAGER.num_blocks + + cache_manager.num_blocks ) - del CACHE_MANAGER del batch - torch.cuda.empty_cache() + del cache_manager - CACHE_MANAGER = CacheManager( + set_cache_manager( num_blocks, self.num_layers, self.num_kv_heads, self.head_size, + self.repeat_slots, self.dtype, self.device, ) return int(num_blocks * BLOCK_SIZE) - def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - cu_seqlen_prefill: Optional[torch.Tensor], - block_tables: torch.Tensor, - slots: torch.Tensor, - input_lengths: torch.Tensor, - max_s: int, - lm_head_indices: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - global CACHE_MANAGER - + def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]: # Model Forward return self.model.forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=CACHE_MANAGER.kv_cache, - block_tables=block_tables, - slots=slots, - input_lengths=input_lengths, - max_s=max_s, - lm_head_indices=lm_head_indices, + input_ids=batch.input_ids, + position_ids=batch.position_ids, + cu_seqlen_prefill=batch.cu_seqlen_prefill, + kv_cache=get_cache_manager().kv_cache, + block_tables=batch.block_tables_tensor, + slots=batch.slots[batch.slot_indices], + input_lengths=batch.input_lengths_tensor, + max_s=batch.max_seqlen, + lm_head_indices=batch.prefill_head_indices, ) @tracer.start_as_current_span("generate_token") @@ -828,19 +735,19 @@ class FlashCausalLM(Model): if batch.needed_blocks_slots: # Allocate blocks to this batch - CACHE_MANAGER.allocate(batch) + block_tables, block_tables_tensor, slots = get_cache_manager().allocate( + batch.needed_blocks_slots, + batch.blocks, + batch.max_blocks, + batch.input_ids.device, + ) + batch.needed_blocks_slots = None + batch.block_tables = block_tables + batch.block_tables_tensor = block_tables_tensor + batch.slots = slots try: - out = self.forward( - batch.input_ids, - batch.position_ids, - batch.cu_seqlen_prefill, - batch.block_tables_tensor, - batch.slots[batch.slot_indices], - batch.input_lengths_tensor, - batch.max_seqlen, - batch.prefill_head_indices, - ) + out = self.forward(batch) except Exception as e: del batch raise e diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py new file mode 100644 index 00000000..266ae8dd --- /dev/null +++ b/server/text_generation_server/models/flash_mistral.py @@ -0,0 +1,357 @@ +import math +import torch +import torch.distributed + +import numpy as np + +from dataclasses import dataclass +from opentelemetry import trace +from transformers import PreTrainedTokenizerBase +from transformers.models.llama import LlamaTokenizerFast +from typing import Optional, Tuple, Type + +from text_generation_server.pb import generate_pb2 +from text_generation_server.models import FlashCausalLM +from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE +from text_generation_server.models.cache_manager import ( + get_cache_manager, + set_cache_manager, +) +from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( + FlashMistralForCausalLM, + MistralConfig, +) +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, + HeterogeneousNextTokenChooser, + StoppingCriteria, +) + +tracer = trace.get_tracer(__name__) + +# Will be set in init +SLIDING_WINDOW: Optional[int] = None +SLIDING_WINDOW_BLOCKS: Optional[int] = None + + +# Adds windowing logic to FlashCausalLMBatch +@dataclass +class FlashMistralBatch(FlashCausalLMBatch): + # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers + # as we only keep SLIDING_WINDOW values instead of the whole tensor + prefill_cache_indices: Optional[torch.Tensor] = None + + @classmethod + def from_pb( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, + device: torch.device, + ) -> "FlashCausalLMBatch": + global SLIDING_WINDOW + global SLIDING_WINDOW_BLOCKS + + batch_inputs = [] + max_truncation = 0 + for r in pb.requests: + batch_inputs.append(r.inputs) + max_truncation = max(max_truncation, r.truncate) + + batch_tokenized_inputs = tokenizer( + batch_inputs, truncation=True, max_length=max_truncation + )["input_ids"] + + position_ids = [] + cu_seqlen_prefill = [0] + needed_blocks_slots = [] + start_slots = [] + slot_indices = [] + prefill_cache_indices = [] + + input_lengths = [] + prefix_offsets = [] + read_offsets = [] + all_input_ids = [] + requests_idx_mapping = {} + + all_prefill_logprobs = True + no_prefill_logprobs = True + prefill_head_indices = [] + prefill_next_token_indices = [] + prefill_cu_outlens = [0] + + next_token_chooser_parameters = [] + stopping_criterias = [] + top_n_tokens = [] + + # Cumulative length + cumulative_length = 0 + cumulative_max_length = 0 + prefill_out_cumulative_length = 0 + + blocks = 0 + max_seqlen = 0 + max_length = 0 + max_blocks = 0 + + # Parse batch + for i, (r, tokenized_input) in enumerate( + zip(pb.requests, batch_tokenized_inputs) + ): + # request id -> idx in list mapping + requests_idx_mapping[r.id] = i + + tokenized_input = tokenized_input[-r.truncate :] + + input_length = len(tokenized_input) + input_lengths.append(input_length) + + prefix_offsets.append(input_length - 5) + read_offsets.append(input_length) + + all_input_ids.append(tokenized_input) + + # Position ids + request_position_ids = torch.arange(0, input_length, dtype=torch.int32) + position_ids.append(request_position_ids) + + # Add cumulative lengths of all previous inputs + cu_seqlen_prefill.append(cumulative_length + input_length) + + next_token_chooser_parameters.append(r.parameters) + + stopping_criteria = StoppingCriteria.from_pb( + r.stopping_parameters, tokenizer + ) + max_new_tokens = stopping_criteria.max_new_tokens + stopping_criterias.append(stopping_criteria) + top_n_tokens.append(r.top_n_tokens) + + # Paged attention + # Remove one as the first token des not have a past + total_tokens = input_length + max_new_tokens - 1 + + # Needed blocks can not go over SLIDING_WINDOW_BLOCKS + needed_blocks = min( + math.ceil(total_tokens / BLOCK_SIZE), SLIDING_WINDOW_BLOCKS + ) + blocks += needed_blocks + + needed_blocks_slots.append((needed_blocks, total_tokens)) + start_slots.append(cumulative_max_length) + + request_slot_indices = torch.arange( + cumulative_max_length, + cumulative_max_length + input_length, + dtype=torch.int64, + ) + slot_indices.append(request_slot_indices) + + # Create tensor to slice into the kv tensor in prefill + request_prefill_cache_indices = torch.arange( + cumulative_length + max(0, input_length - SLIDING_WINDOW), + cumulative_length + input_length, + dtype=torch.int64, + ) + prefill_cache_indices.append(request_prefill_cache_indices) + + all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs + no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs + + if r.prefill_logprobs: + prefill_head_indices.append(request_position_ids + cumulative_length) + prefill_next_token_indices.append( + prefill_out_cumulative_length + input_length - 1 + ) + prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) + prefill_out_cumulative_length += input_length + else: + prefill_head_indices.append( + torch.tensor( + [cumulative_length + input_length - 1], dtype=torch.int32 + ) + ) + prefill_next_token_indices.append(prefill_out_cumulative_length) + prefill_cu_outlens.append(prefill_out_cumulative_length + 1) + prefill_out_cumulative_length += 1 + + # Update + cumulative_length += input_length + cumulative_max_length += total_tokens + max_seqlen = max(max_seqlen, input_length) + max_blocks = max(max_blocks, needed_blocks) + max_length = max(max_length, input_length + max_new_tokens) + + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( + next_token_chooser_parameters, dtype, device + ) + start_slots = torch.tensor(start_slots, dtype=torch.int64) + + # Padded all_input_ids_tensor + all_input_ids_tensor = np.zeros( + (len(all_input_ids), max_length), dtype=np.int64 + ) + for i, input_ids in enumerate(all_input_ids): + all_input_ids_tensor[i, : len(input_ids)] = input_ids + + # Create tensors on device + all_input_ids_tensor = torch.tensor( + all_input_ids_tensor, dtype=torch.int64, device=device + ) + + if len(pb.requests) > 1: + input_ids = np.concatenate(all_input_ids, dtype=np.int64) + position_ids = torch.cat(position_ids) + slot_indices = torch.cat(slot_indices) + prefill_cache_indices = torch.cat(prefill_cache_indices) + else: + input_ids = all_input_ids[0] + position_ids = position_ids[0] + slot_indices = slot_indices[0] + prefill_cache_indices = prefill_cache_indices[0] + + cu_seqlen_prefill = torch.tensor( + cu_seqlen_prefill, device=device, dtype=torch.int32 + ) + + position_ids = position_ids.to(device) + slot_indices = slot_indices.to(device) + prefill_cache_indices = prefill_cache_indices.to(device) + input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + input_lengths_tensor = torch.tensor( + input_lengths, dtype=torch.int32, device=device + ) + + if all_prefill_logprobs: + prefill_head_indices = None + prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 + elif no_prefill_logprobs: + prefill_head_indices = cu_seqlen_prefill[1:] - 1 + prefill_next_token_indices = None + else: + prefill_head_indices = torch.tensor( + torch.cat(prefill_head_indices), dtype=torch.int64, device=device + ) + prefill_next_token_indices = torch.tensor( + prefill_next_token_indices, dtype=torch.int64, device=device + ) + top_n_tokens_tensor = torch.tensor( + top_n_tokens, device=device, dtype=torch.int64 + ) + + return cls( + batch_id=pb.id, + requests=pb.requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + start_slots=start_slots, + slot_indices=slot_indices, + needed_blocks_slots=needed_blocks_slots, + block_tables=None, + block_tables_tensor=None, + slots=None, + max_seqlen=max_seqlen, + prefill_head_indices=prefill_head_indices, + prefill_next_token_indices=prefill_next_token_indices, + prefill_cu_outlens=prefill_cu_outlens, + input_lengths=input_lengths, + input_lengths_tensor=input_lengths_tensor, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, + all_input_ids=all_input_ids, + all_input_ids_tensor=all_input_ids_tensor, + next_token_chooser=next_token_chooser, + stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + blocks=blocks, + max_blocks=max_blocks, + prefill_cache_indices=prefill_cache_indices, + ) + + +class FlashMistral(FlashCausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + global SLIDING_WINDOW + global SLIDING_WINDOW_BLOCKS + + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + raise NotImplementedError("FlashLlama is only available on GPU") + + tokenizer = LlamaTokenizerFast.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + config = MistralConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + config.quantize = quantize + + # Set context windows + SLIDING_WINDOW = config.sliding_window + SLIDING_WINDOW_BLOCKS = math.ceil(config.sliding_window / BLOCK_SIZE) + + torch.distributed.barrier(group=self.process_group) + + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights(filenames, device, dtype, process_group=self.process_group) + if config.quantize in ["gptq", "awq"]: + weights._set_gptq_params(model_id) + + model = FlashMistralForCausalLM(config, weights) + + torch.distributed.barrier(group=self.process_group) + super(FlashMistral, self).__init__( + model=model, + tokenizer=tokenizer, + num_layers=len(model.model.layers), + num_kv_heads=model.model.num_key_value_heads, + head_size=model.model.head_size, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + repeat_slots=True, + ) + + @property + def batch_type(self) -> Type[FlashMistralBatch]: + return FlashMistralBatch + + def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]: + # Model Forward + logits = self.model.forward( + input_ids=batch.input_ids, + position_ids=batch.position_ids, + cu_seqlen_prefill=batch.cu_seqlen_prefill, + kv_cache=get_cache_manager().kv_cache, + block_tables=batch.block_tables_tensor, + slots=batch.slots[batch.slot_indices], + input_lengths=batch.input_lengths_tensor, + max_s=batch.max_seqlen, + prefill_cache_indices=batch.prefill_cache_indices, + lm_head_indices=batch.prefill_head_indices, + ) + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + return logits diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index c472d1fc..aa02b950 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -57,6 +57,7 @@ def attention( cu_seqlens, max_s, softmax_scale, + max_past=0, ): if HAS_FLASH_ATTN_V2: return flash_attn_2_cuda.varlen_fwd( @@ -72,11 +73,15 @@ def attention( softmax_scale, False, True, + max_past, False, None, ) if HAS_FLASH_ATTN: + if max_past != 0: + raise NotImplementedError("max_past is only available with flash attn v2") + # Flash attention v1 requires q, k and v to have the same number of heads if k.shape[1] != q.shape[1]: # MQA expand