add speculative head

This commit is contained in:
OlivierDehaene 2024-02-28 14:58:43 +01:00
parent d1d757e676
commit 725f0e350d
17 changed files with 431 additions and 355 deletions

View File

@ -6,79 +6,79 @@
"prefill": [ "prefill": [
{ {
"id": 2271, "id": 2271,
"text": "Test", "logprob": null,
"logprob": null "text": "Test"
}, },
{ {
"id": 1681, "id": 1681,
"text": " request", "logprob": -8.8515625,
"logprob": -7.0351562 "text": " request"
} }
], ],
"seed": null, "seed": null,
"tokens": [ "tokens": [
{ {
"id": 369, "id": 198,
"text": " for", "logprob": -2.9023438,
"logprob": -2.1914062, "special": false,
"special": false "text": "\n"
}, },
{ {
"id": 279, "id": 2,
"text": " the", "logprob": -2.9160156,
"logprob": -2.6210938, "special": false,
"special": false "text": "#"
}, },
{ {
"id": 2701, "id": 4230,
"text": " following", "logprob": -3.1035156,
"logprob": -3.6445312, "special": false,
"special": false "text": " Create"
}, },
{ {
"id": 729, "id": 264,
"text": " function", "logprob": -1.1025391,
"logprob": -2.9648438, "special": false,
"special": false "text": " a"
}, },
{ {
"id": 271, "id": 1681,
"text": "\n\n", "logprob": -1.6914062,
"logprob": -1.9111328, "special": false,
"special": false "text": " request"
}, },
{ {
"id": 31946, "id": 198,
"text": "Inputs", "logprob": -1.1953125,
"logprob": -1.6855469, "special": false,
"special": false "text": "\n"
}, },
{ {
"id": 25, "id": 2035,
"text": ":", "logprob": -1.3203125,
"logprob": -1.6093254e-05, "special": false,
"special": false "text": "request"
}, },
{ {
"id": 707, "id": 284,
"text": " def", "logprob": -0.13537598,
"logprob": -0.5678711, "special": false,
"special": false "text": " ="
}, },
{ {
"id": 1477, "id": 7388,
"text": " find", "logprob": -1.2402344,
"logprob": -2.5917969, "special": false,
"special": false "text": " requests"
}, },
{ {
"id": 6345, "id": 670,
"text": "_max", "logprob": -0.2775879,
"logprob": -1.8349609, "special": false,
"special": false "text": ".get"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": " for the following function\n\nInputs: def find_max" "generated_text": "\n# Create a request\nrequest = requests.get"
} }

View File

@ -6,79 +6,79 @@
"prefill": [ "prefill": [
{ {
"id": 2271, "id": 2271,
"text": "Test", "logprob": null,
"logprob": null "text": "Test"
}, },
{ {
"id": 1681, "id": 1681,
"text": " request", "logprob": -8.8515625,
"logprob": -7.0351562 "text": " request"
} }
], ],
"seed": 0, "seed": 0,
"tokens": [ "tokens": [
{ {
"id": 311, "id": 311,
"text": " to", "logprob": -1.4277344,
"logprob": -1.4472656, "special": false,
"special": false "text": " to"
},
{
"id": 633,
"text": " get",
"logprob": -0.4741211,
"special": false
},
{
"id": 264,
"text": " a",
"logprob": 0.0,
"special": false
},
{
"id": 1140,
"text": " list",
"logprob": 0.0,
"special": false
},
{
"id": 315,
"text": " of",
"logprob": 0.0,
"special": false
},
{
"id": 678,
"text": " all",
"logprob": 0.0,
"special": false
}, },
{ {
"id": 279, "id": 279,
"text": " the", "logprob": -0.65478516,
"logprob": -0.2590332, "special": false,
"special": false "text": " the"
}, },
{ {
"id": 3847, "id": 2473,
"text": " users", "logprob": -1.8300781,
"logprob": -0.45239258, "special": false,
"special": false "text": " service"
}, },
{ {
"id": 304, "id": 382,
"text": " in", "logprob": -0.75,
"logprob": -0.12322998, "special": false,
"special": false "text": ".\n\n"
}, },
{ {
"id": 419, "id": 286,
"text": " this", "logprob": -0.11621094,
"logprob": -1.7275391, "special": false,
"special": false "text": " "
},
{
"id": 549,
"logprob": 0.0,
"special": false,
"text": " :"
},
{
"id": 689,
"logprob": -0.48608398,
"special": false,
"text": "return"
},
{
"id": 25,
"logprob": 0.0,
"special": false,
"text": ":"
},
{
"id": 5949,
"logprob": -0.5756836,
"special": false,
"text": " Response"
},
{
"id": 504,
"logprob": -0.24499512,
"special": false,
"text": " from"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "Test request to get a list of all the users in this" "generated_text": "Test request to the service.\n\n :return: Response from"
} }

View File

@ -7,81 +7,81 @@
"prefill": [ "prefill": [
{ {
"id": 2271, "id": 2271,
"text": "Test", "logprob": null,
"logprob": null "text": "Test"
}, },
{ {
"id": 1681, "id": 1681,
"text": " request", "logprob": -8.8515625,
"logprob": -7.0351562 "text": " request"
} }
], ],
"seed": null, "seed": null,
"tokens": [ "tokens": [
{ {
"id": 369, "id": 198,
"text": " for", "logprob": -2.9023438,
"logprob": -2.1914062, "special": false,
"special": false "text": "\n"
}, },
{ {
"id": 279, "id": 2,
"text": " the", "logprob": -2.9140625,
"logprob": -2.6210938, "special": false,
"special": false "text": "#"
}, },
{ {
"id": 2701, "id": 4230,
"text": " following", "logprob": -3.1054688,
"logprob": -3.6445312, "special": false,
"special": false "text": " Create"
}, },
{ {
"id": 729, "id": 264,
"text": " function", "logprob": -1.0966797,
"logprob": -2.9648438, "special": false,
"special": false "text": " a"
}, },
{ {
"id": 271, "id": 1681,
"text": "\n\n", "logprob": -1.6914062,
"logprob": -1.9111328, "special": false,
"special": false "text": " request"
}, },
{ {
"id": 31946, "id": 198,
"text": "Inputs", "logprob": -1.1923828,
"logprob": -1.6855469, "special": false,
"special": false "text": "\n"
}, },
{ {
"id": 25, "id": 2035,
"text": ":", "logprob": -1.3193359,
"logprob": -1.6093254e-05, "special": false,
"special": false "text": "request"
}, },
{ {
"id": 707, "id": 284,
"text": " def", "logprob": -0.13586426,
"logprob": -0.5678711, "special": false,
"special": false "text": " ="
}, },
{ {
"id": 1477, "id": 7388,
"text": " find", "logprob": -1.2412109,
"logprob": -2.5917969, "special": false,
"special": false "text": " requests"
}, },
{ {
"id": 6345, "id": 670,
"text": "_max", "logprob": -0.2775879,
"logprob": -1.8349609, "special": false,
"special": false "text": ".get"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": " for the following function\n\nInputs: def find_max" "generated_text": "\n# Create a request\nrequest = requests.get"
}, },
{ {
"details": { "details": {
@ -91,81 +91,81 @@
"prefill": [ "prefill": [
{ {
"id": 2271, "id": 2271,
"text": "Test", "logprob": null,
"logprob": null "text": "Test"
}, },
{ {
"id": 1681, "id": 1681,
"text": " request", "logprob": -8.8515625,
"logprob": -7.0351562 "text": " request"
} }
], ],
"seed": null, "seed": null,
"tokens": [ "tokens": [
{ {
"id": 369, "id": 198,
"text": " for", "logprob": -2.9023438,
"logprob": -2.1914062, "special": false,
"special": false "text": "\n"
}, },
{ {
"id": 279, "id": 2,
"text": " the", "logprob": -2.9140625,
"logprob": -2.6210938, "special": false,
"special": false "text": "#"
}, },
{ {
"id": 2701, "id": 4230,
"text": " following", "logprob": -3.1054688,
"logprob": -3.6445312, "special": false,
"special": false "text": " Create"
}, },
{ {
"id": 729, "id": 264,
"text": " function", "logprob": -1.0966797,
"logprob": -2.9648438, "special": false,
"special": false "text": " a"
}, },
{ {
"id": 271, "id": 1681,
"text": "\n\n", "logprob": -1.6914062,
"logprob": -1.9111328, "special": false,
"special": false "text": " request"
}, },
{ {
"id": 31946, "id": 198,
"text": "Inputs", "logprob": -1.1923828,
"logprob": -1.6855469, "special": false,
"special": false "text": "\n"
}, },
{ {
"id": 25, "id": 2035,
"text": ":", "logprob": -1.3193359,
"logprob": -1.6093254e-05, "special": false,
"special": false "text": "request"
}, },
{ {
"id": 707, "id": 284,
"text": " def", "logprob": -0.13586426,
"logprob": -0.5678711, "special": false,
"special": false "text": " ="
}, },
{ {
"id": 1477, "id": 7388,
"text": " find", "logprob": -1.2412109,
"logprob": -2.5917969, "special": false,
"special": false "text": " requests"
}, },
{ {
"id": 6345, "id": 670,
"text": "_max", "logprob": -0.2775879,
"logprob": -1.8349609, "special": false,
"special": false "text": ".get"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": " for the following function\n\nInputs: def find_max" "generated_text": "\n# Create a request\nrequest = requests.get"
}, },
{ {
"details": { "details": {
@ -175,81 +175,81 @@
"prefill": [ "prefill": [
{ {
"id": 2271, "id": 2271,
"text": "Test", "logprob": null,
"logprob": null "text": "Test"
}, },
{ {
"id": 1681, "id": 1681,
"text": " request", "logprob": -8.8515625,
"logprob": -7.0351562 "text": " request"
} }
], ],
"seed": null, "seed": null,
"tokens": [ "tokens": [
{ {
"id": 369, "id": 198,
"text": " for", "logprob": -2.9023438,
"logprob": -2.1914062, "special": false,
"special": false "text": "\n"
}, },
{ {
"id": 279, "id": 2,
"text": " the", "logprob": -2.9140625,
"logprob": -2.6210938, "special": false,
"special": false "text": "#"
}, },
{ {
"id": 2701, "id": 4230,
"text": " following", "logprob": -3.1054688,
"logprob": -3.6445312, "special": false,
"special": false "text": " Create"
}, },
{ {
"id": 729, "id": 264,
"text": " function", "logprob": -1.0966797,
"logprob": -2.9648438, "special": false,
"special": false "text": " a"
}, },
{ {
"id": 271, "id": 1681,
"text": "\n\n", "logprob": -1.6914062,
"logprob": -1.9111328, "special": false,
"special": false "text": " request"
}, },
{ {
"id": 31946, "id": 198,
"text": "Inputs", "logprob": -1.1923828,
"logprob": -1.6855469, "special": false,
"special": false "text": "\n"
}, },
{ {
"id": 25, "id": 2035,
"text": ":", "logprob": -1.3193359,
"logprob": -1.6093254e-05, "special": false,
"special": false "text": "request"
}, },
{ {
"id": 707, "id": 284,
"text": " def", "logprob": -0.13586426,
"logprob": -0.5678711, "special": false,
"special": false "text": " ="
}, },
{ {
"id": 1477, "id": 7388,
"text": " find", "logprob": -1.2412109,
"logprob": -2.5917969, "special": false,
"special": false "text": " requests"
}, },
{ {
"id": 6345, "id": 670,
"text": "_max", "logprob": -0.2775879,
"logprob": -1.8349609, "special": false,
"special": false "text": ".get"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": " for the following function\n\nInputs: def find_max" "generated_text": "\n# Create a request\nrequest = requests.get"
}, },
{ {
"details": { "details": {
@ -259,80 +259,80 @@
"prefill": [ "prefill": [
{ {
"id": 2271, "id": 2271,
"text": "Test", "logprob": null,
"logprob": null "text": "Test"
}, },
{ {
"id": 1681, "id": 1681,
"text": " request", "logprob": -8.8515625,
"logprob": -7.0351562 "text": " request"
} }
], ],
"seed": null, "seed": null,
"tokens": [ "tokens": [
{ {
"id": 369, "id": 198,
"text": " for", "logprob": -2.9023438,
"logprob": -2.1914062, "special": false,
"special": false "text": "\n"
}, },
{ {
"id": 279, "id": 2,
"text": " the", "logprob": -2.9140625,
"logprob": -2.6210938, "special": false,
"special": false "text": "#"
}, },
{ {
"id": 2701, "id": 4230,
"text": " following", "logprob": -3.1054688,
"logprob": -3.6445312, "special": false,
"special": false "text": " Create"
}, },
{ {
"id": 729, "id": 264,
"text": " function", "logprob": -1.0966797,
"logprob": -2.9648438, "special": false,
"special": false "text": " a"
}, },
{ {
"id": 271, "id": 1681,
"text": "\n\n", "logprob": -1.6914062,
"logprob": -1.9111328, "special": false,
"special": false "text": " request"
}, },
{ {
"id": 31946, "id": 198,
"text": "Inputs", "logprob": -1.1923828,
"logprob": -1.6855469, "special": false,
"special": false "text": "\n"
}, },
{ {
"id": 25, "id": 2035,
"text": ":", "logprob": -1.3193359,
"logprob": -1.6093254e-05, "special": false,
"special": false "text": "request"
}, },
{ {
"id": 707, "id": 284,
"text": " def", "logprob": -0.13586426,
"logprob": -0.5678711, "special": false,
"special": false "text": " ="
}, },
{ {
"id": 1477, "id": 7388,
"text": " find", "logprob": -1.2412109,
"logprob": -2.5917969, "special": false,
"special": false "text": " requests"
}, },
{ {
"id": 6345, "id": 670,
"text": "_max", "logprob": -0.2775879,
"logprob": -1.8349609, "special": false,
"special": false "text": ".get"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": " for the following function\n\nInputs: def find_max" "generated_text": "\n# Create a request\nrequest = requests.get"
} }
] ]

View File

@ -3,7 +3,7 @@ import pytest
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def flash_qwen2_handle(launcher): def flash_qwen2_handle(launcher):
with launcher("Qwen/Qwen1.5-7B") as handle: with launcher("Qwen/Qwen1.5-0.5B") as handle:
yield handle yield handle
@ -20,7 +20,7 @@ async def test_flash_qwen2(flash_qwen2, response_snapshot):
) )
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
assert response.generated_text == " for the following function\n\nInputs: def find_max" assert response.generated_text == "\n# Create a request\nrequest = requests.get"
assert response == response_snapshot assert response == response_snapshot
@ -48,14 +48,12 @@ async def test_flash_qwen2_all_params(flash_qwen2, response_snapshot):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_qwen2_load(flash_qwen2, generate_load, response_snapshot): async def test_flash_qwen2_load(flash_qwen2, generate_load, response_snapshot):
responses = await generate_load( responses = await generate_load(flash_qwen2, "Test request", max_new_tokens=10, n=4)
flash_qwen2, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4 assert len(responses) == 4
assert all( assert all(
[r.generated_text == responses[0].generated_text for r in responses] [r.generated_text == responses[0].generated_text for r in responses]
), f"{[r.generated_text for r in responses]}" ), f"{[r.generated_text for r in responses]}"
assert responses[0].generated_text == ": Let n = 10 - 1" assert responses[0].generated_text == "\n# Create a request\nrequest = requests.get"
assert responses == response_snapshot assert responses == response_snapshot

View File

@ -332,27 +332,6 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == "qwen2":
if FLASH_ATTENTION:
return FlashQwen2(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2")
)
else:
return CausalLM(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type == "gemma": if model_type == "gemma":
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashGemma( return FlashGemma(
@ -364,9 +343,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif sharded: elif sharded:
raise NotImplementedError( raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
FLASH_ATT_ERROR_MESSAGE.format("Sharded Golden Gate")
)
else: else:
return CausalLM( return CausalLM(
model_id, model_id,
@ -424,6 +401,17 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
else:
return CausalLM(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type == "mixtral": if model_type == "mixtral":
sliding_window = config_dict.get("sliding_window", -1) sliding_window = config_dict.get("sliding_window", -1)
@ -438,6 +426,18 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
else:
return CausalLM(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type == "starcoder2": if model_type == "starcoder2":
sliding_window = config_dict.get("sliding_window", -1) sliding_window = config_dict.get("sliding_window", -1)
if ( if (
@ -450,6 +450,43 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
)
else:
return CausalLM(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type == "qwen2":
sliding_window = config_dict.get("sliding_window", -1)
if (
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
) or HAS_FLASH_ATTN_V2_CUDA:
return FlashQwen2(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
else:
return CausalLM(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type == "opt": if model_type == "opt":
return OPTSharded( return OPTSharded(

View File

@ -486,6 +486,9 @@ class CausalLM(Model):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
if use_medusa:
raise RuntimeError("Medusa decoding is not enabled for AutoModel")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype

View File

@ -870,7 +870,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
**deprecated_arguments, **deprecated_arguments,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set

View File

@ -11,7 +11,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead, SpeculativeHead,
get_linear, get_linear,
FastRMSNorm, FastRMSNorm,
) )
@ -51,8 +51,14 @@ def _load_gqa(config, prefix: str, weights):
config.hidden_size, config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
w = [
weights.get_sharded(f"{p}.bias", dim=0)
for p in [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
]
bias = torch.cat(w, dim=0).to(dtype=weights.dtype).to(device=weights.device)
return TensorParallelColumnLinear( return TensorParallelColumnLinear(
get_linear(weight, bias=None, quantize=config.quantize) get_linear(weight, bias=bias, quantize=config.quantize)
) )
@ -170,6 +176,7 @@ class Qwen2Attention(torch.nn.Module):
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
class Qwen2MLP(nn.Module): class Qwen2MLP(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
@ -212,7 +219,9 @@ class Qwen2Layer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, layer_id, config, weights):
super().__init__() super().__init__()
prefix = f"model.layers.{layer_id}" prefix = f"model.layers.{layer_id}"
self.self_attn = Qwen2Attention(prefix=f"{prefix}.self_attn", config=config, weights=weights) self.self_attn = Qwen2Attention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.mlp = Qwen2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.mlp = Qwen2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = FastRMSNorm.load( self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
@ -262,6 +271,7 @@ class Qwen2Layer(nn.Module):
return mlp_output, attn_res return mlp_output, attn_res
class Qwen2Model(torch.nn.Module): class Qwen2Model(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, config, weights):
super().__init__() super().__init__()
@ -338,7 +348,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
super().__init__() super().__init__()
self.model = Qwen2Model(config, weights) self.model = Qwen2Model(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix="lm_head",
weights=weights, weights=weights,

View File

@ -721,7 +721,7 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
) )
hidden_states = outputs[0] hidden_states = outputs[0]
lm_logits = self.embed_out(hidden_states) lm_logits, speculative_logits = self.embed_out(hidden_states)
lm_loss = None lm_loss = None
if labels is not None: if labels is not None:
@ -739,12 +739,15 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
output = (lm_logits,) + outputs[1:] output = (lm_logits,) + outputs[1:]
return ((lm_loss,) + output) if lm_loss is not None else output return ((lm_loss,) + output) if lm_loss is not None else output
return CausalLMOutputWithPast( return (
CausalLMOutputWithPast(
loss=lm_loss, loss=lm_loss,
logits=lm_logits, logits=lm_logits,
past_key_values=outputs.past_key_values, past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
),
speculative_logits,
) )
def prepare_inputs_for_generation( def prepare_inputs_for_generation(

View File

@ -792,16 +792,19 @@ class OPTForCausalLM(OPTPreTrainedModel):
return_dict=return_dict, return_dict=return_dict,
) )
logits = self.lm_head(outputs[0]).contiguous() logits, speculative_logits = self.lm_head(outputs)
loss = None loss = None
return CausalLMOutputWithPast( return (
CausalLMOutputWithPast(
loss=loss, loss=loss,
logits=logits, logits=logits,
past_key_values=outputs.past_key_values, past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
),
speculative_logits,
) )
def prepare_inputs_for_generation( def prepare_inputs_for_generation(

View File

@ -315,7 +315,7 @@ class BaseFlashMistral(FlashCausalLM):
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashLlama is only available on GPU") raise NotImplementedError("FlashMistral is only available on GPU")
tokenizer = LlamaTokenizerFast.from_pretrained( tokenizer = LlamaTokenizerFast.from_pretrained(
model_id, model_id,

View File

@ -1,12 +1,17 @@
import math
import torch import torch
import torch.distributed import torch.distributed
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer
from transformers.models.qwen2 import Qwen2Tokenizer from transformers.models.qwen2 import Qwen2Tokenizer
from typing import Optional from typing import Optional
from text_generation_server.models import FlashCausalLM from text_generation_server.models.cache_manager import BLOCK_SIZE
from text_generation_server.models.flash_mistral import (
BaseFlashMistral,
set_sliding_window,
)
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2ForCausalLM, Qwen2ForCausalLM,
) )
@ -20,12 +25,13 @@ from text_generation_server.utils import (
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
class FlashQwen2(FlashCausalLM): class FlashQwen2(BaseFlashMistral):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -36,16 +42,11 @@ class FlashQwen2(FlashCausalLM):
else: else:
raise NotImplementedError("FlashQwen2 is only available on GPU") raise NotImplementedError("FlashQwen2 is only available on GPU")
try:
tokenizer = Qwen2Tokenizer.from_pretrained( tokenizer = Qwen2Tokenizer.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
trust_remote_code=trust_remote_code, padding_side="left",
) truncation_side="left",
except Exception:
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -53,6 +54,13 @@ class FlashQwen2(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
# Set context windows
if config.sliding_window is not None:
set_sliding_window(
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
@ -63,8 +71,10 @@ class FlashQwen2(FlashCausalLM):
model = Qwen2ForCausalLM(config, weights) model = Qwen2ForCausalLM(config, weights)
self.cuda_graphs = {}
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashQwen2, self).__init__( super(BaseFlashMistral, self).__init__(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
num_layers=len(model.model.layers), num_layers=len(model.model.layers),
@ -74,4 +84,5 @@ class FlashQwen2(FlashCausalLM):
device=device, device=device,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
sliding_window=config.sliding_window,
) )

View File

@ -38,7 +38,7 @@ class FlashStarcoder2(BaseFlashMistral):
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashLlama is only available on GPU") raise NotImplementedError("FlashStarcoder2 is only available on GPU")
tokenizer = GPT2TokenizerFast.from_pretrained( tokenizer = GPT2TokenizerFast.from_pretrained(
model_id, model_id,

View File

@ -167,6 +167,7 @@ class GalacticaSharded(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -194,6 +195,7 @@ class GalacticaSharded(CausalLM):
) )
config.quantize = quantize config.quantize = quantize
tokenizer.pad_token_id = config.pad_token_id tokenizer.pad_token_id = config.pad_token_id
config.use_medusa = use_medusa
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
@ -229,10 +231,10 @@ class GalacticaSharded(CausalLM):
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
): ):
outputs = self.model.forward( outputs, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=True, use_cache=True,
) )
return outputs.logits, outputs.past_key_values return outputs.logits, speculative_logits, outputs.past_key_values

View File

@ -24,6 +24,7 @@ class GPTNeoxSharded(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -50,6 +51,7 @@ class GPTNeoxSharded(CausalLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
@ -75,7 +77,7 @@ class GPTNeoxSharded(CausalLM):
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
): ):
outputs = self.model.forward( outputs, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
@ -84,4 +86,4 @@ class GPTNeoxSharded(CausalLM):
) )
logits = outputs.logits logits = outputs.logits
return logits, outputs.past_key_values return logits, speculative_logits, outputs.past_key_values

View File

@ -12,9 +12,13 @@ class RW(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
if use_medusa:
raise RuntimeError("Medusa decoding is not enabled for AutoModel")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype

View File

@ -536,6 +536,9 @@ class Seq2SeqLM(Model):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
if use_medusa:
raise RuntimeError("Medusa decoding is not enabled for AutoModel")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype