From c6bcadf883e7bd16fb74f509ba1dade1b635e96e Mon Sep 17 00:00:00 2001 From: Aaron Mihalik Date: Fri, 5 Jul 2024 03:46:41 -0400 Subject: [PATCH 01/24] Adding "longrope" for Phi-3 (#2172) (#2179) Adding "longrope" for phi-3 --- server/text_generation_server/layers/rotary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index b14005e6..87a61e82 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -110,7 +110,7 @@ class PositionRotaryEmbedding(nn.Module): beta_fast=32, beta_slow=1, ) - elif rope_scaling["type"] == "su": + elif rope_scaling["type"] in ["su", "longrope"]: short_factor = torch.tensor( rope_scaling["short_factor"], dtype=torch.float32, device=device ) From fb2f74e2b9f1639c0db1464d56c906a7df3d6865 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 5 Jul 2024 10:29:56 +0200 Subject: [PATCH 02/24] Refactor dead code - Removing all `flash_xxx.py` files. (#2166) * Refactor dead code. * First working step. * Remove a lot of duplicated code. * More dead code. * More cleanup. * Fix Santacoder test. * Fixing the simple tests. * Fixing sharding. * Fixes for VLM. * Fixing santacoder (num_kv_heads hardcoded). * Removing more dead code. * Fixing `config.n_head`. * Stopping earlier because of `` in idefics2. * Addresses comments. * Removing the dead code. * Fuse back mistral into FlashCausalLM. * Finish removal. * Fixing docs + causal_lm `batch_class`. * Fixing docs + causal.lm. * Add default to Gemma Causality. * Default value for gemma/gemma2. * Wrong default. --- docs/openapi.json | 2 +- docs/source/supported_models.md | 1 + .../test_flash_idefics2_two_images.json | 48 +-- integration-tests/models/test_idefics2.py | 2 +- server/tests/models/test_bloom.py | 8 +- server/tests/models/test_causal_lm.py | 2 +- server/tests/models/test_santacoder.py | 5 +- server/tests/models/test_seq2seq_lm.py | 2 +- .../text_generation_server/models/__init__.py | 393 +++++++++++------- server/text_generation_server/models/bloom.py | 73 ---- .../models/causal_lm.py | 106 ++++- .../custom_modeling/flash_gemma2_modeling.py | 2 +- .../custom_modeling/flash_gemma_modeling.py | 2 +- .../flash_santacoder_modeling.py | 3 +- .../models/custom_modeling/llava_next.py | 8 +- .../models/flash_causal_lm.py | 193 ++++++++- .../models/flash_cohere.py | 75 ---- .../models/flash_dbrx.py | 100 ----- .../models/flash_gemma.py | 83 ---- .../models/flash_gemma2.py | 83 ---- .../models/flash_gpt2.py | 82 ---- .../models/flash_llama.py | 171 -------- .../models/flash_mistral.py | 126 +----- .../models/flash_mixtral.py | 31 -- .../models/flash_neox.py | 82 ---- .../models/flash_phi.py | 111 ----- .../models/flash_qwen2.py | 93 ----- .../text_generation_server/models/flash_rw.py | 91 ---- .../models/flash_santacoder.py | 99 ----- .../models/flash_starcoder2.py | 84 ---- .../models/galactica.py | 80 ---- .../text_generation_server/models/gpt_neox.py | 89 ---- .../text_generation_server/models/idefics2.py | 51 --- .../models/llava_next.py | 46 -- server/text_generation_server/models/mpt.py | 105 ----- server/text_generation_server/models/opt.py | 86 ---- .../models/pali_gemma.py | 42 -- server/text_generation_server/models/phi.py | 69 --- server/text_generation_server/models/rw.py | 84 ---- .../models/santacoder.py | 77 ---- .../models/seq2seq_lm.py | 99 ++++- server/text_generation_server/models/t5.py | 115 ----- .../models/vlm_causal_lm.py | 36 +- 43 files changed, 689 insertions(+), 2451 deletions(-) delete mode 100644 server/text_generation_server/models/flash_cohere.py delete mode 100644 server/text_generation_server/models/flash_dbrx.py delete mode 100644 server/text_generation_server/models/flash_gemma.py delete mode 100644 server/text_generation_server/models/flash_gemma2.py delete mode 100644 server/text_generation_server/models/flash_gpt2.py delete mode 100644 server/text_generation_server/models/flash_llama.py delete mode 100644 server/text_generation_server/models/flash_mixtral.py delete mode 100644 server/text_generation_server/models/flash_neox.py delete mode 100644 server/text_generation_server/models/flash_phi.py delete mode 100644 server/text_generation_server/models/flash_qwen2.py delete mode 100644 server/text_generation_server/models/flash_rw.py delete mode 100644 server/text_generation_server/models/flash_santacoder.py delete mode 100644 server/text_generation_server/models/flash_starcoder2.py delete mode 100644 server/text_generation_server/models/gpt_neox.py delete mode 100644 server/text_generation_server/models/idefics2.py delete mode 100644 server/text_generation_server/models/llava_next.py delete mode 100644 server/text_generation_server/models/mpt.py delete mode 100644 server/text_generation_server/models/opt.py delete mode 100644 server/text_generation_server/models/phi.py delete mode 100644 server/text_generation_server/models/rw.py delete mode 100644 server/text_generation_server/models/santacoder.py delete mode 100644 server/text_generation_server/models/t5.py diff --git a/docs/openapi.json b/docs/openapi.json index 5e0399e0..9c9a8b1a 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -10,7 +10,7 @@ "name": "Apache 2.0", "url": "https://www.apache.org/licenses/LICENSE-2.0" }, - "version": "2.1.1-dev0" + "version": "2.1.2-dev0" }, "paths": { "/": { diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index 1eeed39f..2bdd00de 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -10,6 +10,7 @@ Text Generation Inference enables serving optimized models on specific hardware - [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) - [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) - [Gemma](https://huggingface.co/google/gemma-7b) +- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224) - [Gemma2](https://huggingface.co/google/gemma2-9b) - [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus) - [Dbrx](https://huggingface.co/databricks/dbrx-instruct) diff --git a/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_two_images.json b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_two_images.json index bf2dc5a1..44ccea71 100644 --- a/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_two_images.json +++ b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_two_images.json @@ -1,130 +1,124 @@ { "details": { "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 20, + "finish_reason": "eos_token", + "generated_tokens": 19, "prefill": [], "seed": null, "tokens": [ { "id": 415, - "logprob": -0.039886475, + "logprob": -0.03665161, "special": false, "text": " The" }, { "id": 12072, - "logprob": -0.1430664, + "logprob": -0.13549805, "special": false, "text": " cow" }, { "id": 349, - "logprob": -0.056488037, + "logprob": -0.05819702, "special": false, "text": " is" }, { "id": 6328, - "logprob": -0.6855469, + "logprob": -0.6826172, "special": false, "text": " standing" }, { "id": 356, - "logprob": -0.1685791, + "logprob": -0.1607666, "special": false, "text": " on" }, { "id": 272, - "logprob": -0.50097656, + "logprob": -0.5073242, "special": false, "text": " the" }, { "id": 10305, - "logprob": -0.017303467, + "logprob": -0.016418457, "special": false, "text": " beach" }, { "id": 304, - "logprob": -1.3564453, + "logprob": -1.3916016, "special": false, "text": " and" }, { "id": 272, - "logprob": -0.017868042, + "logprob": -0.020217896, "special": false, "text": " the" }, { "id": 13088, - "logprob": -0.0027103424, + "logprob": -0.0028133392, "special": false, "text": " chicken" }, { "id": 349, - "logprob": -0.003156662, + "logprob": -0.003145218, "special": false, "text": " is" }, { "id": 6398, - "logprob": -0.37304688, + "logprob": -0.37060547, "special": false, "text": " sitting" }, { "id": 356, - "logprob": -0.034576416, + "logprob": -0.034851074, "special": false, "text": " on" }, { "id": 264, - "logprob": -0.29418945, + "logprob": -0.2878418, "special": false, "text": " a" }, { "id": 17972, - "logprob": -0.042877197, + "logprob": -0.046051025, "special": false, "text": " pile" }, { "id": 302, - "logprob": -0.00028443336, + "logprob": -0.00028848648, "special": false, "text": " of" }, { "id": 2445, - "logprob": -0.023223877, + "logprob": -0.025772095, "special": false, "text": " money" }, { "id": 28723, - "logprob": -0.018157959, + "logprob": -0.018127441, "special": false, "text": "." }, { "id": 32002, - "logprob": -0.00018393993, + "logprob": -0.00019824505, "special": true, "text": "" - }, - { - "id": 2, - "logprob": -1.1920929e-07, - "special": true, - "text": "" } ], "top_tokens": null diff --git a/integration-tests/models/test_idefics2.py b/integration-tests/models/test_idefics2.py index 9aaf6d8a..c5f48da3 100644 --- a/integration-tests/models/test_idefics2.py +++ b/integration-tests/models/test_idefics2.py @@ -57,7 +57,7 @@ async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot) response.generated_text == " The cow is standing on the beach and the chicken is sitting on a pile of money." ), f"{repr(response.generated_text)}" - assert response.details.generated_tokens == 20 + assert response.details.generated_tokens == 19 assert response == response_snapshot diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 32ee6686..08292920 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -8,6 +8,9 @@ from text_generation_server.pb import generate_pb2 from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.utils import weight_hub_files, download_weights from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded +from text_generation_server.models.custom_modeling.bloom_modeling import ( + BloomForCausalLM, +) @pytest.fixture(scope="session") @@ -16,7 +19,10 @@ def default_bloom(): revision = "main" filenames = weight_hub_files(model_id, revision, ".safetensors") download_weights(filenames, model_id, revision) - return BLOOMSharded(model_id) + return BLOOMSharded( + model_id, + model_class=BloomForCausalLM, + ) @pytest.fixture(scope="session") diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 6e6463bc..c000ef26 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -10,7 +10,7 @@ from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch @pytest.fixture(scope="session") def default_causal_lm(): - return CausalLM("gpt2") + return CausalLM.fallback("gpt2") @pytest.fixture(scope="session") diff --git a/server/tests/models/test_santacoder.py b/server/tests/models/test_santacoder.py index cb2622d9..d5c91bff 100644 --- a/server/tests/models/test_santacoder.py +++ b/server/tests/models/test_santacoder.py @@ -1,13 +1,12 @@ import pytest from text_generation_server.pb import generate_pb2 -from text_generation_server.models.causal_lm import CausalLMBatch -from text_generation_server.models.santacoder import SantaCoder +from text_generation_server.models.causal_lm import CausalLMBatch, CausalLM @pytest.fixture(scope="session") def default_santacoder(): - return SantaCoder("bigcode/santacoder") + return CausalLM.fallback(model_id="bigcode/santacoder") @pytest.fixture diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 943c3b08..02666042 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -20,7 +20,7 @@ def mt0_small_tokenizer(): @pytest.fixture(scope="session") def default_seq2seq_lm(): - return Seq2SeqLM("bigscience/mt0-small") + return Seq2SeqLM.fallback("bigscience/mt0-small") @pytest.fixture diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 5ea43290..15e74622 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -11,17 +11,26 @@ from pathlib import Path from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model -from text_generation_server.models.causal_lm import CausalLM -from text_generation_server.models.bloom import BLOOMSharded -from text_generation_server.models.mpt import MPTSharded +from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast +from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM +from text_generation_server.models.custom_modeling.mpt_modeling import ( + MPTForCausalLM, +) +from text_generation_server.models.custom_modeling.bloom_modeling import ( + BloomForCausalLM, +) from text_generation_server.models.seq2seq_lm import Seq2SeqLM -from text_generation_server.models.rw import RW -from text_generation_server.models.opt import OPTSharded -from text_generation_server.models.galactica import GalacticaSharded -from text_generation_server.models.santacoder import SantaCoder -from text_generation_server.models.t5 import T5Sharded -from text_generation_server.models.gpt_neox import GPTNeoxSharded -from text_generation_server.models.phi import Phi +from text_generation_server.models.galactica import GalacticaCausalLMBatch +from text_generation_server.models.custom_modeling.neox_modeling import ( + GPTNeoxForCausalLM, +) +from text_generation_server.models.custom_modeling.phi_modeling import ( + PhiConfig, + PhiForCausalLM, +) +from text_generation_server.models.custom_modeling.t5_modeling import ( + T5ForConditionalGeneration, +) from text_generation_server.utils.import_utils import SYSTEM @@ -41,9 +50,6 @@ __all__ = [ "CausalLM", "GalacticaSharded", "Seq2SeqLM", - "SantaCoder", - "OPTSharded", - "T5Sharded", "get_model", ] @@ -53,38 +59,65 @@ FLASH_ATTENTION = True try: from text_generation_server.models.flash_causal_lm import FlashCausalLM - from text_generation_server.models.flash_rw import FlashRWSharded - from text_generation_server.models.flash_gpt2 import FlashGPT2 - from text_generation_server.models.flash_neox import FlashNeoXSharded - from text_generation_server.models.flash_llama import ( - FlashLlama, + from text_generation_server.models.vlm_causal_lm import VlmCausalLM + from text_generation_server.models.custom_modeling.flash_llama_modeling import ( + FlashLlamaForCausalLM, ) - from text_generation_server.models.flash_qwen2 import ( - FlashQwen2, + from text_generation_server.models.custom_modeling.flash_cohere_modeling import ( + FlashCohereForCausalLM, ) - from text_generation_server.models.flash_cohere import ( - FlashCohere, + from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( + FlashGemmaForCausalLM, ) - from text_generation_server.models.flash_gemma import ( - FlashGemma, + from text_generation_server.models.custom_modeling.flash_gemma2_modeling import ( + FlashGemma2ForCausalLM, ) - from text_generation_server.models.flash_gemma2 import ( - FlashGemma2, + from text_generation_server.models.custom_modeling.flash_dbrx_modeling import ( + FlashDbrxForCausalLM, + DbrxConfig, + ) + from text_generation_server.models.custom_modeling.flash_rw_modeling import ( + RWConfig, + FlashRWForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_neox_modeling import ( + FlashGPTNeoXForCausalLM, ) from text_generation_server.models.pali_gemma import ( - PaliGemma, + PaliGemmaBatch, ) - from text_generation_server.models.flash_santacoder import ( - FlashSantacoderSharded, + from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import ( + PaliGemmaForConditionalGeneration, + ) + from text_generation_server.models.custom_modeling.flash_phi_modeling import ( + FlashPhiForCausalLM, ) from text_generation_server.models.idefics import IDEFICSSharded - from text_generation_server.models.llava_next import LlavaNext - from text_generation_server.models.idefics2 import Idefics2 - from text_generation_server.models.flash_mistral import FlashMistral - from text_generation_server.models.flash_mixtral import FlashMixtral - from text_generation_server.models.flash_phi import FlashPhi - from text_generation_server.models.flash_starcoder2 import FlashStarcoder2 - from text_generation_server.models.flash_dbrx import FlashDbrx + from text_generation_server.models.custom_modeling.llava_next import ( + LlavaNextForConditionalGeneration, + ) + + from text_generation_server.models.custom_modeling.flash_santacoder_modeling import ( + FlashSantacoderForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import ( + FlashStarcoder2ForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( + Qwen2ForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( + FlashMistralForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_mixtral_modeling import ( + FlashMixtralForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_gpt2_modeling import ( + FlashGPT2ForCausalLM, + ) + from text_generation_server.models.custom_modeling.idefics2 import ( + Idefics2ForConditionalGeneration, + ) from text_generation_server.layers.attention import SUPPORTS_WINDOWING except ImportError as e: logger.warning(f"Could not import Flash Attention enabled models: {e}") @@ -93,21 +126,7 @@ except ImportError as e: if FLASH_ATTENTION: __all__.append(FlashCausalLM) - __all__.append(FlashGPT2) - __all__.append(FlashNeoXSharded) - __all__.append(FlashRWSharded) - __all__.append(FlashSantacoderSharded) - __all__.append(FlashLlama) __all__.append(IDEFICSSharded) - __all__.append(FlashMistral) - __all__.append(FlashMixtral) - __all__.append(FlashDbrx) - __all__.append(FlashPhi) - __all__.append(FlashQwen2) - __all__.append(FlashStarcoder2) - __all__.append(FlashGemma) - __all__.append(FlashGemma2) - __all__.append(FlashCohere) MAMBA_AVAILABLE = True try: @@ -148,6 +167,11 @@ class ModelType(enum.Enum): "name": "Gemma", "url": "https://huggingface.co/google/gemma-7b", } + PALIGEMMA = { + "type": "paligemma", + "name": "PaliGemma", + "url": "https://huggingface.co/google/paligemma-3b-pt-224", + } GEMMA2 = { "type": "gemma2", "name": "Gemma2", @@ -445,13 +469,16 @@ def get_model( ) if model_id.startswith("facebook/galactica"): - return GalacticaSharded( - model_id, - revision, + return CausalLM( + model_id=model_id, + # Yes galactica is just an OPT model. + model_class=OPTForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + batch_class=GalacticaCausalLMBatch, ) if ( @@ -460,22 +487,26 @@ def get_model( and model_id.startswith("bigcode/") ): if FLASH_ATTENTION: - return FlashSantacoderSharded( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashSantacoderForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + aliases={"transformer.wte.weight": ["lm_head.weight"]}, + num_kv_heads=1, ) elif sharded: raise NotImplementedError( FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder") ) else: - return SantaCoder( - model_id, - revision, + return CausalLM.fallback( + model_id=model_id, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, @@ -483,38 +514,44 @@ def get_model( ) if model_type == BLOOM: - return BLOOMSharded( - model_id, - revision, + return CausalLM( + model_id=model_id, + model_class=BloomForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + batch_class=CausalLMBatchKeysLast, ) elif model_type == MPT: - return MPTSharded( - model_id, - revision, + return CausalLM( + model_id=model_id, + model_class=MPTForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + batch_class=CausalLMBatchKeysLast, ) elif model_type == GPT2: if FLASH_ATTENTION: try: - return FlashGPT2( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashGPT2ForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) except RuntimeError as e: # Lots of legacy models with various weight names. logger.warning(f"Couldn't load flash gpt2 variant: {e}") - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -525,7 +562,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -535,25 +572,28 @@ def get_model( ) elif model_type == GPT_NEOX: if FLASH_ATTENTION: - return FlashNeoXSharded( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashGPTNeoXForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: - return GPTNeoxSharded( - model_id, - revision, + return CausalLM( + model_id=model_id, + model_class=GPTNeoxForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -564,16 +604,18 @@ def get_model( elif model_type == PHI: if FLASH_ATTENTION: - return FlashPhi( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashPhiForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -588,9 +630,11 @@ def get_model( "Legacy phi-msft is not supported with Flash Attention" ) else: - return Phi( - model_id, - revision, + return CausalLM( + model_id=model_id, + model_class=PhiForCausalLM, + config_class=PhiConfig, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, @@ -599,9 +643,10 @@ def get_model( elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3: if FLASH_ATTENTION: - return FlashLlama( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashLlamaForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, @@ -611,7 +656,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -621,18 +666,22 @@ def get_model( ) if model_type == GEMMA: if FLASH_ATTENTION: - return FlashGemma( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashGemmaForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, + # Works better for these models + default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -642,18 +691,22 @@ def get_model( ) elif model_type == GEMMA2: if FLASH_ATTENTION: - return FlashGemma2( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashGemma2ForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, + # Works better for these models + default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -664,18 +717,20 @@ def get_model( if model_type == COHERE: if FLASH_ATTENTION: - return FlashCohere( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashCohereForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -686,18 +741,23 @@ def get_model( if model_type == DBRX: if FLASH_ATTENTION: - return FlashDbrx( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashDbrxForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, + # Dbrx works better in bfloat16. + default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=DbrxConfig, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -711,27 +771,37 @@ def get_model( if FLASH_ATTENTION: if config_dict.get("alibi", False): raise NotImplementedError("sharded is not supported for this model") - return FlashRWSharded( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashRWForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, + aliases={ + "lm_head.weight": ["transformer.word_embeddings.weight"], + "transformer.word_embeddings.weight": ["lm_head.weight"], + }, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=RWConfig, ) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon")) else: if FLASH_ATTENTION and not config_dict.get("alibi", False): - return FlashRWSharded( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashRWForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=RWConfig, ) else: - return RW( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -742,18 +812,20 @@ def get_model( if model_type == MISTRAL: if FLASH_ATTENTION: - return FlashMistral( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashMistralForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -764,18 +836,20 @@ def get_model( if model_type == MIXTRAL: if FLASH_ATTENTION: - return FlashMixtral( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashMixtralForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -786,19 +860,22 @@ def get_model( if model_type == STARCODER2: if FLASH_ATTENTION: - return FlashStarcoder2( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashStarcoder2ForCausalLM, + revision=revision, quantize=quantize, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError( FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2") ) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -809,17 +886,20 @@ def get_model( if model_type == QWEN2: if FLASH_ATTENTION: - return FlashQwen2( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=Qwen2ForCausalLM, + revision=revision, quantize=quantize, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -829,9 +909,10 @@ def get_model( ) if model_type == OPT: - return OPTSharded( - model_id, - revision, + return CausalLM( + model_id=model_id, + model_class=OPTForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, @@ -839,13 +920,20 @@ def get_model( ) if model_type == T5: - return T5Sharded( - model_id, - revision, + return Seq2SeqLM( + model_id=model_id, + model_class=T5ForConditionalGeneration, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + aliases={ + "shared.weight": [ + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + ] + }, ) if model_type == IDEFICS: if FLASH_ATTENTION: @@ -861,34 +949,45 @@ def get_model( raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == IDEFICS2: if FLASH_ATTENTION: - return Idefics2( - model_id, - revision, + return VlmCausalLM( + model_id=model_id, + model_class=Idefics2ForConditionalGeneration, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + # XXX: Extremely important to cap resolution in order to limit + # VRAM usage. + processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}}, ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) - if model_type == "paligemma": + if model_type == PALIGEMMA: if FLASH_ATTENTION: - return PaliGemma( - model_id, - revision, + return VlmCausalLM( + model_id=model_id, + model_class=PaliGemmaForConditionalGeneration, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, + # Works better for these models + default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + batch_class=PaliGemmaBatch, ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == LLAVA_NEXT: if FLASH_ATTENTION: - return LlavaNext( - model_id, - revision, + return VlmCausalLM( + model_class=LlavaNextForConditionalGeneration, + model_id=model_id, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, @@ -912,7 +1011,7 @@ def get_model( elif quantize == "exl2": raise NotImplementedError("exl2 quantization is not supported for AutoModel") if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -921,7 +1020,7 @@ def get_model( trust_remote_code=trust_remote_code, ) if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: - return Seq2SeqLM( + return Seq2SeqLM.fallback( model_id, revision, quantize=quantize, @@ -933,7 +1032,7 @@ def get_model( auto_map = config_dict.get("auto_map", None) if trust_remote_code and auto_map is not None: if "AutoModelForCausalLM" in auto_map.keys(): - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -942,7 +1041,7 @@ def get_model( trust_remote_code=trust_remote_code, ) if "AutoModelForSeq2SeqLM" in auto_map.keys(): - return Seq2SeqLM( + return Seq2SeqLM.fallback( model_id, revision, quantize=quantize, diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 17aa12e8..732b4c53 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -4,22 +4,12 @@ import torch.distributed from typing import Optional, Type from transformers import ( - AutoTokenizer, - AutoConfig, PreTrainedTokenizerBase, ) -from text_generation_server.models.custom_modeling.bloom_modeling import ( - BloomForCausalLM, -) from text_generation_server.models import CausalLM from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.pb import generate_pb2 -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) class BloomCausalLMBatch(CausalLMBatch): @@ -37,69 +27,6 @@ class BloomCausalLMBatch(CausalLMBatch): class BLOOMSharded(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - slow_but_exact=False, - tp_parallel=True, - trust_remote_code=trust_remote_code, - ) - config.pad_token_id = 3 - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, - device=device, - dtype=dtype, - process_group=self.process_group, - prefix="transformer", - ) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = BloomForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - @property def batch_type(self) -> Type[CausalLMBatch]: return BloomCausalLMBatch diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 10c64c66..cac36ebd 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -1,13 +1,25 @@ import torch import time +import torch.distributed from dataclasses import dataclass from opentelemetry import trace -from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase +from transformers import ( + AutoConfig, + AutoTokenizer, + AutoModelForCausalLM, + PreTrainedTokenizerBase, +) from typing import Optional, Tuple, List, Type, Dict +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) from text_generation_server.models import Model from text_generation_server.utils.chunks import concat_text_chunks +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.models.types import ( Batch, @@ -478,10 +490,87 @@ class CausalLMBatch(Batch): return len(self.requests) +@dataclass +class CausalLMBatchKeysLast(Batch): + keys_head_dim_last: bool = False + + class CausalLM(Model): def __init__( self, model_id: str, + model_class, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + default_dtype=torch.float16, + trust_remote_code: bool = False, + tokenizer_class=AutoTokenizer, + config_class=AutoConfig, + batch_class=CausalLMBatch, + ): + self.batch_class = batch_class + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = default_dtype if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = default_dtype if dtype is None else dtype + else: + device = torch.device("cpu") + # Float16 doesn't exist on target. + dtype = torch.bfloat16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype + + tokenizer = tokenizer_class.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + config = config_class.from_pretrained( + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + ) + config.quantize = quantize + config.speculator = speculator + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = config.pad_token_id + + torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights( + filenames, device=device, dtype=dtype, process_group=self.process_group + ) + if config.quantize in ["awq", "exl2", "gptq", "marlin"]: + weights._set_gptq_params(model_id, revision) + + model = model_class(config, weights) + + torch.distributed.barrier(group=self.process_group) + super().__init__( + model_id=model_id, + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + ) + + @classmethod + def fallback( + cls, + model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, speculator: Optional[str] = None, @@ -537,7 +626,12 @@ class CausalLM(Model): else: tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - super(CausalLM, self).__init__( + self = cls.__new__( + cls, + ) + self.batch_class = CausalLMBatch + super().__init__( + self, model_id=model_id, model=model, tokenizer=tokenizer, @@ -545,15 +639,11 @@ class CausalLM(Model): dtype=dtype, device=device, ) + return self @property def batch_type(self) -> Type[CausalLMBatch]: - return CausalLMBatch - - def decode(self, generated_ids: List[int]) -> str: - return self.tokenizer.decode( - generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) + return self.batch_class def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index cfa6b2fe..625baa91 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -442,7 +442,7 @@ class FlashGemma2Model(torch.nn.Module): class FlashGemma2ForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights, causal: bool): + def __init__(self, prefix, config, weights, *, causal: bool = True): super().__init__() embed_norm = config.hidden_size**0.5 diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 842df0d4..b7ce6307 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -419,7 +419,7 @@ class FlashGemmaModel(torch.nn.Module): class FlashGemmaForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights, causal: bool): + def __init__(self, prefix, config, weights, *, causal: bool = True): super().__init__() embed_norm = config.hidden_size**0.5 diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 30989a37..2bc305fe 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -464,8 +464,9 @@ class FlashSantacoderModel(nn.Module): class FlashSantacoderForCausalLM(nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights): super().__init__() + config.transpose = config.architectures[0].startswith("GPT2") self.transformer = FlashSantacoderModel(config, weights) self.lm_head = SpeculativeHead.load( config, prefix="transformer.wte", weights=weights diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index 6d38442c..567131ef 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -136,7 +136,7 @@ class LlavaNextForConditionalGeneration(nn.Module): self.config = config config.text_config.quantize = config.quantize config.text_config.speculator = config.speculator - self.language_model = load_text_model( + self.text_model = load_text_model( prefix="language_model" if not prefix else f"{prefix}.language_model", config=config.text_config, weights=weights, @@ -180,7 +180,7 @@ class LlavaNextForConditionalGeneration(nn.Module): image_sizes: Optional[torch.LongTensor] = None, adapter_data: Optional[torch.Tensor] = None, ): - inputs_embeds = self.language_model.embed_tokens(input_ids) + inputs_embeds = self.text_model.embed_tokens(input_ids) if pixel_values is not None and len(pixel_values) > 0: # num_special_image_tokens = (input_ids == self.config.image_token_index).sum() # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid" @@ -269,7 +269,7 @@ class LlavaNextForConditionalGeneration(nn.Module): input_ids, inputs_embeds, image_features ) - hidden_states = self.language_model.model( + hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, @@ -283,5 +283,5 @@ class LlavaNextForConditionalGeneration(nn.Module): ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - logits, speculative_logits = self.language_model.lm_head(hidden_states) + logits, speculative_logits = self.text_model.lm_head(hidden_states) return logits, speculative_logits diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 4f276ed4..c7f5f1f9 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -10,7 +10,12 @@ import numpy as np from loguru import logger from dataclasses import dataclass from opentelemetry import trace -from transformers import PreTrainedTokenizerBase +from transformers import ( + PreTrainedTokenizerBase, + AutoConfig, + AutoTokenizer, + GenerationConfig, +) from typing import Iterable, Optional, Tuple, List, Type, Dict from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata @@ -21,6 +26,12 @@ from text_generation_server.models import Model from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.dist import RANK from text_generation_server.utils.speculate import get_speculate +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, + hub, +) from text_generation_server.models.types import ( Batch, Tokens, @@ -799,29 +810,110 @@ class FlashCausalLMBatch(Batch): return len(self.requests) +ADAPTER_LAYERS = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", +] +ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"} + + class FlashCausalLM(Model): def __init__( self, model_id: str, - model: torch.nn.Module, - tokenizer: PreTrainedTokenizerBase, - num_layers: int, - num_kv_heads: int, - head_size: int, - dtype: torch.dtype, - device: torch.device, - rank: int = 0, - world_size: int = 1, - sliding_window: Optional[int] = None, + model_class, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + lora_adapter_ids: Optional[list] = [], + tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer, + config_class: PreTrainedTokenizerBase = AutoConfig, + default_dtype=torch.float16, + aliases=None, + # Used for Santacoder override of config + num_kv_heads=None, + skip_special_tokens: bool = True, ): - self.num_layers = num_layers - self.num_kv_heads = num_kv_heads - self.head_size = head_size + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = default_dtype if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = default_dtype if dtype is None else dtype + else: + device = torch.device("cpu") + # Float16 doesn't exist on target. + dtype = torch.bfloat16 if dtype is None else dtype + else: + raise NotImplementedError(f"{model_class} is only available on GPU") + + tokenizer = tokenizer_class.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + try: + generation_config = GenerationConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + if isinstance(generation_config.eos_token_id, (list, set)): + # TODO Huge hack + tokenizer._eos_token_ids = set(generation_config.eos_token_id) + except Exception: + pass + + config = config_class.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + config.quantize = quantize + config.speculator = speculator + if getattr(config, "sliding_window", None) is not None: + set_sliding_window(config.sliding_window) + else: + config.sliding_window = None + + torch.distributed.barrier(group=self.process_group) + + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights( + filenames, device, dtype, process_group=self.process_group, aliases=aliases + ) + if config.quantize in ["awq", "exl2", "gptq", "marlin"]: + weights._set_gptq_params(model_id, revision) + + prefix = "" + model = model_class(prefix, config, weights) + torch.distributed.barrier(group=self.process_group) + + # VLM models define the config we care about in their text_config + text_config = getattr(config, "text_config", None) + if text_config is not None: + config = text_config + self.num_layers = config.num_hidden_layers + # Validation is done in the model itself + if num_kv_heads is None: + num_kv_heads = getattr(config, "num_key_value_heads", None) + if num_kv_heads is None: + # Final overide for GPT2 + num_kv_heads = config.n_head + self.num_kv_heads = num_kv_heads // self.process_group.size() + self.head_size = config.hidden_size // config.num_attention_heads self.cuda_graphs = {} self.kv_cache = [] - super(FlashCausalLM, self).__init__( + super().__init__( model_id=model_id, model=model, tokenizer=tokenizer, @@ -830,7 +922,7 @@ class FlashCausalLM(Model): device=device, rank=rank, world_size=world_size, - sliding_window=sliding_window, + sliding_window=config.sliding_window, ) @property @@ -1578,3 +1670,72 @@ class FlashCausalLM(Model): forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode return generations, batch, (forward_ns, decode_ns) + + @property + def supports_adapter_loading(self) -> bool: + return True + + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: + layer_weights = {} + + prefix = "model.layers" + + # This accounts for VLMs (e.g. LlavaNext, Idefics2) + # that have a language_model inside of the larger model. + if hasattr(self.model, "language_model"): + _model = self.model.language_model + elif hasattr(self.model, "text_model"): + _model = self.model.text_model + else: + _model = self.model + + for i, layer in enumerate(_model.model.layers): + layer_weights[(i, "q_proj")] = ( + f"{prefix}.{i}.self_attn.q_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, "k_proj")] = ( + f"{prefix}.{i}.self_attn.k_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, "v_proj")] = ( + f"{prefix}.{i}.self_attn.v_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, "o_proj")] = ( + f"{prefix}.{i}.self_attn.o_proj", + layer.self_attn.o_proj, + ) + + # TODO: this is a hack to avoid the gate_proj for + # FlashStarcoder2 that doesnt have these layers + if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"): + layer_weights[(i, "gate_proj")] = ( + f"{prefix}.{i}.mlp.gate_proj", + layer.mlp.gate_up_proj, + ) + layer_weights[(i, "up_proj")] = ( + f"{prefix}.{i}.mlp.up_proj", + layer.mlp.gate_up_proj, + ) + layer_weights[(i, "down_proj")] = ( + f"{prefix}.{i}.mlp.down_proj", + layer.mlp.down_proj, + ) + + layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head) + return layer_weights + + @property + def adapter_layers(self) -> List[str]: + return ADAPTER_LAYERS + + @property + def default_traced_adapter_layers(self) -> List[str]: + return ["q_proj", "v_proj"] + + def get_num_layers_for_type(self, layer_type: str) -> int: + return 1 if layer_type == "lm_head" else len(self.model.model.layers) + + def is_row_parallel(self, layer_type: str) -> bool: + return layer_type in ROW_PARALLEL diff --git a/server/text_generation_server/models/flash_cohere.py b/server/text_generation_server/models/flash_cohere.py deleted file mode 100644 index 9f8bcb3f..00000000 --- a/server/text_generation_server/models/flash_cohere.py +++ /dev/null @@ -1,75 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from typing import Optional -from transformers import AutoTokenizer, AutoConfig - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_cohere_modeling import ( - FlashCohereForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - -tracer = trace.get_tracer(__name__) - - -class FlashCohere(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - raise NotImplementedError("FlashCohere is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - use_fast=True, - from_slow=False, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashCohereForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(FlashCohere, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_dbrx.py b/server/text_generation_server/models/flash_dbrx.py deleted file mode 100644 index 2aba6a00..00000000 --- a/server/text_generation_server/models/flash_dbrx.py +++ /dev/null @@ -1,100 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from typing import Optional -from transformers import AutoTokenizer -from transformers.models.gpt2 import GPT2TokenizerFast - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_dbrx_modeling import ( - FlashDbrxForCausalLM, - DbrxConfig, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - -tracer = trace.get_tracer(__name__) - - -class FlashDbrx(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashDBRX is only available on GPU") - - try: - tokenizer = GPT2TokenizerFast.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - use_fast=True, - from_slow=False, - ) - except: - try: - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - use_fast=True, - from_slow=False, - ) - except: - # FIXME: change back to model id once the tokenizer.json is merged - tokenizer = GPT2TokenizerFast.from_pretrained( - "Xenova/dbrx-instruct-tokenizer", - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - use_fast=True, - from_slow=False, - ) - - config = DbrxConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashDbrxForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(FlashDbrx, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py deleted file mode 100644 index 7e2b8780..00000000 --- a/server/text_generation_server/models/flash_gemma.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from typing import Optional -from transformers import AutoConfig, AutoTokenizer - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( - FlashGemmaForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashGemma(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.bfloat16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashGemma is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - # TODO hardcoded - prefix = "" - model = FlashGemmaForCausalLM(prefix, config, weights, causal=True) - - torch.distributed.barrier(group=self.process_group) - super(FlashGemma, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_gemma2.py b/server/text_generation_server/models/flash_gemma2.py deleted file mode 100644 index 86cfc7e2..00000000 --- a/server/text_generation_server/models/flash_gemma2.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from typing import Optional -from transformers import PretrainedConfig, AutoTokenizer - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_gemma2_modeling import ( - FlashGemma2ForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashGemma2(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.bfloat16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashGemma2 is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = PretrainedConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - # TODO hardcoded - prefix = "" - model = FlashGemma2ForCausalLM(prefix, config, weights, causal=True) - - torch.distributed.barrier(group=self.process_group) - super(FlashGemma2, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_gpt2.py b/server/text_generation_server/models/flash_gpt2.py deleted file mode 100644 index 323fcafa..00000000 --- a/server/text_generation_server/models/flash_gpt2.py +++ /dev/null @@ -1,82 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoConfig, AutoTokenizer, GenerationConfig -from transformers.models.gpt2 import GPT2Tokenizer -from typing import Optional - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_gpt2_modeling import ( - FlashGPT2ForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashGPT2(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashGPT2 is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - prefix = "" - model = FlashGPT2ForCausalLM(prefix, config, weights) - torch.distributed.barrier(group=self.process_group) - super(FlashGPT2, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py deleted file mode 100644 index d996b9c3..00000000 --- a/server/text_generation_server/models/flash_llama.py +++ /dev/null @@ -1,171 +0,0 @@ -import os -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoConfig, AutoTokenizer, GenerationConfig -from typing import Optional, Tuple, Dict, List - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_llama_modeling import ( - FlashLlamaForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, - hub, -) - -tracer = trace.get_tracer(__name__) - -from text_generation_server.utils.import_utils import SYSTEM - -ADAPTER_LAYERS = [ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "gate_proj", - "up_proj", - "down_proj", -] -ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"} - - -class FlashLlama(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - lora_adapter_ids: Optional[list] = [], - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashLlama is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - try: - generation_config = GenerationConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - if isinstance(generation_config.eos_token_id, (list, set)): - # TODO Huge hack - tokenizer._eos_token_ids = set(generation_config.eos_token_id) - except Exception: - pass - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["awq", "exl2", "gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - prefix = "" - model = FlashLlamaForCausalLM(prefix, config, weights) - torch.distributed.barrier(group=self.process_group) - super(FlashLlama, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - @property - def supports_adapter_loading(self) -> bool: - return True - - def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: - layer_weights = {} - - prefix = "model.layers" - - # This accounts for VLMs (e.g. LlavaNext, Idefics2) - # that have a language_model inside of the larger model. - if hasattr(self.model, "language_model"): - _model = self.model.language_model - elif hasattr(self.model, "text_model"): - _model = self.model.text_model - else: - _model = self.model - - for i, layer in enumerate(_model.model.layers): - layer_weights[(i, "q_proj")] = ( - f"{prefix}.{i}.self_attn.q_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "k_proj")] = ( - f"{prefix}.{i}.self_attn.k_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "v_proj")] = ( - f"{prefix}.{i}.self_attn.v_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "o_proj")] = ( - f"{prefix}.{i}.self_attn.o_proj", - layer.self_attn.o_proj, - ) - - layer_weights[(i, "gate_proj")] = ( - f"{prefix}.{i}.mlp.gate_proj", - layer.mlp.gate_up_proj, - ) - layer_weights[(i, "up_proj")] = ( - f"{prefix}.{i}.mlp.up_proj", - layer.mlp.gate_up_proj, - ) - layer_weights[(i, "down_proj")] = ( - f"{prefix}.{i}.mlp.down_proj", - layer.mlp.down_proj, - ) - - layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head) - return layer_weights - - @property - def adapter_layers(self) -> List[str]: - return ADAPTER_LAYERS - - @property - def default_traced_adapter_layers(self) -> List[str]: - return ["q_proj", "v_proj"] - - def get_num_layers_for_type(self, layer_type: str) -> int: - return 1 if layer_type == "lm_head" else len(self.model.model.layers) - - def is_row_parallel(self, layer_type: str) -> bool: - return layer_type in ROW_PARALLEL diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 0f5746de..2b2bd2e0 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -1,24 +1,7 @@ import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoTokenizer, AutoConfig from typing import Optional, Tuple, Dict, List from text_generation_server.models import FlashCausalLM -from text_generation_server.models.flash_causal_lm import set_sliding_window -from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( - FlashMistralForCausalLM, - MistralConfig, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) ADAPTER_LAYERS = [ @@ -33,88 +16,7 @@ ADAPTER_LAYERS = [ ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"} -class BaseFlashMistral(FlashCausalLM): - def __init__( - self, - model_cls, - model_id: str, - config_cls=AutoConfig, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - tokenizer_class=AutoTokenizer, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashMistral is only available on GPU") - - tokenizer = tokenizer_class.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = config_cls.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - # Set context windows - if getattr(config, "sliding_window", None) is not None: - set_sliding_window(config.sliding_window) - else: - config.sliding_window = None - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - prefix = "" - model = model_cls(prefix, config, weights) - - self.cuda_graphs = {} - - torch.distributed.barrier(group=self.process_group) - num_layers, num_kv_heads, head_size = self.get_layer_config(model) - super().__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=num_layers, - num_kv_heads=num_kv_heads, - head_size=head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - sliding_window=config.sliding_window, - ) - - def get_layer_config(self, model) -> Tuple[int, int, int]: - return ( - len(model.model.layers), - model.model.num_key_value_heads, - model.model.head_size, - ) - +class FlashMistral(FlashCausalLM): @property def supports_adapter_loading(self) -> bool: return True @@ -126,9 +28,7 @@ class BaseFlashMistral(FlashCausalLM): # This accounts for VLMs (e.g. LlavaNext, Idefics2) # that have a language_model inside of the larger model. - if hasattr(self.model, "language_model"): - _model = self.model.language_model - elif hasattr(self.model, "text_model"): + if hasattr(self.model, "text_model"): _model = self.model.text_model else: _model = self.model @@ -183,25 +83,3 @@ class BaseFlashMistral(FlashCausalLM): def is_row_parallel(self, layer_type: str) -> bool: return layer_type in ROW_PARALLEL - - -class FlashMistral(BaseFlashMistral): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - super(FlashMistral, self).__init__( - config_cls=MistralConfig, - model_cls=FlashMistralForCausalLM, - model_id=model_id, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) diff --git a/server/text_generation_server/models/flash_mixtral.py b/server/text_generation_server/models/flash_mixtral.py deleted file mode 100644 index 587d423f..00000000 --- a/server/text_generation_server/models/flash_mixtral.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch - -from typing import Optional - -from text_generation_server.models.flash_mistral import BaseFlashMistral -from text_generation_server.models.custom_modeling.flash_mixtral_modeling import ( - MixtralConfig, - FlashMixtralForCausalLM, -) - - -class FlashMixtral(BaseFlashMistral): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - super(FlashMixtral, self).__init__( - config_cls=MixtralConfig, - model_cls=FlashMixtralForCausalLM, - model_id=model_id, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py deleted file mode 100644 index ac1fd573..00000000 --- a/server/text_generation_server/models/flash_neox.py +++ /dev/null @@ -1,82 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoTokenizer, AutoConfig -from typing import Optional - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_neox_modeling import ( - FlashGPTNeoXForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashNeoXSharded(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashNeoX is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group - ) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashGPTNeoXForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(FlashNeoXSharded, self).__init__( - model_id=model_id, - model=model.to(device), - tokenizer=tokenizer, - num_layers=len(model.gpt_neox.layers), - num_kv_heads=model.gpt_neox.num_heads, - head_size=model.gpt_neox.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_phi.py b/server/text_generation_server/models/flash_phi.py deleted file mode 100644 index a530d1c3..00000000 --- a/server/text_generation_server/models/flash_phi.py +++ /dev/null @@ -1,111 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoConfig, AutoTokenizer -from typing import Optional - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_phi_modeling import ( - FlashPhiForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashPhi(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashPhi is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashPhiForCausalLM(config, weights) - if speculator: - from text_generation_server.utils.medusa import MedusaModel - from huggingface_hub import hf_hub_download - import json - import os - from pathlib import Path - - is_local_model = ( - Path(speculator).exists() and Path(speculator).is_dir() - ) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None - - if not is_local_model: - medusa_config = hf_hub_download( - speculator, revision=revision, filename="config.json" - ) - medusa_head = hf_hub_download( - speculator, revision=revision, filename="medusa_lm_head.pt" - ) - else: - medusa_config = str(Path(speculator) / "config.json") - medusa_head = str(Path(speculator) / "medusa_lm_head.pt") - - with open(medusa_config, "r") as f: - config = json.load(f) - medusa_sf = medusa_head[: -len(".pt")] + ".safetensors" - weights = Weights( - [medusa_sf], device, dtype, process_group=self.process_group - ) - lm_head = model.lm_head - model.lm_head = MedusaModel(config, weights, lm_head) - - torch.distributed.barrier(group=self.process_group) - super(FlashPhi, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_qwen2.py b/server/text_generation_server/models/flash_qwen2.py deleted file mode 100644 index cd6078f1..00000000 --- a/server/text_generation_server/models/flash_qwen2.py +++ /dev/null @@ -1,93 +0,0 @@ -import math - -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoTokenizer, AutoConfig -from typing import Optional - -from text_generation_server.models.flash_mistral import ( - BaseFlashMistral, - set_sliding_window, -) -from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( - Qwen2ForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashQwen2(BaseFlashMistral): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashQwen2 is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - # Set context windows - if config.sliding_window is not None: - set_sliding_window(config.sliding_window) - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = Qwen2ForCausalLM(config, weights) - - self.cuda_graphs = {} - - torch.distributed.barrier(group=self.process_group) - super(BaseFlashMistral, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - sliding_window=config.sliding_window, - ) diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py deleted file mode 100644 index b1f75adc..00000000 --- a/server/text_generation_server/models/flash_rw.py +++ /dev/null @@ -1,91 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoTokenizer -from typing import Optional - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_rw_modeling import ( - RWConfig, - FlashRWForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashRWSharded(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashRW is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = RWConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, - device, - dtype, - process_group=self.process_group, - aliases={ - "lm_head.weight": ["transformer.word_embeddings.weight"], - "transformer.word_embeddings.weight": ["lm_head.weight"], - }, - ) - - config.quantize = quantize - config.speculator = speculator - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashRWForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(FlashRWSharded, self).__init__( - model_id=model_id, - model=model.to(device), - tokenizer=tokenizer, - num_layers=len(model.transformer.h), - num_kv_heads=model.transformer.cache_size, - head_size=model.transformer.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py deleted file mode 100644 index e1a7b36e..00000000 --- a/server/text_generation_server/models/flash_santacoder.py +++ /dev/null @@ -1,99 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoTokenizer, AutoConfig -from typing import Optional, List -import json -import os - -from huggingface_hub import hf_hub_download -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_santacoder_modeling import ( - FlashSantacoderForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashSantacoderSharded(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashSantacoderSharded is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - trust_remote_code=True, - ) - config.quantize = quantize - config.speculator = speculator - config.transpose = config.architectures[0].startswith("GPT2") - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, - device=device, - dtype=dtype, - process_group=self.process_group, - aliases={"transformer.wte.weight": ["lm_head.weight"]}, - ) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashSantacoderForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(FlashSantacoderSharded, self).__init__( - model_id=model_id, - model=model.to(device), - tokenizer=tokenizer, - num_layers=len(model.transformer.h), - num_kv_heads=1, - head_size=model.transformer.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - def decode(self, generated_ids: List[int]) -> str: - # Do not skip special tokens as they are used for custom parsing rules of the generated text - return self.tokenizer.decode( - generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False - ) diff --git a/server/text_generation_server/models/flash_starcoder2.py b/server/text_generation_server/models/flash_starcoder2.py deleted file mode 100644 index 369e9e4c..00000000 --- a/server/text_generation_server/models/flash_starcoder2.py +++ /dev/null @@ -1,84 +0,0 @@ -import math - -import torch - -from typing import Optional - -from transformers.models.gpt2 import GPT2TokenizerFast - -from text_generation_server.models.flash_mistral import ( - BaseFlashMistral, - set_sliding_window, -) -from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import ( - Starcoder2Config, - FlashStarcoder2ForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - - -# Starcoder2 has the same base as Mistral -class FlashStarcoder2(BaseFlashMistral): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - raise NotImplementedError("FlashStarcoder2 is only available on GPU") - - tokenizer = GPT2TokenizerFast.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = Starcoder2Config.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - # Set context windows - if config.sliding_window is not None: - set_sliding_window(config.sliding_window) - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashStarcoder2ForCausalLM(config, weights) - - self.cuda_graphs = {} - - torch.distributed.barrier(group=self.process_group) - super(BaseFlashMistral, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - sliding_window=config.sliding_window, - ) diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 30c92d90..2d43244a 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -162,83 +162,3 @@ class GalacticaCausalLMBatch(CausalLMBatch): padding_right_offset=padding_right_offset, max_tokens=max_tokens, ) - - -class GalacticaSharded(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - tp_parallel=True, - trust_remote_code=trust_remote_code, - ) - config.quantize = quantize - tokenizer.pad_token_id = config.pad_token_id - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group - ) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = OPTForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - @property - def batch_type(self) -> Type[CausalLMBatch]: - return GalacticaCausalLMBatch - - def decode(self, generated_ids: List[int]) -> str: - # Do not skip special tokens as they are used for custom parsing rules of the generated text - return self.tokenizer.decode( - generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False - ) - - def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ): - outputs, speculative_logits = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - use_cache=True, - ) - return outputs.logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py deleted file mode 100644 index c37cfb7d..00000000 --- a/server/text_generation_server/models/gpt_neox.py +++ /dev/null @@ -1,89 +0,0 @@ -import torch -import torch.distributed - -from typing import Optional - -from transformers import ( - AutoTokenizer, - AutoConfig, -) -from text_generation_server.models import CausalLM -from text_generation_server.models.custom_modeling.neox_modeling import ( - GPTNeoxForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - - -class GPTNeoxSharded(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - tokenizer.pad_token = tokenizer.eos_token - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group - ) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = GPTNeoxForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ): - outputs, speculative_logits = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - use_cache=True, - ) - - return outputs.logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/idefics2.py b/server/text_generation_server/models/idefics2.py deleted file mode 100644 index 314c0500..00000000 --- a/server/text_generation_server/models/idefics2.py +++ /dev/null @@ -1,51 +0,0 @@ -import torch - -from typing import Optional, Tuple - -from transformers import ( - AutoProcessor, -) -from text_generation_server.models.custom_modeling.idefics2 import ( - Idefics2ForConditionalGeneration, -) - -from text_generation_server.models.vlm_causal_lm import VlmCausalLM - - -class Idefics2(VlmCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.processor = AutoProcessor.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - # XXX: Extremely important to cap resolution in order to limit - # VRAM usage. - size={"longest_edge": 448, "shortest_edge": 378}, - ) - super().__init__( - model_cls=Idefics2ForConditionalGeneration, - model_id=model_id, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - def get_layer_config(self, model) -> Tuple[int, int, int]: - return ( - len(model.text_model.model.layers), - model.text_model.model.num_key_value_heads, - model.text_model.model.head_size, - ) - - def max_past(self) -> Optional[int]: - return getattr(self.model.text_model, "max_past", None) diff --git a/server/text_generation_server/models/llava_next.py b/server/text_generation_server/models/llava_next.py deleted file mode 100644 index effe8b91..00000000 --- a/server/text_generation_server/models/llava_next.py +++ /dev/null @@ -1,46 +0,0 @@ -import torch - -from typing import Optional, Tuple - -from transformers import ( - AutoProcessor, -) -from text_generation_server.models.custom_modeling.llava_next import ( - LlavaNextForConditionalGeneration, -) - -from text_generation_server.models.vlm_causal_lm import VlmCausalLM - - -class LlavaNext(VlmCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.processor = AutoProcessor.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - super().__init__( - model_cls=LlavaNextForConditionalGeneration, - model_id=model_id, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - def get_layer_config(self, model) -> Tuple[int, int, int]: - return ( - len(model.language_model.model.layers), - model.language_model.model.num_key_value_heads, - model.language_model.model.head_size, - ) - - def max_past(self) -> Optional[int]: - return getattr(self.model.language_model, "max_past", None) diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py deleted file mode 100644 index 1e79b25f..00000000 --- a/server/text_generation_server/models/mpt.py +++ /dev/null @@ -1,105 +0,0 @@ -import torch -import torch.distributed - -from pathlib import Path -from typing import Optional, Type -from opentelemetry import trace -from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase -from huggingface_hub import hf_hub_download -import json - -from text_generation_server.models import CausalLM -from text_generation_server.models.causal_lm import CausalLMBatch -from text_generation_server.pb import generate_pb2 -from text_generation_server.models.custom_modeling.mpt_modeling import ( - MPTForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - -tracer = trace.get_tracer(__name__) - - -class MPTCausalLMBatch(CausalLMBatch): - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, - ) -> "CausalLMBatch": - batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device) - batch.keys_head_dim_last = False - return batch - - -class MPTSharded(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - tokenizer.pad_token = tokenizer.eos_token - - # If model_id is a local path, load the file directly - local_path = Path(model_id, "config.json") - if local_path.exists(): - filename = str(local_path.resolve()) - else: - filename = hf_hub_download( - model_id, revision=revision, filename="config.json" - ) - with open(filename, "r") as f: - config = json.load(f) - config = PretrainedConfig(**config) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - config.quantize = quantize - model = MPTForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=False, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - @property - def batch_type(self) -> Type[CausalLMBatch]: - return MPTCausalLMBatch diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py deleted file mode 100644 index 6d7d07f5..00000000 --- a/server/text_generation_server/models/opt.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch -import torch.distributed - -from typing import Optional - -from transformers import ( - AutoTokenizer, - AutoConfig, -) -from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM -from text_generation_server.models import CausalLM -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - - -class OPTSharded(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - ) - config.quantize = quantize - config.speculator = speculator - tokenizer.pad_token_id = config.pad_token_id - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group - ) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = OPTForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ): - outputs, speculative_logits = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - use_cache=True, - ) - - return outputs.logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/pali_gemma.py b/server/text_generation_server/models/pali_gemma.py index a167e467..3994ac70 100644 --- a/server/text_generation_server/models/pali_gemma.py +++ b/server/text_generation_server/models/pali_gemma.py @@ -74,45 +74,3 @@ class PaliGemmaBatch(VlmCausalLMBatch): else: image_inputs = None return batch_tokenized_inputs, image_inputs - - -class PaliGemma(VlmCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.processor = AutoProcessor.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - ) - - super().__init__( - config_cls=AutoConfig, - model_cls=PaliGemmaForConditionalGeneration, - model_id=model_id, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - @property - def batch_type(self): - return PaliGemmaBatch - - def get_layer_config(self, model) -> Tuple[int, int, int]: - return ( - len(model.text_model.model.layers), - model.text_model.model.num_key_value_heads, - model.text_model.model.head_size, - ) - - def max_past(self) -> Optional[int]: - return getattr(self.model.text_model, "max_past", None) diff --git a/server/text_generation_server/models/phi.py b/server/text_generation_server/models/phi.py deleted file mode 100644 index 93d42b2b..00000000 --- a/server/text_generation_server/models/phi.py +++ /dev/null @@ -1,69 +0,0 @@ -import torch -import torch.distributed - -from transformers import AutoConfig, AutoTokenizer -from typing import Optional, List, Tuple - -from text_generation_server.models import CausalLM -from text_generation_server.models.custom_modeling.phi_modeling import ( - PhiConfig, - PhiForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - - -class Phi(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, _rank, _world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device("cuda") - dtype = torch.float16 if dtype is None else dtype - else: - if quantize: - raise ValueError("quantization is not available on CPU") - - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - config = PhiConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - - tokenizer.bos_token_id = config.bos_token_id - tokenizer.eos_token_id = config.eos_token_id - tokenizer.pad_token = tokenizer.eos_token - - config.quantize = quantize - config.speculator = speculator - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - model = PhiForCausalLM(config, weights) - torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - ) diff --git a/server/text_generation_server/models/rw.py b/server/text_generation_server/models/rw.py deleted file mode 100644 index 37ca277b..00000000 --- a/server/text_generation_server/models/rw.py +++ /dev/null @@ -1,84 +0,0 @@ -import torch - -from transformers import AutoTokenizer, AutoModelForCausalLM -from typing import List, Optional, Tuple - -from text_generation_server.models import CausalLM - - -class RW(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - if speculator: - raise RuntimeError("Medusa decoding is not enabled for AutoModel") - - if torch.cuda.is_available(): - device = torch.device("cuda") - dtype = torch.float16 if dtype is None else dtype - else: - if quantize: - raise ValueError("quantization is not available on CPU") - - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - model = AutoModelForCausalLM.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - device_map=( - "auto" - if torch.cuda.is_available() and torch.cuda.device_count() > 1 - else None - ), - load_in_8bit=quantize == "bitsandbytes", - trust_remote_code=trust_remote_code, - ) - if torch.cuda.is_available() and torch.cuda.device_count() == 1: - model = model.cuda() - - if tokenizer.pad_token_id is None: - if model.config.pad_token_id is not None: - tokenizer.pad_token_id = model.config.pad_token_id - elif model.config.eos_token_id is not None: - tokenizer.pad_token_id = model.config.eos_token_id - elif tokenizer.eos_token_id is not None: - tokenizer.pad_token_id = tokenizer.eos_token_id - else: - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - ) - - def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ): - # Model Forward - outputs, speculative_logits = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - use_cache=True, - ) - - return outputs.logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py deleted file mode 100644 index caddbe19..00000000 --- a/server/text_generation_server/models/santacoder.py +++ /dev/null @@ -1,77 +0,0 @@ -import torch -import torch.distributed - -from typing import Optional, List -from transformers import AutoTokenizer, AutoModelForCausalLM - -from text_generation_server.models import CausalLM - -FIM_PREFIX = "" -FIM_MIDDLE = "" -FIM_SUFFIX = "" -FIM_PAD = "" -EOD = "<|endoftext|>" - - -class SantaCoder(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - if torch.cuda.is_available(): - device = torch.device("cuda") - dtype = torch.float16 if dtype is None else dtype - else: - if quantize: - raise ValueError("quantization is not available on CPU") - - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - tokenizer.add_special_tokens( - { - "additional_special_tokens": [ - EOD, - FIM_PREFIX, - FIM_MIDDLE, - FIM_SUFFIX, - FIM_PAD, - ], - "pad_token": EOD, - } - ) - with device: - model = AutoModelForCausalLM.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - load_in_8bit=quantize == "bitsandbytes", - trust_remote_code=trust_remote_code, - ) - - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - ) - - def decode(self, generated_ids: List[int]) -> str: - # Do not skip special tokens as they are used for custom parsing rules of the generated text - return self.tokenizer.decode( - generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False - ) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index d454d804..dbaf1253 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -1,11 +1,22 @@ import torch +import torch.distributed import time from dataclasses import dataclass from opentelemetry import trace -from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase +from transformers import ( + AutoTokenizer, + AutoModelForSeq2SeqLM, + PreTrainedTokenizerBase, + AutoConfig, +) from typing import Optional, Tuple, List, Type, Dict +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.models import Model @@ -531,6 +542,80 @@ class Seq2SeqLM(Model): def __init__( self, model_id: str, + model_class, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + default_dtype=torch.float16, + trust_remote_code: bool = False, + config_class=AutoConfig, + tokenizer_class=AutoTokenizer, + aliases=None, + ): + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = default_dtype if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = default_dtype if dtype is None else dtype + else: + device = torch.device("cpu") + # Float16 doesn't exist on target. + dtype = torch.bfloat16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype + + config = config_class.from_pretrained( + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + ) + config.quantize = quantize + config.speculator = speculator + + tokenizer = tokenizer_class.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + tokenizer.bos_token_id = config.decoder_start_token_id + + torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights( + filenames, + device=device, + dtype=dtype, + process_group=self.process_group, + aliases=aliases, + ) + if config.quantize in ["awq", "exl2", "gptq", "marlin"]: + weights._set_gptq_params(model_id, revision) + + model = model_class(config, weights) + + torch.distributed.barrier(group=self.process_group) + super().__init__( + model_id=model_id, + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + ) + + @classmethod + def fallback( + cls, + model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, speculator: Optional[str] = None, @@ -574,7 +659,11 @@ class Seq2SeqLM(Model): ) tokenizer.bos_token_id = model.config.decoder_start_token_id - super(Seq2SeqLM, self).__init__( + self = cls.__new__( + cls, + ) + super().__init__( + self, model_id=model_id, model=model, tokenizer=tokenizer, @@ -582,16 +671,12 @@ class Seq2SeqLM(Model): dtype=dtype, device=device, ) + return self @property def batch_type(self) -> Type[Seq2SeqLMBatch]: return Seq2SeqLMBatch - def decode(self, decoder_ids: List[int]) -> str: - return self.tokenizer.decode( - decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) - def forward( self, input_ids, diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py deleted file mode 100644 index adef664c..00000000 --- a/server/text_generation_server/models/t5.py +++ /dev/null @@ -1,115 +0,0 @@ -import torch -import torch.distributed - -from typing import List, Optional, Tuple - -from transformers import ( - AutoTokenizer, - AutoConfig, -) - -from text_generation_server.models import Seq2SeqLM -from text_generation_server.models.custom_modeling.t5_modeling import ( - T5ForConditionalGeneration, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - - -class T5Sharded(Seq2SeqLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - ) - config.quantize = quantize - config.speculator = speculator - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - tokenizer.bos_token_id = config.decoder_start_token_id - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, - device=device, - dtype=dtype, - process_group=self.process_group, - aliases={ - "shared.weight": [ - "encoder.embed_tokens.weight", - "decoder.embed_tokens.weight", - ] - }, - ) - - model = T5ForConditionalGeneration(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(Seq2SeqLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - def forward( - self, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask: Optional, - encoder_last_hidden_state: Optional, - past_key_values: Optional = None, - ) -> Tuple[ - torch.Tensor, - torch.Tensor, - List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], - ]: - # Model Forward - outputs, speculative_logits = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - encoder_outputs=encoder_last_hidden_state, - past_key_values=past_key_values, - use_cache=True, - ) - - return ( - outputs.logits, - speculative_logits, - outputs.encoder_last_hidden_state, - outputs.past_key_values, - ) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 1cdf37ea..ace48805 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -9,10 +9,11 @@ from typing import Iterable, Optional, Tuple, List, Type, Dict from transformers import PreTrainedTokenizerBase from transformers.image_processing_utils import select_best_resolution from text_generation_server.pb import generate_pb2 -from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch -from text_generation_server.models.flash_mistral import ( - BaseFlashMistral, +from text_generation_server.models.flash_causal_lm import ( + FlashCausalLMBatch, + FlashCausalLM, ) +from transformers import AutoProcessor tracer = trace.get_tracer(__name__) @@ -239,10 +240,35 @@ class VlmCausalLMBatch(FlashCausalLMBatch): return batch -class VlmCausalLM(BaseFlashMistral): +class VlmCausalLM(FlashCausalLM): + def __init__( + self, + model_id: str, + *, + processor_class=AutoProcessor, + processor_kwargs=None, + batch_class=VlmCausalLMBatch, + revision, + trust_remote_code: bool, + **kwargs, + ): + if processor_kwargs is None: + processor_kwargs = {} + self.processor = processor_class.from_pretrained( + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + **processor_kwargs, + ) + self.batch_class = batch_class + super().__init__(model_id=model_id, **kwargs) + @property def batch_type(self) -> Type[VlmCausalLMBatch]: - return VlmCausalLMBatch + return self.batch_class + + def max_past(self) -> Optional[int]: + return getattr(self.model.text_model, "max_past", None) def forward( self, From 853d4eb9cf51fced975a428de15428fb4860a449 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 5 Jul 2024 09:25:29 +0000 Subject: [PATCH 03/24] Hotfixing after refactor. --- .../custom_modeling/flash_santacoder_modeling.py | 16 ++++++++-------- server/text_generation_server/models/model.py | 4 +++- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 2bc305fe..daef43cc 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -355,7 +355,7 @@ class Block(nn.Module): self.ln_2 = FastLayerNorm.load( prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon ) - self.attn = FlashMQAttention( + self.self_attn = FlashMQAttention( prefix=f"{prefix}.attn", config=config, weights=weights, @@ -378,7 +378,7 @@ class Block(nn.Module): max_s, ): hidden_states, residual = self.ln_1(hidden_states, residual) - hidden_states = self.attn( + hidden_states = self.self_attn( hidden_states, cu_seqlen_prefill, kv_cache, @@ -412,7 +412,7 @@ class FlashSantacoderModel(nn.Module): reduce=False, ) - self.h = nn.ModuleList( + self.layers = nn.ModuleList( [ Block( layer_id, @@ -426,8 +426,8 @@ class FlashSantacoderModel(nn.Module): prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon ) - self.head_size = self.h[0].attn.head_size - self.num_heads = self.h[0].attn.num_heads + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads def forward( self, @@ -446,7 +446,7 @@ class FlashSantacoderModel(nn.Module): torch.distributed.all_reduce(hidden_states, group=self.process_group) residual = None - for i, layer in enumerate(self.h): + for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, residual, @@ -467,7 +467,7 @@ class FlashSantacoderForCausalLM(nn.Module): def __init__(self, prefix, config, weights): super().__init__() config.transpose = config.architectures[0].startswith("GPT2") - self.transformer = FlashSantacoderModel(config, weights) + self.model = FlashSantacoderModel(config, weights) self.lm_head = SpeculativeHead.load( config, prefix="transformer.wte", weights=weights ) @@ -486,7 +486,7 @@ class FlashSantacoderForCausalLM(nn.Module): lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.transformer( + hidden_states = self.model( input_ids, position_ids, cu_seqlen_prefill, diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index c90fd38a..09130b85 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -60,7 +60,7 @@ class Model(ABC): self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict( LayerAdapterWeights ) - self.target_to_layer = self.adapter_target_to_layer() + self.target_to_layer = None self.loaded_adapters = set() self.static_adapter_id = adapter_id @@ -187,6 +187,8 @@ class Model(ABC): into model. Otherwise, the adapter weights are applied during the forward pass and stored separately from the base model parameters. """ + if self.target_to_layer is None: + self.target_to_layer = self.adapter_target_to_layer() if adapter_index in self.loaded_adapters: # Adapter already loaded return From b67d46336e34ca9bddc1a077fb8467086ac522cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 5 Jul 2024 12:22:45 +0200 Subject: [PATCH 04/24] Fix Starcoder2 after refactor (#2189) --- .../flash_starcoder2_modeling.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index a0273c37..2b346283 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -417,14 +417,14 @@ class Starcoder2Layer(nn.Module): class Starcoder2Model(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights): super().__init__() process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() self.embed_tokens = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights + prefix=f"{prefix}.embed_tokens", weights=weights ) self.layers = nn.ModuleList( [ @@ -437,7 +437,7 @@ class Starcoder2Model(torch.nn.Module): ] ) self.norm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load( - prefix="model.norm", weights=weights, eps=config.norm_epsilon + prefix=f"{prefix}.norm", weights=weights, eps=config.norm_epsilon ) self.gradient_checkpointing = False @@ -489,10 +489,15 @@ class Starcoder2Model(torch.nn.Module): class FlashStarcoder2ForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights): super().__init__() - self.model = Starcoder2Model(config, weights) + if not prefix: + prefix = "model" + else: + prefix = f"{prefix}.model" + + self.model = Starcoder2Model(prefix, config, weights) try: self.lm_head = SpeculativeHead.load( config, @@ -502,7 +507,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): except RuntimeError: self.lm_head = SpeculativeHead.load( config, - prefix="model.embed_tokens", + prefix=f"{prefix}.embed_tokens", weights=weights, ) From 67ef0649cf35e43358518bba44c276713d8bb2eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 5 Jul 2024 14:12:16 +0200 Subject: [PATCH 05/24] GPTQ CI improvements (#2151) * Add more representative Llama GPTQ test The Llama GPTQ test is updated to use a model with the commonly-used quantizer config format and activation sorting. The old test is kept around (but renamed) since it tests the format produced by `text-generation-server quantize`. * Add support for manually triggering a release build --- .github/workflows/build.yaml | 7 +- .github/workflows/ci_build.yaml | 11 +- .../test_flash_llama_gptq.json | 89 ++--- .../test_flash_llama_gptq_all_params.json | 85 ++--- .../test_flash_llama_gptq_load.json | 356 ++++++++--------- .../test_server_gptq_quantized.json | 89 +++++ ...test_server_gptq_quantized_all_params.json | 89 +++++ .../test_server_gptq_quantized_load.json | 358 ++++++++++++++++++ .../models/test_flash_llama_gptq.py | 4 +- 9 files changed, 805 insertions(+), 283 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized.json create mode 100644 integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized_all_params.json create mode 100644 integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized_load.json diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index b0049701..a665d9b0 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -11,6 +11,11 @@ on: # - rocm # - intel required: true + release-tests: + description: "Run release integration tests" + required: true + default: false + type: boolean jobs: build-and-push: @@ -148,7 +153,7 @@ jobs: runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"] if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest' env: - PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main') && '--release' || '' }} + PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == 'true') && '--release' || '' }} steps: - name: Checkout repository uses: actions/checkout@v4 diff --git a/.github/workflows/ci_build.yaml b/.github/workflows/ci_build.yaml index 754c4850..d62297e4 100644 --- a/.github/workflows/ci_build.yaml +++ b/.github/workflows/ci_build.yaml @@ -20,7 +20,14 @@ on: - "Dockerfile_amd" - "Dockerfile_intel" branches: - - 'main' + - "main" + workflow_dispatch: + inputs: + release-tests: + description: "Run release integration tests" + required: true + default: false + type: boolean jobs: build: @@ -33,4 +40,6 @@ jobs: uses: ./.github/workflows/build.yaml # calls the one above ^ with: hardware: ${{ matrix.hardware }} + # https://github.com/actions/runner/issues/2206 + release-tests: ${{ inputs.release-tests == true }} secrets: inherit diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json index 7797cc6c..0f99d259 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json @@ -5,85 +5,80 @@ "generated_tokens": 10, "prefill": [ { - "id": 1, + "id": 2323, "logprob": null, - "text": "" - }, - { - "id": 4321, - "logprob": -9.7890625, "text": "Test" }, { - "id": 2009, - "logprob": -9.625, - "text": "request" + "id": 1715, + "logprob": -11.34375, + "text": " request" } ], "seed": null, "tokens": [ { - "id": 13, - "logprob": -2.3359375, + "id": 198, + "logprob": -2.5742188, "special": false, "text": "\n" }, { - "id": 3057, - "logprob": -1.8779297, + "id": 262, + "logprob": -1.6230469, "special": false, - "text": "Test" + "text": " " }, { - "id": 2009, - "logprob": -1.2744141, + "id": 3270, + "logprob": -2.046875, + "special": false, + "text": " \"\"\"\n" + }, + { + "id": 262, + "logprob": -0.015281677, + "special": false, + "text": " " + }, + { + "id": 422, + "logprob": -2.1425781, + "special": false, + "text": " if" + }, + { + "id": 1715, + "logprob": -0.9238281, "special": false, "text": " request" }, { - "id": 13, - "logprob": -1.6933594, + "id": 13204, + "logprob": -0.076660156, "special": false, - "text": "\n" + "text": ".method" }, { - "id": 3057, - "logprob": -1.4648438, + "id": 624, + "logprob": -0.021987915, "special": false, - "text": "Test" + "text": " ==" }, { - "id": 2009, - "logprob": -0.15600586, + "id": 364, + "logprob": -0.39208984, "special": false, - "text": " request" + "text": " '" }, { - "id": 13, - "logprob": -0.8027344, + "id": 3019, + "logprob": -0.10821533, "special": false, - "text": "\n" - }, - { - "id": 3057, - "logprob": -0.23022461, - "special": false, - "text": "Test" - }, - { - "id": 2009, - "logprob": -0.0069885254, - "special": false, - "text": " request" - }, - { - "id": 13, - "logprob": -0.02218628, - "special": false, - "text": "\n" + "text": "POST" } ], "top_tokens": null }, - "generated_text": "\nTest request\nTest request\nTest request\n" + "generated_text": "\n \"\"\"\n if request.method == 'POST" } diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json index fa2fd4a2..4152b5b3 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json @@ -5,85 +5,80 @@ "generated_tokens": 10, "prefill": [ { - "id": 1, + "id": 2323, "logprob": null, - "text": "" - }, - { - "id": 4321, - "logprob": -9.84375, "text": "Test" }, { - "id": 2009, - "logprob": -9.6015625, - "text": "request" + "id": 1715, + "logprob": -11.34375, + "text": " request" } ], "seed": 0, "tokens": [ { - "id": 29899, - "logprob": -1.5625, + "id": 13, + "logprob": -2.2539062, "special": false, - "text": "-" + "text": "." }, { - "id": 1454, - "logprob": -0.20410156, + "id": 578, + "logprob": -0.15563965, "special": false, - "text": "for" + "text": " The" }, { - "id": 29899, + "id": 3622, + "logprob": -0.8203125, + "special": false, + "text": " server" + }, + { + "id": 706, "logprob": 0.0, "special": false, - "text": "-" + "text": " has" }, { - "id": 9342, + "id": 539, "logprob": 0.0, "special": false, - "text": "comment" + "text": " not" }, { - "id": 29901, + "id": 3686, "logprob": 0.0, "special": false, - "text": ":" + "text": " yet" }, { - "id": 396, - "logprob": -0.27685547, - "special": false, - "text": " #" - }, - { - "id": 29906, - "logprob": -0.4970703, - "special": false, - "text": "2" - }, - { - "id": 29900, - "logprob": -0.80615234, - "special": false, - "text": "0" - }, - { - "id": 29896, + "id": 3288, "logprob": 0.0, "special": false, - "text": "1" + "text": " sent" }, { - "id": 29955, - "logprob": -1.0751953, + "id": 904, + "logprob": 0.0, "special": false, - "text": "7" + "text": " any" + }, + { + "id": 828, + "logprob": 0.0, + "special": false, + "text": " data" + }, + { + "id": 382, + "logprob": -1.5517578, + "special": false, + "text": ".\n\n" } ], "top_tokens": null }, - "generated_text": "Test request-for-comment: #2017" + "generated_text": "Test request. The server has not yet sent any data.\n\n" } diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json index 594b7351..75e90303 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json @@ -6,87 +6,82 @@ "generated_tokens": 10, "prefill": [ { - "id": 1, + "id": 2323, "logprob": null, - "text": "" - }, - { - "id": 4321, - "logprob": -9.828125, "text": "Test" }, { - "id": 2009, - "logprob": -9.609375, - "text": "request" + "id": 1715, + "logprob": -11.34375, + "text": " request" } ], "seed": null, "tokens": [ { - "id": 13, - "logprob": -2.3300781, + "id": 198, + "logprob": -2.5742188, "special": false, "text": "\n" }, { - "id": 3057, - "logprob": -1.8740234, + "id": 262, + "logprob": -1.6220703, "special": false, - "text": "Test" + "text": " " }, { - "id": 2009, - "logprob": -1.2646484, + "id": 3270, + "logprob": -2.0410156, + "special": false, + "text": " \"\"\"\n" + }, + { + "id": 262, + "logprob": -0.015281677, + "special": false, + "text": " " + }, + { + "id": 422, + "logprob": -2.1445312, + "special": false, + "text": " if" + }, + { + "id": 1715, + "logprob": -0.92333984, "special": false, "text": " request" }, { - "id": 13, - "logprob": -1.7158203, + "id": 13204, + "logprob": -0.07672119, "special": false, - "text": "\n" + "text": ".method" }, { - "id": 3057, - "logprob": -1.4667969, + "id": 624, + "logprob": -0.021987915, "special": false, - "text": "Test" + "text": " ==" }, { - "id": 2009, - "logprob": -0.15344238, + "id": 364, + "logprob": -0.39208984, "special": false, - "text": " request" + "text": " '" }, { - "id": 13, - "logprob": -0.81591797, + "id": 3019, + "logprob": -0.10638428, "special": false, - "text": "\n" - }, - { - "id": 3057, - "logprob": -0.22973633, - "special": false, - "text": "Test" - }, - { - "id": 2009, - "logprob": -0.007045746, - "special": false, - "text": " request" - }, - { - "id": 13, - "logprob": -0.021957397, - "special": false, - "text": "\n" + "text": "POST" } ], "top_tokens": null }, - "generated_text": "\nTest request\nTest request\nTest request\n" + "generated_text": "\n \"\"\"\n if request.method == 'POST" }, { "details": { @@ -95,87 +90,82 @@ "generated_tokens": 10, "prefill": [ { - "id": 1, + "id": 2323, "logprob": null, - "text": "" - }, - { - "id": 4321, - "logprob": -9.84375, "text": "Test" }, { - "id": 2009, - "logprob": -9.59375, - "text": "request" + "id": 1715, + "logprob": -11.34375, + "text": " request" } ], "seed": null, "tokens": [ { - "id": 13, - "logprob": -2.3378906, + "id": 198, + "logprob": -2.5742188, "special": false, "text": "\n" }, { - "id": 3057, - "logprob": -1.8779297, + "id": 262, + "logprob": -1.6220703, "special": false, - "text": "Test" + "text": " " }, { - "id": 2009, - "logprob": -1.2636719, + "id": 3270, + "logprob": -2.0410156, + "special": false, + "text": " \"\"\"\n" + }, + { + "id": 262, + "logprob": -0.015281677, + "special": false, + "text": " " + }, + { + "id": 422, + "logprob": -2.1445312, + "special": false, + "text": " if" + }, + { + "id": 1715, + "logprob": -0.92333984, "special": false, "text": " request" }, { - "id": 13, - "logprob": -1.6992188, + "id": 13204, + "logprob": -0.07672119, "special": false, - "text": "\n" + "text": ".method" }, { - "id": 3057, - "logprob": -1.4589844, + "id": 624, + "logprob": -0.021987915, "special": false, - "text": "Test" + "text": " ==" }, { - "id": 2009, - "logprob": -0.15344238, + "id": 364, + "logprob": -0.39208984, "special": false, - "text": " request" + "text": " '" }, { - "id": 13, - "logprob": -0.79052734, + "id": 3019, + "logprob": -0.10638428, "special": false, - "text": "\n" - }, - { - "id": 3057, - "logprob": -0.22937012, - "special": false, - "text": "Test" - }, - { - "id": 2009, - "logprob": -0.007041931, - "special": false, - "text": " request" - }, - { - "id": 13, - "logprob": -0.022140503, - "special": false, - "text": "\n" + "text": "POST" } ], "top_tokens": null }, - "generated_text": "\nTest request\nTest request\nTest request\n" + "generated_text": "\n \"\"\"\n if request.method == 'POST" }, { "details": { @@ -184,87 +174,82 @@ "generated_tokens": 10, "prefill": [ { - "id": 1, + "id": 2323, "logprob": null, - "text": "" - }, - { - "id": 4321, - "logprob": -9.84375, "text": "Test" }, { - "id": 2009, - "logprob": -9.609375, - "text": "request" + "id": 1715, + "logprob": -11.34375, + "text": " request" } ], "seed": null, "tokens": [ { - "id": 13, - "logprob": -2.3261719, + "id": 198, + "logprob": -2.5742188, "special": false, "text": "\n" }, { - "id": 3057, - "logprob": -1.8730469, + "id": 262, + "logprob": -1.6220703, "special": false, - "text": "Test" + "text": " " }, { - "id": 2009, - "logprob": -1.2587891, + "id": 3270, + "logprob": -2.0410156, + "special": false, + "text": " \"\"\"\n" + }, + { + "id": 262, + "logprob": -0.015281677, + "special": false, + "text": " " + }, + { + "id": 422, + "logprob": -2.1445312, + "special": false, + "text": " if" + }, + { + "id": 1715, + "logprob": -0.92333984, "special": false, "text": " request" }, { - "id": 13, - "logprob": -1.6894531, + "id": 13204, + "logprob": -0.07672119, "special": false, - "text": "\n" + "text": ".method" }, { - "id": 3057, - "logprob": -1.46875, + "id": 624, + "logprob": -0.021987915, "special": false, - "text": "Test" + "text": " ==" }, { - "id": 2009, - "logprob": -0.1541748, + "id": 364, + "logprob": -0.39208984, "special": false, - "text": " request" + "text": " '" }, { - "id": 13, - "logprob": -0.80322266, + "id": 3019, + "logprob": -0.10638428, "special": false, - "text": "\n" - }, - { - "id": 3057, - "logprob": -0.22912598, - "special": false, - "text": "Test" - }, - { - "id": 2009, - "logprob": -0.0070495605, - "special": false, - "text": " request" - }, - { - "id": 13, - "logprob": -0.021606445, - "special": false, - "text": "\n" + "text": "POST" } ], "top_tokens": null }, - "generated_text": "\nTest request\nTest request\nTest request\n" + "generated_text": "\n \"\"\"\n if request.method == 'POST" }, { "details": { @@ -273,86 +258,81 @@ "generated_tokens": 10, "prefill": [ { - "id": 1, + "id": 2323, "logprob": null, - "text": "" - }, - { - "id": 4321, - "logprob": -9.84375, "text": "Test" }, { - "id": 2009, - "logprob": -9.6015625, - "text": "request" + "id": 1715, + "logprob": -11.34375, + "text": " request" } ], "seed": null, "tokens": [ { - "id": 13, - "logprob": -2.3320312, + "id": 198, + "logprob": -2.5742188, "special": false, "text": "\n" }, { - "id": 3057, - "logprob": -1.875, + "id": 262, + "logprob": -1.6220703, "special": false, - "text": "Test" + "text": " " }, { - "id": 2009, - "logprob": -1.2646484, + "id": 3270, + "logprob": -2.0410156, + "special": false, + "text": " \"\"\"\n" + }, + { + "id": 262, + "logprob": -0.015281677, + "special": false, + "text": " " + }, + { + "id": 422, + "logprob": -2.1445312, + "special": false, + "text": " if" + }, + { + "id": 1715, + "logprob": -0.92333984, "special": false, "text": " request" }, { - "id": 13, - "logprob": -1.6884766, + "id": 13204, + "logprob": -0.07672119, "special": false, - "text": "\n" + "text": ".method" }, { - "id": 3057, - "logprob": -1.4589844, + "id": 624, + "logprob": -0.021987915, "special": false, - "text": "Test" + "text": " ==" }, { - "id": 2009, - "logprob": -0.15185547, + "id": 364, + "logprob": -0.39208984, "special": false, - "text": " request" + "text": " '" }, { - "id": 13, - "logprob": -0.79833984, + "id": 3019, + "logprob": -0.10638428, "special": false, - "text": "\n" - }, - { - "id": 3057, - "logprob": -0.22827148, - "special": false, - "text": "Test" - }, - { - "id": 2009, - "logprob": -0.006996155, - "special": false, - "text": " request" - }, - { - "id": 13, - "logprob": -0.021560669, - "special": false, - "text": "\n" + "text": "POST" } ], "top_tokens": null }, - "generated_text": "\nTest request\nTest request\nTest request\n" + "generated_text": "\n \"\"\"\n if request.method == 'POST" } ] diff --git a/integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized.json b/integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized.json new file mode 100644 index 00000000..69c1f47d --- /dev/null +++ b/integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.8359375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -9.6171875, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -2.3417969, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.8730469, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -1.2626953, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -1.7060547, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.4482422, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.15246582, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.796875, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.22766113, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.007045746, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.021759033, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": "\nTest request\nTest request\nTest request\n" +} diff --git a/integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized_all_params.json b/integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized_all_params.json new file mode 100644 index 00000000..9b5ee9ee --- /dev/null +++ b/integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized_all_params.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.7890625, + "text": "Test" + }, + { + "id": 2009, + "logprob": -9.625, + "text": "request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 29899, + "logprob": -1.4980469, + "special": false, + "text": "-" + }, + { + "id": 1454, + "logprob": -0.19433594, + "special": false, + "text": "for" + }, + { + "id": 29899, + "logprob": 0.0, + "special": false, + "text": "-" + }, + { + "id": 9342, + "logprob": 0.0, + "special": false, + "text": "comment" + }, + { + "id": 29901, + "logprob": 0.0, + "special": false, + "text": ":" + }, + { + "id": 396, + "logprob": -0.27392578, + "special": false, + "text": " #" + }, + { + "id": 29906, + "logprob": -0.49389648, + "special": false, + "text": "2" + }, + { + "id": 29900, + "logprob": -0.81103516, + "special": false, + "text": "0" + }, + { + "id": 29896, + "logprob": 0.0, + "special": false, + "text": "1" + }, + { + "id": 29955, + "logprob": -1.0800781, + "special": false, + "text": "7" + } + ], + "top_tokens": null + }, + "generated_text": "Test request-for-comment: #2017" +} diff --git a/integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized_load.json b/integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized_load.json new file mode 100644 index 00000000..df975635 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized_load.json @@ -0,0 +1,358 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.8828125, + "text": "Test" + }, + { + "id": 2009, + "logprob": -9.5859375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -2.3359375, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.8623047, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -1.2451172, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -1.6923828, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.4492188, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.15197754, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.8022461, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.22583008, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.007095337, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.021652222, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": "\nTest request\nTest request\nTest request\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.796875, + "text": "Test" + }, + { + "id": 2009, + "logprob": -9.625, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -2.3476562, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.8789062, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -1.2734375, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -1.703125, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.4677734, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.15454102, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.7973633, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.23278809, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.006980896, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.022033691, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": "\nTest request\nTest request\nTest request\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.9296875, + "text": "Test" + }, + { + "id": 2009, + "logprob": -9.5703125, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -2.3203125, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.8486328, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -1.2480469, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -1.7060547, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.4511719, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.1529541, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.81396484, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.22180176, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.007133484, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.021835327, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": "\nTest request\nTest request\nTest request\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.84375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -9.6171875, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -2.3261719, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.8691406, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -1.2597656, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -1.7070312, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.4550781, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.1538086, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.79345703, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.22924805, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.0070266724, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.021942139, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": "\nTest request\nTest request\nTest request\n" + } +] diff --git a/integration-tests/models/test_flash_llama_gptq.py b/integration-tests/models/test_flash_llama_gptq.py index 135f4b05..94a48e49 100644 --- a/integration-tests/models/test_flash_llama_gptq.py +++ b/integration-tests/models/test_flash_llama_gptq.py @@ -3,7 +3,9 @@ import pytest @pytest.fixture(scope="module") def flash_llama_gptq_handle(launcher): - with launcher("huggingface/llama-7b-gptq", num_shard=2, quantize="gptq") as handle: + with launcher( + "astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit", num_shard=2, quantize="gptq" + ) as handle: yield handle From 05c094fcfae4d869e12910f637b4dc9d7a9e0421 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 5 Jul 2024 16:07:48 +0200 Subject: [PATCH 06/24] Consistently take `prefix` in model constructors (#2191) * Consistently take `prefix` in model constructors * Release test check fix * Misc refactor-related fixes --- .github/workflows/build.yaml | 2 +- .../text_generation_server/models/__init__.py | 3 +- .../models/causal_lm.py | 3 +- .../models/custom_modeling/bloom_modeling.py | 2 +- .../models/custom_modeling/clip.py | 6 +-- .../custom_modeling/flash_cohere_modeling.py | 22 ++++++---- .../custom_modeling/flash_dbrx_modeling.py | 18 +++++--- .../custom_modeling/flash_gemma2_modeling.py | 10 ++--- .../custom_modeling/flash_gemma_modeling.py | 12 +++--- .../custom_modeling/flash_gpt2_modeling.py | 8 ++-- .../custom_modeling/flash_llama_modeling.py | 4 +- .../custom_modeling/flash_mistral_modeling.py | 8 ++-- .../custom_modeling/flash_mixtral_modeling.py | 10 ++--- .../custom_modeling/flash_neox_modeling.py | 16 +++++--- .../custom_modeling/flash_phi_modeling.py | 18 +++++--- .../custom_modeling/flash_qwen2_modeling.py | 20 +++++---- .../custom_modeling/flash_rw_modeling.py | 34 ++++++++------- .../flash_santacoder_modeling.py | 21 ++++++---- .../models/custom_modeling/mpt_modeling.py | 20 +++++---- .../models/custom_modeling/neox_modeling.py | 28 ++++++++----- .../models/custom_modeling/opt_modeling.py | 41 ++++++++++++------- .../models/custom_modeling/phi_modeling.py | 16 +++++--- .../models/flash_causal_lm.py | 19 +++++---- 23 files changed, 210 insertions(+), 131 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index a665d9b0..8213887f 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -153,7 +153,7 @@ jobs: runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"] if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest' env: - PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == 'true') && '--release' || '' }} + PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '' }} steps: - name: Checkout repository uses: actions/checkout@v4 diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 15e74622..58131a3a 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -16,6 +16,7 @@ from text_generation_server.models.custom_modeling.opt_modeling import OPTForCau from text_generation_server.models.custom_modeling.mpt_modeling import ( MPTForCausalLM, ) +from text_generation_server.models.bloom import BloomCausalLMBatch from text_generation_server.models.custom_modeling.bloom_modeling import ( BloomForCausalLM, ) @@ -522,7 +523,7 @@ def get_model( speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, - batch_class=CausalLMBatchKeysLast, + batch_class=BloomCausalLMBatch, ) elif model_type == MPT: return CausalLM( diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index cac36ebd..868a3cc0 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -553,7 +553,8 @@ class CausalLM(Model): if config.quantize in ["awq", "exl2", "gptq", "marlin"]: weights._set_gptq_params(model_id, revision) - model = model_class(config, weights) + prefix = "" + model = model_class(prefix, config, weights) torch.distributed.barrier(group=self.process_group) super().__init__( diff --git a/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/server/text_generation_server/models/custom_modeling/bloom_modeling.py index 0d8a1b59..77b89c5b 100644 --- a/server/text_generation_server/models/custom_modeling/bloom_modeling.py +++ b/server/text_generation_server/models/custom_modeling/bloom_modeling.py @@ -816,7 +816,7 @@ class BloomModel(BloomPreTrainedModel): class BloomForCausalLM(BloomPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) self.transformer = BloomModel(config, weights) diff --git a/server/text_generation_server/models/custom_modeling/clip.py b/server/text_generation_server/models/custom_modeling/clip.py index 56618bf1..27b9ff1c 100644 --- a/server/text_generation_server/models/custom_modeling/clip.py +++ b/server/text_generation_server/models/custom_modeling/clip.py @@ -446,7 +446,7 @@ class CLIPEncoder(nn.Module): class CLIPTextTransformer(nn.Module): - def __init__(self, config: CLIPTextConfig): + def __init__(self, prefix: str, config: CLIPTextConfig): super().__init__() self.config = config embed_dim = config.hidden_size @@ -536,9 +536,9 @@ class CLIPTextModel(CLIPPreTrainedModel): _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"] - def __init__(self, config: CLIPTextConfig): + def __init__(self, prefix, config: CLIPTextConfig): super().__init__(config) - self.text_model = CLIPTextTransformer(config) + self.text_model = CLIPTextTransformer(prefix, config) # Initialize weights and apply final processing self.post_init() diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index e088f9aa..f993fe72 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -363,9 +363,9 @@ class CohereMLP(nn.Module): class FlashCohereLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() - prefix = f"model.layers.{layer_id}" + prefix = f"{prefix}.layers.{layer_id}" self.self_attn = FlashCohereAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights ) @@ -416,18 +416,19 @@ class FlashCohereLayer(nn.Module): class FlashCohereModel(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() self.embed_tokens = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights + prefix=f"{prefix}.embed_tokens", weights=weights ) self.layers = nn.ModuleList( [ FlashCohereLayer( + prefix, layer_id, config, weights, @@ -436,7 +437,7 @@ class FlashCohereModel(torch.nn.Module): ] ) self.norm = FastLayerNorm.load_no_bias( - prefix="model.norm", weights=weights, eps=config.layer_norm_eps + prefix=f"{prefix}.norm", weights=weights, eps=config.layer_norm_eps ) self.gradient_checkpointing = False @@ -486,10 +487,15 @@ class FlashCohereModel(torch.nn.Module): class FlashCohereForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() - self.model = FlashCohereModel(config, weights) + if not prefix: + prefix = "model" + else: + prefix = f"{prefix}.model" + + self.model = FlashCohereModel(prefix, config, weights) try: self.lm_head = SpeculativeHead.load( config, @@ -499,7 +505,7 @@ class FlashCohereForCausalLM(torch.nn.Module): except RuntimeError: self.lm_head = SpeculativeHead.load( config, - prefix="model.embed_tokens", + prefix=f"{prefix}.embed_tokens", weights=weights, ) self.logit_scale = config.logit_scale diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index aea7f399..e469495f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -593,9 +593,9 @@ class DenseMoE(nn.Module): class DbrxLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() - prefix = f"transformer.blocks.{layer_id}" + prefix = f"{prefix}.blocks.{layer_id}" self.attn = DbrxNormAttentionNorm( prefix=f"{prefix}.norm_attn_norm", config=config, weights=weights @@ -637,16 +637,17 @@ class DbrxLayer(nn.Module): class DbrxModel(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.embed_tokens = TensorParallelEmbedding( - prefix="transformer.wte", weights=weights + prefix=f"{prefix}.wte", weights=weights ) self.layers = nn.ModuleList( [ DbrxLayer( + prefix, layer_id, config, weights, @@ -655,7 +656,7 @@ class DbrxModel(torch.nn.Module): ] ) self.norm = FastLayerNorm.load_no_bias( - prefix="transformer.norm_f", weights=weights, eps=1e-5 + prefix=f"{prefix}.norm_f", weights=weights, eps=1e-5 ) self.head_size = self.layers[0].attn.self_attn.head_size @@ -702,9 +703,14 @@ class DbrxModel(torch.nn.Module): class FlashDbrxForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() + if not prefix: + prefix = "transformer" + else: + prefix = f"{prefix}.transformer" + self.model = DbrxModel(config, weights) self.lm_head = SpeculativeHead.load( config, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 625baa91..beff08b3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -102,7 +102,7 @@ class Gemma2Config(PretrainedConfig): class Gemma2FastRMSNorm(FastRMSNorm): @classmethod - def load(cls, prefix, weights, eps=1e-6): + def load(cls, prefix: str, weights, eps=1e-6): dtype = weights.dtype weights.dtype = torch.float32 weight = weights.get_tensor(f"{prefix}.weight") + 1 @@ -123,7 +123,7 @@ class Gemma2FastRMSNorm(FastRMSNorm): return hidden_states.to(self.dtype), residual -def load_attention(config, prefix, weights): +def load_attention(config, prefix: str, weights): if config.num_attention_heads != config.num_key_value_heads: return _load_gqa(config, prefix, weights) else: @@ -305,7 +305,7 @@ class Gemma2MLP(nn.Module): class FlashGemma2Layer(nn.Module): - def __init__(self, prefix, config, weights, causal: bool, is_sliding: bool): + def __init__(self, prefix: str, config, weights, causal: bool, is_sliding: bool): super().__init__() self.self_attn = FlashGemma2Attention( prefix=f"{prefix}.self_attn", @@ -376,7 +376,7 @@ class FlashGemma2Layer(nn.Module): class FlashGemma2Model(torch.nn.Module): - def __init__(self, prefix, config, weights, causal: bool): + def __init__(self, prefix: str, config, weights, causal: bool): super().__init__() process_group = weights.process_group @@ -442,7 +442,7 @@ class FlashGemma2Model(torch.nn.Module): class FlashGemma2ForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights, *, causal: bool = True): + def __init__(self, prefix: str, config, weights, *, causal: bool = True): super().__init__() embed_norm = config.hidden_size**0.5 diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index b7ce6307..14b62b00 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -102,7 +102,7 @@ class GemmaConfig(PretrainedConfig): class GemmaFastRMSNorm(FastRMSNorm): @classmethod - def load(cls, prefix, weights, eps=1e-6): + def load(cls, prefix: str, weights, eps=1e-6): dtype = weights.dtype weights.dtype = torch.float32 weight = weights.get_tensor(f"{prefix}.weight") + 1 @@ -123,7 +123,7 @@ class GemmaFastRMSNorm(FastRMSNorm): return hidden_states.to(self.dtype), residual -def load_attention(config, prefix, weights): +def load_attention(config, prefix: str, weights): if config.num_attention_heads != config.num_key_value_heads: return _load_gqa(config, prefix, weights) else: @@ -261,7 +261,7 @@ class FlashGemmaAttention(torch.nn.Module): class GemmaMLP(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() act = config.hidden_act self.act = ( @@ -299,7 +299,7 @@ class GemmaMLP(nn.Module): class FlashGemmaLayer(nn.Module): - def __init__(self, prefix, config, weights, causal: bool): + def __init__(self, prefix: str, config, weights, causal: bool): super().__init__() self.self_attn = FlashGemmaAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal @@ -354,7 +354,7 @@ class FlashGemmaLayer(nn.Module): class FlashGemmaModel(torch.nn.Module): - def __init__(self, prefix, config, weights, causal: bool): + def __init__(self, prefix: str, config, weights, causal: bool): super().__init__() process_group = weights.process_group @@ -419,7 +419,7 @@ class FlashGemmaModel(torch.nn.Module): class FlashGemmaForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights, *, causal: bool = True): + def __init__(self, prefix: str, config, weights, *, causal: bool = True): super().__init__() embed_norm = config.hidden_size**0.5 diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 9f800146..d5dc25cf 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -261,7 +261,7 @@ class FlashGPT2Attention(torch.nn.Module): class GPT2MLP(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() act = config.activation_function self.act = ( @@ -298,7 +298,7 @@ class GPT2MLP(nn.Module): class FlashGPT2Layer(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.self_attn = FlashGPT2Attention( prefix=f"{prefix}.attn", config=config, weights=weights @@ -350,7 +350,7 @@ class FlashGPT2Layer(nn.Module): class FlashGPT2Model(torch.nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() process_group = weights.process_group @@ -414,7 +414,7 @@ class FlashGPT2Model(torch.nn.Module): class FlashGPT2ForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.embed_tokens = TensorParallelEmbedding( diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 77a7e2d5..78832341 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -54,7 +54,7 @@ if SYSTEM == "rocm": raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") -def load_attention(config, prefix, weights, layer_id): +def load_attention(config, prefix: str, weights, layer_id): # Only defined in granite. bias = getattr(config, "attention_bias", False) head_size = config.hidden_size // config.num_attention_heads @@ -467,7 +467,7 @@ class FlashLlamaModel(torch.nn.Module): class FlashLlamaForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.embed_tokens = TensorParallelEmbedding( diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 396969cd..8028dbe8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -248,7 +248,7 @@ class MistralAttention(torch.nn.Module): class MistralMLP(nn.Module): - def __init__(self, prefix, config, weights, layer_id): + def __init__(self, prefix: str, config, weights, layer_id): super().__init__() self.hidden_act = config.hidden_act self.act = ( @@ -328,7 +328,7 @@ class MistralMLP(nn.Module): class MistralLayer(nn.Module): - def __init__(self, prefix, config, weights, layer_id): + def __init__(self, prefix: str, config, weights, layer_id): super().__init__() self.self_attn = MistralAttention( prefix=f"{prefix}.self_attn", @@ -392,7 +392,7 @@ class MistralLayer(nn.Module): class MistralModel(torch.nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() process_group = weights.process_group @@ -462,7 +462,7 @@ class MistralModel(torch.nn.Module): class FlashMistralForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights, name=None): + def __init__(self, prefix: str, config, weights, name=None): if name is None: name = "model" super().__init__() diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 2d6a7f97..429793ea 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -116,7 +116,7 @@ def promote_scalar(x: torch.Tensor) -> torch.Tensor: return x.view(1) if len(x.size()) == 0 else x -def load_attention(config, prefix, weights): +def load_attention(config, prefix: str, weights): if config.num_attention_heads != config.num_key_value_heads: return _load_gqa(config, prefix, weights) else: @@ -155,7 +155,7 @@ def _load_gqa(config, prefix: str, weights): ) -def _load_experts(config, prefix, mat, weights): +def _load_experts(config, prefix: str, mat, weights): if config.quantize is not None: raise NotImplementedError("Mixtral does not support weight quantization yet.") @@ -475,7 +475,7 @@ class DenseMoE(nn.Module): class MixtralLayer(nn.Module): - def __init__(self, prefix, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() prefix = f"{prefix}.layers.{layer_id}" @@ -536,7 +536,7 @@ class MixtralLayer(nn.Module): class MixtralModel(torch.nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.embed_tokens = TensorParallelEmbedding( @@ -610,7 +610,7 @@ class MixtralModel(torch.nn.Module): class FlashMixtralForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.model = MixtralModel(prefix, config, weights) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 33aebc2b..0eca181b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -305,12 +305,12 @@ class FlashGPTNeoXPreTrainedModel(PreTrainedModel): class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) self.config = config self.embed_in = TensorParallelEmbedding( - prefix="gpt_neox.embed_in", weights=weights + prefix=f"{prefix}.embed_in", weights=weights ) self.layers = nn.ModuleList( @@ -320,7 +320,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): ] ) self.final_layer_norm = FastLayerNorm.load( - prefix="gpt_neox.final_layer_norm", + prefix=f"{prefix}.final_layer_norm", weights=weights, eps=config.layer_norm_eps, ) @@ -370,9 +370,15 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights): super().__init__(config) - self.gpt_neox = FlashGPTNeoXModel(config, weights) + + if not prefix: + prefix = "gpt_neox" + else: + prefix = f"{prefix}.gpt_neox" + + self.gpt_neox = FlashGPTNeoXModel(prefix, config, weights) self.embed_out = SpeculativeHead.load( config, prefix="embed_out", weights=weights diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index f237ea37..7401bc27 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -258,9 +258,9 @@ class PhiMLP(nn.Module): class FlashPhiLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() - prefix = f"model.layers.{layer_id}" + prefix = f"{prefix}.layers.{layer_id}" self.self_attn = FlashPhiAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights ) @@ -307,18 +307,19 @@ class FlashPhiLayer(nn.Module): class FlashPhiModel(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() self.embed_tokens = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights + prefix=f"{prefix}.embed_tokens", weights=weights ) self.layers = nn.ModuleList( [ FlashPhiLayer( + prefix, layer_id, config, weights, @@ -378,10 +379,15 @@ class FlashPhiModel(torch.nn.Module): class FlashPhiForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() - self.model = FlashPhiModel(config, weights) + if not prefix: + prefix = "model" + else: + prefix = f"{prefix}.model" + + self.model = FlashPhiModel(prefix, config, weights) self.lm_head = SpeculativeHead.load( config, prefix="lm_head", diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 1cc6a613..a98709c5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -203,9 +203,9 @@ class Qwen2MLP(nn.Module): class Qwen2Layer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix, layer_id, config, weights): super().__init__() - prefix = f"model.layers.{layer_id}" + prefix = f"{prefix}.layers.{layer_id}" self.self_attn = Qwen2Attention( prefix=f"{prefix}.self_attn", config=config, weights=weights ) @@ -260,17 +260,18 @@ class Qwen2Layer(nn.Module): class Qwen2Model(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() self.embed_tokens = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights + prefix=f"{prefix}.embed_tokens", weights=weights ) self.layers = nn.ModuleList( [ Qwen2Layer( + prefix, layer_id, config, weights, @@ -279,7 +280,7 @@ class Qwen2Model(torch.nn.Module): ] ) self.norm = FastRMSNorm.load( - prefix="model.norm", weights=weights, eps=config.rms_norm_eps + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps ) self.gradient_checkpointing = False @@ -331,10 +332,15 @@ class Qwen2Model(torch.nn.Module): class Qwen2ForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() - self.model = Qwen2Model(config, weights) + if not prefix: + prefix = "model" + else: + prefix = f"{prefix}.model" + + self.model = Qwen2Model(prefix, config, weights) self.lm_head = SpeculativeHead.load( config, prefix="lm_head", diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index e7614232..d12ed567 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -127,7 +127,7 @@ class FlashRWAttention(torch.nn.Module): def __init__( self, config, - prefix, + prefix: str, weights, ): super().__init__() @@ -236,7 +236,7 @@ class FlashRWLargeAttention(torch.nn.Module): def __init__( self, config, - prefix, + prefix: str, weights, ): super().__init__() @@ -358,7 +358,7 @@ class FlashRWLargeAttention(torch.nn.Module): class FlashMLP(nn.Module): - def __init__(self, config, prefix, weights): + def __init__(self, config, prefix: str, weights): super().__init__() self.act = torch.nn.functional.gelu @@ -380,6 +380,7 @@ class FlashRWLayer(nn.Module): def __init__( self, layer_id, + prefix: str, config, weights, ): @@ -388,7 +389,7 @@ class FlashRWLayer(nn.Module): parallel_attn = config.parallel_attn self.parallel_attn = parallel_attn - prefix = f"transformer.h.{layer_id}" + prefix = f"{prefix}.h.{layer_id}" self.input_layernorm = FastLayerNorm.load( prefix=f"{prefix}.input_layernorm", @@ -479,7 +480,7 @@ class FlashRWLayer(nn.Module): class FlashRWLayerNorm(nn.Module): - def __init__(self, config, prefix, weights): + def __init__(self, config, prefix: str, weights): super().__init__() self.num_ln = config.num_ln_in_parallel_attn @@ -518,9 +519,9 @@ class FlashRWLayerNorm(nn.Module): class FlashRWLargeLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, layer_id, prefix: str, config, weights): super().__init__() - prefix = f"transformer.h.{layer_id}" + prefix = f"{prefix}.h.{layer_id}" self.ln_layer = FlashRWLayerNorm(config, prefix, weights) @@ -580,18 +581,18 @@ class FlashRWPreTrainedModel(PreTrainedModel): class FlashRWModel(FlashRWPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) self.config = config self.word_embeddings = TensorParallelEmbedding( - prefix="transformer.word_embeddings", weights=weights + prefix=f"{prefix}.word_embeddings", weights=weights ) if config.new_decoder_architecture: self.h = nn.ModuleList( [ - FlashRWLargeLayer(layer_id, config, weights) + FlashRWLargeLayer(layer_id, prefix, config, weights) for layer_id in range(config.num_hidden_layers) ] ) @@ -599,14 +600,14 @@ class FlashRWModel(FlashRWPreTrainedModel): else: self.h = nn.ModuleList( [ - FlashRWLayer(layer_id, config, weights) + FlashRWLayer(layer_id, prefix, config, weights) for layer_id in range(config.num_hidden_layers) ] ) self.cache_size = self.h[0].self_attention.num_heads_kv self.ln_f = FastLayerNorm.load( - prefix="transformer.ln_f", + prefix=f"{prefix}.ln_f", weights=weights, eps=config.layer_norm_epsilon, ) @@ -653,10 +654,15 @@ class FlashRWModel(FlashRWPreTrainedModel): class FlashRWForCausalLM(FlashRWPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) - self.transformer = FlashRWModel(config, weights) + if not prefix: + prefix = "transformer" + else: + prefix = f"{prefix}.transformer" + + self.transformer = FlashRWModel(prefix, config, weights) self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights) diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index daef43cc..21a22046 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -346,9 +346,9 @@ class MLP(nn.Module): class Block(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() - prefix = f"transformer.h.{layer_id}" + prefix = f"{prefix}.h.{layer_id}" self.ln_1 = FastLayerNorm.load( prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon ) @@ -396,18 +396,18 @@ class Block(nn.Module): class FlashSantacoderModel(nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.config = config self.process_group = weights.process_group self.wte = TensorParallelEmbedding( - prefix="transformer.wte", + prefix=f"{prefix}.wte", weights=weights, reduce=False, ) self.wpe = TensorParallelEmbedding( - prefix="transformer.wpe", + prefix=f"{prefix}.wpe", weights=weights, reduce=False, ) @@ -415,6 +415,7 @@ class FlashSantacoderModel(nn.Module): self.layers = nn.ModuleList( [ Block( + prefix, layer_id, config, weights, @@ -466,10 +467,16 @@ class FlashSantacoderModel(nn.Module): class FlashSantacoderForCausalLM(nn.Module): def __init__(self, prefix, config, weights): super().__init__() + + if not prefix: + prefix = "transformer" + else: + prefix = f"{prefix}.transformer" + config.transpose = config.architectures[0].startswith("GPT2") - self.model = FlashSantacoderModel(config, weights) + self.model = FlashSantacoderModel(prefix, config, weights) self.lm_head = SpeculativeHead.load( - config, prefix="transformer.wte", weights=weights + config, prefix=f"{prefix}.wte", weights=weights ) def forward( diff --git a/server/text_generation_server/models/custom_modeling/mpt_modeling.py b/server/text_generation_server/models/custom_modeling/mpt_modeling.py index f7981bf5..fb09a8f1 100644 --- a/server/text_generation_server/models/custom_modeling/mpt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mpt_modeling.py @@ -783,7 +783,7 @@ class MPTPreTrainedModel(PreTrainedModel): class MPTModel(MPTPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): # config._validate_config() super().__init__(config) self.world_size = weights.process_group.size() @@ -809,13 +809,13 @@ class MPTModel(MPTPreTrainedModel): f"Requested norm type ({config.norm_type}) is not implemented within this repo." ) - self.wte = TensorParallelEmbedding("transformer.wte", weights) + self.wte = TensorParallelEmbedding(f"{prefix}.wte", weights) if not self.alibi: - self.wpe = TensorParallelEmbedding("transformer.wpe", weights) + self.wpe = TensorParallelEmbedding(f"{prefix}.wpe", weights) self.blocks = nn.ModuleList( [ - MPTBlock(config, prefix=f"transformer.blocks.{i}", weights=weights) + MPTBlock(config, prefix=f"{prefix}.blocks.{i}", weights=weights) for i in range(config.n_layers) ] ) @@ -1085,13 +1085,19 @@ class MPTModel(MPTPreTrainedModel): class MPTForCausalLM(MPTPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) + + if not prefix: + prefix = "transformer" + else: + prefix = f"{prefix}.transformer" + if not config.tie_word_embeddings: raise ValueError("MPTForCausalLM only supports tied word embeddings") - self.transformer = MPTModel(config, weights) + self.transformer = MPTModel(prefix, config, weights) self.lm_head = SpeculativeHead.load( - config, prefix="transformer.wte", weights=weights + config, prefix=f"{prefix}.wte", weights=weights ) self.logit_scale = None if config.logit_scale is not None: diff --git a/server/text_generation_server/models/custom_modeling/neox_modeling.py b/server/text_generation_server/models/custom_modeling/neox_modeling.py index fcad32fa..8998778f 100644 --- a/server/text_generation_server/models/custom_modeling/neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/neox_modeling.py @@ -404,24 +404,24 @@ class GPTNeoXMLP(nn.Module): class GPTNeoXLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, layer_id, prefix: str, config, weights): super().__init__() self.use_parallel_residual = config.use_parallel_residual self.input_layernorm = nn.LayerNorm.load( - prefix=f"gpt_neox.layers.{layer_id}.input_layernorm", + prefix=f"{prefix}.layers.{layer_id}.input_layernorm", weights=weights, eps=config.layer_norm_eps, ) self.post_attention_layernorm = nn.LayerNorm.load( - prefix=f"gpt_neox.layers.{layer_id}.post_attention_layernorm", + prefix=f"{prefix}.layers.{layer_id}.post_attention_layernorm", weights=weights, eps=config.layer_norm_eps, ) self.attention = GPTNeoXAttention( - config, prefix=f"gpt_neox.layers.{layer_id}.attention", weights=weights + config, prefix=f"{prefix}.layers.{layer_id}.attention", weights=weights ) self.mlp = GPTNeoXMLP( - config, prefix=f"gpt_neox.layers.{layer_id}.mlp", weights=weights + config, prefix=f"{prefix}.layers.{layer_id}.mlp", weights=weights ) def forward( @@ -472,23 +472,23 @@ class GPTNeoXLayer(nn.Module): class GPTNeoXModel(GPTNeoXPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) self.config = config self.num_attention_heads = config.num_attention_heads self.embed_in = TensorParallelEmbedding( - prefix="gpt_neox.embed_in", weights=weights + prefix=f"{prefix}.embed_in", weights=weights ) self.layers = nn.ModuleList( [ - GPTNeoXLayer(layer_id, config, weights) + GPTNeoXLayer(layer_id, prefix, config, weights) for layer_id in range(config.num_hidden_layers) ] ) self.final_layer_norm = nn.LayerNorm.load( - prefix="gpt_neox.final_layer_norm", + prefix=f"{prefix}.final_layer_norm", weights=weights, eps=config.layer_norm_eps, ) @@ -640,9 +640,15 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel): _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) - self.gpt_neox = GPTNeoXModel(config, weights) + + if not prefix: + prefix = "gpt_neox" + else: + prefix = f"{prefix}.gpt_neox" + + self.gpt_neox = GPTNeoXModel(prefix, config, weights) self.embed_out = SpeculativeHead.load( config, prefix="embed_out", weights=weights ) diff --git a/server/text_generation_server/models/custom_modeling/opt_modeling.py b/server/text_generation_server/models/custom_modeling/opt_modeling.py index 9b2d01e0..5ab02959 100644 --- a/server/text_generation_server/models/custom_modeling/opt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/opt_modeling.py @@ -94,11 +94,11 @@ class OPTLearnedPositionalEmbedding(nn.Module): This module learns positional embeddings up to a fixed maximum size. """ - def __init__(self, weights): + def __init__(self, prefix: str, weights): super().__init__() self.offset = 2 self.weight = nn.Parameter( - weights.get_tensor("model.decoder.embed_positions.weight") + weights.get_tensor(f"{prefix}.decoder.embed_positions.weight") ) def forward( @@ -311,11 +311,11 @@ class OPTAttention(nn.Module): class OPTDecoderLayer(nn.Module): - def __init__(self, layer_id: int, config: OPTConfig, weights): + def __init__(self, layer_id: int, prefix: str, config: OPTConfig, weights): super().__init__() self.process_group = weights.process_group self.hidden_size = config.hidden_size - prefix = f"model.decoder.layers.{layer_id}" + prefix = f"{prefix}.decoder.layers.{layer_id}" self.self_attn = OPTAttention( config, prefix=f"{prefix}.self_attn", @@ -429,7 +429,7 @@ class OPTPreTrainedModel(PreTrainedModel): class OPTDecoder(OPTPreTrainedModel): - def __init__(self, config: OPTConfig, weights): + def __init__(self, prefix: str, config: OPTConfig, weights): super().__init__(config) self.dropout = config.dropout self.layerdrop = config.layerdrop @@ -438,20 +438,26 @@ class OPTDecoder(OPTPreTrainedModel): self.vocab_size = config.vocab_size self.embed_tokens = TensorParallelEmbedding( - prefix="model.decoder.embed_tokens", weights=weights + prefix=f"{prefix}.decoder.embed_tokens", weights=weights ) - self.embed_positions = OPTLearnedPositionalEmbedding(weights) + self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights) if config.word_embed_proj_dim != config.hidden_size: self.project_out = FastLinear.load( - config, prefix="model.decoder.project_out", weights=weights, bias=False + config, + prefix=f"{prefix}.decoder.project_out", + weights=weights, + bias=False, ) else: self.project_out = None if config.word_embed_proj_dim != config.hidden_size: self.project_in = FastLinear.load( - config, prefix="model.decoder.project_in", weights=weights, bias=False + config, + prefix=f"{prefix}.decoder.project_in", + weights=weights, + bias=False, ) else: self.project_in = None @@ -461,14 +467,14 @@ class OPTDecoder(OPTPreTrainedModel): # see https://github.com/facebookresearch/metaseq/pull/164 if config.do_layer_norm_before and not config._remove_final_layer_norm: self.final_layer_norm = nn.LayerNorm.load( - prefix="model.decoder.final_layer_norm", weights=weights, eps=EPS + prefix=f"{prefix}.decoder.final_layer_norm", weights=weights, eps=EPS ) else: self.final_layer_norm = None self.layers = nn.ModuleList( [ - OPTDecoderLayer(layer_id, config, weights) + OPTDecoderLayer(layer_id, prefix, config, weights) for layer_id in range(config.num_hidden_layers) ] ) @@ -686,9 +692,9 @@ class OPTDecoder(OPTPreTrainedModel): class OPTModel(OPTPreTrainedModel): - def __init__(self, config: OPTConfig, weights): + def __init__(self, prefix: str, config: OPTConfig, weights): super().__init__(config) - self.decoder = OPTDecoder(config, weights) + self.decoder = OPTDecoder(prefix, config, weights) # Initialize weights and apply final processing def forward( @@ -743,13 +749,18 @@ class OPTModel(OPTPreTrainedModel): class OPTForCausalLM(OPTPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights): super().__init__(config) + if not prefix: + prefix = "model" + else: + prefix = f"{prefix}.model" + self.model = OPTModel(config, weights) self.lm_head = SpeculativeHead.load( - config, prefix="model.decoder.embed_tokens", weights=weights + config, prefix=f"{prefix}.decoder.embed_tokens", weights=weights ) def forward( diff --git a/server/text_generation_server/models/custom_modeling/phi_modeling.py b/server/text_generation_server/models/custom_modeling/phi_modeling.py index 04b470eb..b4d56db1 100644 --- a/server/text_generation_server/models/custom_modeling/phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/phi_modeling.py @@ -248,16 +248,16 @@ class PhiBlock(nn.Module): # PhiModel implements the embedding layer and the transformer blocks. class PhiModel(nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.tp_rank = weights.process_group.rank() self.tp_world_size = weights.process_group.size() self.embed_tokens = TensorParallelEmbedding( - prefix="transformer.embd.wte", weights=weights + prefix=f"{prefix}.embd.wte", weights=weights ) self.blocks = nn.ModuleList( [ - PhiBlock(f"transformer.h.{layer_id}", config, weights) + PhiBlock(f"{prefix}.h.{layer_id}", config, weights) for layer_id in range(config.n_layer) ] ) @@ -289,9 +289,15 @@ class PhiModel(nn.Module): # PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object. class PhiForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() - self.model = PhiModel(config, weights) + + if not prefix: + prefix = "transformer" + else: + prefix = f"{prefix}.transformer" + + self.model = PhiModel(prefix, config, weights) self.lm_head = PhiCausalLMHead(config, weights) def forward( diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index c7f5f1f9..e66011a1 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -878,10 +878,6 @@ class FlashCausalLM(Model): ) config.quantize = quantize config.speculator = speculator - if getattr(config, "sliding_window", None) is not None: - set_sliding_window(config.sliding_window) - else: - config.sliding_window = None torch.distributed.barrier(group=self.process_group) @@ -900,13 +896,22 @@ class FlashCausalLM(Model): text_config = getattr(config, "text_config", None) if text_config is not None: config = text_config + + if getattr(config, "sliding_window", None) is not None: + set_sliding_window(config.sliding_window) + else: + config.sliding_window = None + self.num_layers = config.num_hidden_layers # Validation is done in the model itself if num_kv_heads is None: - num_kv_heads = getattr(config, "num_key_value_heads", None) + # Order is important here. + for attr in ["num_key_value_heads", "num_key_value_heads", "n_head"]: + num_kv_heads = getattr(config, "num_attention_heads", None) + if num_kv_heads is not None: + break if num_kv_heads is None: - # Final overide for GPT2 - num_kv_heads = config.n_head + raise ValueError("Cannot get the number of key/value heads") self.num_kv_heads = num_kv_heads // self.process_group.size() self.head_size = config.hidden_size // config.num_attention_heads From 521d0d990f1624d4821d6a1763805df312306fa9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?icyboy=E2=84=A2?= Date: Mon, 8 Jul 2024 15:01:14 +0800 Subject: [PATCH 07/24] fix dbrx & opt model prefix bug (#2201) * Update idefics_causal_lm.py Fix syntax issues * fix dbrx & opt model prefix bug --- .../models/custom_modeling/flash_dbrx_modeling.py | 2 +- .../models/custom_modeling/opt_modeling.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index e469495f..41aa5859 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -711,7 +711,7 @@ class FlashDbrxForCausalLM(torch.nn.Module): else: prefix = f"{prefix}.transformer" - self.model = DbrxModel(config, weights) + self.model = DbrxModel(prefix, config, weights) self.lm_head = SpeculativeHead.load( config, prefix="lm_head", diff --git a/server/text_generation_server/models/custom_modeling/opt_modeling.py b/server/text_generation_server/models/custom_modeling/opt_modeling.py index 5ab02959..84a1c069 100644 --- a/server/text_generation_server/models/custom_modeling/opt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/opt_modeling.py @@ -757,7 +757,7 @@ class OPTForCausalLM(OPTPreTrainedModel): else: prefix = f"{prefix}.model" - self.model = OPTModel(config, weights) + self.model = OPTModel(prefix, config, weights) self.lm_head = SpeculativeHead.load( config, prefix=f"{prefix}.decoder.embed_tokens", weights=weights From cce475a9491bc011f8f0e89a2ef8d3a1bef88e74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 8 Jul 2024 09:52:12 +0200 Subject: [PATCH 08/24] hotfix: Fix number of KV heads (#2202) Fix number of KV heads --- server/text_generation_server/models/flash_causal_lm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index e66011a1..42b2f686 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -906,8 +906,8 @@ class FlashCausalLM(Model): # Validation is done in the model itself if num_kv_heads is None: # Order is important here. - for attr in ["num_key_value_heads", "num_key_value_heads", "n_head"]: - num_kv_heads = getattr(config, "num_attention_heads", None) + for attr in ["num_key_value_heads", "num_attention_heads", "n_head"]: + num_kv_heads = getattr(config, attr, None) if num_kv_heads is not None: break if num_kv_heads is None: From 153fcf7739bfa1e943389c817a87eb06e685b9a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 8 Jul 2024 11:19:48 +0200 Subject: [PATCH 09/24] Fix incorrect cache allocation with multi-query (#2203) We wouldn't allocate any memory in multi-query (1 KV head). Fixes Starcoder et al. --- server/text_generation_server/models/flash_causal_lm.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 42b2f686..5c086a73 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -912,7 +912,12 @@ class FlashCausalLM(Model): break if num_kv_heads is None: raise ValueError("Cannot get the number of key/value heads") - self.num_kv_heads = num_kv_heads // self.process_group.size() + self.num_kv_heads = ( + num_kv_heads // self.process_group.size() + if num_kv_heads > 1 + else num_kv_heads + ) + assert self.num_kv_heads > 0 self.head_size = config.hidden_size // config.num_attention_heads self.cuda_graphs = {} From 5c7c9f13903f09636aaf99210710bf07002cdb87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 8 Jul 2024 13:22:38 +0200 Subject: [PATCH 10/24] Falcon/DBRX: get correct number of key-value heads (#2205) --- server/text_generation_server/models/__init__.py | 4 ++++ .../models/custom_modeling/flash_dbrx_modeling.py | 12 ++++++++++++ .../models/custom_modeling/flash_rw_modeling.py | 1 + .../text_generation_server/models/flash_causal_lm.py | 11 +++++------ 4 files changed, 22 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 58131a3a..ba980195 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -797,6 +797,10 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + aliases={ + "lm_head.weight": ["transformer.word_embeddings.weight"], + "transformer.word_embeddings.weight": ["lm_head.weight"], + }, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, config_class=RWConfig, diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 41aa5859..44411687 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -105,6 +105,12 @@ class DbrxFFNConfig(PretrainedConfig): class DbrxConfig(PretrainedConfig): + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "n_heads", + "num_hidden_layers": "n_layers", + } + def __init__( self, d_model: int = 2048, @@ -157,6 +163,12 @@ class DbrxConfig(PretrainedConfig): **kwargs, ) + @property + def num_key_value_heads(self): + # We can't use the attribute map, since this the number of KV + # heads is not top-level. + return self.attn_config.kv_n_heads + def promote_scalar(x: torch.Tensor) -> torch.Tensor: return x.view(1) if len(x.size()) == 0 else x diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index d12ed567..4813e2df 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -42,6 +42,7 @@ class RWConfig(PretrainedConfig): attribute_map = { "num_hidden_layers": "n_layer", "num_attention_heads": "n_head", + "num_key_value_heads": "n_head_kv", } def __init__( diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5c086a73..bf1fda4a 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -905,13 +905,12 @@ class FlashCausalLM(Model): self.num_layers = config.num_hidden_layers # Validation is done in the model itself if num_kv_heads is None: - # Order is important here. - for attr in ["num_key_value_heads", "num_attention_heads", "n_head"]: - num_kv_heads = getattr(config, attr, None) - if num_kv_heads is not None: - break + num_kv_heads = getattr(config, "num_key_value_heads", None) + # GPT-2 workaround if num_kv_heads is None: - raise ValueError("Cannot get the number of key/value heads") + num_kv_heads = getattr(config, "n_head", None) + if num_kv_heads is None: + raise ValueError("Cannot get the number of key/value heads") self.num_kv_heads = ( num_kv_heads // self.process_group.size() if num_kv_heads > 1 From 07e240ca37f48b8bce5169c96e49cb63c0714fea Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Mon, 8 Jul 2024 21:57:06 +0800 Subject: [PATCH 11/24] add doc for intel gpus (#2181) Signed-off-by: Wang, Yi A --- docs/source/_toctree.yml | 2 ++ docs/source/architecture.md | 1 + docs/source/installation_intel.md | 19 +++++++++++++++++++ docs/source/quicktour.md | 2 +- 4 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 docs/source/installation_intel.md diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index c9b4efd9..119c5662 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -11,6 +11,8 @@ title: Using TGI with Intel Gaudi - local: installation_inferentia title: Using TGI with AWS Inferentia + - local: installation_intel + title: Using TGI with Intel GPUs - local: installation title: Installation from source - local: supported_models diff --git a/docs/source/architecture.md b/docs/source/architecture.md index a8418817..28c84f62 100644 --- a/docs/source/architecture.md +++ b/docs/source/architecture.md @@ -103,6 +103,7 @@ Several variants of the model server exist that are actively supported by Huggin - By default, the model server will attempt building [a server optimized for Nvidia GPUs with CUDA](https://huggingface.co/docs/text-generation-inference/installation_nvidia). The code for this version is hosted in the [main TGI repository](https://github.com/huggingface/text-generation-inference). - A [version optimized for AMD with ROCm](https://huggingface.co/docs/text-generation-inference/installation_amd) is hosted in the main TGI repository. Some model features differ. +- A [version optimized for Intel GPUs](https://huggingface.co/docs/text-generation-inference/installation_intel) is hosted in the main TGI repository. Some model features differ. - The [version for Intel Gaudi](https://huggingface.co/docs/text-generation-inference/installation_gaudi) is maintained on a forked repository, often resynchronized with the main [TGI repository](https://github.com/huggingface/tgi-gaudi). - A [version for Neuron (AWS Inferentia2)](https://huggingface.co/docs/text-generation-inference/installation_inferentia) is maintained as part of [Optimum Neuron](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference). - A version for Google TPUs is maintained as part of [Optimum TPU](https://github.com/huggingface/optimum-tpu/tree/main/text-generation-inference). diff --git a/docs/source/installation_intel.md b/docs/source/installation_intel.md new file mode 100644 index 00000000..f9fda863 --- /dev/null +++ b/docs/source/installation_intel.md @@ -0,0 +1,19 @@ +# Using TGI with Intel GPUs + +TGI optimized models are supported on Intel Data Center GPU [Max1100](https://www.intel.com/content/www/us/en/products/sku/232876/intel-data-center-gpu-max-1100/specifications.html), [Max1550](https://www.intel.com/content/www/us/en/products/sku/232873/intel-data-center-gpu-max-1550/specifications.html), the recommended usage is through Docker. + + +On a server powered by Intel GPUs, TGI can be launched with the following command: + +```bash +model=teknium/OpenHermes-2.5-Mistral-7B +volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run + +docker run --rm --privileged --cap-add=sys_nice \ + --device=/dev/dri \ + --ipc=host --shm-size 1g --net host -v $volume:/data \ + ghcr.io/huggingface/text-generation-inference:latest-intel \ + --model-id $model --cuda-graphs 0 +``` + +The launched TGI server can then be queried from clients, make sure to check out the [Consuming TGI](./basic_tutorials/consuming_tgi) guide. diff --git a/docs/source/quicktour.md b/docs/source/quicktour.md index c546bc03..f056baad 100644 --- a/docs/source/quicktour.md +++ b/docs/source/quicktour.md @@ -17,7 +17,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \ ### Supported hardware -TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPUs](./installation_nvidia), [Using TGI with AMD GPUs](./installation_amd), [Using TGI with Gaudi](./installation_gaudi), [Using TGI with Inferentia](./installation_inferentia) guides depending on which hardware you would like to deploy TGI on. +TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPUs](./installation_nvidia), [Using TGI with AMD GPUs](./installation_amd), [Using TGI with Intel GPUs](./installation_intel), [Using TGI with Gaudi](./installation_gaudi), [Using TGI with Inferentia](./installation_inferentia) guides depending on which hardware you would like to deploy TGI on. ## Consuming TGI From 16d9e505fddf6b5ff349545374c85e75ab193184 Mon Sep 17 00:00:00 2001 From: Javier Martinez Date: Mon, 8 Jul 2024 15:59:16 +0200 Subject: [PATCH 12/24] fix: python deserialization (#2178) --- clients/python/text_generation/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index a56edaca..e36dd470 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -61,7 +61,7 @@ class ChoiceDeltaToolCall(BaseModel): class ChoiceDelta(BaseModel): role: str content: Optional[str] = None - tool_calls: Optional[ChoiceDeltaToolCall] + tool_calls: Optional[ChoiceDeltaToolCall] = None class Choice(BaseModel): From 58effe78b5cc69355dad406f44cfe773cb4ed40d Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Mon, 8 Jul 2024 22:03:59 +0800 Subject: [PATCH 13/24] =?UTF-8?q?update=20to=20metrics=200.23.0=20or=20cou?= =?UTF-8?q?ld=20work=20with=20metrics-exporter-promethe=E2=80=A6=20(#2190)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit update to metrics 0.23.0 or could work with metrics-exporter-prometheus 0.15.1 Signed-off-by: Wang, Yi A --- Cargo.lock | 28 ++----------- router/Cargo.toml | 2 +- router/src/infer/mod.rs | 8 ++-- router/src/infer/v2/queue.rs | 8 ++-- router/src/infer/v2/scheduler.rs | 61 ++++++++++++++++------------ router/src/infer/v3/queue.rs | 8 ++-- router/src/infer/v3/scheduler.rs | 61 ++++++++++++++++------------ router/src/server.rs | 68 ++++++++++++++------------------ router/src/validation.rs | 4 +- 9 files changed, 119 insertions(+), 129 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 090e2e80..a8a04c71 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1935,17 +1935,6 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" -[[package]] -name = "metrics" -version = "0.21.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fde3af1a009ed76a778cb84fdef9e7dbbdf5775ae3e4cc1f434a6a307f6f76c5" -dependencies = [ - "ahash", - "metrics-macros", - "portable-atomic", -] - [[package]] name = "metrics" version = "0.23.0" @@ -1969,7 +1958,7 @@ dependencies = [ "hyper-util", "indexmap 2.2.6", "ipnet", - "metrics 0.23.0", + "metrics", "metrics-util", "quanta", "thiserror", @@ -1977,17 +1966,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "metrics-macros" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38b4faf00617defe497754acde3024865bc143d44a86799b24e191ecff91354f" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.68", -] - [[package]] name = "metrics-util" version = "0.17.0" @@ -1997,7 +1975,7 @@ dependencies = [ "crossbeam-epoch", "crossbeam-utils", "hashbrown 0.14.5", - "metrics 0.23.0", + "metrics", "num_cpus", "quanta", "sketches-ddsketch", @@ -3834,7 +3812,7 @@ dependencies = [ "init-tracing-opentelemetry", "itertools 0.10.5", "jsonschema", - "metrics 0.21.1", + "metrics", "metrics-exporter-prometheus", "minijinja", "minijinja-contrib", diff --git a/router/Cargo.toml b/router/Cargo.toml index 5855ac86..60fb5c9d 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -24,7 +24,7 @@ futures = "0.3.28" hf-hub = { workspace = true } itertools = "0.10" jsonschema = { version = "0.17.1", features = ["draft202012"] } -metrics = "0.21.1" +metrics = "0.23.0" metrics-exporter-prometheus = { version = "0.15.1", features = [] } nohash-hasher = "0.2.0" opentelemetry = { version = "0.20.0", features = ["rt-tokio"] } diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 49282eb9..f3b10450 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -91,14 +91,14 @@ impl Infer { .limit_concurrent_requests .try_acquire_owned() .map_err(|err| { - metrics::increment_counter!("tgi_request_failure", "err" => "overloaded"); + metrics::counter!("tgi_request_failure", "err" => "overloaded").increment(1); tracing::error!("{err}"); err })?; // Validate request let valid_request = self.validation.validate(request).await.map_err(|err| { - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); tracing::error!("{err}"); err })?; @@ -140,7 +140,7 @@ impl Infer { .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? .apply(messages, grammar_with_prompt) .map_err(|e| { - metrics::increment_counter!("tgi_request_failure", "err" => "template"); + metrics::counter!("tgi_request_failure", "err" => "template").increment(1); tracing::error!("{e}"); e }) @@ -214,7 +214,7 @@ impl Infer { }) } else { let err = InferError::IncompleteGeneration; - metrics::increment_counter!("tgi_request_failure", "err" => "incomplete"); + metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1); tracing::error!("{err}"); Err(err) } diff --git a/router/src/infer/v2/queue.rs b/router/src/infer/v2/queue.rs index 93cf9469..0b51645a 100644 --- a/router/src/infer/v2/queue.rs +++ b/router/src/infer/v2/queue.rs @@ -111,7 +111,7 @@ async fn queue_task( match cmd { QueueCommand::Append(entry, span) => { span.in_scope(|| state.append(*entry)); - metrics::increment_gauge!("tgi_queue_size", 1.0); + metrics::gauge!("tgi_queue_size").increment(1.0); } QueueCommand::NextBatch { min_size, @@ -124,7 +124,7 @@ async fn queue_task( let next_batch = state.next_batch(min_size, max_size, prefill_token_budget, token_budget); response_sender.send(next_batch).unwrap(); - metrics::gauge!("tgi_queue_size", state.entries.len() as f64); + metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64); }), } } @@ -226,7 +226,7 @@ impl State { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) if entry.response_tx.is_closed() { - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); tracing::debug!("Dropping entry"); continue; } @@ -336,7 +336,7 @@ impl State { // Increment batch id self.next_batch_id += 1; - metrics::histogram!("tgi_batch_next_size", batch.size as f64); + metrics::histogram!("tgi_batch_next_size").record(batch.size as f64); Some((batch_entries, batch, next_batch_span)) } diff --git a/router/src/infer/v2/scheduler.rs b/router/src/infer/v2/scheduler.rs index e4c3de26..97379bc5 100644 --- a/router/src/infer/v2/scheduler.rs +++ b/router/src/infer/v2/scheduler.rs @@ -148,8 +148,8 @@ pub(crate) async fn batching_task( let batch_size = batch.size; let batch_max_tokens = batch.max_tokens; let mut batches = vec![batch]; - metrics::gauge!("tgi_batch_current_size", batch_size as f64); - metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64); + metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); + metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); let min_size = if waiting_tokens >= max_waiting_tokens { // If we didn't onboard any new requests since >= max_waiting_tokens, we try @@ -170,9 +170,11 @@ pub(crate) async fn batching_task( { // Tracking metrics if min_size.is_some() { - metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure"); + metrics::counter!("tgi_batch_concat", "reason" => "backpressure") + .increment(1); } else { - metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded"); + metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded") + .increment(1); } entries.iter_mut().for_each(|(_, entry)| { @@ -219,8 +221,8 @@ pub(crate) async fn batching_task( .await; waiting_tokens += 1; } - metrics::gauge!("tgi_batch_current_size", 0.0); - metrics::gauge!("tgi_batch_current_max_tokens", 0.0); + metrics::gauge!("tgi_batch_current_size").set(0.0); + metrics::gauge!("tgi_batch_current_max_tokens").set(0.0); } } } @@ -234,7 +236,7 @@ async fn prefill( ) -> Option { let start_time = Instant::now(); let batch_id = batch.id; - metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill"); + metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1); match client.prefill(batch).await { Ok((generations, next_batch, timings)) => { @@ -248,11 +250,15 @@ async fn prefill( // Filter next batch and remove requests that were stopped let next_batch = filter_batch(client, next_batch, entries).await; - metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill"); - metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill"); - metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill"); - metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); - metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill"); + metrics::histogram!("tgi_batch_forward_duration","method" => "prefill") + .record(timings.forward.as_secs_f64()); + metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill") + .record(timings.decode.as_secs_f64()); + metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill") + .record(start_filtering_time.elapsed().as_secs_f64()); + metrics::histogram!("tgi_batch_inference_duration","method" => "prefill") + .record(start_time.elapsed().as_secs_f64()); + metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1); next_batch } // If we have an error, we discard the whole batch @@ -261,7 +267,7 @@ async fn prefill( generation_health.store(false, Ordering::SeqCst); let _ = client.clear_cache(Some(batch_id)).await; send_errors(err, entries); - metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill"); + metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1); None } } @@ -276,7 +282,7 @@ async fn decode( ) -> Option { let start_time = Instant::now(); let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); - metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode"); + metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1); match client.decode(batches).await { Ok((generations, next_batch, timings)) => { @@ -291,13 +297,18 @@ async fn decode( let next_batch = filter_batch(client, next_batch, entries).await; if let Some(concat_duration) = timings.concat { - metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode"); + metrics::histogram!("tgi_batch_concat_duration", "method" => "decode") + .record(concat_duration.as_secs_f64()); } - metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode"); - metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode"); - metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode"); - metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); - metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode"); + metrics::histogram!("tgi_batch_forward_duration", "method" => "decode") + .record(timings.forward.as_secs_f64()); + metrics::histogram!("tgi_batch_decode_duration", "method" => "decode") + .record(timings.decode.as_secs_f64()); + metrics::histogram!("tgi_batch_filter_duration", "method" => "decode") + .record(start_filtering_time.elapsed().as_secs_f64()); + metrics::histogram!("tgi_batch_inference_duration", "method" => "decode") + .record(start_time.elapsed().as_secs_f64()); + metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1); next_batch } // If we have an error, we discard the whole batch @@ -307,7 +318,7 @@ async fn decode( let _ = client.clear_cache(Some(id)).await; } send_errors(err, entries); - metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode"); + metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1); None } } @@ -365,7 +376,7 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap "dropped"); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); err }).unwrap_or(true); if stopped { @@ -381,7 +392,7 @@ fn send_responses( ) -> Result>>> { // Return directly if the channel is disconnected if entry.response_tx.is_closed() { - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); return Ok(true); } @@ -407,7 +418,7 @@ fn send_responses( // Create last Token let tokens_ = generation.tokens.expect("Non empty tokens in generation"); let n = tokens_.ids.len(); - metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64); + metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64); let mut iterator = tokens_ .ids .into_iter() @@ -472,7 +483,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap) { // Create and enter a span to link this function back to the entry let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); let err = InferError::GenerationError(error.to_string()); - metrics::increment_counter!("tgi_request_failure", "err" => "generation"); + metrics::counter!("tgi_request_failure", "err" => "generation").increment(1); tracing::error!("{err}"); // unwrap_or is valid here as we don't care if the receiver is gone. diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index ba65b9b6..894d9cab 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -126,7 +126,7 @@ async fn queue_task( match cmd { QueueCommand::Append(entry, span) => { span.in_scope(|| state.append(*entry)); - metrics::increment_gauge!("tgi_queue_size", 1.0); + metrics::gauge!("tgi_queue_size").increment(1.0); } QueueCommand::NextBatch { min_size, @@ -141,7 +141,7 @@ async fn queue_task( .instrument(span) .await; response_sender.send(next_batch).unwrap(); - metrics::gauge!("tgi_queue_size", state.entries.len() as f64); + metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64); } } } @@ -248,7 +248,7 @@ impl State { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) if entry.response_tx.is_closed() { - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); tracing::debug!("Dropping entry"); continue; } @@ -399,7 +399,7 @@ impl State { // Increment batch id self.next_batch_id += 1; - metrics::histogram!("tgi_batch_next_size", batch.size as f64); + metrics::histogram!("tgi_batch_next_size").record(batch.size as f64); Some((batch_entries, batch, next_batch_span)) } diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index 543ce89f..26cd9584 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -154,8 +154,8 @@ pub(crate) async fn batching_task( let batch_size = batch.size; let batch_max_tokens = batch.max_tokens; let mut batches = vec![batch]; - metrics::gauge!("tgi_batch_current_size", batch_size as f64); - metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64); + metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); + metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); let min_size = if waiting_tokens >= max_waiting_tokens { // If we didn't onboard any new requests since >= max_waiting_tokens, we try @@ -176,9 +176,11 @@ pub(crate) async fn batching_task( { // Tracking metrics if min_size.is_some() { - metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure"); + metrics::counter!("tgi_batch_concat", "reason" => "backpressure") + .increment(1); } else { - metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded"); + metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded") + .increment(1); } entries.iter_mut().for_each(|(_, entry)| { @@ -225,8 +227,8 @@ pub(crate) async fn batching_task( .await; waiting_tokens += 1; } - metrics::gauge!("tgi_batch_current_size", 0.0); - metrics::gauge!("tgi_batch_current_max_tokens", 0.0); + metrics::gauge!("tgi_batch_current_size").set(0.0); + metrics::gauge!("tgi_batch_current_max_tokens").set(0.0); } } } @@ -240,7 +242,7 @@ async fn prefill( ) -> Option { let start_time = Instant::now(); let batch_id = batch.id; - metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill"); + metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1); match client.prefill(batch).await { Ok((generations, next_batch, timings)) => { @@ -254,11 +256,15 @@ async fn prefill( // Filter next batch and remove requests that were stopped let next_batch = filter_batch(client, next_batch, entries).await; - metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill"); - metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill"); - metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill"); - metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill"); - metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill"); + metrics::histogram!("tgi_batch_forward_duration","method" => "prefill") + .record(timings.forward.as_secs_f64()); + metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill") + .record(timings.decode.as_secs_f64()); + metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill") + .record(start_filtering_time.elapsed().as_secs_f64()); + metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill") + .record(start_time.elapsed().as_secs_f64()); + metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1); next_batch } // If we have an error, we discard the whole batch @@ -267,7 +273,7 @@ async fn prefill( generation_health.store(false, Ordering::SeqCst); let _ = client.clear_cache(Some(batch_id)).await; send_errors(err, entries); - metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill"); + metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1); None } } @@ -282,7 +288,7 @@ async fn decode( ) -> Option { let start_time = Instant::now(); let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); - metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode"); + metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1); match client.decode(batches).await { Ok((generations, next_batch, timings)) => { @@ -297,13 +303,18 @@ async fn decode( let next_batch = filter_batch(client, next_batch, entries).await; if let Some(concat_duration) = timings.concat { - metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode"); + metrics::histogram!("tgi_batch_concat_duration", "method" => "decode") + .record(concat_duration.as_secs_f64()); } - metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode"); - metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode"); - metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode"); - metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode"); - metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode"); + metrics::histogram!("tgi_batch_forward_duration", "method" => "decode") + .record(timings.forward.as_secs_f64()); + metrics::histogram!("tgi_batch_decode_duration", "method" => "decode") + .record(timings.decode.as_secs_f64()); + metrics::histogram!("tgi_batch_filter_duration", "method" => "decode") + .record(start_filtering_time.elapsed().as_secs_f64()); + metrics::histogram!("tgi_batch_inference_duration", "method" => "decode") + .record(start_time.elapsed().as_secs_f64()); + metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1); next_batch } // If we have an error, we discard the whole batch @@ -313,7 +324,7 @@ async fn decode( let _ = client.clear_cache(Some(id)).await; } send_errors(err, entries); - metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode"); + metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1); None } } @@ -371,7 +382,7 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap "dropped"); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); err }).unwrap_or(true); if stopped { @@ -387,7 +398,7 @@ fn send_responses( ) -> Result>>> { // Return directly if the channel is disconnected if entry.response_tx.is_closed() { - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); return Ok(true); } @@ -413,7 +424,7 @@ fn send_responses( // Create last Token let tokens_ = generation.tokens.expect("Non empty tokens in generation"); let n = tokens_.ids.len(); - metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64); + metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64); let mut iterator = tokens_ .ids .into_iter() @@ -478,7 +489,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap) { // Create and enter a span to link this function back to the entry let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); let err = InferError::GenerationError(error.to_string()); - metrics::increment_counter!("tgi_request_failure", "err" => "generation"); + metrics::counter!("tgi_request_failure", "err" => "generation").increment(1); tracing::error!("{err}"); // unwrap_or is valid here as we don't care if the receiver is gone. diff --git a/router/src/server.rs b/router/src/server.rs index db8b16ad..4af8962e 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -185,7 +185,7 @@ pub(crate) async fn generate_internal( span: tracing::Span, ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { let start_time = Instant::now(); - metrics::increment_counter!("tgi_request_count"); + metrics::counter!("tgi_request_count").increment(1); // Do not long ultra long inputs, like image payloads. tracing::debug!("Input: {}", &req.inputs[..1000.min(req.inputs.len())]); @@ -301,25 +301,15 @@ pub(crate) async fn generate_internal( ); // Metrics - metrics::increment_counter!("tgi_request_success"); - metrics::histogram!("tgi_request_duration", total_time.as_secs_f64()); - metrics::histogram!( - "tgi_request_validation_duration", - validation_time.as_secs_f64() - ); - metrics::histogram!("tgi_request_queue_duration", queue_time.as_secs_f64()); - metrics::histogram!( - "tgi_request_inference_duration", - inference_time.as_secs_f64() - ); - metrics::histogram!( - "tgi_request_mean_time_per_token_duration", - time_per_token.as_secs_f64() - ); - metrics::histogram!( - "tgi_request_generated_tokens", - response.generated_text.generated_tokens as f64 - ); + metrics::counter!("tgi_request_success").increment(1); + metrics::histogram!("tgi_request_duration").record(total_time.as_secs_f64()); + metrics::histogram!("tgi_request_validation_duration").record(validation_time.as_secs_f64()); + metrics::histogram!("tgi_request_queue_duration").record(queue_time.as_secs_f64()); + metrics::histogram!("tgi_request_inference_duration").record(inference_time.as_secs_f64()); + metrics::histogram!("tgi_request_mean_time_per_token_duration") + .record(time_per_token.as_secs_f64()); + metrics::histogram!("tgi_request_generated_tokens") + .record(response.generated_text.generated_tokens as f64); // Send response let mut output_text = response.generated_text.text; @@ -399,7 +389,7 @@ async fn generate_stream_internal( span: tracing::Span, ) -> (HeaderMap, impl Stream>) { let start_time = Instant::now(); - metrics::increment_counter!("tgi_request_count"); + metrics::counter!("tgi_request_count").increment(1); tracing::debug!("Input: {}", req.inputs); @@ -427,12 +417,12 @@ async fn generate_stream_internal( let best_of = req.parameters.best_of.unwrap_or(1); if best_of != 1 { let err = InferError::from(ValidationError::BestOfStream); - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); tracing::error!("{err}"); yield Ok(Event::from(err)); } else if req.parameters.decoder_input_details { let err = InferError::from(ValidationError::PrefillDetailsStream); - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); tracing::error!("{err}"); yield Ok(Event::from(err)); } else { @@ -500,13 +490,13 @@ async fn generate_stream_internal( span.record("seed", format!("{:?}", generated_text.seed)); // Metrics - metrics::increment_counter!("tgi_request_success"); - metrics::histogram!("tgi_request_duration", total_time.as_secs_f64()); - metrics::histogram!("tgi_request_validation_duration", validation_time.as_secs_f64()); - metrics::histogram!("tgi_request_queue_duration", queue_time.as_secs_f64()); - metrics::histogram!("tgi_request_inference_duration", inference_time.as_secs_f64()); - metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token.as_secs_f64()); - metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64); + metrics::counter!("tgi_request_success").increment(1); + metrics::histogram!("tgi_request_duration").record(total_time.as_secs_f64()); + metrics::histogram!("tgi_request_validation_duration").record(validation_time.as_secs_f64()); + metrics::histogram!("tgi_request_queue_duration").record(queue_time.as_secs_f64()); + metrics::histogram!("tgi_request_inference_duration").record(inference_time.as_secs_f64()); + metrics::histogram!("tgi_request_mean_time_per_token_duration").record(time_per_token.as_secs_f64()); + metrics::histogram!("tgi_request_generated_tokens").record(generated_text.generated_tokens as f64); // StreamResponse end_reached = true; @@ -553,7 +543,7 @@ async fn generate_stream_internal( // Skip if we already sent an error if !end_reached && !error { let err = InferError::IncompleteGeneration; - metrics::increment_counter!("tgi_request_failure", "err" => "incomplete"); + metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1); tracing::error!("{err}"); yield Ok(Event::from(err)); } @@ -604,7 +594,7 @@ async fn completions( Json(req): Json, ) -> Result)> { let span = tracing::Span::current(); - metrics::increment_counter!("tgi_request_count"); + metrics::counter!("tgi_request_count").increment(1); let CompletionRequest { max_tokens, @@ -625,7 +615,7 @@ async fn completions( // if suffix is present throw an error if req.suffix.is_some() { - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); return Err(( StatusCode::UNPROCESSABLE_ENTITY, Json(ErrorResponse { @@ -637,7 +627,7 @@ async fn completions( } if req.prompt.0.len() > info.max_client_batch_size { - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); return Err(( StatusCode::UNPROCESSABLE_ENTITY, Json(ErrorResponse { @@ -1009,7 +999,7 @@ async fn chat_completions( Json(req): Json, ) -> Result)> { let span = tracing::Span::current(); - metrics::increment_counter!("tgi_request_count"); + metrics::counter!("tgi_request_count").increment(1); let ChatRequest { logprobs, max_tokens, @@ -1039,7 +1029,7 @@ async fn chat_completions( // response_format and tools are mutually exclusive if response_format.is_some() && tools.as_ref().is_some() { - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); return Err(( StatusCode::UNPROCESSABLE_ENTITY, Json(ErrorResponse { @@ -1053,7 +1043,7 @@ async fn chat_completions( let tool_grammar = match ToolGrammar::apply(tools, tool_choice) { Ok(grammar) => grammar, Err(err) => { - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); tracing::error!("{err}"); return Err(( StatusCode::UNPROCESSABLE_ENTITY, @@ -1082,7 +1072,7 @@ async fn chat_completions( let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) { Ok(inputs) => inputs, Err(err) => { - metrics::increment_counter!("tgi_request_failure", "err" => "validation"); + metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); tracing::error!("{err}"); return Err(( StatusCode::UNPROCESSABLE_ENTITY, @@ -1280,7 +1270,7 @@ async fn vertex_compatibility( Json(req): Json, ) -> Result)> { let span = tracing::Span::current(); - metrics::increment_counter!("tgi_request_count"); + metrics::counter!("tgi_request_count").increment(1); // check that theres at least one instance if req.instances.is_empty() { diff --git a/router/src/validation.rs b/router/src/validation.rs index 12cf2ab3..07ad14c9 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -157,7 +157,7 @@ impl Validation { )); } - metrics::histogram!("tgi_request_input_length", input_length as f64); + metrics::histogram!("tgi_request_input_length").record(input_length as f64); Ok((inputs, input_length, max_new_tokens)) } // Return inputs without validation @@ -384,7 +384,7 @@ impl Validation { ignore_eos_token: false, }; - metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64); + metrics::histogram!("tgi_request_max_new_tokens").record(max_new_tokens as f64); Ok(ValidGenerateRequest { inputs, From 87ebb6477bfc2a573f5ca7fa196fa87454dc6dc4 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 8 Jul 2024 10:06:49 -0400 Subject: [PATCH 14/24] feat: use model name as adapter id in chat endpoints (#2128) --- router/src/lib.rs | 4 ++-- router/src/server.rs | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index 165b2ad2..080c029a 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -384,7 +384,7 @@ pub struct CompletionRequest { /// UNUSED #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] /// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. - pub model: String, + pub model: Option, /// The prompt to generate completions for. #[schema(example = "What is Deep Learning?")] @@ -731,7 +731,7 @@ impl ChatCompletionChunk { pub(crate) struct ChatRequest { #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] /// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. - pub model: String, + pub model: Option, /// A list of messages comprising the conversation so far. #[schema(example = "[{\"role\": \"user\", \"content\": \"What is Deep Learning?\"}]")] diff --git a/router/src/server.rs b/router/src/server.rs index 4af8962e..4b52710d 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -597,6 +597,7 @@ async fn completions( metrics::counter!("tgi_request_count").increment(1); let CompletionRequest { + model, max_tokens, seed, stop, @@ -665,7 +666,7 @@ async fn completions( seed, top_n_tokens: None, grammar: None, - ..Default::default() + adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from), }, }) .collect(); @@ -1001,6 +1002,7 @@ async fn chat_completions( let span = tracing::Span::current(); metrics::counter!("tgi_request_count").increment(1); let ChatRequest { + model, logprobs, max_tokens, messages, @@ -1106,7 +1108,7 @@ async fn chat_completions( seed, top_n_tokens: req.top_logprobs, grammar, - ..Default::default() + adapter_id: model.filter(|m| *m != "tgi").map(String::from), }, }; From 4c50b6d04bbf4db0d61ae6a04c9f44662b608c52 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Mon, 8 Jul 2024 17:52:10 +0200 Subject: [PATCH 15/24] Fix nccl regression on PyTorch 2.3 upgrade (#2099) * fix nccl issue * add note in dockerfile * use v2.22.3 that also fixes @samsamoa's repro * poetry actually can't handle the conflict between torch and nccl * set LD_PRELOAD --- Dockerfile | 7 ++++++- server/Makefile | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index d4772b4a..3f2e8ef0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -40,7 +40,9 @@ RUN cargo build --profile release-opt # Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS pytorch-install +# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099 ARG PYTORCH_VERSION=2.3.0 + ARG PYTHON_VERSION=3.10 # Keep in sync with `server/pyproject.toml ARG CUDA_VERSION=12.1 @@ -241,7 +243,10 @@ COPY server/Makefile server/Makefile RUN cd server && \ make gen-server && \ pip install -r requirements_cuda.txt && \ - pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir + pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir && \ + pip install nvidia-nccl-cu12==2.22.3 + +ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2 # Deps before the binaries # The binaries change on every build given we burn the SHA into them diff --git a/server/Makefile b/server/Makefile index 0099c56a..d701c819 100644 --- a/server/Makefile +++ b/server/Makefile @@ -35,5 +35,5 @@ run-dev: SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded export-requirements: - poetry export -o requirements_cuda.txt --without-hashes + poetry export -o requirements_cuda.txt --without-hashes --with cuda poetry export -o requirements_rocm.txt --without-hashes From 5e2a305880f6d356a0a94db338b1c3db8d9db89a Mon Sep 17 00:00:00 2001 From: Guillaume LEGENDRE Date: Mon, 8 Jul 2024 18:13:32 +0200 Subject: [PATCH 16/24] Fix buildx cache + change runner type (#2176) * Update build.yaml * Update build.yaml * change to S3 cache * change to CPU Runners * remove comments --- .github/workflows/build.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 8213887f..3705a4c7 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -28,7 +28,7 @@ jobs: group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }} cancel-in-progress: true # TODO see with @Glegendre to get CPU runner here instead - runs-on: [self-hosted, nvidia-gpu , multi-gpu, 4-a10, ci] + runs-on: [self-hosted, intel-cpu, 32-cpu, 256-ram, ci] permissions: contents: write packages: write @@ -135,9 +135,9 @@ jobs: GIT_SHA=${{ env.GITHUB_SHA }} DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }} tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }} - labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }} - cache-from: type=registry,ref=registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference:cache${{ env.LABEL }},mode=min - cache-to: type=registry,ref=registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference:cache${{ env.LABEL }},mode=min + labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }} + cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min + cache-to: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min - name: Final id: final run: | From fe710af25f9297afca1ef2d974a0def654775bb7 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 9 Jul 2024 11:13:48 +0200 Subject: [PATCH 17/24] Adding sanity check to openapi docs. --- .github/workflows/autodocs.yaml | 5 ++ Cargo.lock | 8 ++-- docs/openapi.json | 64 +++++++++++++++++++++++-- docs/source/basic_tutorials/launcher.md | 19 +------- router/src/server.rs | 6 ++- update_doc.py | 20 ++++++-- 6 files changed, 91 insertions(+), 31 deletions(-) diff --git a/.github/workflows/autodocs.yaml b/.github/workflows/autodocs.yaml index 8af0b95d..e0a759c5 100644 --- a/.github/workflows/autodocs.yaml +++ b/.github/workflows/autodocs.yaml @@ -30,6 +30,10 @@ jobs: id: install-router run: cargo install --path router/ + - uses: actions/setup-node@v4 + with: + node-version: 22 + - name: Set up Python uses: actions/setup-python@v2 with: @@ -37,4 +41,5 @@ jobs: - name: Check that documentation is up-to-date run: | + npm install -g swagger-ui python update_doc.py --check diff --git a/Cargo.lock b/Cargo.lock index a8a04c71..ffc98baa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3740,7 +3740,7 @@ dependencies = [ [[package]] name = "text-generation-benchmark" -version = "2.1.1-dev0" +version = "2.1.2-dev0" dependencies = [ "average", "clap", @@ -3761,7 +3761,7 @@ dependencies = [ [[package]] name = "text-generation-client" -version = "2.1.1-dev0" +version = "2.1.2-dev0" dependencies = [ "async-trait", "base64 0.22.1", @@ -3779,7 +3779,7 @@ dependencies = [ [[package]] name = "text-generation-launcher" -version = "2.1.1-dev0" +version = "2.1.2-dev0" dependencies = [ "clap", "ctrlc", @@ -3798,7 +3798,7 @@ dependencies = [ [[package]] name = "text-generation-router" -version = "2.1.1-dev0" +version = "2.1.2-dev0" dependencies = [ "async-stream", "axum 0.7.5", diff --git a/docs/openapi.json b/docs/openapi.json index 9c9a8b1a..f368f30f 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -809,7 +809,6 @@ "ChatRequest": { "type": "object", "required": [ - "model", "messages" ], "properties": { @@ -854,7 +853,8 @@ "model": { "type": "string", "description": "[UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", - "example": "mistralai/Mistral-7B-Instruct-v0.2" + "example": "mistralai/Mistral-7B-Instruct-v0.2", + "nullable": true }, "n": { "type": "integer", @@ -1116,7 +1116,6 @@ "CompletionRequest": { "type": "object", "required": [ - "model", "prompt" ], "properties": { @@ -1138,7 +1137,8 @@ "model": { "type": "string", "description": "UNUSED\nID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", - "example": "mistralai/Mistral-7B-Instruct-v0.2" + "example": "mistralai/Mistral-7B-Instruct-v0.2", + "nullable": true }, "prompt": { "$ref": "#/components/schemas/Prompt" @@ -1708,6 +1708,62 @@ } } }, + "MessageChunk": { + "oneOf": [ + { + "type": "object", + "required": [ + "text", + "type" + ], + "properties": { + "text": { + "type": "string" + }, + "type": { + "type": "string", + "enum": [ + "text" + ] + } + } + }, + { + "type": "object", + "required": [ + "image_url", + "type" + ], + "properties": { + "image_url": { + "$ref": "#/components/schemas/Url" + }, + "type": { + "type": "string", + "enum": [ + "image_url" + ] + } + } + } + ], + "discriminator": { + "propertyName": "type" + } + }, + "MessageContent": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "$ref": "#/components/schemas/MessageChunk" + } + } + ] + }, "PrefillToken": { "type": "object", "required": [ diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index 5e40146f..1e5b6fd2 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -62,9 +62,7 @@ Options: Possible values: - awq: 4 bit quantization. Requires a specific AWQ quantized model: . Should replace GPTQ models wherever possible because of the better latency - eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from - - exl2: Variable bit quantization. Requires a specific EXL2 quantized model: . Requires exllama2 kernels and does not support tensor parallelism (num_shard > 1) - gptq: 4 bit quantization. Requires a specific GTPQ quantized model: . text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels - - marlin: 4 bit quantization. Requires a specific Marlin quantized model: - bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16 - bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16 - bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model @@ -126,7 +124,7 @@ Options: ## MAX_TOP_N_TOKENS ```shell --max-top-n-tokens - This is the maximum allowed value for clients to set `top_n_tokens`. `top_n_tokens` is used to return information about the the `n` most likely tokens at each generation step, instead of just the sampled token. This information can be used for downstream tasks like for classification or ranking + This is the maximum allowed value for clients to set `top_n_tokens`. `top_n_tokens is used to return information about the the `n` most likely tokens at each generation step, instead of just the sampled token. This information can be used for downstream tasks like for classification or ranking [env: MAX_TOP_N_TOKENS=] [default: 5] @@ -336,13 +334,6 @@ Options: --otlp-endpoint [env: OTLP_ENDPOINT=] -``` -## OTLP_SERVICE_NAME -```shell - --otlp-service-name - [env: OTLP_SERVICE_NAME=] - [default: text-generation-inference.router] - ``` ## CORS_ALLOW_ORIGIN ```shell @@ -416,14 +407,6 @@ Options: [env: MAX_CLIENT_BATCH_SIZE=] [default: 4] -``` -## LORA_ADAPTERS -```shell - --lora-adapters - Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during startup that will be available to callers via the `adapter_id` field in a request - - [env: LORA_ADAPTERS=] - ``` ## HELP ```shell diff --git a/router/src/server.rs b/router/src/server.rs index 4b52710d..8cc09af3 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -13,8 +13,8 @@ use crate::validation::ValidationError; use crate::{ BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info, - Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, - Usage, Validation, + Message, MessageChunk, MessageContent, PrefillToken, SimpleToken, StreamDetails, + StreamResponse, Token, TokenizeResponse, Usage, Validation, }; use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, @@ -1446,6 +1446,8 @@ pub async fn run( GrammarType, ChatRequest, Message, + MessageContent, + MessageChunk, ChatCompletionComplete, ChatCompletionChoice, ChatCompletionDelta, diff --git a/update_doc.py b/update_doc.py index 1ff94a2c..03b5c792 100644 --- a/update_doc.py +++ b/update_doc.py @@ -155,7 +155,7 @@ def check_openapi(check: bool): filename, ], capture_output=True, - ).stdout.decode() + ).stdout.decode("utf-8") os.remove(tmp_filename) if diff: @@ -164,11 +164,25 @@ def check_openapi(check: bool): "OpenAPI documentation is not up-to-date, run `python update_doc.py` in order to update it" ) - return True else: os.rename(tmp_filename, filename) print("OpenAPI documentation updated.") - return True + errors = subprocess.run( + [ + "swagger-cli", + # allow for trailing whitespace since it's not significant + # and the precommit hook will remove it + "validate", + filename, + ], + capture_output=True, + ).stderr.decode("utf-8") + if errors: + print(errors) + raise Exception( + f"OpenAPI documentation is invalid, `swagger-cli validate` showed some error:\n {errors}" + ) + return True def main(): From f5ba9bfd52c859852aed93fe2b54b7e1a7fc0bc9 Mon Sep 17 00:00:00 2001 From: vinkamath <42322982+vinkamath@users.noreply.github.com> Date: Tue, 9 Jul 2024 02:22:08 -0700 Subject: [PATCH 18/24] Fixed README ToC (#2196) Co-authored-by: Vinayak Kamath --- README.md | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 4c1c1e29..4287c119 100644 --- a/README.md +++ b/README.md @@ -20,19 +20,20 @@ to power Hugging Chat, the Inference API and Inference Endpoint. ## Table of contents -- [Get Started](#get-started) - - [API Documentation](#api-documentation) - - [Using a private or gated model](#using-a-private-or-gated-model) - - [A note on Shared Memory](#a-note-on-shared-memory-shm) - - [Distributed Tracing](#distributed-tracing) - - [Local Install](#local-install) - - [CUDA Kernels](#cuda-kernels) -- [Optimized architectures](#optimized-architectures) -- [Run Mistral](#run-a-model) - - [Run](#run) - - [Quantization](#quantization) -- [Develop](#develop) -- [Testing](#testing) + - [Get Started](#get-started) + - [Docker](#docker) + - [API documentation](#api-documentation) + - [Using a private or gated model](#using-a-private-or-gated-model) + - [A note on Shared Memory (shm)](#a-note-on-shared-memory-shm) + - [Distributed Tracing](#distributed-tracing) + - [Architecture](#architecture) + - [Local install](#local-install) + - [Optimized architectures](#optimized-architectures) + - [Run locally](#run-locally) + - [Run](#run) + - [Quantization](#quantization) + - [Develop](#develop) + - [Testing](#testing) Text Generation Inference (TGI) is a toolkit for deploying and serving Large Language Models (LLMs). TGI enables high-performance text generation for the most popular open-source LLMs, including Llama, Falcon, StarCoder, BLOOM, GPT-NeoX, and [more](https://huggingface.co/docs/text-generation-inference/supported_models). TGI implements many features, such as: From 4c976fb4064f95b6604745b81c91c6b7bbd20072 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 9 Jul 2024 17:23:48 +0200 Subject: [PATCH 19/24] Updating the self check (#2209) * Updating the self check * Fix. * Revert the CLI . * cli. * Space. * Revert cargo update. --- .github/workflows/autodocs.yaml | 2 +- .github/workflows/build.yaml | 2 +- docs/openapi.json | 88 ++++++++++++++++++++++++- docs/source/basic_tutorials/launcher.md | 19 +++++- router/src/lib.rs | 2 +- router/src/server.rs | 19 ++++-- update_doc.py | 4 +- 7 files changed, 123 insertions(+), 13 deletions(-) diff --git a/.github/workflows/autodocs.yaml b/.github/workflows/autodocs.yaml index e0a759c5..e10b232c 100644 --- a/.github/workflows/autodocs.yaml +++ b/.github/workflows/autodocs.yaml @@ -41,5 +41,5 @@ jobs: - name: Check that documentation is up-to-date run: | - npm install -g swagger-ui + npm install -g swagger-cli python update_doc.py --check diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 3705a4c7..cd9f19ba 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -135,7 +135,7 @@ jobs: GIT_SHA=${{ env.GITHUB_SHA }} DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }} tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }} - labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }} + labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }} cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min cache-to: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min - name: Final diff --git a/docs/openapi.json b/docs/openapi.json index f368f30f..3e7050ab 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -492,12 +492,12 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/Completion" + "$ref": "#/components/schemas/CompletionFinal" } }, "text/event-stream": { "schema": { - "$ref": "#/components/schemas/CompletionCompleteChunk" + "$ref": "#/components/schemas/Chunk" } } } @@ -1324,6 +1324,17 @@ } } }, + "FunctionName": { + "type": "object", + "required": [ + "name" + ], + "properties": { + "name": { + "type": "string" + } + } + }, "GenerateParameters": { "type": "object", "properties": { @@ -1764,6 +1775,16 @@ } ] }, + "OutputMessage": { + "oneOf": [ + { + "$ref": "#/components/schemas/TextMessage" + }, + { + "$ref": "#/components/schemas/ToolCallMessage" + } + ] + }, "PrefillToken": { "type": "object", "required": [ @@ -1890,6 +1911,23 @@ } } }, + "TextMessage": { + "type": "object", + "required": [ + "role", + "content" + ], + "properties": { + "content": { + "type": "string", + "example": "My name is David and I" + }, + "role": { + "type": "string", + "example": "user" + } + } + }, "Token": { "type": "object", "required": [ @@ -1962,6 +2000,41 @@ } } }, + "ToolCallDelta": { + "type": "object", + "required": [ + "role", + "tool_calls" + ], + "properties": { + "role": { + "type": "string", + "example": "assistant" + }, + "tool_calls": { + "$ref": "#/components/schemas/DeltaToolCall" + } + } + }, + "ToolCallMessage": { + "type": "object", + "required": [ + "role", + "tool_calls" + ], + "properties": { + "role": { + "type": "string", + "example": "assistant" + }, + "tool_calls": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ToolCall" + } + } + } + }, "ToolType": { "oneOf": [ { @@ -1985,6 +2058,17 @@ } ] }, + "Url": { + "type": "object", + "required": [ + "url" + ], + "properties": { + "url": { + "type": "string" + } + } + }, "Usage": { "type": "object", "required": [ diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index 1e5b6fd2..5e40146f 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -62,7 +62,9 @@ Options: Possible values: - awq: 4 bit quantization. Requires a specific AWQ quantized model: . Should replace GPTQ models wherever possible because of the better latency - eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from + - exl2: Variable bit quantization. Requires a specific EXL2 quantized model: . Requires exllama2 kernels and does not support tensor parallelism (num_shard > 1) - gptq: 4 bit quantization. Requires a specific GTPQ quantized model: . text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels + - marlin: 4 bit quantization. Requires a specific Marlin quantized model: - bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16 - bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16 - bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model @@ -124,7 +126,7 @@ Options: ## MAX_TOP_N_TOKENS ```shell --max-top-n-tokens - This is the maximum allowed value for clients to set `top_n_tokens`. `top_n_tokens is used to return information about the the `n` most likely tokens at each generation step, instead of just the sampled token. This information can be used for downstream tasks like for classification or ranking + This is the maximum allowed value for clients to set `top_n_tokens`. `top_n_tokens` is used to return information about the the `n` most likely tokens at each generation step, instead of just the sampled token. This information can be used for downstream tasks like for classification or ranking [env: MAX_TOP_N_TOKENS=] [default: 5] @@ -334,6 +336,13 @@ Options: --otlp-endpoint [env: OTLP_ENDPOINT=] +``` +## OTLP_SERVICE_NAME +```shell + --otlp-service-name + [env: OTLP_SERVICE_NAME=] + [default: text-generation-inference.router] + ``` ## CORS_ALLOW_ORIGIN ```shell @@ -407,6 +416,14 @@ Options: [env: MAX_CLIENT_BATCH_SIZE=] [default: 4] +``` +## LORA_ADAPTERS +```shell + --lora-adapters + Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during startup that will be available to callers via the `adapter_id` field in a request + + [env: LORA_ADAPTERS=] + ``` ## HELP ```shell diff --git a/router/src/lib.rs b/router/src/lib.rs index 080c029a..f856406d 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -848,7 +848,7 @@ pub enum ToolType { Function { function: FunctionName }, } -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)] pub struct FunctionName { pub name: String, } diff --git a/router/src/server.rs b/router/src/server.rs index 8cc09af3..4e5af99c 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -11,10 +11,11 @@ use crate::kserve::{ }; use crate::validation::ValidationError; use crate::{ - BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, - GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info, - Message, MessageChunk, MessageContent, PrefillToken, SimpleToken, StreamDetails, - StreamResponse, Token, TokenizeResponse, Usage, Validation, + BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters, + GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, + HubTokenizerConfig, Info, Message, MessageChunk, MessageContent, OutputMessage, PrefillToken, + SimpleToken, StreamDetails, StreamResponse, TextMessage, Token, TokenizeResponse, + ToolCallDelta, ToolCallMessage, Url, Usage, Validation, }; use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, @@ -562,8 +563,8 @@ request_body = CompletionRequest, responses( (status = 200, description = "Generated Chat Completion", content( -("application/json" = Completion), -("text/event-stream" = CompletionCompleteChunk), +("application/json" = CompletionFinal), +("text/event-stream" = Chunk), )), (status = 424, description = "Generation Error", body = ErrorResponse, example = json ! ({"error": "Request failed during generation"})), @@ -1448,6 +1449,12 @@ pub async fn run( Message, MessageContent, MessageChunk, + Url, + FunctionName, + OutputMessage, + TextMessage, + ToolCallMessage, + ToolCallDelta, ChatCompletionComplete, ChatCompletionChoice, ChatCompletionDelta, diff --git a/update_doc.py b/update_doc.py index 03b5c792..bfa7e4e9 100644 --- a/update_doc.py +++ b/update_doc.py @@ -177,7 +177,9 @@ def check_openapi(check: bool): ], capture_output=True, ).stderr.decode("utf-8") - if errors: + # The openapi specs fails on `exclusive_minimum` which is expected to be a boolean where + # utoipa outputs a value instead: https://github.com/juhaku/utoipa/issues/969 + if not errors.startswith("Swagger schema validation failed."): print(errors) raise Exception( f"OpenAPI documentation is invalid, `swagger-cli validate` showed some error:\n {errors}" From 8511669cb29115bdf0bc2da5328e69d041030996 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 9 Jul 2024 20:04:03 +0200 Subject: [PATCH 20/24] Move quantized weight handling out of the `Weights` class (#2194) Quantized weights were loaded in the `Weights` class, but this was getting quite unwieldy, where every higher level method to load weights was a long conditional to cover all the different quantizers. This change moves loading of quantized weights out of the `Weights` class. This is done by defining a simple `WeightsLoader` interface that is implemented by `Exl2WeightsLoader`, `GPTQWeightsLoader`, and `MarlinWeightsLoader`. These implementations are in the quantizers' respective modules. The `Weights` class provides the low-level load operations (such as loading tensors or sharded tensors), but delegates loads that need quantizer-specific weight processing to a loader. The loaders still use the low-level functionality provided by `Weights`. I initially tried making a hierarchy where a class like `GPTQWeights` would inherit from `Weights`. But it is not very flexible (e.g. does not work well with the new weight storage mock used in tests) and the implicit indirections made the code harder to follow. --- server/tests/utils/test_layers.py | 8 +- server/tests/utils/test_weights.py | 162 ++-- server/text_generation_server/layers/exl2.py | 60 ++ .../layers/gptq/__init__.py | 354 ++++++++- .../layers/gptq/quantize.py | 3 + .../text_generation_server/layers/marlin.py | 130 +++- .../layers/tensor_parallel.py | 17 +- .../models/causal_lm.py | 12 +- .../custom_modeling/flash_cohere_modeling.py | 1 - .../custom_modeling/flash_gemma2_modeling.py | 1 - .../custom_modeling/flash_gemma_modeling.py | 1 - .../custom_modeling/flash_gpt2_modeling.py | 7 +- .../custom_modeling/flash_mixtral_modeling.py | 1 - .../custom_modeling/flash_neox_modeling.py | 4 +- .../custom_modeling/flash_phi_modeling.py | 1 - .../custom_modeling/flash_rw_modeling.py | 2 +- .../flash_santacoder_modeling.py | 19 +- .../flash_starcoder2_modeling.py | 1 - .../models/flash_causal_lm.py | 11 +- .../text_generation_server/models/idefics.py | 5 + server/text_generation_server/models/mamba.py | 12 +- .../models/seq2seq_lm.py | 5 + .../utils/quantization.py | 119 +++ .../text_generation_server/utils/weights.py | 691 +++--------------- 24 files changed, 896 insertions(+), 731 deletions(-) create mode 100644 server/text_generation_server/utils/quantization.py diff --git a/server/tests/utils/test_layers.py b/server/tests/utils/test_layers.py index 9a8da0d6..1e3aaf6b 100644 --- a/server/tests/utils/test_layers.py +++ b/server/tests/utils/test_layers.py @@ -2,6 +2,7 @@ import torch from text_generation_server.layers import ( TensorParallelEmbedding, ) +from text_generation_server.utils.weights import DefaultWeightsLoader class ProcessGroup: @@ -42,7 +43,12 @@ class Weights: def test_weight_hub_files_offline_error(): vocab_size = 17 - weights = Weights(rank=0, world_size=1, vocab_size=vocab_size, hidden_dim=256) + weights = Weights( + rank=0, + world_size=1, + vocab_size=vocab_size, + hidden_dim=256, + ) embeddings = TensorParallelEmbedding("", weights) input_ids = torch.arange(vocab_size) diff --git a/server/tests/utils/test_weights.py b/server/tests/utils/test_weights.py index 8f88b1f8..36b27be8 100644 --- a/server/tests/utils/test_weights.py +++ b/server/tests/utils/test_weights.py @@ -1,13 +1,47 @@ import pytest import torch -from text_generation_server.utils.weights import Weights -from text_generation_server.layers.gptq import GPTQWeight -from text_generation_server.layers.exl2 import Exl2Weight -from text_generation_server.layers.marlin import MarlinWeight +from text_generation_server.utils.weights import ( + DefaultWeightsLoader, + Weights, + WeightsLoader, +) +from text_generation_server.layers.gptq import GPTQWeight, GPTQWeightsLoader +from text_generation_server.layers.exl2 import Exl2Weight, Exl2WeightsLoader +from text_generation_server.layers.marlin import MarlinWeight, MarlinWeightsLoader from types import SimpleNamespace from typing import List, Optional, Dict, Union from pathlib import Path + +@pytest.fixture +def gptq_weights_loader(): + return GPTQWeightsLoader( + bits=4, + groupsize=-1, + desc_act=False, + quant_method="gptq", + quantize="gptq", + sym=True, + ) + + +@pytest.fixture +def gptq_weights_loader_awq(): + return GPTQWeightsLoader( + bits=4, + groupsize=-1, + desc_act=False, + quant_method="awq", + quantize="awq", + sym=True, + ) + + +@pytest.fixture +def marlin_weights_loader(): + return MarlinWeightsLoader(bits=4, is_marlin_24=False) + + dummy_file_system = { "test_weights": { "layer.0.weight": torch.tensor( @@ -58,7 +92,7 @@ dummy_file_system = { dtype=torch.float32, ), }, - "test_get_multi_weights_row": { + "test_get_weights_row": { "weight.weight": torch.tensor( [ [1, 2], @@ -101,7 +135,7 @@ dummy_file_system = { "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "weight.s": torch.tensor([[0.5000], [0.2500]], dtype=torch.float16), }, - "test_get_multi_weights_row_gptq": { + "test_get_weights_row_gptq": { "weight.qweight": torch.tensor( [ [1, 2], @@ -200,7 +234,7 @@ dummy_file_system = { "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), "weight.q_groups": torch.tensor([4], dtype=torch.int16), }, - "test_get_multi_weights_row_exl2": { + "test_get_weights_row_exl2": { "weight.q_weight": torch.tensor( [ [1, 2], @@ -245,7 +279,7 @@ dummy_file_system = { "weight.q_scale_max": torch.tensor([100], dtype=torch.float16), "weight.q_groups": torch.tensor([4], dtype=torch.int16), }, - "test_get_multi_weights_row_marlin": { + "test_get_weights_row_marlin": { "weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32), "weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16), }, @@ -308,6 +342,7 @@ class MockWeights(Weights): dummy_fs, aliases: Optional[Dict[str, List[str]]] = None, prefix: Optional[str] = None, + weights_loader: Optional[WeightsLoader] = None, ): routing = {} self.dummy_fs = dummy_fs @@ -327,6 +362,9 @@ class MockWeights(Weights): self.dtype = dtype self.process_group = process_group self.prefix = prefix + self.weights_loader = ( + DefaultWeightsLoader() if weights_loader is None else weights_loader + ) self._handles = {} def _get_handle(self, filename: Union[Path, str]): @@ -412,12 +450,10 @@ def test_get_weights_col_packed(): ) prefix = "weight" - quantize = None block_sizes = 1 w = weights.get_weights_col_packed( prefix=prefix, - quantize=quantize, block_sizes=block_sizes, ) @@ -448,12 +484,10 @@ def test_get_weights_col_packed_block_size(): ) prefix = "weight" - quantize = None block_sizes = 2 w = weights.get_weights_col_packed( prefix=prefix, - quantize=quantize, block_sizes=block_sizes, ) @@ -484,12 +518,10 @@ def test_get_weights_col_packed_block_size_arr(): ) prefix = "weight" - quantize = None block_sizes = [1, 1] w = weights.get_weights_col_packed( prefix=prefix, - quantize=quantize, block_sizes=block_sizes, ) @@ -519,11 +551,9 @@ def test_get_multi_weights_col(): ) prefixes = ["weight", "weight"] - quantize = None w = weights.get_multi_weights_col( prefixes=prefixes, - quantize=quantize, dim=0, ) @@ -545,10 +575,10 @@ def test_get_multi_weights_col(): ) -def test_get_multi_weights_row(): +def test_get_weights_row(): weights = MockWeights( [ - "test_get_multi_weights_row", + "test_get_weights_row", ], device="cpu", dtype=torch.float32, @@ -557,11 +587,9 @@ def test_get_multi_weights_row(): ) prefix = "weight" - quantize = None - w = weights.get_multi_weights_row( + w = weights.get_weights_row( prefix=prefix, - quantize=quantize, ) assert torch.allclose( @@ -576,7 +604,7 @@ def test_get_multi_weights_row(): # test_get_weights_col -def test_get_weights_col_awq(): +def test_get_weights_col_awq(gptq_weights_loader_awq): weights = MockWeights( [ "test_get_weights_col_gptq", @@ -585,14 +613,13 @@ def test_get_weights_col_awq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader_awq, ) prefix = "weight" - quantize = "awq" w = weights.get_weights_col( prefix=prefix, - quantize=quantize, ) expected_weight = GPTQWeight( @@ -617,7 +644,7 @@ def test_get_weights_col_awq(): assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" -def test_get_weights_col_gtpq(): +def test_get_weights_col_gtpq(gptq_weights_loader): weights = MockWeights( [ "test_get_weights_col_gptq", @@ -626,14 +653,13 @@ def test_get_weights_col_gtpq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader, ) prefix = "weight" - quantize = "gptq" w = weights.get_weights_col( prefix=prefix, - quantize=quantize, ) expected_weight = GPTQWeight( @@ -664,14 +690,13 @@ def test_get_weights_col_exl2(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=Exl2WeightsLoader(), ) prefix = "weight" - quantize = "exl2" w = weights.get_weights_col( prefix=prefix, - quantize=quantize, ) scaled_scale_max = 0.3906 * 256 @@ -692,7 +717,7 @@ def test_get_weights_col_exl2(): assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" -def test_get_weights_col_marlin(): +def test_get_weights_col_marlin(marlin_weights_loader): weights = MockWeights( [ "test_get_weights_col_marlin", @@ -701,14 +726,13 @@ def test_get_weights_col_marlin(): dtype=torch.float16, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=marlin_weights_loader, ) prefix = "weight" - quantize = "marlin" w = weights.get_weights_col( prefix=prefix, - quantize=quantize, ) expected_weight = MarlinWeight( @@ -723,7 +747,7 @@ def test_get_weights_col_marlin(): # test_get_weights_col_packed -def test_get_weights_col_packed_awq(): +def test_get_weights_col_packed_awq(gptq_weights_loader_awq): weights = MockWeights( [ "test_get_weights_col_packed_gptq", @@ -732,15 +756,14 @@ def test_get_weights_col_packed_awq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader_awq, ) prefix = "weight" - quantize = "awq" block_sizes = 1 w = weights.get_weights_col_packed( prefix=prefix, - quantize=quantize, block_sizes=block_sizes, ) @@ -773,15 +796,14 @@ def test_get_weights_col_packed_exl2(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=Exl2WeightsLoader(), ) prefix = "weight" - quantize = "exl2" block_sizes = 1 w = weights.get_weights_col_packed( prefix=prefix, - quantize=quantize, block_sizes=block_sizes, ) @@ -803,7 +825,7 @@ def test_get_weights_col_packed_exl2(): assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" -def test_get_weights_col_packed_gptq(): +def test_get_weights_col_packed_gptq(gptq_weights_loader): weights = MockWeights( [ "test_get_weights_col_packed_gptq", @@ -812,14 +834,13 @@ def test_get_weights_col_packed_gptq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader, ) prefixes = ["weight"] - quantize = "gptq" w = weights.get_multi_weights_col( prefixes=prefixes, - quantize=quantize, dim=0, ) @@ -842,7 +863,7 @@ def test_get_weights_col_packed_gptq(): assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" -def test_get_weights_col_packed_marlin(): +def test_get_weights_col_packed_marlin(marlin_weights_loader): weights = MockWeights( [ "test_get_weights_col_packed_marlin", @@ -851,14 +872,13 @@ def test_get_weights_col_packed_marlin(): dtype=torch.float16, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=marlin_weights_loader, ) prefix = "weight" - quantize = "marlin" w = weights.get_multi_weights_col( prefixes=[prefix], - quantize=quantize, dim=0, ) @@ -876,7 +896,7 @@ def test_get_weights_col_packed_marlin(): # test_get_multi_weights_col -def test_get_multi_weights_col_awq(): +def test_get_multi_weights_col_awq(gptq_weights_loader_awq): weights = MockWeights( [ "test_get_multi_weights_col_gptq", @@ -885,14 +905,13 @@ def test_get_multi_weights_col_awq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader_awq, ) prefixes = ["weight"] - quantize = "awq" w = weights.get_multi_weights_col( prefixes=prefixes, - quantize=quantize, dim=0, ) @@ -924,22 +943,21 @@ def test_get_multi_weights_col_exl2(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=Exl2WeightsLoader(), ) prefix = "weight" - quantize = "exl2" try: w = weights.get_multi_weights_col( prefixes=[prefix], - quantize=quantize, dim=0, ) except ValueError as e: assert e.args[0] == "get_multi_weights_col is not supported for exl2" -def test_get_multi_weights_col_gptq(): +def test_get_multi_weights_col_gptq(gptq_weights_loader): weights = MockWeights( [ "test_get_multi_weights_col_gptq", @@ -948,14 +966,13 @@ def test_get_multi_weights_col_gptq(): dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader, ) prefixes = ["weight"] - quantize = "gptq" w = weights.get_multi_weights_col( prefixes=prefixes, - quantize=quantize, dim=0, ) @@ -978,7 +995,7 @@ def test_get_multi_weights_col_gptq(): assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" -def test_get_multi_weights_col_marlin(): +def test_get_multi_weights_col_marlin(marlin_weights_loader): weights = MockWeights( [ "test_get_multi_weights_col_marlin", @@ -987,14 +1004,13 @@ def test_get_multi_weights_col_marlin(): dtype=torch.float16, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=marlin_weights_loader, ) prefix = "weight" - quantize = "marlin" w = weights.get_multi_weights_col( prefixes=[prefix], - quantize=quantize, dim=0, ) @@ -1007,26 +1023,25 @@ def test_get_multi_weights_col_marlin(): assert torch.allclose(w.s, expected_weight.s), "s mismatch" -# test_get_multi_weights_row +# test_get_weights_row -def test_get_multi_weights_row_awq(): +def test_get_weights_row_awq(gptq_weights_loader_awq): weights = MockWeights( [ - "test_get_multi_weights_row_gptq", + "test_get_weights_row_gptq", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader_awq, ) prefix = "weight" - quantize = "awq" - w = weights.get_multi_weights_row( + w = weights.get_weights_row( prefix=prefix, - quantize=quantize, ) expected_weight = GPTQWeight( @@ -1048,23 +1063,22 @@ def test_get_multi_weights_row_awq(): assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" -def test_get_multi_weights_row_exl2(): +def test_get_weights_row_exl2(): weights = MockWeights( [ - "test_get_multi_weights_row_exl2", + "test_get_weights_row_exl2", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=Exl2WeightsLoader(), ) prefix = "weight" - quantize = "exl2" - w = weights.get_multi_weights_row( + w = weights.get_weights_row( prefix=prefix, - quantize=quantize, ) print(w) @@ -1086,23 +1100,22 @@ def test_get_multi_weights_row_exl2(): assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch" -def test_get_multi_weights_row_gptq(): +def test_get_weights_row_gptq(gptq_weights_loader): weights = MockWeights( [ - "test_get_multi_weights_row_gptq", + "test_get_weights_row_gptq", ], device="cpu", dtype=torch.float32, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=gptq_weights_loader, ) prefix = "weight" - quantize = "gptq" - w = weights.get_multi_weights_row( + w = weights.get_weights_row( prefix=prefix, - quantize=quantize, ) expected_weight = GPTQWeight( @@ -1124,23 +1137,22 @@ def test_get_multi_weights_row_gptq(): assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" -def test_get_multi_weights_row_marlin(): +def test_get_weights_row_marlin(marlin_weights_loader): weights = MockWeights( [ - "test_get_multi_weights_row_marlin", + "test_get_weights_row_marlin", ], device="cpu", dtype=torch.float16, process_group=dummy_process_group, dummy_fs=dummy_file_system, + weights_loader=marlin_weights_loader, ) prefix = "weight" - quantize = "marlin" - w = weights.get_multi_weights_row( + w = weights.get_weights_row( prefix=prefix, - quantize=quantize, ) expected_weight = MarlinWeight( diff --git a/server/text_generation_server/layers/exl2.py b/server/text_generation_server/layers/exl2.py index f6cb729e..55cba1cc 100644 --- a/server/text_generation_server/layers/exl2.py +++ b/server/text_generation_server/layers/exl2.py @@ -1,6 +1,9 @@ import torch +from typing import List, Union from dataclasses import dataclass +from text_generation_server.utils.weights import WeightsLoader, Weights + @dataclass class Exl2Weight: @@ -21,3 +24,60 @@ class Exl2Weight: @property def device(self) -> torch.device: return self.q_weight.device + + +class Exl2WeightsLoader(WeightsLoader): + """Loader for exl2-quantized weights.""" + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + raise RuntimeError("Column-packed weights are not supported for exl") + + def get_weights_col(self, weights: Weights, prefix: str): + try: + q_weight = weights.get_tensor(f"{prefix}.q_weight") + except RuntimeError: + raise RuntimeError( + "Cannot load `exl2`-quantized weight, make sure the model is already quantized." + ) + + q_scale = weights.get_tensor(f"{prefix}.q_scale") + q_invperm = weights.get_tensor(f"{prefix}.q_invperm") + q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max") + q_groups = weights.get_tensor(f"{prefix}.q_groups") + + return Exl2Weight( + q_weight=q_weight, + q_scale=q_scale, + q_invperm=q_invperm, + q_scale_max=q_scale_max, + q_groups=q_groups, + ) + + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + raise ValueError("get_multi_weights_col is not supported for exl2") + + def get_weights_row(self, weights: Weights, prefix: str): + try: + q_weight = weights.get_tensor(f"{prefix}.q_weight") + except RuntimeError: + raise RuntimeError( + "Cannot load `exl2`-quantized weight, make sure the model is already quantized." + ) + + q_scale = weights.get_tensor(f"{prefix}.q_scale") + q_invperm = weights.get_tensor(f"{prefix}.q_invperm") + q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max") + q_groups = weights.get_tensor(f"{prefix}.q_groups") + + return Exl2Weight( + q_weight=q_weight, + q_scale=q_scale, + q_invperm=q_invperm, + q_scale_max=q_scale_max, + q_groups=q_groups, + ) diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index 56080145..efcb3118 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -1,20 +1,14 @@ from dataclasses import dataclass +from loguru import logger import os -from typing import Optional +from typing import List, Optional, Union +from safetensors import SafetensorError +from text_generation_server.utils.weights import Weights, WeightsLoader import torch from text_generation_server.utils.import_utils import ( SYSTEM, ) - - -@dataclass -class GPTQParams: - bits: int - checkpoint_format: Optional[str] - groupsize: int - desc_act: bool - quant_method: str - sym: bool +from text_generation_server.utils.log import log_once @dataclass @@ -69,3 +63,341 @@ elif CAN_EXLLAMA: pass from text_generation_server.layers.gptq.quant_linear import QuantLinear + + +class GPTQWeightsLoader(WeightsLoader): + """ + Loader for GPTQ- and AWQ-quantized weights. + """ + + def __init__( + self, + *, + bits: int, + desc_act: bool, + groupsize: int, + quant_method: str, + quantize: str, + sym: bool, + ): + self.bits = bits + self.desc_act = desc_act + self.groupsize = groupsize + self.quant_method = quant_method + self.quantize = quantize + self.sym = sym + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + from text_generation_server.layers.marlin import ( + can_use_gptq_marlin, + repack_gptq_for_marlin, + ) + + try: + qweight = weights.get_packed_sharded( + f"{prefix}.qweight", dim=1, block_sizes=block_sizes + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight, make sure the model is already quantized." + ) + scales = weights.get_packed_sharded( + f"{prefix}.scales", dim=1, block_sizes=block_sizes + ) + scales = scales.to(dtype=weights.dtype) + + self._get_gptq_params(weights) + if can_use_gptq_marlin( + bits=self.bits, + groupsize=self.groupsize, + quant_method=self.quant_method, + quantize=self.quantize, + sym=self.sym, + ): + g_idx = weights.get_tensor(f"{prefix}.g_idx") + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + sym=self.sym, + sharded_infeatures=False, + ) + + qzeros = weights.get_packed_sharded( + f"{prefix}.qzeros", dim=1, block_sizes=block_sizes + ) + if self.quantize == "gptq" and self.quant_method == "gptq": + g_idx = weights.get_tensor(f"{prefix}.g_idx") + elif self.quantize == "gptq" and self.quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.layers.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // self.bits), + device=qweight.device, + ) + // self.groupsize + ).to(dtype=torch.int32) + else: + g_idx = None + + return GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=self.bits, + groupsize=self.groupsize, + use_exllama=False, + ) + + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + from text_generation_server.layers.marlin import ( + can_use_gptq_marlin, + repack_gptq_for_marlin, + ) + + try: + qweight = torch.cat( + [weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight, make sure the model is already quantized" + ) + + scales = torch.cat( + [weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 + ) + + self._get_gptq_params(weights) + if can_use_gptq_marlin( + bits=self.bits, + groupsize=self.groupsize, + quant_method=self.quant_method, + quantize=self.quantize, + sym=self.sym, + ): + w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + sym=self.sym, + sharded_infeatures=False, + ) + + qzeros = torch.cat( + [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 + ) + + from text_generation_server.layers.gptq import HAS_EXLLAMA + + use_exllama = ( + self.bits == 4 + and HAS_EXLLAMA + and self.quantize == "gptq" + and not self.desc_act + ) + + if self.quantize == "gptq" and self.quant_method == "gptq": + w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + elif self.quantize == "gptq" and self.quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.layers.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + if use_exllama: + g_idx = None + else: + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // self.bits), + device=qweight.device, + ) + // self.groupsize + ).to(dtype=torch.int32) + else: + g_idx = None + + return GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=self.bits, + groupsize=self.groupsize, + use_exllama=use_exllama, + ) + + def get_weights_row(self, weights: Weights, prefix: str): + from text_generation_server.layers.marlin import ( + can_use_gptq_marlin, + repack_gptq_for_marlin, + ) + + self._get_gptq_params(weights) + if can_use_gptq_marlin( + bits=self.bits, + groupsize=self.groupsize, + quant_method=self.quant_method, + quantize=self.quantize, + sym=self.sym, + ): + log_once(logger.info, "Using GPTQ-Marlin kernels") + try: + qweight = weights.get_sharded(f"{prefix}.qweight", dim=0) + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" + ) + + g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) + if self.desc_act or self.groupsize == -1: + scales = weights.get_tensor(f"{prefix}.scales") + else: + scales = weights.get_sharded(f"{prefix}.scales", dim=0) + + sharded_in_features = weights.process_group.size() > 1 + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + sym=self.sym, + sharded_infeatures=sharded_in_features, + ) + + use_exllama = True + if self.bits != 4: + use_exllama = False + + if self.desc_act: + log_once(logger.warning, "Disabling exllama because desc_act=True") + use_exllama = False + + try: + qweight = weights.get_sharded(f"{prefix}.qweight", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + + if self.quantize == "gptq" and self.quant_method == "gptq": + g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0) + else: + g_idx = None + + if weights.process_group.size() > 1: + if g_idx is not None: + if ( + not torch.equal( + g_idx.cpu(), + torch.tensor( + [i // self.groupsize for i in range(g_idx.shape[0])], + dtype=torch.int32, + ), + ) + and not (g_idx == 0).all() + ): + # Exllama implementation does not support row tensor parallelism with act-order, as + # it would require to reorder input activations that are split unto several GPUs + use_exllama = False + + from text_generation_server.layers.gptq import ( + HAS_EXLLAMA, + CAN_EXLLAMA, + GPTQWeight, + ) + + if use_exllama: + if not HAS_EXLLAMA: + if CAN_EXLLAMA: + log_once( + logger.warning, + "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True", + ) + use_exllama = False + else: + log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") + + if use_exllama and self.groupsize != -1: + qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0) + scales = weights.get_sharded(f"{prefix}.scales", dim=0) + else: + qzeros = weights.get_tensor(f"{prefix}.qzeros") + scales = weights.get_tensor(f"{prefix}.scales") + + if use_exllama and g_idx is not None: + g_idx = g_idx - g_idx[0] + + if self.quantize == "gptq" and self.quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.layers.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + if use_exllama: + g_idx = None + else: + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // self.bits), + device=qweight.device, + ) + // self.groupsize + ).to(dtype=torch.int32) + + return GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=self.bits, + groupsize=self.groupsize, + use_exllama=use_exllama, + ) + + def _get_gptq_params(self, weights: Weights): + try: + self.bits = weights.get_tensor("gptq_bits").item() + self.groupsize = weights.get_tensor("gptq_groupsize").item() + self.desc_act = False + self.sym = False + self.quant_method = "gptq" + except (SafetensorError, RuntimeError) as e: + pass diff --git a/server/text_generation_server/layers/gptq/quantize.py b/server/text_generation_server/layers/gptq/quantize.py index 8d029817..c65d5e78 100644 --- a/server/text_generation_server/layers/gptq/quantize.py +++ b/server/text_generation_server/layers/gptq/quantize.py @@ -16,6 +16,8 @@ from text_generation_server.layers.gptq.quant_linear import QuantLinear from loguru import logger from typing import Optional +from text_generation_server.utils.weights import DefaultWeightsLoader + DEV = torch.device("cuda:0") @@ -891,6 +893,7 @@ def quantize( dtype=torch.float16, process_group=process_group, aliases={"embed_tokens.weight": ["lm_head.weight"]}, + weights_loader=DefaultWeightsLoader(), ) hooks = [] for name, module in model.named_modules(): diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py index a1af67a3..ecb88e76 100644 --- a/server/text_generation_server/layers/marlin.py +++ b/server/text_generation_server/layers/marlin.py @@ -1,10 +1,10 @@ from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union +from text_generation_server.utils.weights import Weights, WeightsLoader import torch import torch.nn as nn -from text_generation_server.layers.gptq import GPTQParams from text_generation_server.utils.import_utils import SYSTEM try: @@ -24,16 +24,132 @@ GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128] MARLIN_TILE_SIZE = 16 -def can_use_gptq_marlin(gptq_params: GPTQParams, quantize: str) -> bool: +class MarlinWeightsLoader(WeightsLoader): + """Loader for Marlin-quantized weights.""" + + def __init__(self, *, bits: int, is_marlin_24: bool): + self.bits = bits + self.is_marlin_24 = is_marlin_24 + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + if self.is_marlin_24: + B = weights.get_packed_sharded( + f"{prefix}.B_24", dim=1, block_sizes=block_sizes + ) + B_meta = weights.get_packed_sharded( + f"{prefix}.B_meta", dim=1, block_sizes=block_sizes + ) + s = weights.get_packed_sharded( + f"{prefix}.s", dim=1, block_sizes=block_sizes + ) + + weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + else: + B = weights.get_packed_sharded( + f"{prefix}.B", dim=1, block_sizes=block_sizes + ) + s = weights.get_packed_sharded( + f"{prefix}.s", dim=1, block_sizes=block_sizes + ) + weight = MarlinWeight(B=B, s=s) + + return weight + + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): + is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" + if is_marlin_24: + try: + B = torch.cat( + [weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `marlin` weight, make sure the model is already quantized" + ) + + B_meta = torch.cat( + [weights.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1 + ) + + s = torch.cat( + [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 + ) + + weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + else: + try: + B = torch.cat( + [weights.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1 + ) + except RuntimeError: + raise RuntimeError( + f"Cannot load `marlin` weight, make sure the model is already quantized" + ) + s = torch.cat( + [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 + ) + + weight = MarlinWeight(B=B, s=s) + + return weight + + def get_weights_row(self, weights: Weights, prefix: str): + is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" + if is_marlin_24: + try: + B = weights.get_sharded(f"{prefix}.B_24", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." + ) + + B_meta = weights.get_sharded(f"{prefix}.B_meta", dim=0) + num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0] + if num_groups == 1: + # The number of groups is 1 when groupsize == -1. share + # scales between all shards in this case. + s = weights.get_tensor(f"{prefix}.s") + else: + s = weights.get_sharded(f"{prefix}.s", dim=0) + + weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + else: + try: + B = weights.get_sharded(f"{prefix}.B", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` weight, make sure the model is already quantized." + ) + + num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0] + if num_groups == 1: + # The number of groups is 1 when groupsize == -1. share + # scales between all shards in this case. + s = weights.get_tensor(f"{prefix}.s") + else: + s = weights.get_sharded(f"{prefix}.s", dim=0) + weight = MarlinWeight(B=B, s=s) + + return weight + + +def can_use_gptq_marlin( + *, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool +) -> bool: return ( SYSTEM == "cuda" and marlin_kernels is not None and has_sm_8_0 and quantize == "gptq" - and gptq_params.quant_method == "gptq" - and gptq_params.bits in GPTQ_MARLIN_BITS - and gptq_params.groupsize in GPTQ_MARLIN_GROUP_SIZES - and gptq_params.sym + and quant_method == "gptq" + and bits in GPTQ_MARLIN_BITS + and groupsize in GPTQ_MARLIN_GROUP_SIZES + and sym ) diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py index 038de258..011f105b 100644 --- a/server/text_generation_server/layers/tensor_parallel.py +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -52,7 +52,7 @@ class TensorParallelHead(SuperLayer): weight = weights.get_tensor(f"{prefix}.weight") except: # ...otherwise they are quantized. - weight = weights.get_weights_col(prefix, config.quantize) + weight = weights.get_weights_col(prefix) should_gather = weights.process_group.size() > 1 elif weights.process_group.size() > 1: try: @@ -129,9 +129,7 @@ class TensorParallelColumnLinear(SuperLayer): @classmethod def load_gate_up(cls, config, prefix: str, weights, bias: bool): """Specific method when the QKV was joined after the fact""" - weight = weights.get_weights_col_packed_gate_up( - prefix, quantize=config.quantize - ) + weight = weights.get_weights_col_packed_gate_up(prefix) if bias: raise NotImplementedError("packed_gate_up only implemented without bias") else: @@ -152,7 +150,6 @@ class TensorParallelColumnLinear(SuperLayer): """Specific method when the QKV was joined after the fact""" weight = weights.get_weights_col_packed_qkv( prefix, - quantize=config.quantize, num_heads=num_heads, num_key_value_heads=num_key_value_heads, ) @@ -165,7 +162,7 @@ class TensorParallelColumnLinear(SuperLayer): @classmethod def load(cls, config, prefix: str, weights, bias: bool): - weight = weights.get_weights_col(prefix, config.quantize) + weight = weights.get_weights_col(prefix) if bias: bias = weights.get_sharded(f"{prefix}.bias", dim=0) else: @@ -178,14 +175,12 @@ class TensorParallelColumnLinear(SuperLayer): if config.quantize == "exl2": linears = [] for prefix in prefixes: - weight = weights.get_weights_col(prefix, config.quantize) + weight = weights.get_weights_col(prefix) b = weights.get_tensor(f"{prefix}.bias") if bias else None linears.append(get_linear(weight, b, config.quantize)) linear = LayerConcat(linears) else: - weight = weights.get_multi_weights_col( - prefixes, quantize=config.quantize, dim=dim - ) + weight = weights.get_multi_weights_col(prefixes, dim=dim) if bias: b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] bias = torch.cat(b, dim=dim) @@ -202,7 +197,7 @@ class TensorParallelRowLinear(SuperLayer): @classmethod def load(cls, config, prefix: str, weights, bias: bool): - weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + weight = weights.get_weights_row(prefix) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 868a3cc0..0ea82b1e 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -20,6 +20,7 @@ from text_generation_server.utils import ( from text_generation_server.models import Model from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.models.types import ( Batch, @@ -546,12 +547,17 @@ class CausalLM(Model): tokenizer.pad_token_id = config.pad_token_id torch.distributed.barrier(group=self.process_group) + weights_loader = get_loader( + quantize=quantize, model_id=model_id, revision=revision + ) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group + filenames, + device=device, + dtype=dtype, + process_group=self.process_group, + weights_loader=weights_loader, ) - if config.quantize in ["awq", "exl2", "gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) prefix = "" model = model_class(prefix, config, weights) diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index f993fe72..25719b99 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -163,7 +163,6 @@ def _load_gqa(config, prefix: str, weights): weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index beff08b3..a3ce5521 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -141,7 +141,6 @@ def _load_gqa(config, prefix: str, weights): weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 14b62b00..34a7efa2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -141,7 +141,6 @@ def _load_gqa(config, prefix: str, weights): weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index d5dc25cf..cbfcb1b8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -61,7 +61,6 @@ def _load_qkv_gptq(config, prefix: str, weights): # Weights weight = weights.get_weights_col_packed_qkv( f"{prefix}.c_attn", - config.quantize, config.num_attention_heads, config.num_attention_heads, ) @@ -137,7 +136,7 @@ def load_row(config, prefix: str, weights, bias: bool): """load_row, but with transposed weight matrices.""" if config.quantize == "gptq": - weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + weight = weights.get_weights_row(prefix) else: weight = weights.get_sharded(f"{prefix}.weight", dim=0).T @@ -155,9 +154,7 @@ def load_row(config, prefix: str, weights, bias: bool): def load_col(config, prefix: str, weights, bias: bool): """load_col, but with transposed weight matrices.""" if config.quantize == "gptq": - weight = weights.get_multi_weights_col( - [prefix], quantize=config.quantize, dim=1 - ) + weight = weights.get_multi_weights_col([prefix], dim=1) else: weight = weights.get_sharded(f"{prefix}.weight", dim=1).T diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 429793ea..49c0e903 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -135,7 +135,6 @@ def _load_gqa(config, prefix: str, weights): weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 0eca181b..85dcb2a6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -48,7 +48,7 @@ from text_generation_server.layers.rotary import ( def load_row(config, prefix: str, weights, bias: bool): - weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + weight = weights.get_weights_row(prefix) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process @@ -64,7 +64,7 @@ def load_row(config, prefix: str, weights, bias: bool): def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size): - weight = weights.get_multi_weights_col([prefix], quantize=config.quantize, dim=0) + weight = weights.get_multi_weights_col([prefix], dim=0) if isinstance(weight, torch.Tensor): # Only on non quantized versions weight = ( diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 7401bc27..6c508264 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -85,7 +85,6 @@ def _load_gqa(config, prefix: str, weights): weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 4813e2df..65b40fed 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -23,7 +23,7 @@ from text_generation_server.layers.attention import ( def load_row(config, prefix: str, weights, bias: bool): - weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + weight = weights.get_weights_row(prefix) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 21a22046..77b9d49c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -17,6 +17,7 @@ from text_generation_server.layers import ( TensorParallelEmbedding, get_linear, ) +from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.layernorm import ( FastLayerNorm, ) @@ -81,11 +82,13 @@ def _load_multi_mqa_gptq( qzeros = torch.cat([q_tensor, kv_tensor], dim=1) qzeros = qzeros.to(device=weights.device) - gptq_params = weights._get_gptq_params() - if gptq_params.quant_method == "gptq": + loader = weights.weights_loader + assert isinstance(loader, GPTQWeightsLoader) + loader._get_gptq_params(weights) + if loader.quant_method == "gptq": g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") g_idx = g_idx.to(device=weights.device) - elif gptq_params.quant_method == "awq": + elif loader.quant_method == "awq": g_idx = None from text_generation_server.layers.awq.conversion_utils import ( fast_awq_to_gptq, @@ -100,8 +103,8 @@ def _load_multi_mqa_gptq( qzeros=qzeros, scales=scales, g_idx=g_idx, - bits=gptq_params.bits, - groupsize=gptq_params.groupsize, + bits=loader.bits, + groupsize=loader.groupsize, use_exllama=HAS_EXLLAMA, ) @@ -197,9 +200,7 @@ def load_col(config, prefix: str, weights, bias: bool): if config.transpose: weight = weights.get_sharded(f"{prefix}.weight", dim=1).T else: - weight = weights.get_multi_weights_col( - [prefix], quantize=config.quantize, dim=0 - ) + weight = weights.get_multi_weights_col([prefix], dim=0) if bias: bias = weights.get_sharded(f"{prefix}.bias", dim=0) @@ -212,7 +213,7 @@ def load_row(config, prefix: str, weights, bias: bool): if config.transpose: weight = weights.get_sharded(f"{prefix}.weight", dim=0).T else: - weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + weight = weights.get_weights_row(prefix) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 2b346283..19556f78 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -126,7 +126,6 @@ def _load_gqa(config, prefix: str, weights): weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, dim=0, ) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index bf1fda4a..2ca9eef3 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -50,6 +50,7 @@ from text_generation_server.models.globals import ( from text_generation_server.layers.attention import Seqlen from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION +from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments from text_generation_server.utils.import_utils import ( @@ -881,12 +882,16 @@ class FlashCausalLM(Model): torch.distributed.barrier(group=self.process_group) + weights_loader = get_loader(quantize, model_id, revision) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, device, dtype, process_group=self.process_group, aliases=aliases + filenames, + device, + dtype, + process_group=self.process_group, + aliases=aliases, + weights_loader=weights_loader, ) - if config.quantize in ["awq", "exl2", "gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) prefix = "" model = model_class(prefix, config, weights) diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index f2955bd0..0deab6ce 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -23,6 +23,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) +from text_generation_server.utils.quantization import get_loader class IDEFICSSharded(IdeficsCausalLM): @@ -70,6 +71,9 @@ class IDEFICSSharded(IdeficsCausalLM): trust_remote_code=trust_remote_code, ) + weights_loader = get_loader( + quantize=quantize, model_id=model_id, revision=revision + ) torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( @@ -77,6 +81,7 @@ class IDEFICSSharded(IdeficsCausalLM): device=device, dtype=dtype, process_group=self.process_group, + weights_loader=weights_loader, ) model = IdeficsForVisionText2Text(config, weights) diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 9189b45c..4ed9722c 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -28,6 +28,7 @@ from text_generation_server.models.types import ( GeneratedText, ) from text_generation_server.utils.chunks import concat_text_chunks +from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.tokens import batch_top_tokens, Sampling from dataclasses import dataclass from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling @@ -448,8 +449,17 @@ class Mamba(Model): config.quantize = quantize config.speculator = speculator torch.distributed.barrier(group=self.process_group) + weights_loader = get_loader( + quantize=quantize, model_id=model_id, revision=revision + ) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) + weights = Weights( + filenames, + device, + dtype, + process_group=self.process_group, + weights_loader=weights_loader, + ) model = MambaModel(config, weights) torch.distributed.barrier(group=self.process_group) super(Mamba, self).__init__( diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index dbaf1253..fa8b5025 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -18,6 +18,7 @@ from text_generation_server.utils import ( Weights, ) from text_generation_server.utils.chunks import concat_text_chunks +from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.models import Model from text_generation_server.models.types import ( @@ -586,6 +587,9 @@ class Seq2SeqLM(Model): ) tokenizer.bos_token_id = config.decoder_start_token_id + weights_loader = get_loader( + quantize=quantize, model_id=model_id, revision=revision + ) torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( @@ -594,6 +598,7 @@ class Seq2SeqLM(Model): dtype=dtype, process_group=self.process_group, aliases=aliases, + weights_loader=weights_loader, ) if config.quantize in ["awq", "exl2", "gptq", "marlin"]: weights._set_gptq_params(model_id, revision) diff --git a/server/text_generation_server/utils/quantization.py b/server/text_generation_server/utils/quantization.py new file mode 100644 index 00000000..07975bea --- /dev/null +++ b/server/text_generation_server/utils/quantization.py @@ -0,0 +1,119 @@ +from typing import Optional +import os +import json +from dataclasses import dataclass + +from huggingface_hub import hf_hub_download + +from text_generation_server.utils.weights import DefaultWeightsLoader, WeightsLoader + + +@dataclass +class _QuantizerConfig: + bits: int + checkpoint_format: Optional[str] + desc_act: bool + groupsize: int + quant_method: str + sym: bool + + +# We should probably do this with Pytantic JSON deserialization, +# but for now we'll stay close to the old _set_gptq_params. +def _get_quantizer_config(model_id, revision): + bits = 4 + groupsize = -1 + quant_method = "gptq" + checkpoint_format = None + sym = True + desc_act = False + + filename = "config.json" + try: + if os.path.exists(os.path.join(model_id, filename)): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download(model_id, filename=filename, revision=revision) + with open(filename, "r") as f: + data = json.load(f) + bits = data["quantization_config"]["bits"] + groupsize = data["quantization_config"]["group_size"] + # Order is important here, desc_act is missing on some real models + quant_method = data["quantization_config"]["quant_method"] + checkpoint_format = data["quantization_config"].get("checkpoint_format") + sym = data["quantization_config"]["sym"] + desc_act = data["quantization_config"]["desc_act"] + except Exception: + filename = "quantize_config.json" + try: + if os.path.exists(os.path.join(model_id, filename)): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download( + model_id, filename=filename, revision=revision + ) + with open(filename, "r") as f: + data = json.load(f) + bits = data["bits"] + groupsize = data["group_size"] + sym = data["sym"] + desc_act = data["desc_act"] + if "version" in data and data["version"] == "GEMM": + quant_method = "awq" + except Exception: + filename = "quant_config.json" + try: + if os.path.exists(os.path.join(model_id, filename)): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download( + model_id, filename=filename, revision=revision + ) + with open(filename, "r") as f: + data = json.load(f) + bits = data["w_bit"] + groupsize = data["q_group_size"] + desc_act = data["desc_act"] + if "version" in data and data["version"] == "GEMM": + quant_method = "awq" + except Exception: + pass + + return _QuantizerConfig( + bits=bits, + groupsize=groupsize, + quant_method=quant_method, + checkpoint_format=checkpoint_format, + sym=sym, + desc_act=desc_act, + ) + + +def get_loader( + quantize: Optional[str], model_id: str, revision: Optional[str] +) -> WeightsLoader: + quantizer_config = _get_quantizer_config(model_id, revision) + if quantize in {"awq", "gptq"}: + from text_generation_server.layers.gptq import GPTQWeightsLoader + + return GPTQWeightsLoader( + bits=quantizer_config.bits, + desc_act=quantizer_config.desc_act, + groupsize=quantizer_config.groupsize, + quant_method=quantizer_config.quant_method, + quantize=quantize, + sym=quantizer_config.sym, + ) + elif quantize == "exl2": + from text_generation_server.layers.exl2 import Exl2WeightsLoader + + return Exl2WeightsLoader() + elif quantize == "marlin": + from text_generation_server.layers.marlin import MarlinWeightsLoader + + return MarlinWeightsLoader( + bits=quantizer_config.bits, + is_marlin_24=quantizer_config.checkpoint_format == "marlin_24", + ) + else: + return DefaultWeightsLoader() diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 3731fd24..1a62fb3b 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,13 +1,88 @@ -import os +from abc import ABC, abstractmethod from pathlib import Path from typing import Dict, List, Optional, Union -from safetensors import safe_open, SafetensorError +from safetensors import safe_open import torch -from loguru import logger -from huggingface_hub import hf_hub_download -import json -from text_generation_server.layers.gptq import GPTQParams -from text_generation_server.utils.log import log_once + + +class WeightsLoader(ABC): + """ + Instances of this type implement higher-level weight loading. + + At a low-level, every weight is stored in the Safetensors format. + The interpretation of weights may be different however, for instance + could be packed, quantized weights. Loaders are responsible for + interpreting the raw tensors, sharding tensors in a manner compatible + with the format, etc. + """ + + @abstractmethod + def get_weights_col_packed( + self, + weights: "Weights", + prefix: str, + block_sizes: Union[int, List[int]], + ): + """ + Get the packed weights at the given prefix with column-splitting for + tensor parallelism. This method should be used when multiple different + weights are packed into a tensor, for instance, query/key/value + weights or a gate/up projection. + + The `block_sizes` determines the proportions of the packed tensors. + The columns are split in equally sized blocks when `block_sizes` is an + `int`, or in blocks proportional given to the sizes. For instance + `[2, 1, 1]` will divide an input with dimensionality `1024` in + `[512, 256, 256]`. + """ + ... + + def get_weights_col(self, weights: "Weights", prefix: str): + """ + Get weights at the given prefix and apply column-splitting for tensor + paralllism. + """ + return weights.get_multi_weights_col([prefix], 0) + + @abstractmethod + def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): + """ + Get the weights at the given prefixes, column-split them for tensor + parallelim, and then concatenate the weights along the given dimension. + """ + ... + + @abstractmethod + def get_weights_row(self, weights: "Weights", prefix: str): + """ + Get the weights at the given prefix and apply row-splitting for tensor + parallism. + """ + ... + + +class DefaultWeightsLoader(WeightsLoader): + """ + Loader that uses tensors as-is with the exception of applying sharding + and/or concatenation. + """ + + def get_weights_col_packed( + self, + weights: "Weights", + prefix: str, + block_sizes: Union[int, List[int]], + ): + return weights.get_packed_sharded( + f"{prefix}.weight", dim=0, block_sizes=block_sizes + ) + + def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): + w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] + return torch.cat(w, dim=dim) + + def get_weights_row(self, weights: "Weights", prefix: str): + return weights.get_sharded(f"{prefix}.weight", dim=1) class Weights: @@ -17,6 +92,7 @@ class Weights: device, dtype, process_group, + weights_loader: WeightsLoader, aliases: Optional[Dict[str, List[str]]] = None, prefix: Optional[str] = None, ): @@ -37,6 +113,7 @@ class Weights: self.dtype = dtype self.process_group = process_group self.prefix = prefix + self.weights_loader = weights_loader self._handles = {} def _get_handle(self, filename): @@ -181,295 +258,27 @@ class Weights: num_key_value_heads: int, ): return self.get_weights_col_packed( - prefix, quantize, [num_heads, num_key_value_heads, num_key_value_heads] + prefix, [num_heads, num_key_value_heads, num_key_value_heads] ) - def get_weights_col_packed_gate_up(self, prefix: str, quantize: str): - return self.get_weights_col_packed(prefix, quantize, 2) + def get_weights_col_packed_gate_up(self, prefix: str): + return self.get_weights_col_packed(prefix, 2) - def get_weights_col_packed( - self, prefix: str, quantize: str, block_sizes: Union[int, List[int]] - ): + def get_weights_col_packed(self, prefix: str, block_sizes: Union[int, List[int]]): """ - Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being - already alternating Q,K,V within the main tensor. - The columns are split in equally sized blocks when blocks is an `int`, or in blocks proportional given to the sizes. For instance `[2, 1, 1]` will divide an input with dimensionality `1024` in `[512, 256, 256]`. This is convenient for e.g. splitting QKV without knowing the storage details of quantized weights. """ - if quantize in ["gptq", "awq"]: - from text_generation_server.layers.gptq import GPTQWeight - from text_generation_server.layers.marlin import ( - can_use_gptq_marlin, - repack_gptq_for_marlin, - ) + return self.weights_loader.get_weights_col_packed(self, prefix, block_sizes) - try: - qweight = self.get_packed_sharded( - f"{prefix}.qweight", dim=1, block_sizes=block_sizes - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight, make sure the model is already quantized." - ) - scales = self.get_packed_sharded( - f"{prefix}.scales", dim=1, block_sizes=block_sizes - ) - scales = scales.to(dtype=self.dtype) + def get_weights_col(self, prefix: str): + return self.weights_loader.get_weights_col(self, prefix) - gptq_params = self._get_gptq_params() - if can_use_gptq_marlin(gptq_params, quantize): - g_idx = self.get_tensor(f"{prefix}.g_idx") - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - desc_act=gptq_params.desc_act, - groupsize=gptq_params.groupsize, - sym=gptq_params.sym, - sharded_infeatures=False, - ) - - qzeros = self.get_packed_sharded( - f"{prefix}.qzeros", dim=1, block_sizes=block_sizes - ) - if quantize == "gptq" and gptq_params.quant_method == "gptq": - g_idx = self.get_tensor(f"{prefix}.g_idx") - elif quantize == "gptq" and gptq_params.quant_method == "awq": - log_once( - logger.info, "Converting AWQ model to Exllama/GPTQ packing format." - ) - from text_generation_server.layers.awq.conversion_utils import ( - fast_awq_to_gptq, - ) - - qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - g_idx = ( - torch.arange( - qweight.shape[0] * (32 // gptq_params.bits), - device=qweight.device, - ) - // gptq_params.groupsize - ).to(dtype=torch.int32) - else: - g_idx = None - - weight = GPTQWeight( - qweight=qweight, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - groupsize=gptq_params.groupsize, - use_exllama=False, - ) - elif quantize == "marlin": - from text_generation_server.layers.marlin import ( - GPTQMarlin24Weight, - MarlinWeight, - repack_gptq_for_marlin, - ) - - is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" - if is_marlin_24: - B = self.get_packed_sharded( - f"{prefix}.B_24", dim=1, block_sizes=block_sizes - ) - B_meta = self.get_packed_sharded( - f"{prefix}.B_meta", dim=1, block_sizes=block_sizes - ) - s = self.get_packed_sharded( - f"{prefix}.s", dim=1, block_sizes=block_sizes - ) - - gptq_params = self._get_gptq_params() - weight = GPTQMarlin24Weight( - B=B, B_meta=B_meta, s=s, bits=gptq_params.bits - ) - else: - B = self.get_packed_sharded( - f"{prefix}.B", dim=1, block_sizes=block_sizes - ) - s = self.get_packed_sharded( - f"{prefix}.s", dim=1, block_sizes=block_sizes - ) - weight = MarlinWeight(B=B, s=s) - else: - weight = self.get_packed_sharded( - f"{prefix}.weight", dim=0, block_sizes=block_sizes - ) - return weight - - def get_weights_col(self, prefix: str, quantize: str): - if quantize == "exl2": - from text_generation_server.layers.exl2 import Exl2Weight - - try: - q_weight = self.get_tensor(f"{prefix}.q_weight") - except RuntimeError: - raise RuntimeError( - f"Cannot load `exl2`-quantized weight, make sure the model is already quantized." - ) - - q_scale = self.get_tensor(f"{prefix}.q_scale") - q_invperm = self.get_tensor(f"{prefix}.q_invperm") - q_scale_max = self.get_tensor(f"{prefix}.q_scale_max") - q_groups = self.get_tensor(f"{prefix}.q_groups") - - return Exl2Weight( - q_weight=q_weight, - q_scale=q_scale, - q_invperm=q_invperm, - q_scale_max=q_scale_max, - q_groups=q_groups, - ) - - return self.get_multi_weights_col([prefix], quantize, 0) - - def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int): - if quantize == "exl2": - raise ValueError("get_multi_weights_col is not supported for exl2") - elif quantize in ["gptq", "awq"]: - from text_generation_server.layers.gptq import GPTQWeight - from text_generation_server.layers.marlin import ( - can_use_gptq_marlin, - repack_gptq_for_marlin, - ) - - try: - qweight = torch.cat( - [self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight, make sure the model is already quantized" - ) - - scales = torch.cat( - [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 - ) - - gptq_params = self._get_gptq_params() - if can_use_gptq_marlin(gptq_params, quantize): - w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] - for w2 in w[1:]: - torch.testing.assert_close(w2, w[0]) - g_idx = w[0] - - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - desc_act=gptq_params.desc_act, - groupsize=gptq_params.groupsize, - sym=gptq_params.sym, - sharded_infeatures=False, - ) - - qzeros = torch.cat( - [self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 - ) - - from text_generation_server.layers.gptq import HAS_EXLLAMA - - use_exllama = ( - gptq_params.bits == 4 - and HAS_EXLLAMA - and quantize == "gptq" - and not gptq_params.desc_act - ) - - if quantize == "gptq" and gptq_params.quant_method == "gptq": - w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] - for w2 in w[1:]: - torch.testing.assert_close(w2, w[0]) - g_idx = w[0] - elif quantize == "gptq" and gptq_params.quant_method == "awq": - log_once( - logger.info, "Converting AWQ model to Exllama/GPTQ packing format." - ) - from text_generation_server.layers.awq.conversion_utils import ( - fast_awq_to_gptq, - ) - - qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - if use_exllama: - g_idx = None - else: - g_idx = ( - torch.arange( - qweight.shape[0] * (32 // gptq_params.bits), - device=qweight.device, - ) - // gptq_params.groupsize - ).to(dtype=torch.int32) - else: - g_idx = None - - weight = GPTQWeight( - qweight=qweight, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - groupsize=gptq_params.groupsize, - use_exllama=use_exllama, - ) - elif quantize == "marlin": - from text_generation_server.layers.gptq import GPTQWeight - from text_generation_server.layers.marlin import ( - GPTQMarlin24Weight, - MarlinWeight, - ) - - is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" - if is_marlin_24: - try: - B = torch.cat( - [self.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight, make sure the model is already quantized" - ) - - B_meta = torch.cat( - [self.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1 - ) - - s = torch.cat( - [self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 - ) - - gptq_params = self._get_gptq_params() - weight = GPTQMarlin24Weight( - B=B, B_meta=B_meta, s=s, bits=gptq_params.bits - ) - else: - try: - B = torch.cat( - [self.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1 - ) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight, make sure the model is already quantized" - ) - s = torch.cat( - [self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 - ) - - weight = MarlinWeight(B=B, s=s) - - else: - w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] - weight = torch.cat(w, dim=dim) - - return weight + def get_multi_weights_col(self, prefixes: List[str], dim: int): + return self.weights_loader.get_multi_weights_col(self, prefixes, dim) def get_tensor_shard(self, var, dim): world_size = self.process_group.size() @@ -487,318 +296,8 @@ class Weights: tensor = tensor.to(device=self.device) return tensor - def get_multi_weights_row(self, prefix: str, quantize: str): - if quantize == "exl2": - from text_generation_server.layers.exl2 import Exl2Weight - - try: - q_weight = self.get_tensor(f"{prefix}.q_weight") - except RuntimeError: - raise RuntimeError( - f"Cannot load `exl2`-quantized weight, make sure the model is already quantized." - ) - - q_scale = self.get_tensor(f"{prefix}.q_scale") - q_invperm = self.get_tensor(f"{prefix}.q_invperm") - q_scale_max = self.get_tensor(f"{prefix}.q_scale_max") - q_groups = self.get_tensor(f"{prefix}.q_groups") - - return Exl2Weight( - q_weight=q_weight, - q_scale=q_scale, - q_invperm=q_invperm, - q_scale_max=q_scale_max, - q_groups=q_groups, - ) - - elif quantize == "gptq": - from text_generation_server.layers.marlin import ( - can_use_gptq_marlin, - repack_gptq_for_marlin, - ) - - gptq_params = self._get_gptq_params() - if can_use_gptq_marlin(gptq_params, quantize): - log_once(logger.info, "Using GPTQ-Marlin kernels") - try: - qweight = self.get_sharded(f"{prefix}.qweight", dim=0) - except RuntimeError: - raise RuntimeError( - f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" - ) - - g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - if gptq_params.desc_act or gptq_params.groupsize == -1: - scales = self.get_tensor(f"{prefix}.scales") - else: - scales = self.get_sharded(f"{prefix}.scales", dim=0) - - sharded_in_features = self.process_group.size() > 1 - - return repack_gptq_for_marlin( - qweight=qweight, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - desc_act=gptq_params.desc_act, - groupsize=gptq_params.groupsize, - sym=gptq_params.sym, - sharded_infeatures=sharded_in_features, - ) - - use_exllama = True - if gptq_params.bits != 4: - use_exllama = False - - if gptq_params.desc_act: - log_once(logger.warning, "Disabling exllama because desc_act=True") - use_exllama = False - - try: - qweight = self.get_sharded(f"{prefix}.qweight", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" - ) - - if gptq_params.quant_method == "gptq": - g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - elif gptq_params.quant_method == "awq": - g_idx = None - - if self.process_group.size() > 1: - if g_idx is not None: - if ( - not torch.equal( - g_idx.cpu(), - torch.tensor( - [ - i // gptq_params.groupsize - for i in range(g_idx.shape[0]) - ], - dtype=torch.int32, - ), - ) - and not (g_idx == 0).all() - ): - # Exllama implementation does not support row tensor parallelism with act-order, as - # it would require to reorder input activations that are split unto several GPUs - use_exllama = False - - from text_generation_server.layers.gptq import ( - HAS_EXLLAMA, - CAN_EXLLAMA, - GPTQWeight, - ) - - if use_exllama: - if not HAS_EXLLAMA: - if CAN_EXLLAMA: - log_once( - logger.warning, - "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True", - ) - use_exllama = False - else: - log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") - - if use_exllama and gptq_params.groupsize != -1: - qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) - scales = self.get_sharded(f"{prefix}.scales", dim=0) - else: - qzeros = self.get_tensor(f"{prefix}.qzeros") - scales = self.get_tensor(f"{prefix}.scales") - - if use_exllama and g_idx is not None: - g_idx = g_idx - g_idx[0] - - if gptq_params.quant_method == "awq": - log_once( - logger.info, "Converting AWQ model to Exllama/GPTQ packing format." - ) - from text_generation_server.layers.awq.conversion_utils import ( - fast_awq_to_gptq, - ) - - qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) - if use_exllama: - g_idx = None - else: - g_idx = ( - torch.arange( - qweight.shape[0] * (32 // gptq_params.bits), - device=qweight.device, - ) - // gptq_params.groupsize - ).to(dtype=torch.int32) - - weight = GPTQWeight( - qweight=qweight, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - groupsize=gptq_params.groupsize, - use_exllama=use_exllama, - ) - elif quantize == "awq": - from text_generation_server.layers.gptq import GPTQWeight - - gptq_params = self._get_gptq_params() - - try: - qweight = self.get_sharded(f"{prefix}.qweight", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `awq` weight, make sure the model is already quantized" - ) - - qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) - scales = self.get_sharded(f"{prefix}.scales", dim=0) - g_idx = None - use_exllama = False - - weight = GPTQWeight( - qweight=qweight, - qzeros=qzeros, - scales=scales, - g_idx=g_idx, - bits=gptq_params.bits, - groupsize=gptq_params.groupsize, - use_exllama=use_exllama, - ) - elif quantize == "marlin": - from text_generation_server.layers.gptq import GPTQWeight - from text_generation_server.layers.marlin import ( - GPTQMarlin24Weight, - MarlinWeight, - ) - - is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" - if is_marlin_24: - try: - B = self.get_sharded(f"{prefix}.B_24", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." - ) - - B_meta = self.get_sharded(f"{prefix}.B_meta", dim=0) - num_groups = self._get_slice(f"{prefix}.s").get_shape()[0] - if num_groups == 1: - # The number of groups is 1 when groupsize == -1. share - # scales between all shards in this case. - s = self.get_tensor(f"{prefix}.s") - else: - s = self.get_sharded(f"{prefix}.s", dim=0) - - gptq_params = self._get_gptq_params() - weight = GPTQMarlin24Weight( - B=B, B_meta=B_meta, s=s, bits=gptq_params.bits - ) - else: - try: - B = self.get_sharded(f"{prefix}.B", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `marlin` weight, make sure the model is already quantized." - ) - - num_groups = self._get_slice(f"{prefix}.s").get_shape()[0] - if num_groups == 1: - # The number of groups is 1 when groupsize == -1. share - # scales between all shards in this case. - s = self.get_tensor(f"{prefix}.s") - else: - s = self.get_sharded(f"{prefix}.s", dim=0) - weight = MarlinWeight(B=B, s=s) - else: - weight = self.get_sharded(f"{prefix}.weight", dim=1) - return weight - - def _get_gptq_params(self) -> GPTQParams: - try: - bits = self.get_tensor("gptq_bits").item() - groupsize = self.get_tensor("gptq_groupsize").item() - checkpoint_format = getattr(self, "gptq_checkpoint_format", None) - desc_act = False - sym = False - quant_method = "gptq" - except (SafetensorError, RuntimeError) as e: - try: - bits = self.gptq_bits - groupsize = self.gptq_groupsize - checkpoint_format = getattr(self, "gptq_checkpoint_format", None) - desc_act = getattr(self, "gptq_desc_act", False) - quant_method = getattr(self, "quant_method", "gptq") - sym = getattr(self, "sym", True) - except Exception: - raise e - - return GPTQParams( - bits=bits, - checkpoint_format=checkpoint_format, - desc_act=desc_act, - groupsize=groupsize, - quant_method=quant_method, - sym=sym, - ) - - def _set_gptq_params(self, model_id, revision): - filename = "config.json" - try: - if os.path.exists(os.path.join(model_id, filename)): - filename = os.path.join(model_id, filename) - else: - filename = hf_hub_download( - model_id, filename=filename, revision=revision - ) - with open(filename, "r") as f: - data = json.load(f) - self.gptq_bits = data["quantization_config"]["bits"] - self.gptq_groupsize = data["quantization_config"]["group_size"] - # Order is important here, desc_act is missing on some real models - self.quant_method = data["quantization_config"]["quant_method"] - self.gptq_checkpoint_format = data["quantization_config"].get( - "checkpoint_format" - ) - self.gptq_sym = data["quantization_config"]["sym"] - self.gptq_desc_act = data["quantization_config"]["desc_act"] - except Exception: - filename = "quantize_config.json" - try: - if os.path.exists(os.path.join(model_id, filename)): - filename = os.path.join(model_id, filename) - else: - filename = hf_hub_download( - model_id, filename=filename, revision=revision - ) - with open(filename, "r") as f: - data = json.load(f) - self.gptq_bits = data["bits"] - self.gptq_groupsize = data["group_size"] - self.gptq_sym = data["sym"] - self.gptq_desc_act = data["desc_act"] - if "version" in data and data["version"] == "GEMM": - self.quant_method = "awq" - except Exception: - filename = "quant_config.json" - try: - if os.path.exists(os.path.join(model_id, filename)): - filename = os.path.join(model_id, filename) - else: - filename = hf_hub_download( - model_id, filename=filename, revision=revision - ) - with open(filename, "r") as f: - data = json.load(f) - self.gptq_bits = data["w_bit"] - self.gptq_groupsize = data["q_group_size"] - self.gptq_desc_act = data["desc_act"] - if "version" in data and data["version"] == "GEMM": - self.quant_method = "awq" - except Exception: - pass + def get_weights_row(self, prefix: str): + return self.weights_loader.get_weights_row(self, prefix) def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]: From cb150eb2956534f35852813523c2a677f43793f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 11 Jul 2024 16:03:26 +0200 Subject: [PATCH 21/24] Add support for FP8 on compute capability >=8.0, <8.9 (#2213) Use FP8 GPTQ-Marlin kernels to enable FP8 support on CUDA GPUs with compute capability >=8.0 and <8.9. Co-authored-by: Florian Zimmermeister --- server/marlin/marlin_kernels/__init__.pyi | 15 + server/marlin/marlin_kernels/ext.cpp | 2 + server/marlin/marlin_kernels/ext.hh | 5 + server/marlin/marlin_kernels/fp8_marlin.cu | 1308 +++++++++++++++++ server/marlin/setup.py | 1 + server/text_generation_server/layers/fp8.py | 19 + .../text_generation_server/layers/linear.py | 4 +- .../text_generation_server/layers/marlin.py | 115 +- 8 files changed, 1465 insertions(+), 4 deletions(-) create mode 100644 server/marlin/marlin_kernels/fp8_marlin.cu diff --git a/server/marlin/marlin_kernels/__init__.pyi b/server/marlin/marlin_kernels/__init__.pyi index 663984d0..53464719 100644 --- a/server/marlin/marlin_kernels/__init__.pyi +++ b/server/marlin/marlin_kernels/__init__.pyi @@ -59,3 +59,18 @@ def marlin_gemm( Matrix multiplication using Marlin kernels. """ ... + +# fp8 marlin +def fp8_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + num_bits: int, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return torch.ops._C.fp8_marlin_gemm( + a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k + ) diff --git a/server/marlin/marlin_kernels/ext.cpp b/server/marlin/marlin_kernels/ext.cpp index 37eccef6..04e1530f 100644 --- a/server/marlin/marlin_kernels/ext.cpp +++ b/server/marlin/marlin_kernels/ext.cpp @@ -9,4 +9,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("gptq_marlin_repack", &gptq_marlin_repack, "Repack GPTQ parameters for Marlin"); m.def("marlin_gemm", &marlin_gemm, "Marlin gemm"); + // fp8_marlin Optimized Quantized GEMM for FP8 weight-only. + m.def("fp8_marlin_gemm", &fp8_marlin_gemm); } diff --git a/server/marlin/marlin_kernels/ext.hh b/server/marlin/marlin_kernels/ext.hh index d1caaab7..102c058e 100644 --- a/server/marlin/marlin_kernels/ext.hh +++ b/server/marlin/marlin_kernels/ext.hh @@ -27,4 +27,9 @@ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, torch::Tensor &b_scales, torch::Tensor &workspace, int64_t size_m, int64_t size_n, int64_t size_k); +torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& workspace, + int64_t num_bits, int64_t size_m, int64_t size_n, + int64_t size_k); + #endif diff --git a/server/marlin/marlin_kernels/fp8_marlin.cu b/server/marlin/marlin_kernels/fp8_marlin.cu new file mode 100644 index 00000000..aaef67e5 --- /dev/null +++ b/server/marlin/marlin_kernels/fp8_marlin.cu @@ -0,0 +1,1308 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#include "./gptq_marlin.cuh" +#include "./gptq_marlin_dtypes.cuh" + +using namespace gptq_marlin; + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert(std::is_same::value || \ + std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +template +inline std::string str(T x) { + return std::to_string(x); +} + +namespace fp8_marlin { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) {} + +} // namespace fp8_marlin + +torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& workspace, + int64_t num_bits, int64_t size_m, int64_t size_n, + int64_t size_k) { + TORCH_CHECK_NOT_IMPLEMENTED(false, + "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); + return torch::empty({1, 1}); +} + +#else + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +template +__device__ inline void mma(const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +template +__device__ inline void ldsm4(typename ScalarType::FragA& frag_a, + const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +// Fast FP8ToFp16/FP8ToBf16: Efficiently dequantize 8bit fp8_e4m3 values to fp16 +// bf16 Reference: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 +template +__device__ inline typename ScalarType::FragB dequant_8bit(int q) { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); +} + +template <> +__device__ inline typename ScalarType::FragB dequant_8bit(int q) { + // Constants for FP8 (E4M3) and FP16 formats + constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, FP16_EXPONENT = 5; + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; + + // Calculate MASK for extracting mantissa and exponent + constexpr int MASK1 = 0x80000000; + constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); + constexpr int MASK3 = MASK2 & 0x7fffffff; + constexpr int MASK = MASK3 | (MASK3 >> 16); + // Final MASK value: 0x7F007F00 + + // Extract and shift FP8 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = + (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + typename ScalarType::FragB frag_b; + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = __hmul2(*reinterpret_cast(&Out1), bias_reg); + frag_b[0] = __hmul2(*reinterpret_cast(&Out2), bias_reg); + return frag_b; +} + +template <> +__device__ inline typename ScalarType::FragB +dequant_8bit(int q) { + // Constants for FP8 (E4M3) and BF16 formats + constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + + // Calculate MASK for extracting mantissa and exponent + constexpr int MASK1 = 0x80000000; + constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); + constexpr int MASK3 = MASK2 & 0x7fffffff; + constexpr int MASK = MASK3 | (MASK3 >> 16); + // Final MASK value: 0x7F007F00 + + // Extract and shift FP8 values to BF16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = + (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent + // position + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const nv_bfloat162 bias_reg = + __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + + // Convert to bfloat162 and apply bias + typename ScalarType::FragB frag_b; + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = __hmul2(*reinterpret_cast(&Out1), bias_reg); + frag_b[0] = __hmul2(*reinterpret_cast(&Out2), bias_reg); + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +template +__device__ inline void scale(typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s = + ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +// Given 2 floats multiply by 2 scales (halves) +template +__device__ inline void scale_float(float* c, + typename ScalarType::FragS& s) { + scalar_t* s_ptr = reinterpret_cast(&s); + c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} + +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + using Dtype = ScalarType; + using scalar_t2 = typename ScalarType::scalar_t2; + using FragA = typename ScalarType::FragA; + using FragB = typename ScalarType::FragB; + using FragC = typename ScalarType::FragC; + using FragS = typename ScalarType::FragS; + + constexpr int pack_factor = 32 / num_bits; + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = div_ceil(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 8; + C += 16 * thread_m_blocks * prob_n / 8; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x * b_thread_vecs; + int b_sh_rd = threadIdx.x * b_thread_vecs; + + // For act_order + int slice_k_start = tb_k * slice_row; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + int s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // We scale a `half2` tile in row-major layout for column-wise quantization. + int s_sh_rd = + 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_g_idx = sh_b + (stages * b_sh_stage); + int4* sh_s = sh_g_idx + (stages * g_idx_stage); + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + constexpr int sh_max_num_groups = 32; + + auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups < sh_max_num_groups) { + sh_num_groups = sh_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); + } + + B_ptr[i] += b_gl_rd_delta_o; + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], + &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + is_same_group[pipe] = false; + same_group_id[pipe] = 0; + return; + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + FragB frag_b0; + FragB frag_b1; + + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + + frag_b0 = dequant_8bit(b_quant_0); + frag_b1 = dequant_8bit(b_quant_1); + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred( + &sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += + Dtype::num2float(reinterpret_cast(&c_red)[j]); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast(&c)[j] = + Dtype::float2num(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = + c; + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + scalar_t2 res = + Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + + ((scalar_t2*)sh)[idx] = res; + }; + + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + C[c_gl_wr] = sh[c_sh_rd]; + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + + #pragma unroll + for (int i = 0; i < stages - 1; i++) { + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + slice_k_start_shared_fetch += tb_k * (stages - 1); + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + + a_gl_rd += a_gl_rd_delta_o * stages; + slice_k_start += tb_k * stages; + slice_k_start_shared_fetch += tb_k * stages; + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + + thread_block_reduce(); + + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float(reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + 0]); + + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + + start_pipes(); + } + } + } +} + + #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, GROUP_BLOCKS, NUM_THREADS) \ + else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute( \ + Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + Marlin \ + <<>>( \ + A_ptr, B_ptr, C_ptr, s_ptr, num_groups, prob_m, prob_n, prob_k, \ + locks); \ + } + +typedef struct { + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +typedef struct { + int max_m_blocks; + thread_config_t tb_cfg; +} exec_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, + {64, 128, 128}, + {128, 64, 128}, +}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, + {64, 128, 128}, + {128, 64, 128}, + +}; + +int get_scales_cache_size(thread_config_t const& th_config, int prob_m, + int prob_n, int prob_k, int num_bits, + int group_size) { + int tb_n = th_config.thread_n; + + // Get max scale groups per thread-block + // Fixed for channelwise + int tb_groups = 1; + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * pipe_stages; +} + +bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int scales_cache_size, int max_shared_mem) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + + int b_size = (tb_k * tb_n / pack_factor) * 4; + + // Get A size + int m_blocks = div_ceil(prob_m, 16); + int tb_max_m = 16; + + while (true) { + if (m_blocks >= max_m_blocks) { + tb_max_m *= max_m_blocks; + break; + } + + max_m_blocks--; + if (max_m_blocks == 0) { + TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); + } + } + + int a_size = (tb_max_m * tb_k) * 2; + + float pipe_size = (a_size + b_size) * pipe_stages; + + TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity + + return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); +} + +bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int group_size, int max_shared_mem) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || + th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + // Determine cache for scales + int scales_cache_size = get_scales_cache_size(th_config, prob_m, prob_n, + prob_k, num_bits, group_size); + + // Check that pipeline fits into cache + if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, scales_cache_size, max_shared_mem)) { + return false; + } + + return true; +} + +exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, + int num_bits, int group_size, + int max_shared_mem) { + int max_m_blocks = 4; + while (max_m_blocks > 0) { + if (prob_m <= 16) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } + } + } else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } + } + } + + max_m_blocks--; // Process less M blocks per invocation to reduce cache + // usage + } + + return exec_config_t{0, {-1, -1, -1}}; +} + + #define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) + +template +void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, int prob_m, + int prob_n, int prob_k, void* workspace, int num_bits, + int num_groups, int group_size, int dev, + cudaStream_t stream, int thread_k, int thread_n, int sms, + int max_par) { + TORCH_CHECK(num_bits == 8, "num_bits must be 8. Got = ", num_bits); + TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, + ", ", prob_n, ", ", prob_k, "]"); + + int tot_m = prob_m; + int tot_m_blocks = div_ceil(tot_m, 16); + int pad = 16 * tot_m_blocks - tot_m; + + if (sms == -1) { + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + } + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + // Set thread config + exec_config_t exec_cfg; + if (thread_k != -1 && thread_n != -1) { + // User-defined config + exec_cfg = + exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}}; + } else { + // Auto config + exec_cfg = determine_thread_config(prob_m, prob_n, prob_k, num_bits, + group_size, max_shared_mem); + } + + TORCH_CHECK( + exec_cfg.max_m_blocks > 0 && + is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, prob_m, + prob_n, prob_k, num_bits, group_size, max_shared_mem), + "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, + ", thread_k = ", exec_cfg.tb_cfg.thread_k, + ", thread_n = ", exec_cfg.tb_cfg.thread_n, + ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", prob_m, + ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", group_size = ", group_size, ", max_shared_mem = ", max_shared_mem); + + int num_threads = exec_cfg.tb_cfg.num_threads; + thread_k = exec_cfg.tb_cfg.thread_k; + thread_n = exec_cfg.tb_cfg.thread_n; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + int blocks = sms; + + TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, + " is not divisible by thread_n = ", thread_n); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + + int group_blocks = -1; + + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + int4* C_ptr = (int4*)C; + const int4* s_ptr = (const int4*)s; + + int* locks = (int*)workspace; + + // Main loop + for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) { + int thread_m_blocks = tot_m_blocks - i; + prob_m = tot_m - 16 * i; + int par = 1; + if (thread_m_blocks > exec_cfg.max_m_blocks) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); + if (par > max_par) par = max_par; + prob_m = (16 * exec_cfg.max_m_blocks) * par; + i += exec_cfg.max_m_blocks * (par - 1); + thread_m_blocks = exec_cfg.max_m_blocks; + } + + // Define kernel configurations + if (false) { + } + CALL_IF(8, 32, 2, 256) + CALL_IF(8, 16, 4, 256) + CALL_IF(8, 8, 8, 256) + CALL_IF(8, 8, 4, 128) + CALL_IF(8, 4, 8, 128) + else { + TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + + str(prob_n) + ", " + str(prob_k) + "]" + + ", num_groups = " + str(num_groups) + + ", group_size = " + str(group_size) + + ", thread_m_blocks = " + str(thread_m_blocks) + + ", thread_n_blocks = " + str(thread_n_blocks) + + ", thread_k_blocks = " + str(thread_k_blocks)); + } + + A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; + C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; + } +} + +} // namespace fp8_marlin + +torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& workspace, + int64_t num_bits, int64_t size_m, int64_t size_n, + int64_t size_k) { + // Verify num_bits + TORCH_CHECK(num_bits == 8, "num_bits must be 8. Got = ", num_bits); + int pack_factor = 32 / num_bits; + + // Verify A + TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), + ", size_m = ", size_m); + TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1), + ", size_k = ", size_k); + + // Verify B + TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k, + " is not divisible by tile_size = ", gptq_marlin::tile_size); + TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0), + "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), + ", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size); + TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0, + "b_q_weight.size(1) = ", b_q_weight.size(1), + " is not divisible by tile_size = ", gptq_marlin::tile_size); + int actual_size_n = + (b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor; + TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, + ", actual_size_n = ", actual_size_n); + + // Verify device and strides + TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); + TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); + + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + + TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); + TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + + // Alloc buffers + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + torch::Tensor c = torch::empty({size_m, size_n}, options); + + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_k = -1; + // thread_n: `n` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_n = -1; + // sms: number of SMs to use for the kernel (can usually be left as auto -1) + int sms = -1; + + // Detect groupsize and act_order + int num_groups = -1; + int group_size = -1; + + int b_rank = b_scales.sizes().size(); + TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2"); + TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1), + " is not size_n = ", size_n); + // Channelwise only for FP8 + TORCH_CHECK(b_scales.size(0) == 1) + num_groups = b_scales.size(0); + + // Verify workspace size + TORCH_CHECK( + size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n, + ", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n); + int min_workspace_size = + (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par; + TORCH_CHECK(workspace.numel() >= min_workspace_size, + "workspace.numel = ", workspace.numel(), + " is below min_workspace_size = ", min_workspace_size); + + int dev = a.get_device(); + if (a.scalar_type() == at::ScalarType::Half) { + fp8_marlin::marlin_mm_f16i4( + a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), + b_scales.data_ptr(), size_m, size_n, size_k, + workspace.data_ptr(), num_bits, num_groups, group_size, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, + gptq_marlin::max_par); + } else if (a.scalar_type() == at::ScalarType::BFloat16) { + fp8_marlin::marlin_mm_f16i4( + a.data_ptr(), b_q_weight.data_ptr(), + c.data_ptr(), b_scales.data_ptr(), size_m, + size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size, + dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, + gptq_marlin::max_par); + } else { + TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16"); + } + + return c; +} + +#endif diff --git a/server/marlin/setup.py b/server/marlin/setup.py index aed84e9e..cc38bccf 100644 --- a/server/marlin/setup.py +++ b/server/marlin/setup.py @@ -9,6 +9,7 @@ setup( CUDAExtension( name="marlin_kernels", sources=[ + "marlin_kernels/fp8_marlin.cu", "marlin_kernels/gptq_marlin.cu", "marlin_kernels/gptq_marlin_repack.cu", "marlin_kernels/marlin_cuda_kernel.cu", diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index dd61d081..b76af8f1 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -1,4 +1,23 @@ +from enum import Enum, auto + import torch +from text_generation_server.utils.import_utils import SYSTEM + + +def get_fp8_linear() -> torch.nn.Module: + """ + Return an FP8 linear `Module` that is compatible with the current system. + """ + + if SYSTEM == "cuda": + major, minor = torch.cuda.get_device_capability() + if major == 8 and minor < 9: + from text_generation_server.layers.marlin import GPTQMarlinFP8Linear + + return GPTQMarlinFP8Linear + + # On other systems let Torch decide if the hardware supports FP8. + return Fp8Linear def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index e94e5465..babd86b0 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -106,9 +106,9 @@ def get_linear(weight, bias, quantize): "Please install EETQ from https://github.com/NetEase-FuXi/EETQ" ) elif quantize == "fp8": - from text_generation_server.layers.fp8 import Fp8Linear + from text_generation_server.layers.fp8 import get_fp8_linear - linear = Fp8Linear(weight, bias) + linear = get_fp8_linear()(weight, bias) elif quantize == "bitsandbytes": try: from text_generation_server.layers.bnb import ( diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py index ecb88e76..9777a47e 100644 --- a/server/text_generation_server/layers/marlin.py +++ b/server/text_generation_server/layers/marlin.py @@ -1,11 +1,13 @@ from dataclasses import dataclass from typing import List, Optional, Tuple, Union -from text_generation_server.utils.weights import Weights, WeightsLoader import torch import torch.nn as nn - +from loguru import logger +from text_generation_server.layers.fp8 import fp8_quantize from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.log import log_once +from text_generation_server.utils.weights import Weights, WeightsLoader try: import marlin_kernels @@ -455,6 +457,115 @@ class GPTQMarlin24Linear(nn.Module): return C +class GPTQMarlinFP8Linear(nn.Module): + """ + FP8 GPTQ-Marlin linear layer. + """ + + def __init__( + self, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + ) -> None: + super().__init__() + + _check_marlin_kernels() + assert marlin_kernels is not None + + log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel") + + qweight, scale = fp8_quantize(weight) + scale = scale.to(torch.float16) + qweight, scales = repack_fp8_for_marlin(qweight, scale) + + in_features = qweight.shape[0] * MARLIN_TILE_SIZE + out_features = scales.shape[1] + _check_valid_shape(in_features=in_features, out_features=out_features) + + self.qweight = qweight + self.scales = scales + self.bias = bias if bias is not None else None + + self.workspace = torch.zeros( + out_features // 64 * 16, dtype=torch.int, device=qweight.device + ) + + def forward(self, A: torch.Tensor) -> torch.Tensor: + assert marlin_kernels is not None + + A_flat = A.view(-1, A.shape[-1]) + C = marlin_kernels.fp8_marlin_gemm( + A_flat, + self.qweight, + self.scales, + self.workspace, + 8, + A_flat.shape[0], + self.scales.shape[1], + A_flat.shape[1], + ) + C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) + + if self.bias is not None: + C += self.bias + + return C + + +def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: + """ + Repack FP8 weights to gptq format (packed int32 elements). + """ + assert fp8_tensor.dtype == torch.float8_e4m3fn + + if fp8_tensor.shape[0] % 4 != 0: + raise ValueError( + f"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}" + ) + + # Reshape to prepare for packing + reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + + # Convert fp8 to uint8 (byte) representation + byte_tensor = reshaped.view(torch.uint8) + + # Pack 4 uint8 values into one int32 + packed = torch.zeros( + fp8_tensor.shape[0] // 4, + fp8_tensor.shape[1], + dtype=torch.int32, + device=fp8_tensor.device, + ) + + for i in range(4): + packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8) + + return packed + + +def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor): + """ + Repack FP8 tensor for GPTQ-Marlin. + """ + + out_features, in_features = weight.shape + + # Torch linear layers weights with shape [out_features, in_features], + # GPTQ-quantized weights use [in_feateres/pack_factor, in_features], + # so transpose before packing. + qweight = pack_fp8_as_int32(weight.t()) + + perm = torch.empty(0, dtype=torch.int, device=qweight.device) + repacked = marlin_kernels.gptq_marlin_repack( + qweight, perm, in_features, out_features, 8 + ) + + scales = scale.reshape(1, 1).repeat(1, out_features) + scales = permute_scales(scales) + + return repacked, scales + + @dataclass class MarlinWeight: """ From d789de329a087301d651ee943e0d76e0dbf5ced5 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 11 Jul 2024 10:42:58 -0400 Subject: [PATCH 22/24] fix: append DONE message to chat stream (#2221) * fix: append DONE message to chat stream * fix: update completions endpoint --- router/src/server.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/router/src/server.rs b/router/src/server.rs index 4e5af99c..d3a280ca 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -812,6 +812,10 @@ async fn completions( } }; + let stream = stream.chain(futures::stream::once(async { + Ok(Event::default().data("[DONE]")) + })); + let sse = Sse::new(stream).keep_alive(KeepAlive::default()); Ok((headers, sse).into_response()) } else { @@ -1171,6 +1175,11 @@ async fn chat_completions( span, ) .await; + + let response_stream = response_stream.chain(futures::stream::once(async { + Ok(Event::default().data("[DONE]")) + })); + let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); Ok((headers, sse).into_response()) } else { From c46eaf707b6a45860b04d37351884b25c4c63772 Mon Sep 17 00:00:00 2001 From: SeongBeomLEE <2712qwer@gmail.com> Date: Fri, 12 Jul 2024 17:04:51 +0900 Subject: [PATCH 23/24] [fix] Modifying base in yarn embedding (#2212) --- server/text_generation_server/layers/rotary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 87a61e82..8c354b82 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -102,7 +102,7 @@ class PositionRotaryEmbedding(nn.Module): max_position_embeddings=rope_scaling[ "original_max_position_embeddings" ], - base=10000.0, + base=base, device=inv_freq.device, scaling_factor=scaling_factor, extrapolation_factor=1, From dbb23fbfa868ad8f961c60896e346fad3d2ab440 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 12 Jul 2024 12:20:12 +0200 Subject: [PATCH 24/24] Use symmetric quantization in the `quantize` subcommand (#2120) Packing of asymmetric quantization is broken, all (q)zeros values of `0` get reset to `1`, resulting in a loss of accuracy. So instead use symmetric quantization. To be able to distinguish models with symmetric and asymmetric quantization, a new config tensor `gptq_sym` is added. If this tensor is not present, we assume `sym=False`. --- server/text_generation_server/cli.py | 1 + .../text_generation_server/layers/gptq/__init__.py | 12 ++++++++---- .../text_generation_server/layers/gptq/quantize.py | 3 +++ server/text_generation_server/utils/weights.py | 7 +++++++ 4 files changed, 19 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 68ae95dd..71ad18f7 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -353,6 +353,7 @@ def quantize( upload_to_model_id=upload_to_model_id, percdamp=percdamp, act_order=act_order, + sym=True, ) diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index efcb3118..aaa7a68a 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -393,11 +393,15 @@ class GPTQWeightsLoader(WeightsLoader): ) def _get_gptq_params(self, weights: Weights): - try: + if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"): self.bits = weights.get_tensor("gptq_bits").item() self.groupsize = weights.get_tensor("gptq_groupsize").item() self.desc_act = False - self.sym = False + # `server quantize` used asymmetric quantization unconditionally + # before the `gptq_sym` setting tensor was added. + self.sym = ( + weights.get_tensor("gptq_sym").item() + if weights._has_tensor("gptq_sym") + else False + ) self.quant_method = "gptq" - except (SafetensorError, RuntimeError) as e: - pass diff --git a/server/text_generation_server/layers/gptq/quantize.py b/server/text_generation_server/layers/gptq/quantize.py index c65d5e78..0271d913 100644 --- a/server/text_generation_server/layers/gptq/quantize.py +++ b/server/text_generation_server/layers/gptq/quantize.py @@ -871,6 +871,7 @@ def quantize( upload_to_model_id: Optional[str], percdamp: float, act_order: bool, + sym: bool, ): print("loading model") config = AutoConfig.from_pretrained( @@ -946,6 +947,7 @@ def quantize( percdamp=percdamp, act_order=act_order, hooks=hooks, + sym=sym, ) print(time.time() - tick) @@ -957,6 +959,7 @@ def quantize( state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} state_dict["gptq_bits"] = torch.LongTensor([bits]) state_dict["gptq_groupsize"] = torch.LongTensor([groupsize]) + state_dict["gptq_sym"] = torch.BoolTensor([sym]) max_shard_size = "10GB" shards, index = shard_checkpoint( diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 1a62fb3b..50a9167a 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -146,6 +146,13 @@ class Weights: slice_ = f.get_slice(tensor_name) return slice_ + def _has_tensor(self, tensor_name: str): + try: + self.get_filename(tensor_name) + except Exception: + return False + return True + def get_shape(self, tensor_name: str): return self._get_slice(tensor_name).get_shape()