diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index 2bdd00de6..bc124f319 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -5,6 +5,7 @@ Text Generation Inference enables serving optimized models on specific hardware ## Supported Models +- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2) - [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal) - [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal) - [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) diff --git a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2.json b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2.json new file mode 100644 index 000000000..03f903672 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 100000, + "logprob": null, + "text": "<|begin▁of▁sentence|>" + }, + { + "id": 3533, + "logprob": -9.625, + "text": "Test" + }, + { + "id": 3102, + "logprob": -11.1875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 185, + "logprob": -1.5546875, + "special": false, + "text": "\n" + }, + { + "id": 549, + "logprob": -2.84375, + "special": false, + "text": "The" + }, + { + "id": 1727, + "logprob": -2.34375, + "special": false, + "text": " test" + }, + { + "id": 3102, + "logprob": -0.8359375, + "special": false, + "text": " request" + }, + { + "id": 317, + "logprob": -1.0859375, + "special": false, + "text": " is" + }, + { + "id": 254, + "logprob": -1.5390625, + "special": false, + "text": " the" + }, + { + "id": 1022, + "logprob": -1.1875, + "special": false, + "text": " first" + }, + { + "id": 3458, + "logprob": -0.35546875, + "special": false, + "text": " step" + }, + { + "id": 279, + "logprob": -0.8828125, + "special": false, + "text": " in" + }, + { + "id": 254, + "logprob": -0.71484375, + "special": false, + "text": " the" + } + ], + "top_tokens": null + }, + "generated_text": "\nThe test request is the first step in the" +} diff --git a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_all_params.json b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_all_params.json new file mode 100644 index 000000000..e84135cf5 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_all_params.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 100000, + "logprob": null, + "text": "<|begin▁of▁sentence|>" + }, + { + "id": 3533, + "logprob": -9.625, + "text": "Test" + }, + { + "id": 3102, + "logprob": -11.1875, + "text": " request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 2143, + "logprob": -1.828125, + "special": false, + "text": " sent" + }, + { + "id": 10081, + "logprob": -0.36914062, + "special": false, + "text": " successfully" + }, + { + "id": 13, + "logprob": 0.0, + "special": false, + "text": "." + }, + { + "id": 185, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 1380, + "logprob": -0.38671875, + "special": false, + "text": "We" + }, + { + "id": 543, + "logprob": -0.12695312, + "special": false, + "text": " will" + }, + { + "id": 752, + "logprob": -0.20117188, + "special": false, + "text": " get" + }, + { + "id": 279, + "logprob": 0.0, + "special": false, + "text": " in" + }, + { + "id": 5402, + "logprob": 0.0, + "special": false, + "text": " touch" + }, + { + "id": 366, + "logprob": 0.0, + "special": false, + "text": " with" + } + ], + "top_tokens": null + }, + "generated_text": "Test request sent successfully.\nWe will get in touch with" +} diff --git a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_load.json b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_load.json new file mode 100644 index 000000000..a4b9784a3 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_load.json @@ -0,0 +1,358 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 100000, + "logprob": null, + "text": "<|begin▁of▁sentence|>" + }, + { + "id": 3533, + "logprob": -9.625, + "text": "Test" + }, + { + "id": 3102, + "logprob": -11.1875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 185, + "logprob": -1.5546875, + "special": false, + "text": "\n" + }, + { + "id": 549, + "logprob": -2.8125, + "special": false, + "text": "The" + }, + { + "id": 1727, + "logprob": -2.375, + "special": false, + "text": " test" + }, + { + "id": 3102, + "logprob": -0.890625, + "special": false, + "text": " request" + }, + { + "id": 317, + "logprob": -1.1484375, + "special": false, + "text": " is" + }, + { + "id": 245, + "logprob": -1.5390625, + "special": false, + "text": " a" + }, + { + "id": 3102, + "logprob": -2.609375, + "special": false, + "text": " request" + }, + { + "id": 327, + "logprob": -0.75, + "special": false, + "text": " for" + }, + { + "id": 245, + "logprob": -1.1171875, + "special": false, + "text": " a" + }, + { + "id": 1727, + "logprob": -0.90625, + "special": false, + "text": " test" + } + ], + "top_tokens": null + }, + "generated_text": "\nThe test request is a request for a test" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 100000, + "logprob": null, + "text": "<|begin▁of▁sentence|>" + }, + { + "id": 3533, + "logprob": -9.625, + "text": "Test" + }, + { + "id": 3102, + "logprob": -11.25, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 185, + "logprob": -1.5546875, + "special": false, + "text": "\n" + }, + { + "id": 549, + "logprob": -2.8125, + "special": false, + "text": "The" + }, + { + "id": 1727, + "logprob": -2.375, + "special": false, + "text": " test" + }, + { + "id": 3102, + "logprob": -0.890625, + "special": false, + "text": " request" + }, + { + "id": 317, + "logprob": -1.1484375, + "special": false, + "text": " is" + }, + { + "id": 245, + "logprob": -1.5390625, + "special": false, + "text": " a" + }, + { + "id": 3102, + "logprob": -2.609375, + "special": false, + "text": " request" + }, + { + "id": 327, + "logprob": -0.75, + "special": false, + "text": " for" + }, + { + "id": 245, + "logprob": -1.1171875, + "special": false, + "text": " a" + }, + { + "id": 1727, + "logprob": -0.90625, + "special": false, + "text": " test" + } + ], + "top_tokens": null + }, + "generated_text": "\nThe test request is a request for a test" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 100000, + "logprob": null, + "text": "<|begin▁of▁sentence|>" + }, + { + "id": 3533, + "logprob": -9.625, + "text": "Test" + }, + { + "id": 3102, + "logprob": -11.25, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 185, + "logprob": -1.5546875, + "special": false, + "text": "\n" + }, + { + "id": 549, + "logprob": -2.8125, + "special": false, + "text": "The" + }, + { + "id": 1727, + "logprob": -2.375, + "special": false, + "text": " test" + }, + { + "id": 3102, + "logprob": -0.890625, + "special": false, + "text": " request" + }, + { + "id": 317, + "logprob": -1.1484375, + "special": false, + "text": " is" + }, + { + "id": 245, + "logprob": -1.5390625, + "special": false, + "text": " a" + }, + { + "id": 3102, + "logprob": -2.609375, + "special": false, + "text": " request" + }, + { + "id": 327, + "logprob": -0.75, + "special": false, + "text": " for" + }, + { + "id": 245, + "logprob": -1.1171875, + "special": false, + "text": " a" + }, + { + "id": 1727, + "logprob": -0.90625, + "special": false, + "text": " test" + } + ], + "top_tokens": null + }, + "generated_text": "\nThe test request is a request for a test" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 100000, + "logprob": null, + "text": "<|begin▁of▁sentence|>" + }, + { + "id": 3533, + "logprob": -9.625, + "text": "Test" + }, + { + "id": 3102, + "logprob": -11.25, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 185, + "logprob": -1.5546875, + "special": false, + "text": "\n" + }, + { + "id": 549, + "logprob": -2.8125, + "special": false, + "text": "The" + }, + { + "id": 1727, + "logprob": -2.375, + "special": false, + "text": " test" + }, + { + "id": 3102, + "logprob": -0.890625, + "special": false, + "text": " request" + }, + { + "id": 317, + "logprob": -1.1484375, + "special": false, + "text": " is" + }, + { + "id": 245, + "logprob": -1.5390625, + "special": false, + "text": " a" + }, + { + "id": 3102, + "logprob": -2.609375, + "special": false, + "text": " request" + }, + { + "id": 327, + "logprob": -0.75, + "special": false, + "text": " for" + }, + { + "id": 245, + "logprob": -1.1171875, + "special": false, + "text": " a" + }, + { + "id": 1727, + "logprob": -0.90625, + "special": false, + "text": " test" + } + ], + "top_tokens": null + }, + "generated_text": "\nThe test request is a request for a test" + } +] diff --git a/integration-tests/models/test_flash_deepseek_v2.py b/integration-tests/models/test_flash_deepseek_v2.py new file mode 100644 index 000000000..010e08c90 --- /dev/null +++ b/integration-tests/models/test_flash_deepseek_v2.py @@ -0,0 +1,63 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_deepseek_v2_handle(launcher): + with launcher("deepseek-ai/DeepSeek-V2-Lite", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_deepseek_v2(flash_deepseek_v2_handle): + await flash_deepseek_v2_handle.health(300) + return flash_deepseek_v2_handle.client + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_deepseek_v2(flash_deepseek_v2, response_snapshot): + response = await flash_deepseek_v2.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_deepseek_v2_all_params(flash_deepseek_v2, response_snapshot): + response = await flash_deepseek_v2.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_deepseek_v2_load( + flash_deepseek_v2, generate_load, response_snapshot +): + responses = await generate_load( + flash_deepseek_v2, "Test request", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/server/Makefile-vllm b/server/Makefile-vllm index 2f2b5ef68..f1f805290 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,14 +1,14 @@ -commit_cuda := b5dfc61db88a81069e45b44f7cc99bd9e62a60fa +commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921 build-vllm-cuda: if [ ! -d 'vllm' ]; then \ pip install -U ninja packaging --no-cache-dir && \ git clone https://github.com/Narsil/vllm.git vllm; \ fi - cd vllm && git fetch && git checkout $(commit_cuda) && python setup.py build + cd vllm && git fetch origin && git checkout $(commit_cuda) && python setup.py build install-vllm-cuda: build-vllm-cuda - cd vllm && git fetch && git checkout $(commit_cuda) && pip install -e . + cd vllm && git fetch origin && git checkout $(commit_cuda) && pip install -e . build-vllm-rocm: if [ ! -d 'vllm' ]; then \ diff --git a/server/text_generation_server/layers/exl2.py b/server/text_generation_server/layers/exl2.py index 1a5acfa85..a6e07f453 100644 --- a/server/text_generation_server/layers/exl2.py +++ b/server/text_generation_server/layers/exl2.py @@ -34,6 +34,30 @@ class Exl2Weight(Weight): class Exl2WeightsLoader(WeightsLoader): """Loader for exl2-quantized weights.""" + def get_weights(self, weights: "Weights", prefix: str): + """ + Get weights at the given prefix and apply without tensor paralllism. + """ + try: + q_weight = weights.get_tensor(f"{prefix}.q_weight") + except RuntimeError: + raise RuntimeError( + "Cannot load `exl2`-quantized weight, make sure the model is already quantized." + ) + + q_scale = weights.get_tensor(f"{prefix}.q_scale") + q_invperm = weights.get_tensor(f"{prefix}.q_invperm") + q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max") + q_groups = weights.get_tensor(f"{prefix}.q_groups") + + return Exl2Weight( + q_weight=q_weight, + q_scale=q_scale, + q_invperm=q_invperm, + q_scale_max=q_scale_max, + q_groups=q_groups, + ) + def get_weights_col_packed( self, weights: Weights, @@ -43,46 +67,12 @@ class Exl2WeightsLoader(WeightsLoader): raise RuntimeError("Column-packed weights are not supported for exl") def get_weights_col(self, weights: Weights, prefix: str): - try: - q_weight = weights.get_tensor(f"{prefix}.q_weight") - except RuntimeError: - raise RuntimeError( - "Cannot load `exl2`-quantized weight, make sure the model is already quantized." - ) - - q_scale = weights.get_tensor(f"{prefix}.q_scale") - q_invperm = weights.get_tensor(f"{prefix}.q_invperm") - q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max") - q_groups = weights.get_tensor(f"{prefix}.q_groups") - - return Exl2Weight( - q_weight=q_weight, - q_scale=q_scale, - q_invperm=q_invperm, - q_scale_max=q_scale_max, - q_groups=q_groups, - ) + # Sharding is not yet supported, so we return the weights as-is. + return self.get_weights(weights, prefix) def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): raise ValueError("get_multi_weights_col is not supported for exl2") def get_weights_row(self, weights: Weights, prefix: str): - try: - q_weight = weights.get_tensor(f"{prefix}.q_weight") - except RuntimeError: - raise RuntimeError( - "Cannot load `exl2`-quantized weight, make sure the model is already quantized." - ) - - q_scale = weights.get_tensor(f"{prefix}.q_scale") - q_invperm = weights.get_tensor(f"{prefix}.q_invperm") - q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max") - q_groups = weights.get_tensor(f"{prefix}.q_groups") - - return Exl2Weight( - q_weight=q_weight, - q_scale=q_scale, - q_invperm=q_invperm, - q_scale_max=q_scale_max, - q_groups=q_groups, - ) + # Sharding is not yet supported, so we return the weights as-is. + return self.get_weights(weights, prefix) diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index c98dbefc4..6feca2755 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -134,6 +134,115 @@ class GPTQWeightsLoader(WeightsLoader): self.quantize = quantize self.sym = sym + def get_weights(self, weights: Weights, prefix: str): + from text_generation_server.layers.marlin import ( + can_use_gptq_marlin, + repack_gptq_for_marlin, + ) + + self._get_gptq_params(weights) + if can_use_gptq_marlin( + bits=self.bits, + groupsize=self.groupsize, + quant_method=self.quant_method, + quantize=self.quantize, + sym=self.sym, + ): + log_once(logger.info, "Using GPTQ-Marlin kernels") + try: + qweight = weights.get_tensor(f"{prefix}.qweight") + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" + ) + + g_idx = weights.get_tensor(f"{prefix}.g_idx") + scales = weights.get_tensor(f"{prefix}.scales") + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + sym=self.sym, + sharded_infeatures=False, + ) + + use_exllama = True + if self.bits != 4: + use_exllama = False + + if self.desc_act: + log_once(logger.warning, "Disabling exllama because desc_act=True") + use_exllama = False + + try: + qweight = weights.get_tensor(f"{prefix}.qweight") + except RuntimeError: + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + + if self.quantize == "gptq" and self.quant_method == "gptq": + g_idx = weights.get_tensor(f"{prefix}.g_idx") + else: + g_idx = None + + from text_generation_server.layers.gptq import ( + HAS_EXLLAMA, + CAN_EXLLAMA, + GPTQWeight, + ) + + if use_exllama: + if not HAS_EXLLAMA: + if CAN_EXLLAMA: + log_once( + logger.warning, + "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True", + ) + use_exllama = False + else: + log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") + + qzeros = weights.get_tensor(f"{prefix}.qzeros") + scales = weights.get_tensor(f"{prefix}.scales") + + if use_exllama and g_idx is not None: + g_idx = g_idx - g_idx[0] + + if self.quantize == "gptq" and self.quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.layers.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + if use_exllama: + g_idx = None + else: + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // self.bits), + device=qweight.device, + ) + // self.groupsize + ).to(dtype=torch.int32) + + return GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=self.bits, + groupsize=self.groupsize, + use_exllama=use_exllama, + ) + def get_weights_col_packed( self, weights: Weights, diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py index e7f017a4d..a913ff57a 100644 --- a/server/text_generation_server/layers/marlin.py +++ b/server/text_generation_server/layers/marlin.py @@ -33,6 +33,35 @@ class MarlinWeightsLoader(WeightsLoader): self.bits = bits self.is_marlin_24 = is_marlin_24 + def get_weights(self, weights: "Weights", prefix: str): + """ + Get weights at the given prefix and apply without tensor paralllism. + """ + is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" + if is_marlin_24: + try: + B = weights.get_tensor(f"{prefix}.B_24") + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." + ) + + B_meta = weights.get_tensor(f"{prefix}.B_meta") + s = weights.get_tensor(f"{prefix}.s") + weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + else: + try: + B = weights.get_tensor(f"{prefix}.B") + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` weight, make sure the model is already quantized." + ) + + s = weights.get_tensor(f"{prefix}.s") + weight = MarlinWeight(B=B, s=s) + + return weight + def get_weights_col_packed( self, weights: Weights, diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 8c354b82b..db78ee1cc 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -1,6 +1,7 @@ import os import torch from torch import nn +from loguru import logger from text_generation_server.utils.import_utils import SYSTEM @@ -97,6 +98,8 @@ class PositionRotaryEmbedding(nn.Module): ) elif rope_scaling["type"] == "yarn": scaling_factor = rope_scaling["factor"] + mscale = rope_scaling.get("mscale", 1.0) + mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0) return YarnPositionRotaryEmbedding( dim=2 * inv_freq.shape[0], max_position_embeddings=rope_scaling[ @@ -109,6 +112,8 @@ class PositionRotaryEmbedding(nn.Module): attn_factor=1, beta_fast=32, beta_slow=1, + mscale=mscale, + mscale_all_dim=mscale_all_dim, ) elif rope_scaling["type"] in ["su", "longrope"]: short_factor = torch.tensor( @@ -181,6 +186,8 @@ class PositionRotaryEmbedding(nn.Module): scaling_factor=scaling_factor, ) elif rope_scaling["type"] == "yarn": + mscale = rope_scaling.get("mscale", 1.0) + mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0) return YarnPositionRotaryEmbedding( dim=2 * inv_freq.shape[0], max_position_embeddings=rope_scaling[ @@ -193,6 +200,8 @@ class PositionRotaryEmbedding(nn.Module): attn_factor=1, beta_fast=32, beta_slow=1, + mscale=mscale, + mscale_all_dim=mscale_all_dim, ) else: raise NotImplementedError( @@ -346,10 +355,10 @@ def linear_ramp_mask(min, max, dim): return ramp_func -def get_mscale(scale=1): +def get_mscale(scale: float = 1.0, mscale: float = 1.0): if scale <= 1: return 1.0 - return 0.1 * math.log(scale) + 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): @@ -365,6 +374,8 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): attn_factor, beta_fast, beta_slow, + mscale: float, + mscale_all_dim: float, ): inv_freq = _create_inv_freq(dim, base, device) super().__init__(inv_freq, scaling_factor) @@ -375,8 +386,12 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): self.attn_factor = attn_factor self.beta_fast = beta_fast self.beta_slow = beta_slow + self.mscale_all_dim = mscale_all_dim + self.scaling_factor = scaling_factor self.mscale = float( - get_mscale(self.scaling_factor) * self.attn_factor + get_mscale(self.scaling_factor, mscale) + / get_mscale(self.scaling_factor, mscale_all_dim) + * self.attn_factor ) # Get n-d magnitude scaling corrected for interpolation def _update_cos_sin_cache(self, dtype, device, seqlen): @@ -387,7 +402,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): or self._cos_cached.device != device or self._cos_cached.dtype != dtype ): - if seqlen > self.max_position_embeddings: + if seqlen > self.max_position_embeddings or True: inv_freq_extrapolation = _create_inv_freq( self.dim, self.base, self.inv_freq.device ) @@ -400,6 +415,7 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): self.base, self.max_position_embeddings, ) + inv_freq_mask = ( 1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device) ) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation @@ -409,9 +425,6 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): ) self.inv_freq = inv_freq - self.mscale = float( - get_mscale(self.scaling_factor) * self.attn_factor - ) # Get n-d magnitude scaling corrected for interpolation self._seq_len_cached = seqlen t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 32156a068..690a88876 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -61,6 +61,10 @@ FLASH_ATTENTION = True try: from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.vlm_causal_lm import VlmCausalLM + from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import ( + FlashDeepseekV2ForCausalLM, + DeepseekV2Config, + ) from text_generation_server.models.custom_modeling.flash_llama_modeling import ( FlashLlamaForCausalLM, ) @@ -141,6 +145,11 @@ if MAMBA_AVAILABLE: class ModelType(enum.Enum): + DEEPSEEK_V2 = { + "type": "deepseek_v2", + "name": "Deepseek V2", + "url": "https://huggingface.co/deepseek-ai/DeepSeek-V2", + } IDEFICS2 = { "type": "idefics2", "name": "Idefics 2", @@ -459,7 +468,40 @@ def get_model( f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})." ) - if model_type == MAMBA: + if model_type == DEEPSEEK_V2: + if FLASH_ATTENTION: + head_size = max( + config_dict.get("qk_nope_dim", 128) + + config_dict.get("qk_rope_dim", 64), + config_dict.get("v_head_dim", 128), + ) + return FlashCausalLM( + model_id=model_id, + model_class=FlashDeepseekV2ForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + default_dtype=torch.bfloat16, + dtype=dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=DeepseekV2Config, + head_size=head_size, + ) + elif sharded: + raise NotImplementedError( + FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2") + ) + else: + return CausalLM.fallback( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + elif model_type == MAMBA: return Mamba( model_id, revision, diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py new file mode 100644 index 000000000..8ffbd143c --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -0,0 +1,983 @@ +# coding=utf-8 +# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.distributed +from text_generation_server.layers import ( + FastLinear, + SpeculativeHead, + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + get_linear, +) +from text_generation_server.layers.attention import ( + attention, + paged_attention, + reshape_and_cache, +) +from text_generation_server.layers.attention.common import Seqlen +from text_generation_server.layers.layernorm import FastRMSNorm +from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.log import log_once +from text_generation_server.utils.weights import Weights +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig + + +class DeepseekV2Config(PretrainedConfig): + def __init__( + self, + vocab_size=102400, + hidden_size=4096, + intermediate_size=11008, + moe_intermediate_size=1407, + num_hidden_layers=30, + num_attention_heads=32, + num_key_value_heads=32, + n_shared_experts=2, + n_routed_experts=160, + ep_size=1, + routed_scaling_factor=1.0, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + topk_method="gready", + n_group=8, + topk_group=3, + num_experts_per_tok=6, + moe_layer_freq=1, + first_k_dense_replace=0, + norm_topk_prob=False, + scoring_func="softmax", + aux_loss_alpha=0.001, + seq_aux=True, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=100000, + eos_token_id=100001, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + tie_word_embeddings = kwargs.pop("tie_word_embeddings", False) + if tie_word_embeddings: + raise ValueError( + "tie_word_embeddings is not supported for Deepseek V2 models." + ) + + if ep_size != 1: + raise ValueError( + f"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}" + ) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +def _load_experts(config, prefix: str, mat: str, weights: Weights): + if config.quantize is not None: + raise NotImplementedError( + "Deepseek V2 does not support weight quantization yet." + ) + + assert mat in ["gate_proj", "up_proj", "down_proj"] + + world_size = weights.process_group.size() + rank = weights.process_group.rank() + + assert ( + config.moe_intermediate_size % world_size == 0 + ), f"The chosen size {config.moe_intermediate_size} is not compatible with sharding on {world_size} shards" + + block_size = config.moe_intermediate_size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + + tensor = torch.empty( + (config.n_routed_experts * block_size, config.hidden_size), + dtype=weights.dtype, + device=weights.device, + ) + + for i in range(config.n_routed_experts): + slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight") + + if mat == "down_proj": + expert_slice = slice_[:, start:stop].t().contiguous() + else: + expert_slice = slice_[start:stop] + tensor[i * block_size : (i + 1) * block_size] = expert_slice.to( + dtype=weights.dtype + ).to(device=weights.device) + return tensor + + +class DeepseekV2Attention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights: Weights, + ): + super().__init__() + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.kv_lora_rank = config.kv_lora_rank + self.q_lora_rank = config.q_lora_rank + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_rope_head_dim = config.qk_rope_head_dim + self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim + self.value_head_size = config.v_head_dim + self.head_pad_size = max(self.head_size, self.value_head_size) + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.qk_rope_head_dim, + base=config.rope_theta, + device=weights.device, + ) + + mscale = get_mscale( + self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim + ) + self.softmax_scale = self.head_size**-0.5 * mscale * mscale + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + if self.q_lora_rank is None: + self.q_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.q_proj", + weights=weights, + bias=config.attention_bias, + ) + else: + self.q_a_proj = get_linear( + weight=weights.get_weights(f"{prefix}.q_a_proj"), + bias=( + weights.get_tensor(f"{prefix}.q_a_proj.bias") + if config.attention_bias + else None + ), + quantize=config.quantize, + ) + self.q_a_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.q_a_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.q_b_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.q_b_proj", + weights=weights, + bias=config.attention_bias, + ) + + self.kv_a_proj_with_mqa = get_linear( + weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"), + bias=( + weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias") + if config.attention_bias + else None + ), + quantize=config.quantize, + ) + + self.kv_a_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps + ) + + self.kv_b_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.kv_b_proj", + weights=weights, + bias=config.attention_bias, + ) + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + cu_seqlen_prefill: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: Seqlen, + max_s: int, + ): + if self.q_lora_rank is None: + query = self.q_proj(hidden_states) + else: + query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0]) + query = query.view(-1, self.num_heads, self.head_size) + + _, query_pe = torch.split( + query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, key_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + + key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim) + kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view( + -1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size + ) + + key_nope, value = torch.split( + kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1 + ) + + batch_size, heads, head_dim = query_pe.shape + query_pe = ( + query_pe.view(batch_size, heads, head_dim // 2, 2) + .transpose(2, 3) + .reshape(batch_size, heads, head_dim) + ) + batch_size, heads, head_dim = key_pe.shape + key_pe = ( + key_pe.view(batch_size, heads, head_dim // 2, 2) + .transpose(2, 3) + .reshape(batch_size, heads, head_dim) + ) + self.rotary_emb(query_pe, key_pe, cos, sin) + + query[..., self.qk_nope_head_dim :] = query_pe + key = torch.empty_like(query) + key[..., : self.qk_nope_head_dim] = key_nope + key[..., self.qk_nope_head_dim :] = key_pe + + # We need to pad the heads because Flash Attention does not support + # qk and v with different head sizes. + query = torch.nn.functional.pad( + query, (0, self.head_pad_size - self.head_size), value=0 + ) + key = torch.nn.functional.pad( + key, (0, self.head_pad_size - self.head_size), value=0 + ) + value = torch.nn.functional.pad( + value, (0, self.head_pad_size - self.value_head_size), value=0 + ) + + reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + + # Output tensor + attn_output = torch.empty_like(query) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attention( + query, + key, + value, + attn_output, + cu_seqlen_prefill, + max_s, + self.softmax_scale, + ) + # Decode + else: + paged_attention( + attn_output, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + max_s, + ) + + # Remove padding. + attn_output = attn_output[..., : self.value_head_size] + + return self.o_proj( + attn_output.reshape(-1, self.num_heads * self.value_head_size) + ) + + +class DeepseekV2MLP(nn.Module): + def __init__(self, prefix: str, config, weights, intermediate_size: int): + super().__init__() + self.hidden_act = config.hidden_act + if self.hidden_act != "silu": + # Bail out because MoE only supports silu. + raise NotImplementedError( + "Currently only `silu` is supported as an activation for Deepseek V2." + ) + self.act = ACT2FN[self.hidden_act] + + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + + self.intermediate_size = intermediate_size // weights.process_group.size() + + # TODO: This is a hotfix to be removed & properly refactored. + self.quantize = config.quantize + + def forward(self, hidden_states: torch.Tensor, reduce: bool = True): + if ( + SYSTEM == "rocm" + and self.hidden_act == "silu" + and hidden_states.shape[0] == 1 + and not self.quantize + ): + out = torch.empty( + hidden_states.shape[0], + self.intermediate_size, + dtype=hidden_states.dtype, + device="cuda", + ) + _custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8) + return self.down_proj(out, reduce=reduce) + else: + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce + ) + + +class BlockSparseMoE(nn.Module): + def __init__(self, prefix, config: DeepseekV2Config, weights): + super().__init__() + + self.hidden_dim = config.hidden_size + self.moe_intermediate_size = ( + config.moe_intermediate_size // weights.process_group.size() + ) + self.n_routed_experts = config.n_routed_experts + self.n_expert_group = config.n_group + self.topk_group = config.topk_group + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + self.routed_scaling_factor = config.routed_scaling_factor + + gate_proj = _load_experts( + config, f"{prefix}.experts", "gate_proj", weights + ).view(self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim) + + up_proj = _load_experts(config, f"{prefix}.experts", "up_proj", weights).view( + self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim + ) + + self.gate_up_proj = torch.cat([gate_proj, up_proj], dim=1) + + self.down_proj = ( + _load_experts(config, f"{prefix}.experts", "down_proj", weights) + .view(self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim) + .transpose(1, 2) + .contiguous() + ) + + # Gating + self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) + + if config.n_shared_experts is not None: + self.shared_experts = DeepseekV2MLP( + prefix=f"{prefix}.shared_experts", + config=config, + weights=weights, + intermediate_size=config.moe_intermediate_size + * config.n_shared_experts, + ) + else: + self.shared_experts = None + + self.process_group = weights.process_group + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.shared_experts is not None: + shared_output = self.shared_experts(x, reduce=False) + else: + shared_output = None + + router_logits = self.gate(x) + topk_weights, topk_ids = grouped_topk( + x, + router_logits, + self.top_k, + renormalize=self.norm_topk_prob, + num_expert_group=self.n_expert_group, + topk_group=self.topk_group, + ) + out = ( + fused_experts( + x, + self.gate_up_proj, + self.down_proj, + topk_weights, + topk_ids, + inplace=True, + ) + * self.routed_scaling_factor + ) + + if shared_output is not None: + out = out + shared_output + + # Reduce sum + if self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) + + return out.view(*x.shape) + + +class DenseMoE(nn.Module): + def __init__(self, prefix: str, config: DeepseekV2Config, weights: Weights): + super().__init__() + + self.hidden_dim = config.hidden_size + self.moe_intermediate_size = config.moe_intermediate_size + self.n_routed_experts = config.n_routed_experts + self.n_expert_group = config.n_group + self.topk_group = config.topk_group + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + self.routed_scaling_factor = config.routed_scaling_factor + + # Gating + # + # Seems like no one quantizes the gate. + self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) + + self.experts = [ + DeepseekV2MLP( + f"{prefix}.experts.{i}", config, weights, self.moe_intermediate_size + ) + for i in range(self.n_routed_experts) + ] + + if config.n_shared_experts is not None: + self.shared_experts = DeepseekV2MLP( + prefix=f"{prefix}.shared_experts", + config=config, + weights=weights, + intermediate_size=config.moe_intermediate_size + * config.n_shared_experts, + ) + else: + self.shared_experts = None + + self.process_group = weights.process_group + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + x: (sequence_length, model_dim) + gate_logits: (sequence_length, n_experts) + """ + # optional reshape + input_shape = x.shape + x = x.view(-1, input_shape[-1]) + + if self.shared_experts is not None: + shared_output = self.shared_experts(x, reduce=False) + else: + shared_output = None + + # gate_logits: (sequence_length, n_experts) + router_logits = self.gate(x) + + topk_weights, topk_ids = grouped_topk( + x, + router_logits, + self.top_k, + renormalize=self.norm_topk_prob, + num_expert_group=self.n_expert_group, + topk_group=self.topk_group, + ) + + out = self.moe_infer_gpu(x, topk_ids, topk_weights) * self.routed_scaling_factor + + if shared_output is not None: + out = out + shared_output + + # Reduce sum + if self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) + + return out + + def moe_infer_gpu( + self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor + ): + weights = torch.zeros( + topk_ids.shape[0], len(self.experts), dtype=x.dtype, device=x.device + ) + weights.scatter_(1, topk_ids, topk_weight) + + out = x.new_zeros(x.shape[0], self.hidden_dim) + for i, expert in enumerate(self.experts): + # Add expert output to out with masking + out += expert(x, reduce=False) * weights[:, i].view(-1, 1) + return out + + +class DeepseekV2Layer(nn.Module): + def __init__(self, prefix, layer_id, config, weights): + super().__init__() + prefix = f"{prefix}.layers.{layer_id}" + + self.self_attn = DeepseekV2Attention( + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + ) + + if ( + config.n_routed_experts is not None + and layer_id >= config.first_k_dense_replace + and layer_id % config.moe_layer_freq == 0 + ): + moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE + self.mlp = moe_cls(f"{prefix}.mlp", config, weights) + else: + self.mlp = DeepseekV2MLP( + prefix=f"{prefix}.mlp", + config=config, + weights=weights, + intermediate_size=config.intermediate_size, + ) + + self.input_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + cu_seqlen_prefill: torch.Tensor, + kv_cache, + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: Seqlen, + max_s: int, + ): + normed_hidden_states, residual = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + + # faster post attention rms norm + normed_attn_res_output, residual = self.post_attention_layernorm( + attn_output, residual + ) + + output = self.mlp(normed_attn_res_output) + + return output, residual + + +class DeepseekV2Model(torch.nn.Module): + def __init__(self, prefix: str, config, weights: Weights): + super().__init__() + + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens", weights=weights + ) + + self.layers = nn.ModuleList( + [ + DeepseekV2Layer( + prefix, + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = FastRMSNorm.load( + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps + ) + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, + max_s, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashDeepseekV2ForCausalLM(torch.nn.Module): + def __init__(self, prefix: str, config, weights: Weights): + super().__init__() + + self.model = DeepseekV2Model( + "model" if not prefix else f"{prefix}.model", config, weights + ) + self.lm_head = SpeculativeHead.load( + config, + prefix="lm_head" if not prefix else f"{prefix}.lm_head", + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits + + +# Functions below are from vLLM: +# +# https://github.com/vllm-project/vllm/blob/f7160d946a0a07703e72d81ba9ecf3913f192605/vllm/model_executor/layers/fused_moe/fused_moe.py#L397 +# +# Remove after we have synced our version with upstream. + + +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor]: + scores = torch.softmax(gating_output, dim=-1) + num_token = scores.shape[0] + group_scores = ( + scores.view(num_token, num_expert_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) + .reshape(num_token, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights, topk_ids + + +def get_default_config( + M: int, + E: int, + N: int, + K: int, + topk: int, + dtype: Optional[str], +) -> Dict[str, int]: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + if M <= E: + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } + return config + + +def fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +): + # Check constraints. + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] + + import triton.language as tl + from vllm import _custom_ops as ops + from vllm.model_executor.layers.fused_moe.fused_moe import ( + get_moe_configs, + invoke_fused_moe_kernel, + moe_align_block_size, + ) + + M, _ = hidden_states.shape + E, N, _ = w1.shape + + if override_config: + config = override_config + else: + # First try to load optimal config from the file + configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None) + + if configs: + # If an optimal configuration map has been found, look up the + # optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Else use the default config + config = get_default_config( + M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None + ) + + intermediate_cache1 = torch.empty( + (M, topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache2 = torch.empty( + (M * topk_ids.shape[1], N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache3 = torch.empty( + (M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, config["BLOCK_SIZE_M"], E + ) + compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 + + invoke_fused_moe_kernel( + hidden_states, + w1, + intermediate_cache1, + a1_scale, + w1_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk_ids.shape[1], + config, + compute_type=compute_type, + use_fp8=use_fp8, + ) + + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + + invoke_fused_moe_kernel( + intermediate_cache2, + w2, + intermediate_cache3, + a2_scale, + w2_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + config, + compute_type=compute_type, + use_fp8=use_fp8, + ) + + if inplace: + return torch.sum( + intermediate_cache3.view(*intermediate_cache3.shape), + dim=1, + out=hidden_states, + ) + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 2ca9eef33..2888f1f79 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -839,7 +839,9 @@ class FlashCausalLM(Model): default_dtype=torch.float16, aliases=None, # Used for Santacoder override of config - num_kv_heads=None, + num_kv_heads: Optional[int] = None, + # Deepseek V2 uses different QK and V dims. + head_size: Optional[int] = None, skip_special_tokens: bool = True, ): self.process_group, rank, world_size = initialize_torch_distributed() @@ -922,7 +924,11 @@ class FlashCausalLM(Model): else num_kv_heads ) assert self.num_kv_heads > 0 - self.head_size = config.hidden_size // config.num_attention_heads + + if head_size is None: + self.head_size = config.hidden_size // config.num_attention_heads + else: + self.head_size = head_size self.cuda_graphs = {} self.kv_cache = [] diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 6876b7009..91592df09 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -21,6 +21,13 @@ class WeightsLoader(ABC): with the format, etc. """ + @abstractmethod + def get_weights(self, weights: "Weights", prefix: str): + """ + Get weights at the given prefix and apply without tensor paralllism. + """ + ... + @abstractmethod def get_weights_col_packed( self, @@ -104,6 +111,9 @@ class DefaultWeightsLoader(WeightsLoader): and/or concatenation. """ + def get_weights(self, weights: "Weights", prefix: str): + return weights.get_tensor(f"{prefix}.weight") + def get_weights_col_packed( self, weights: "Weights", @@ -299,6 +309,9 @@ class Weights: return tensor + def get_weights(self, prefix: str): + return self.weights_loader.get_weights(self, prefix) + def get_weights_col_packed_qkv( self, prefix: str,