Speculative decoding + mistral

This commit is contained in:
Nicolas Patry 2023-12-04 13:42:28 +00:00
parent e7e07342bd
commit 7ed07bcc05
10 changed files with 169 additions and 93 deletions

View File

@ -57,22 +57,28 @@
"text": " learning"
},
{
"id": 313,
"logprob": -1.0712891,
"id": 508,
"logprob": -1.5087891,
"special": false,
"text": " ("
"text": " can"
},
{
"id": 15189,
"logprob": -0.7578125,
"special": false,
"text": "also"
},
{
"id": 2998,
"id": 367,
"logprob": 0.0,
"special": false,
"text": " known"
"text": " be"
},
{
"id": 2714,
"logprob": -0.6538086,
"special": false,
"text": " thought"
},
{
"id": 310,
"logprob": 0.0,
"special": false,
"text": " of"
},
{
"id": 408,
@ -81,18 +87,18 @@
"text": " as"
},
{
"id": 6483,
"id": 263,
"logprob": 0.0,
"special": false,
"text": " deep"
"text": " a"
},
{
"id": 19677,
"logprob": 0.0,
"id": 11306,
"logprob": -0.5488281,
"special": false,
"text": " neural"
"text": " subset"
}
]
},
"generated_text": "What is Deep Learning?\nDeep learning (also known as deep neural"
"generated_text": "What is Deep Learning?\nDeep learning can be thought of as a subset"
}

View File

@ -17,7 +17,7 @@
},
{
"id": 338,
"logprob": -1.5488281,
"logprob": -1.5498047,
"text": "is"
},
{
@ -27,12 +27,12 @@
},
{
"id": 29257,
"logprob": -1.2753906,
"logprob": -1.2734375,
"text": "Learning"
},
{
"id": 29973,
"logprob": -0.48046875,
"logprob": -0.48217773,
"text": "?"
}
],
@ -40,19 +40,19 @@
"tokens": [
{
"id": 13,
"logprob": -1.1845703,
"logprob": -1.1875,
"special": false,
"text": "\n"
},
{
"id": 2772,
"logprob": -0.5727539,
"logprob": -0.5708008,
"special": false,
"text": "De"
},
{
"id": 1022,
"logprob": -0.00010967255,
"logprob": -0.00010931492,
"special": false,
"text": "ep"
},
@ -64,37 +64,37 @@
},
{
"id": 338,
"logprob": -0.04510498,
"logprob": -0.044433594,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.018295288,
"logprob": -0.018310547,
"special": false,
"text": " a"
},
{
"id": 11306,
"logprob": -0.45922852,
"logprob": -0.46044922,
"special": false,
"text": " subset"
},
{
"id": 310,
"logprob": -0.00020992756,
"logprob": -0.0002104044,
"special": false,
"text": " of"
},
{
"id": 4933,
"logprob": -0.0046539307,
"logprob": -0.004711151,
"special": false,
"text": " machine"
},
{
"id": 6509,
"logprob": -0.00025844574,
"logprob": -0.00025820732,
"special": false,
"text": " learning"
},
@ -112,7 +112,7 @@
}
]
},
"generated_text": "ep learning is a subset of machine learning that involves"
"generated_text": "\nDeep learning is a subset of machine learning that involves"
},
{
"details": {
@ -127,12 +127,12 @@
},
{
"id": 1724,
"logprob": -10.734375,
"logprob": -10.7421875,
"text": "What"
},
{
"id": 338,
"logprob": -1.5488281,
"logprob": -1.5498047,
"text": "is"
},
{
@ -155,19 +155,19 @@
"tokens": [
{
"id": 13,
"logprob": -1.1826172,
"logprob": -1.1835938,
"special": false,
"text": "\n"
},
{
"id": 2772,
"logprob": -0.56689453,
"logprob": -0.57470703,
"special": false,
"text": "De"
},
{
"id": 1022,
"logprob": -0.000108003616,
"logprob": -0.00010788441,
"special": false,
"text": "ep"
},
@ -179,13 +179,13 @@
},
{
"id": 338,
"logprob": -0.044433594,
"logprob": -0.04510498,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.018295288,
"logprob": -0.018585205,
"special": false,
"text": " a"
},
@ -197,19 +197,19 @@
},
{
"id": 310,
"logprob": -0.0002104044,
"logprob": -0.00021457672,
"special": false,
"text": " of"
},
{
"id": 4933,
"logprob": -0.004711151,
"logprob": -0.004776001,
"special": false,
"text": " machine"
},
{
"id": 6509,
"logprob": -0.00025892258,
"logprob": -0.0002593994,
"special": false,
"text": " learning"
},
@ -227,7 +227,7 @@
}
]
},
"generated_text": "ep learning is a subset of machine learning that involves"
"generated_text": "\nDeep learning is a subset of machine learning that involves"
},
{
"details": {
@ -242,12 +242,12 @@
},
{
"id": 1724,
"logprob": -10.734375,
"logprob": -10.7421875,
"text": "What"
},
{
"id": 338,
"logprob": -1.5488281,
"logprob": -1.5498047,
"text": "is"
},
{
@ -270,19 +270,19 @@
"tokens": [
{
"id": 13,
"logprob": -1.1826172,
"logprob": -1.1835938,
"special": false,
"text": "\n"
},
{
"id": 2772,
"logprob": -0.56689453,
"logprob": -0.57470703,
"special": false,
"text": "De"
},
{
"id": 1022,
"logprob": -0.000108003616,
"logprob": -0.00010788441,
"special": false,
"text": "ep"
},
@ -294,13 +294,13 @@
},
{
"id": 338,
"logprob": -0.044433594,
"logprob": -0.04510498,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.018295288,
"logprob": -0.018585205,
"special": false,
"text": " a"
},
@ -312,19 +312,19 @@
},
{
"id": 310,
"logprob": -0.0002104044,
"logprob": -0.00021457672,
"special": false,
"text": " of"
},
{
"id": 4933,
"logprob": -0.004711151,
"logprob": -0.004776001,
"special": false,
"text": " machine"
},
{
"id": 6509,
"logprob": -0.00025892258,
"logprob": -0.0002593994,
"special": false,
"text": " learning"
},
@ -342,7 +342,7 @@
}
]
},
"generated_text": "ep learning is a subset of machine learning that involves"
"generated_text": "\nDeep learning is a subset of machine learning that involves"
},
{
"details": {
@ -357,12 +357,12 @@
},
{
"id": 1724,
"logprob": -10.734375,
"logprob": -10.7421875,
"text": "What"
},
{
"id": 338,
"logprob": -1.5488281,
"logprob": -1.5498047,
"text": "is"
},
{
@ -385,19 +385,19 @@
"tokens": [
{
"id": 13,
"logprob": -1.1826172,
"logprob": -1.1835938,
"special": false,
"text": "\n"
},
{
"id": 2772,
"logprob": -0.56689453,
"logprob": -0.57470703,
"special": false,
"text": "De"
},
{
"id": 1022,
"logprob": -0.000108003616,
"logprob": -0.00010788441,
"special": false,
"text": "ep"
},
@ -409,13 +409,13 @@
},
{
"id": 338,
"logprob": -0.044433594,
"logprob": -0.04510498,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.018295288,
"logprob": -0.018585205,
"special": false,
"text": " a"
},
@ -427,19 +427,19 @@
},
{
"id": 310,
"logprob": -0.0002104044,
"logprob": -0.00021457672,
"special": false,
"text": " of"
},
{
"id": 4933,
"logprob": -0.004711151,
"logprob": -0.004776001,
"special": false,
"text": " machine"
},
{
"id": 6509,
"logprob": -0.00025892258,
"logprob": -0.0002593994,
"special": false,
"text": " learning"
},
@ -457,6 +457,6 @@
}
]
},
"generated_text": "ep learning is a subset of machine learning that involves"
"generated_text": "\nDeep learning is a subset of machine learning that involves"
}
]

View File

@ -111,5 +111,5 @@
}
]
},
"generated_text": "ep learning is a subset of machine learning that involves"
"generated_text": "\nDeep learning is a subset of machine learning that involves"
}

View File

@ -43,7 +43,7 @@ async def test_flash_llama_all_params(flash_llama, response_snapshot):
seed=0,
)
assert response.details.generated_tokens == 5
assert response.details.generated_tokens == 10
assert response == response_snapshot

View File

@ -54,6 +54,6 @@ async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}"
assert responses[0].generated_text == 'ep learning is a subset of machine learning that involves'
assert responses[0].generated_text == '\nDeep learning is a subset of machine learning that involves'
assert responses == response_snapshot

View File

@ -21,6 +21,7 @@ async def test_flash_mistral(flash_mistral, response_snapshot):
)
assert response.details.generated_tokens == 10
assert response.generated_text == ": Let n = 10 - 1"
assert response == response_snapshot
@ -55,6 +56,7 @@ async def test_flash_mistral_load(flash_mistral, generate_load, response_snapsho
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}"
assert responses[0].generated_text == ": Let n = 10 - 1"
assert responses == response_snapshot

View File

@ -101,6 +101,8 @@ def get_model(
else:
raise RuntimeError(f"Unknown dtype {dtype}")
SPECULATE = 2
if "facebook/galactica" in model_id:
return GalacticaSharded(
model_id,

View File

@ -479,19 +479,19 @@ class FlashCausalLMBatch(Batch):
max_blocks = 0
max_length = 0
max_seqlen = 0
speculative_length = 0 if batches[0].speculative_ids is None else batches[0].speculative_ids.shape[1]
for b in batches:
total_batch_size += len(b)
total_slots += len(b.slots)
blocks += b.blocks
speculative_length = 0 if b.speculative_ids is None else b.speculative_ids.shape[1]
max_blocks = max(max_blocks, b.max_blocks)
max_seqlen = max(max_seqlen, b.max_seqlen)
max_length = max(
max_length,
max(
input_length
+ speculative_length
+ stopping_criteria.max_new_tokens
+ speculative_length
- stopping_criteria.current_tokens
for input_length, stopping_criteria in zip(
b.input_lengths, b.stopping_criterias
@ -994,7 +994,8 @@ class FlashCausalLM(Model):
# Evaluate stopping criteria
for next_token_id in _next_token_ids:
left = 0
for j, next_token_id in enumerate(_next_token_ids):
stop, reason = stopping_criteria(
next_token_id,
next_token_text,
@ -1002,21 +1003,26 @@ class FlashCausalLM(Model):
if stop:
stopped = True
left = len(_next_token_ids) - 1 - j
break
if not stop:
else:
stopped = False
_next_token_ids = _next_token_ids[:len(_next_token_ids) - left]
# Shard generations
# All generations will be appended in the rust sharded client
if i % self.world_size == self.rank:
if stop:
# Decode generated tokens
# Remove potentially accepted ids that do not respect
# the stopping_criteria
_ids = all_input_ids[:len(all_input_ids)-left]
output_text, _, _ = self.decode_token(
all_input_ids,
prefix_offset=len(all_input_ids)
_ids,
prefix_offset=len(_ids)
- stopping_criteria.current_tokens
- 1,
read_offset=len(all_input_ids)
read_offset=len(_ids)
- stopping_criteria.current_tokens,
skip_special_tokens=True,
)

View File

@ -132,7 +132,9 @@ class FlashMistralBatch(FlashCausalLMBatch):
# Paged attention
# Remove one as the first token des not have a past
total_tokens = input_length + max_new_tokens - 1
from text_generation_server.models import SPECULATE
speculative_length = SPECULATE
total_tokens = input_length + max_new_tokens - 1 + speculative_length
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
needed_blocks = min(
@ -183,7 +185,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
cumulative_max_length += total_tokens
max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, needed_blocks)
max_length = max(max_length, input_length + max_new_tokens)
max_length = max(max_length, input_length + max_new_tokens + speculative_length)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device
@ -341,17 +343,55 @@ class FlashMistral(FlashCausalLM):
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward
if batch.speculative_ids is not None:
input_ids=batch.input_ids
position_ids=batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor
slots=batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor
max_s=batch.max_seqlen
lm_head_indices=batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
# Add Copy the block tables for all members
block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous()
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids=batch.input_ids
position_ids=batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor
slots=batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor
max_s=batch.max_seqlen
lm_head_indices=batch.prefill_head_indices
logits = self.model.forward(
input_ids=batch.input_ids,
position_ids=batch.position_ids,
cu_seqlen_prefill=batch.cu_seqlen_prefill,
kv_cache=get_cache_manager().kv_cache,
block_tables=batch.block_tables_tensor,
slots=batch.slots[batch.slot_indices],
input_lengths=batch.input_lengths_tensor,
max_s=batch.max_seqlen,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=batch.prefill_head_indices,
lm_head_indices=lm_head_indices,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None

View File

@ -258,16 +258,34 @@ class HeterogeneousNextTokenChooser:
self.device = device
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None):
if self.watermark_processor is not None:
scores = self.watermark_processor(input_ids, scores)
if self.repetition_processor is not None:
scores = self.repetition_processor(input_ids, scores)
if speculated_ids is not None:
B = scores.shape[0] // (speculated_ids.shape[1] + 1)
S = speculated_ids.shape[1] + 1
scores = scores.view(B, S, -1)
else:
B = scores.shape[0]
S = 1
scores = scores.view(B, S, -1)
for warper in self.warpers:
scores = warper(input_ids, scores)
all_next_ids = []
all_scores = []
for j in range(S):
_scores = scores[:, j]
if self.watermark_processor is not None:
_scores = self.watermark_processor(input_ids, _scores)
if self.repetition_processor is not None:
_scores = self.repetition_processor(input_ids, _scores)
for warper in self.warpers:
_scores = warper(input_ids, _scores)
next_ids = self.choice(scores)
next_ids = self.choice(_scores)
scores[:, j] = _scores
all_next_ids.append(next_ids.unsqueeze(1))
next_ids = torch.cat(all_next_ids, dim=1).reshape(B*S)
scores = scores.view( B* S, -1)
if speculated_ids is not None:
accepted_ids = []
B = next_ids.shape[0] // (speculated_ids.shape[1] + 1)
@ -289,6 +307,9 @@ class HeterogeneousNextTokenChooser:
else:
break
accepted_ids.append(accepted)
from loguru import logger
logger.info(f"ACCEPTED IDS {accepted_ids}")
accepted_ids = torch.tensor(accepted_ids, device=input_ids.device, dtype=input_ids.dtype)
next_ids = next_ids[indices]
scores = scores[indices]
@ -298,7 +319,6 @@ class HeterogeneousNextTokenChooser:
else:
accepted_ids = torch.ones_like(next_ids)
logprobs = torch.log_softmax(scores, -1)
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)