From 725f0e350d6627af8ab304c7efac21501ffb3666 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 28 Feb 2024 14:58:43 +0100 Subject: [PATCH] add speculative head --- .../test_flash_qwen2/test_flash_qwen2.json | 96 ++--- .../test_flash_qwen2_all_params.json | 94 ++--- .../test_flash_qwen2_load.json | 376 +++++++++--------- integration-tests/models/test_flash_qwen2.py | 10 +- .../text_generation_server/models/__init__.py | 85 ++-- .../models/causal_lm.py | 3 + .../models/custom_modeling/bloom_modeling.py | 2 +- .../custom_modeling/flash_qwen2_modeling.py | 20 +- .../models/custom_modeling/neox_modeling.py | 17 +- .../models/custom_modeling/opt_modeling.py | 17 +- .../models/flash_mistral.py | 2 +- .../models/flash_qwen2.py | 43 +- .../models/flash_starcoder2.py | 2 +- .../models/galactica.py | 6 +- .../text_generation_server/models/gpt_neox.py | 6 +- server/text_generation_server/models/rw.py | 4 + .../models/seq2seq_lm.py | 3 + 17 files changed, 431 insertions(+), 355 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2.json b/integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2.json index 2e3906ce..7219f9e6 100644 --- a/integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2.json +++ b/integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2.json @@ -5,80 +5,80 @@ "generated_tokens": 10, "prefill": [ { - "id": 2271, - "text": "Test", - "logprob": null + "id": 2271, + "logprob": null, + "text": "Test" }, { - "id": 1681, - "text": " request", - "logprob": -7.0351562 + "id": 1681, + "logprob": -8.8515625, + "text": " request" } ], "seed": null, "tokens": [ { - "id": 369, - "text": " for", - "logprob": -2.1914062, - "special": false + "id": 198, + "logprob": -2.9023438, + "special": false, + "text": "\n" }, { - "id": 279, - "text": " the", - "logprob": -2.6210938, - "special": false + "id": 2, + "logprob": -2.9160156, + "special": false, + "text": "#" }, { - "id": 2701, - "text": " following", - "logprob": -3.6445312, - "special": false + "id": 4230, + "logprob": -3.1035156, + "special": false, + "text": " Create" }, { - "id": 729, - "text": " function", - "logprob": -2.9648438, - "special": false + "id": 264, + "logprob": -1.1025391, + "special": false, + "text": " a" }, { - "id": 271, - "text": "\n\n", - "logprob": -1.9111328, - "special": false + "id": 1681, + "logprob": -1.6914062, + "special": false, + "text": " request" }, { - "id": 31946, - "text": "Inputs", - "logprob": -1.6855469, - "special": false + "id": 198, + "logprob": -1.1953125, + "special": false, + "text": "\n" }, { - "id": 25, - "text": ":", - "logprob": -1.6093254e-05, - "special": false + "id": 2035, + "logprob": -1.3203125, + "special": false, + "text": "request" }, { - "id": 707, - "text": " def", - "logprob": -0.5678711, - "special": false + "id": 284, + "logprob": -0.13537598, + "special": false, + "text": " =" }, { - "id": 1477, - "text": " find", - "logprob": -2.5917969, - "special": false + "id": 7388, + "logprob": -1.2402344, + "special": false, + "text": " requests" }, { - "id": 6345, - "text": "_max", - "logprob": -1.8349609, - "special": false + "id": 670, + "logprob": -0.2775879, + "special": false, + "text": ".get" } - ], + ], "top_tokens": null }, - "generated_text": " for the following function\n\nInputs: def find_max" + "generated_text": "\n# Create a request\nrequest = requests.get" } diff --git a/integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2_all_params.json b/integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2_all_params.json index bdaab6f2..4a2936af 100644 --- a/integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2_all_params.json @@ -5,80 +5,80 @@ "generated_tokens": 10, "prefill": [ { - "id": 2271, - "text": "Test", - "logprob": null + "id": 2271, + "logprob": null, + "text": "Test" }, { - "id": 1681, - "text": " request", - "logprob": -7.0351562 + "id": 1681, + "logprob": -8.8515625, + "text": " request" } ], "seed": 0, "tokens": [ { - "id": 311, - "text": " to", - "logprob": -1.4472656, - "special": false + "id": 311, + "logprob": -1.4277344, + "special": false, + "text": " to" }, { - "id": 633, - "text": " get", - "logprob": -0.4741211, - "special": false + "id": 279, + "logprob": -0.65478516, + "special": false, + "text": " the" }, { - "id": 264, - "text": " a", - "logprob": 0.0, - "special": false + "id": 2473, + "logprob": -1.8300781, + "special": false, + "text": " service" }, { - "id": 1140, - "text": " list", - "logprob": 0.0, - "special": false + "id": 382, + "logprob": -0.75, + "special": false, + "text": ".\n\n" }, { - "id": 315, - "text": " of", - "logprob": 0.0, - "special": false + "id": 286, + "logprob": -0.11621094, + "special": false, + "text": " " }, { - "id": 678, - "text": " all", - "logprob": 0.0, - "special": false + "id": 549, + "logprob": 0.0, + "special": false, + "text": " :" }, { - "id": 279, - "text": " the", - "logprob": -0.2590332, - "special": false + "id": 689, + "logprob": -0.48608398, + "special": false, + "text": "return" }, { - "id": 3847, - "text": " users", - "logprob": -0.45239258, - "special": false + "id": 25, + "logprob": 0.0, + "special": false, + "text": ":" }, { - "id": 304, - "text": " in", - "logprob": -0.12322998, - "special": false + "id": 5949, + "logprob": -0.5756836, + "special": false, + "text": " Response" }, { - "id": 419, - "text": " this", - "logprob": -1.7275391, - "special": false + "id": 504, + "logprob": -0.24499512, + "special": false, + "text": " from" } ], "top_tokens": null }, - "generated_text": "Test request to get a list of all the users in this" + "generated_text": "Test request to the service.\n\n :return: Response from" } diff --git a/integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2_load.json b/integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2_load.json index 6f0b21e9..4786ff24 100644 --- a/integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2_load.json +++ b/integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2_load.json @@ -6,82 +6,82 @@ "generated_tokens": 10, "prefill": [ { - "id": 2271, - "text": "Test", - "logprob": null + "id": 2271, + "logprob": null, + "text": "Test" }, { - "id": 1681, - "text": " request", - "logprob": -7.0351562 + "id": 1681, + "logprob": -8.8515625, + "text": " request" } ], "seed": null, "tokens": [ { - "id": 369, - "text": " for", - "logprob": -2.1914062, - "special": false + "id": 198, + "logprob": -2.9023438, + "special": false, + "text": "\n" }, { - "id": 279, - "text": " the", - "logprob": -2.6210938, - "special": false + "id": 2, + "logprob": -2.9140625, + "special": false, + "text": "#" }, { - "id": 2701, - "text": " following", - "logprob": -3.6445312, - "special": false + "id": 4230, + "logprob": -3.1054688, + "special": false, + "text": " Create" }, { - "id": 729, - "text": " function", - "logprob": -2.9648438, - "special": false + "id": 264, + "logprob": -1.0966797, + "special": false, + "text": " a" }, { - "id": 271, - "text": "\n\n", - "logprob": -1.9111328, - "special": false + "id": 1681, + "logprob": -1.6914062, + "special": false, + "text": " request" }, { - "id": 31946, - "text": "Inputs", - "logprob": -1.6855469, - "special": false + "id": 198, + "logprob": -1.1923828, + "special": false, + "text": "\n" }, { - "id": 25, - "text": ":", - "logprob": -1.6093254e-05, - "special": false + "id": 2035, + "logprob": -1.3193359, + "special": false, + "text": "request" }, { - "id": 707, - "text": " def", - "logprob": -0.5678711, - "special": false + "id": 284, + "logprob": -0.13586426, + "special": false, + "text": " =" }, { - "id": 1477, - "text": " find", - "logprob": -2.5917969, - "special": false + "id": 7388, + "logprob": -1.2412109, + "special": false, + "text": " requests" }, { - "id": 6345, - "text": "_max", - "logprob": -1.8349609, - "special": false + "id": 670, + "logprob": -0.2775879, + "special": false, + "text": ".get" } ], "top_tokens": null }, - "generated_text": " for the following function\n\nInputs: def find_max" + "generated_text": "\n# Create a request\nrequest = requests.get" }, { "details": { @@ -90,82 +90,82 @@ "generated_tokens": 10, "prefill": [ { - "id": 2271, - "text": "Test", - "logprob": null + "id": 2271, + "logprob": null, + "text": "Test" }, { - "id": 1681, - "text": " request", - "logprob": -7.0351562 + "id": 1681, + "logprob": -8.8515625, + "text": " request" } ], "seed": null, "tokens": [ { - "id": 369, - "text": " for", - "logprob": -2.1914062, - "special": false + "id": 198, + "logprob": -2.9023438, + "special": false, + "text": "\n" }, { - "id": 279, - "text": " the", - "logprob": -2.6210938, - "special": false + "id": 2, + "logprob": -2.9140625, + "special": false, + "text": "#" }, { - "id": 2701, - "text": " following", - "logprob": -3.6445312, - "special": false + "id": 4230, + "logprob": -3.1054688, + "special": false, + "text": " Create" }, { - "id": 729, - "text": " function", - "logprob": -2.9648438, - "special": false + "id": 264, + "logprob": -1.0966797, + "special": false, + "text": " a" }, { - "id": 271, - "text": "\n\n", - "logprob": -1.9111328, - "special": false + "id": 1681, + "logprob": -1.6914062, + "special": false, + "text": " request" }, { - "id": 31946, - "text": "Inputs", - "logprob": -1.6855469, - "special": false + "id": 198, + "logprob": -1.1923828, + "special": false, + "text": "\n" }, { - "id": 25, - "text": ":", - "logprob": -1.6093254e-05, - "special": false + "id": 2035, + "logprob": -1.3193359, + "special": false, + "text": "request" }, { - "id": 707, - "text": " def", - "logprob": -0.5678711, - "special": false + "id": 284, + "logprob": -0.13586426, + "special": false, + "text": " =" }, { - "id": 1477, - "text": " find", - "logprob": -2.5917969, - "special": false + "id": 7388, + "logprob": -1.2412109, + "special": false, + "text": " requests" }, { - "id": 6345, - "text": "_max", - "logprob": -1.8349609, - "special": false + "id": 670, + "logprob": -0.2775879, + "special": false, + "text": ".get" } ], "top_tokens": null }, - "generated_text": " for the following function\n\nInputs: def find_max" + "generated_text": "\n# Create a request\nrequest = requests.get" }, { "details": { @@ -174,82 +174,82 @@ "generated_tokens": 10, "prefill": [ { - "id": 2271, - "text": "Test", - "logprob": null + "id": 2271, + "logprob": null, + "text": "Test" }, { - "id": 1681, - "text": " request", - "logprob": -7.0351562 + "id": 1681, + "logprob": -8.8515625, + "text": " request" } ], "seed": null, "tokens": [ { - "id": 369, - "text": " for", - "logprob": -2.1914062, - "special": false + "id": 198, + "logprob": -2.9023438, + "special": false, + "text": "\n" }, { - "id": 279, - "text": " the", - "logprob": -2.6210938, - "special": false + "id": 2, + "logprob": -2.9140625, + "special": false, + "text": "#" }, { - "id": 2701, - "text": " following", - "logprob": -3.6445312, - "special": false + "id": 4230, + "logprob": -3.1054688, + "special": false, + "text": " Create" }, { - "id": 729, - "text": " function", - "logprob": -2.9648438, - "special": false + "id": 264, + "logprob": -1.0966797, + "special": false, + "text": " a" }, { - "id": 271, - "text": "\n\n", - "logprob": -1.9111328, - "special": false + "id": 1681, + "logprob": -1.6914062, + "special": false, + "text": " request" }, { - "id": 31946, - "text": "Inputs", - "logprob": -1.6855469, - "special": false + "id": 198, + "logprob": -1.1923828, + "special": false, + "text": "\n" }, { - "id": 25, - "text": ":", - "logprob": -1.6093254e-05, - "special": false + "id": 2035, + "logprob": -1.3193359, + "special": false, + "text": "request" }, { - "id": 707, - "text": " def", - "logprob": -0.5678711, - "special": false + "id": 284, + "logprob": -0.13586426, + "special": false, + "text": " =" }, { - "id": 1477, - "text": " find", - "logprob": -2.5917969, - "special": false + "id": 7388, + "logprob": -1.2412109, + "special": false, + "text": " requests" }, { - "id": 6345, - "text": "_max", - "logprob": -1.8349609, - "special": false + "id": 670, + "logprob": -0.2775879, + "special": false, + "text": ".get" } ], "top_tokens": null }, - "generated_text": " for the following function\n\nInputs: def find_max" + "generated_text": "\n# Create a request\nrequest = requests.get" }, { "details": { @@ -258,81 +258,81 @@ "generated_tokens": 10, "prefill": [ { - "id": 2271, - "text": "Test", - "logprob": null + "id": 2271, + "logprob": null, + "text": "Test" }, { - "id": 1681, - "text": " request", - "logprob": -7.0351562 + "id": 1681, + "logprob": -8.8515625, + "text": " request" } ], "seed": null, "tokens": [ { - "id": 369, - "text": " for", - "logprob": -2.1914062, - "special": false + "id": 198, + "logprob": -2.9023438, + "special": false, + "text": "\n" }, { - "id": 279, - "text": " the", - "logprob": -2.6210938, - "special": false + "id": 2, + "logprob": -2.9140625, + "special": false, + "text": "#" }, { - "id": 2701, - "text": " following", - "logprob": -3.6445312, - "special": false + "id": 4230, + "logprob": -3.1054688, + "special": false, + "text": " Create" }, { - "id": 729, - "text": " function", - "logprob": -2.9648438, - "special": false + "id": 264, + "logprob": -1.0966797, + "special": false, + "text": " a" }, { - "id": 271, - "text": "\n\n", - "logprob": -1.9111328, - "special": false + "id": 1681, + "logprob": -1.6914062, + "special": false, + "text": " request" }, { - "id": 31946, - "text": "Inputs", - "logprob": -1.6855469, - "special": false + "id": 198, + "logprob": -1.1923828, + "special": false, + "text": "\n" }, { - "id": 25, - "text": ":", - "logprob": -1.6093254e-05, - "special": false + "id": 2035, + "logprob": -1.3193359, + "special": false, + "text": "request" }, { - "id": 707, - "text": " def", - "logprob": -0.5678711, - "special": false + "id": 284, + "logprob": -0.13586426, + "special": false, + "text": " =" }, { - "id": 1477, - "text": " find", - "logprob": -2.5917969, - "special": false + "id": 7388, + "logprob": -1.2412109, + "special": false, + "text": " requests" }, { - "id": 6345, - "text": "_max", - "logprob": -1.8349609, - "special": false + "id": 670, + "logprob": -0.2775879, + "special": false, + "text": ".get" } ], "top_tokens": null }, - "generated_text": " for the following function\n\nInputs: def find_max" + "generated_text": "\n# Create a request\nrequest = requests.get" } ] diff --git a/integration-tests/models/test_flash_qwen2.py b/integration-tests/models/test_flash_qwen2.py index e07ed553..2963aeb4 100644 --- a/integration-tests/models/test_flash_qwen2.py +++ b/integration-tests/models/test_flash_qwen2.py @@ -3,7 +3,7 @@ import pytest @pytest.fixture(scope="module") def flash_qwen2_handle(launcher): - with launcher("Qwen/Qwen1.5-7B") as handle: + with launcher("Qwen/Qwen1.5-0.5B") as handle: yield handle @@ -20,7 +20,7 @@ async def test_flash_qwen2(flash_qwen2, response_snapshot): ) assert response.details.generated_tokens == 10 - assert response.generated_text == " for the following function\n\nInputs: def find_max" + assert response.generated_text == "\n# Create a request\nrequest = requests.get" assert response == response_snapshot @@ -48,14 +48,12 @@ async def test_flash_qwen2_all_params(flash_qwen2, response_snapshot): @pytest.mark.asyncio async def test_flash_qwen2_load(flash_qwen2, generate_load, response_snapshot): - responses = await generate_load( - flash_qwen2, "Test request", max_new_tokens=10, n=4 - ) + responses = await generate_load(flash_qwen2, "Test request", max_new_tokens=10, n=4) assert len(responses) == 4 assert all( [r.generated_text == responses[0].generated_text for r in responses] ), f"{[r.generated_text for r in responses]}" - assert responses[0].generated_text == ": Let n = 10 - 1" + assert responses[0].generated_text == "\n# Create a request\nrequest = requests.get" assert responses == response_snapshot diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 32c763eb..e7b0b9e2 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -332,27 +332,6 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - elif model_type == "qwen2": - if FLASH_ATTENTION: - return FlashQwen2( - model_id, - revision, - quantize=quantize, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - elif sharded: - raise NotImplementedError( - FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2") - ) - else: - return CausalLM( - model_id, - revision, - quantize=quantize, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) if model_type == "gemma": if FLASH_ATTENTION: return FlashGemma( @@ -364,9 +343,7 @@ def get_model( trust_remote_code=trust_remote_code, ) elif sharded: - raise NotImplementedError( - FLASH_ATT_ERROR_MESSAGE.format("Sharded Golden Gate") - ) + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma")) else: return CausalLM( model_id, @@ -424,6 +401,17 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) + elif sharded: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral")) + else: + return CausalLM( + model_id, + revision, + quantize=quantize, + use_medusa=use_medusa, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) if model_type == "mixtral": sliding_window = config_dict.get("sliding_window", -1) @@ -438,6 +426,18 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) + elif sharded: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral")) + else: + return CausalLM( + model_id, + revision, + quantize=quantize, + use_medusa=use_medusa, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + if model_type == "starcoder2": sliding_window = config_dict.get("sliding_window", -1) if ( @@ -450,6 +450,43 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) + elif sharded: + raise NotImplementedError( + FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2") + ) + else: + return CausalLM( + model_id, + revision, + quantize=quantize, + use_medusa=use_medusa, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + + if model_type == "qwen2": + sliding_window = config_dict.get("sliding_window", -1) + if ( + (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION + ) or HAS_FLASH_ATTN_V2_CUDA: + return FlashQwen2( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + elif sharded: + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2")) + else: + return CausalLM( + model_id, + revision, + quantize=quantize, + use_medusa=use_medusa, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) if model_type == "opt": return OPTSharded( diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index bbcef210..93ec6ba4 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -486,6 +486,9 @@ class CausalLM(Model): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + if use_medusa: + raise RuntimeError("Medusa decoding is not enabled for AutoModel") + if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 if dtype is None else dtype diff --git a/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/server/text_generation_server/models/custom_modeling/bloom_modeling.py index 10b40483..c8f02bca 100644 --- a/server/text_generation_server/models/custom_modeling/bloom_modeling.py +++ b/server/text_generation_server/models/custom_modeling/bloom_modeling.py @@ -870,7 +870,7 @@ class BloomForCausalLM(BloomPreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **deprecated_arguments, - ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 42a2cbde..94023b33 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -11,7 +11,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, PositionRotaryEmbedding, - TensorParallelHead, + SpeculativeHead, get_linear, FastRMSNorm, ) @@ -51,8 +51,14 @@ def _load_gqa(config, prefix: str, weights): config.hidden_size, ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + w = [ + weights.get_sharded(f"{p}.bias", dim=0) + for p in [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"] + ] + bias = torch.cat(w, dim=0).to(dtype=weights.dtype).to(device=weights.device) + return TensorParallelColumnLinear( - get_linear(weight, bias=None, quantize=config.quantize) + get_linear(weight, bias=bias, quantize=config.quantize) ) @@ -170,6 +176,7 @@ class Qwen2Attention(torch.nn.Module): return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + class Qwen2MLP(nn.Module): def __init__(self, prefix, config, weights): super().__init__() @@ -212,7 +219,9 @@ class Qwen2Layer(nn.Module): def __init__(self, layer_id, config, weights): super().__init__() prefix = f"model.layers.{layer_id}" - self.self_attn = Qwen2Attention(prefix=f"{prefix}.self_attn", config=config, weights=weights) + self.self_attn = Qwen2Attention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) self.mlp = Qwen2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.input_layernorm = FastRMSNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps @@ -262,6 +271,7 @@ class Qwen2Layer(nn.Module): return mlp_output, attn_res + class Qwen2Model(torch.nn.Module): def __init__(self, config, weights): super().__init__() @@ -286,7 +296,7 @@ class Qwen2Model(torch.nn.Module): ) self.gradient_checkpointing = False - + self.head_size = self.layers[0].self_attn.head_size self.num_heads = self.layers[0].self_attn.num_heads self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads @@ -338,7 +348,7 @@ class Qwen2ForCausalLM(torch.nn.Module): super().__init__() self.model = Qwen2Model(config, weights) - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="lm_head", weights=weights, diff --git a/server/text_generation_server/models/custom_modeling/neox_modeling.py b/server/text_generation_server/models/custom_modeling/neox_modeling.py index 2550d2d1..1b060060 100644 --- a/server/text_generation_server/models/custom_modeling/neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/neox_modeling.py @@ -721,7 +721,7 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel): ) hidden_states = outputs[0] - lm_logits = self.embed_out(hidden_states) + lm_logits, speculative_logits = self.embed_out(hidden_states) lm_loss = None if labels is not None: @@ -739,12 +739,15 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel): output = (lm_logits,) + outputs[1:] return ((lm_loss,) + output) if lm_loss is not None else output - return CausalLMOutputWithPast( - loss=lm_loss, - logits=lm_logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, + return ( + CausalLMOutputWithPast( + loss=lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ), + speculative_logits, ) def prepare_inputs_for_generation( diff --git a/server/text_generation_server/models/custom_modeling/opt_modeling.py b/server/text_generation_server/models/custom_modeling/opt_modeling.py index de5e95af..7a5cf917 100644 --- a/server/text_generation_server/models/custom_modeling/opt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/opt_modeling.py @@ -792,16 +792,19 @@ class OPTForCausalLM(OPTPreTrainedModel): return_dict=return_dict, ) - logits = self.lm_head(outputs[0]).contiguous() + logits, speculative_logits = self.lm_head(outputs) loss = None - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, + return ( + CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ), + speculative_logits, ) def prepare_inputs_for_generation( diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index fd5c18e0..8149c1b0 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -315,7 +315,7 @@ class BaseFlashMistral(FlashCausalLM): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype else: - raise NotImplementedError("FlashLlama is only available on GPU") + raise NotImplementedError("FlashMistral is only available on GPU") tokenizer = LlamaTokenizerFast.from_pretrained( model_id, diff --git a/server/text_generation_server/models/flash_qwen2.py b/server/text_generation_server/models/flash_qwen2.py index 791ebecd..c3c63516 100644 --- a/server/text_generation_server/models/flash_qwen2.py +++ b/server/text_generation_server/models/flash_qwen2.py @@ -1,12 +1,17 @@ +import math + import torch import torch.distributed from opentelemetry import trace -from transformers import AutoTokenizer from transformers.models.qwen2 import Qwen2Tokenizer from typing import Optional -from text_generation_server.models import FlashCausalLM +from text_generation_server.models.cache_manager import BLOCK_SIZE +from text_generation_server.models.flash_mistral import ( + BaseFlashMistral, + set_sliding_window, +) from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( Qwen2ForCausalLM, ) @@ -20,12 +25,13 @@ from text_generation_server.utils import ( tracer = trace.get_tracer(__name__) -class FlashQwen2(FlashCausalLM): +class FlashQwen2(BaseFlashMistral): def __init__( self, model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -36,23 +42,25 @@ class FlashQwen2(FlashCausalLM): else: raise NotImplementedError("FlashQwen2 is only available on GPU") - try: - tokenizer = Qwen2Tokenizer.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - ) - except Exception: - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - ) + tokenizer = Qwen2Tokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) config = Qwen2Config.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize + config.use_medusa = use_medusa + + # Set context windows + if config.sliding_window is not None: + set_sliding_window( + config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE) + ) torch.distributed.barrier(group=self.process_group) @@ -63,8 +71,10 @@ class FlashQwen2(FlashCausalLM): model = Qwen2ForCausalLM(config, weights) + self.cuda_graphs = {} + torch.distributed.barrier(group=self.process_group) - super(FlashQwen2, self).__init__( + super(BaseFlashMistral, self).__init__( model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), @@ -74,4 +84,5 @@ class FlashQwen2(FlashCausalLM): device=device, rank=rank, world_size=world_size, + sliding_window=config.sliding_window, ) diff --git a/server/text_generation_server/models/flash_starcoder2.py b/server/text_generation_server/models/flash_starcoder2.py index 2f6ae757..68e726d8 100644 --- a/server/text_generation_server/models/flash_starcoder2.py +++ b/server/text_generation_server/models/flash_starcoder2.py @@ -38,7 +38,7 @@ class FlashStarcoder2(BaseFlashMistral): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype else: - raise NotImplementedError("FlashLlama is only available on GPU") + raise NotImplementedError("FlashStarcoder2 is only available on GPU") tokenizer = GPT2TokenizerFast.from_pretrained( model_id, diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 3607c285..a46f86be 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -167,6 +167,7 @@ class GalacticaSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -194,6 +195,7 @@ class GalacticaSharded(CausalLM): ) config.quantize = quantize tokenizer.pad_token_id = config.pad_token_id + config.use_medusa = use_medusa torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") @@ -229,10 +231,10 @@ class GalacticaSharded(CausalLM): def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): - outputs = self.model.forward( + outputs, speculative_logits = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, use_cache=True, ) - return outputs.logits, outputs.past_key_values + return outputs.logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index 45df4839..1c4cfe7d 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -24,6 +24,7 @@ class GPTNeoxSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -50,6 +51,7 @@ class GPTNeoxSharded(CausalLM): trust_remote_code=trust_remote_code, ) config.quantize = quantize + config.use_medusa = use_medusa torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") @@ -75,7 +77,7 @@ class GPTNeoxSharded(CausalLM): def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): - outputs = self.model.forward( + outputs, speculative_logits = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -84,4 +86,4 @@ class GPTNeoxSharded(CausalLM): ) logits = outputs.logits - return logits, outputs.past_key_values + return logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/rw.py b/server/text_generation_server/models/rw.py index 22ab093e..92c93542 100644 --- a/server/text_generation_server/models/rw.py +++ b/server/text_generation_server/models/rw.py @@ -12,9 +12,13 @@ class RW(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + if use_medusa: + raise RuntimeError("Medusa decoding is not enabled for AutoModel") + if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 if dtype is None else dtype diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index fae9a2df..e55a661c 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -536,6 +536,9 @@ class Seq2SeqLM(Model): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): + if use_medusa: + raise RuntimeError("Medusa decoding is not enabled for AutoModel") + if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 if dtype is None else dtype