diff --git a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_all_params.json b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_all_params.json index ad4c6c30..4dd815b3 100644 --- a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_all_params.json @@ -57,22 +57,28 @@ "text": " learning" }, { - "id": 313, - "logprob": -1.0712891, + "id": 508, + "logprob": -1.5087891, "special": false, - "text": " (" + "text": " can" }, { - "id": 15189, - "logprob": -0.7578125, - "special": false, - "text": "also" - }, - { - "id": 2998, + "id": 367, "logprob": 0.0, "special": false, - "text": " known" + "text": " be" + }, + { + "id": 2714, + "logprob": -0.6538086, + "special": false, + "text": " thought" + }, + { + "id": 310, + "logprob": 0.0, + "special": false, + "text": " of" }, { "id": 408, @@ -81,18 +87,18 @@ "text": " as" }, { - "id": 6483, + "id": 263, "logprob": 0.0, "special": false, - "text": " deep" + "text": " a" }, { - "id": 19677, - "logprob": 0.0, + "id": 11306, + "logprob": -0.5488281, "special": false, - "text": " neural" + "text": " subset" } ] }, - "generated_text": "What is Deep Learning?\nDeep learning (also known as deep neural" + "generated_text": "What is Deep Learning?\nDeep learning can be thought of as a subset" } diff --git a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_load.json b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_load.json index 82a7b9e1..6698f5f4 100644 --- a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_load.json +++ b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_load.json @@ -17,7 +17,7 @@ }, { "id": 338, - "logprob": -1.5488281, + "logprob": -1.5498047, "text": "is" }, { @@ -27,12 +27,12 @@ }, { "id": 29257, - "logprob": -1.2753906, + "logprob": -1.2734375, "text": "Learning" }, { "id": 29973, - "logprob": -0.48046875, + "logprob": -0.48217773, "text": "?" } ], @@ -40,19 +40,19 @@ "tokens": [ { "id": 13, - "logprob": -1.1845703, + "logprob": -1.1875, "special": false, "text": "\n" }, { "id": 2772, - "logprob": -0.5727539, + "logprob": -0.5708008, "special": false, "text": "De" }, { "id": 1022, - "logprob": -0.00010967255, + "logprob": -0.00010931492, "special": false, "text": "ep" }, @@ -64,37 +64,37 @@ }, { "id": 338, - "logprob": -0.04510498, + "logprob": -0.044433594, "special": false, "text": " is" }, { "id": 263, - "logprob": -0.018295288, + "logprob": -0.018310547, "special": false, "text": " a" }, { "id": 11306, - "logprob": -0.45922852, + "logprob": -0.46044922, "special": false, "text": " subset" }, { "id": 310, - "logprob": -0.00020992756, + "logprob": -0.0002104044, "special": false, "text": " of" }, { "id": 4933, - "logprob": -0.0046539307, + "logprob": -0.004711151, "special": false, "text": " machine" }, { "id": 6509, - "logprob": -0.00025844574, + "logprob": -0.00025820732, "special": false, "text": " learning" }, @@ -112,7 +112,7 @@ } ] }, - "generated_text": "ep learning is a subset of machine learning that involves" + "generated_text": "\nDeep learning is a subset of machine learning that involves" }, { "details": { @@ -127,12 +127,12 @@ }, { "id": 1724, - "logprob": -10.734375, + "logprob": -10.7421875, "text": "What" }, { "id": 338, - "logprob": -1.5488281, + "logprob": -1.5498047, "text": "is" }, { @@ -155,19 +155,19 @@ "tokens": [ { "id": 13, - "logprob": -1.1826172, + "logprob": -1.1835938, "special": false, "text": "\n" }, { "id": 2772, - "logprob": -0.56689453, + "logprob": -0.57470703, "special": false, "text": "De" }, { "id": 1022, - "logprob": -0.000108003616, + "logprob": -0.00010788441, "special": false, "text": "ep" }, @@ -179,13 +179,13 @@ }, { "id": 338, - "logprob": -0.044433594, + "logprob": -0.04510498, "special": false, "text": " is" }, { "id": 263, - "logprob": -0.018295288, + "logprob": -0.018585205, "special": false, "text": " a" }, @@ -197,19 +197,19 @@ }, { "id": 310, - "logprob": -0.0002104044, + "logprob": -0.00021457672, "special": false, "text": " of" }, { "id": 4933, - "logprob": -0.004711151, + "logprob": -0.004776001, "special": false, "text": " machine" }, { "id": 6509, - "logprob": -0.00025892258, + "logprob": -0.0002593994, "special": false, "text": " learning" }, @@ -227,7 +227,7 @@ } ] }, - "generated_text": "ep learning is a subset of machine learning that involves" + "generated_text": "\nDeep learning is a subset of machine learning that involves" }, { "details": { @@ -242,12 +242,12 @@ }, { "id": 1724, - "logprob": -10.734375, + "logprob": -10.7421875, "text": "What" }, { "id": 338, - "logprob": -1.5488281, + "logprob": -1.5498047, "text": "is" }, { @@ -270,19 +270,19 @@ "tokens": [ { "id": 13, - "logprob": -1.1826172, + "logprob": -1.1835938, "special": false, "text": "\n" }, { "id": 2772, - "logprob": -0.56689453, + "logprob": -0.57470703, "special": false, "text": "De" }, { "id": 1022, - "logprob": -0.000108003616, + "logprob": -0.00010788441, "special": false, "text": "ep" }, @@ -294,13 +294,13 @@ }, { "id": 338, - "logprob": -0.044433594, + "logprob": -0.04510498, "special": false, "text": " is" }, { "id": 263, - "logprob": -0.018295288, + "logprob": -0.018585205, "special": false, "text": " a" }, @@ -312,19 +312,19 @@ }, { "id": 310, - "logprob": -0.0002104044, + "logprob": -0.00021457672, "special": false, "text": " of" }, { "id": 4933, - "logprob": -0.004711151, + "logprob": -0.004776001, "special": false, "text": " machine" }, { "id": 6509, - "logprob": -0.00025892258, + "logprob": -0.0002593994, "special": false, "text": " learning" }, @@ -342,7 +342,7 @@ } ] }, - "generated_text": "ep learning is a subset of machine learning that involves" + "generated_text": "\nDeep learning is a subset of machine learning that involves" }, { "details": { @@ -357,12 +357,12 @@ }, { "id": 1724, - "logprob": -10.734375, + "logprob": -10.7421875, "text": "What" }, { "id": 338, - "logprob": -1.5488281, + "logprob": -1.5498047, "text": "is" }, { @@ -385,19 +385,19 @@ "tokens": [ { "id": 13, - "logprob": -1.1826172, + "logprob": -1.1835938, "special": false, "text": "\n" }, { "id": 2772, - "logprob": -0.56689453, + "logprob": -0.57470703, "special": false, "text": "De" }, { "id": 1022, - "logprob": -0.000108003616, + "logprob": -0.00010788441, "special": false, "text": "ep" }, @@ -409,13 +409,13 @@ }, { "id": 338, - "logprob": -0.044433594, + "logprob": -0.04510498, "special": false, "text": " is" }, { "id": 263, - "logprob": -0.018295288, + "logprob": -0.018585205, "special": false, "text": " a" }, @@ -427,19 +427,19 @@ }, { "id": 310, - "logprob": -0.0002104044, + "logprob": -0.00021457672, "special": false, "text": " of" }, { "id": 4933, - "logprob": -0.004711151, + "logprob": -0.004776001, "special": false, "text": " machine" }, { "id": 6509, - "logprob": -0.00025892258, + "logprob": -0.0002593994, "special": false, "text": " learning" }, @@ -457,6 +457,6 @@ } ] }, - "generated_text": "ep learning is a subset of machine learning that involves" + "generated_text": "\nDeep learning is a subset of machine learning that involves" } ] diff --git a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_simple.json b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_simple.json index 0a1e3198..cd3cb53a 100644 --- a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_simple.json +++ b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_simple.json @@ -111,5 +111,5 @@ } ] }, - "generated_text": "ep learning is a subset of machine learning that involves" + "generated_text": "\nDeep learning is a subset of machine learning that involves" } diff --git a/integration-tests/models/test_flash_llama.py b/integration-tests/models/test_flash_llama.py index c69314ff..a9536bc4 100644 --- a/integration-tests/models/test_flash_llama.py +++ b/integration-tests/models/test_flash_llama.py @@ -43,7 +43,7 @@ async def test_flash_llama_all_params(flash_llama, response_snapshot): seed=0, ) - assert response.details.generated_tokens == 5 + assert response.details.generated_tokens == 10 assert response == response_snapshot diff --git a/integration-tests/models/test_flash_medusa.py b/integration-tests/models/test_flash_medusa.py index b48914b8..e9dcf6d9 100644 --- a/integration-tests/models/test_flash_medusa.py +++ b/integration-tests/models/test_flash_medusa.py @@ -54,6 +54,6 @@ async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot) assert len(responses) == 4 assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}" - assert responses[0].generated_text == 'ep learning is a subset of machine learning that involves' + assert responses[0].generated_text == '\nDeep learning is a subset of machine learning that involves' assert responses == response_snapshot diff --git a/integration-tests/models/test_flash_mistral.py b/integration-tests/models/test_flash_mistral.py index 63cb09b5..7d21afd9 100644 --- a/integration-tests/models/test_flash_mistral.py +++ b/integration-tests/models/test_flash_mistral.py @@ -21,6 +21,7 @@ async def test_flash_mistral(flash_mistral, response_snapshot): ) assert response.details.generated_tokens == 10 + assert response.generated_text == ": Let n = 10 - 1" assert response == response_snapshot @@ -55,6 +56,7 @@ async def test_flash_mistral_load(flash_mistral, generate_load, response_snapsho ) assert len(responses) == 4 - assert all([r.generated_text == responses[0].generated_text for r in responses]) + assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}" + assert responses[0].generated_text == ": Let n = 10 - 1" assert responses == response_snapshot diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 21d32e00..5000025c 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -101,6 +101,8 @@ def get_model( else: raise RuntimeError(f"Unknown dtype {dtype}") + SPECULATE = 2 + if "facebook/galactica" in model_id: return GalacticaSharded( model_id, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index faa446d2..952068ec 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -479,19 +479,19 @@ class FlashCausalLMBatch(Batch): max_blocks = 0 max_length = 0 max_seqlen = 0 - speculative_length = 0 if batches[0].speculative_ids is None else batches[0].speculative_ids.shape[1] for b in batches: total_batch_size += len(b) total_slots += len(b.slots) blocks += b.blocks + speculative_length = 0 if b.speculative_ids is None else b.speculative_ids.shape[1] max_blocks = max(max_blocks, b.max_blocks) max_seqlen = max(max_seqlen, b.max_seqlen) max_length = max( max_length, max( input_length - + speculative_length + stopping_criteria.max_new_tokens + + speculative_length - stopping_criteria.current_tokens for input_length, stopping_criteria in zip( b.input_lengths, b.stopping_criterias @@ -994,7 +994,8 @@ class FlashCausalLM(Model): # Evaluate stopping criteria - for next_token_id in _next_token_ids: + left = 0 + for j, next_token_id in enumerate(_next_token_ids): stop, reason = stopping_criteria( next_token_id, next_token_text, @@ -1002,21 +1003,26 @@ class FlashCausalLM(Model): if stop: stopped = True + left = len(_next_token_ids) - 1 - j break - if not stop: + else: stopped = False + _next_token_ids = _next_token_ids[:len(_next_token_ids) - left] # Shard generations # All generations will be appended in the rust sharded client if i % self.world_size == self.rank: if stop: # Decode generated tokens + # Remove potentially accepted ids that do not respect + # the stopping_criteria + _ids = all_input_ids[:len(all_input_ids)-left] output_text, _, _ = self.decode_token( - all_input_ids, - prefix_offset=len(all_input_ids) + _ids, + prefix_offset=len(_ids) - stopping_criteria.current_tokens - 1, - read_offset=len(all_input_ids) + read_offset=len(_ids) - stopping_criteria.current_tokens, skip_special_tokens=True, ) diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index aa0b1fe3..e6ada2c9 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -132,7 +132,9 @@ class FlashMistralBatch(FlashCausalLMBatch): # Paged attention # Remove one as the first token des not have a past - total_tokens = input_length + max_new_tokens - 1 + from text_generation_server.models import SPECULATE + speculative_length = SPECULATE + total_tokens = input_length + max_new_tokens - 1 + speculative_length # Needed blocks can not go over SLIDING_WINDOW_BLOCKS needed_blocks = min( @@ -183,7 +185,7 @@ class FlashMistralBatch(FlashCausalLMBatch): cumulative_max_length += total_tokens max_seqlen = max(max_seqlen, input_length) max_blocks = max(max_blocks, needed_blocks) - max_length = max(max_length, input_length + max_new_tokens) + max_length = max(max_length, input_length + max_new_tokens + speculative_length) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype, device @@ -341,17 +343,55 @@ class FlashMistral(FlashCausalLM): def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]: # Model Forward + if batch.speculative_ids is not None: + input_ids=batch.input_ids + position_ids=batch.position_ids + cu_seqlen_prefill=batch.cu_seqlen_prefill + kv_cache=get_cache_manager().kv_cache + block_tables=batch.block_tables_tensor + slots=batch.slots[batch.slot_indices] + input_lengths=batch.input_lengths_tensor + max_s=batch.max_seqlen + lm_head_indices=batch.prefill_head_indices + + speculative_ids = batch.speculative_ids + + B, speculative_length = speculative_ids.shape + new_length = speculative_length + 1 + new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1) + arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) + arange_int = arange.to(dtype=torch.int32) + new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1) + slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + + # Add Copy the block tables for all members + block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous() + max_s = max_s + speculative_length + + input_ids = new_input_ids + position_ids = new_position_ids + else: + input_ids=batch.input_ids + position_ids=batch.position_ids + cu_seqlen_prefill=batch.cu_seqlen_prefill + kv_cache=get_cache_manager().kv_cache + block_tables=batch.block_tables_tensor + slots=batch.slots[batch.slot_indices] + input_lengths=batch.input_lengths_tensor + max_s=batch.max_seqlen + lm_head_indices=batch.prefill_head_indices logits = self.model.forward( - input_ids=batch.input_ids, - position_ids=batch.position_ids, - cu_seqlen_prefill=batch.cu_seqlen_prefill, - kv_cache=get_cache_manager().kv_cache, - block_tables=batch.block_tables_tensor, - slots=batch.slots[batch.slot_indices], - input_lengths=batch.input_lengths_tensor, - max_s=batch.max_seqlen, + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, prefill_cache_indices=batch.prefill_cache_indices, - lm_head_indices=batch.prefill_head_indices, + lm_head_indices=lm_head_indices, ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index ffbbf40c..4c77b660 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -258,16 +258,34 @@ class HeterogeneousNextTokenChooser: self.device = device def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None): - if self.watermark_processor is not None: - scores = self.watermark_processor(input_ids, scores) - if self.repetition_processor is not None: - scores = self.repetition_processor(input_ids, scores) + if speculated_ids is not None: + B = scores.shape[0] // (speculated_ids.shape[1] + 1) + S = speculated_ids.shape[1] + 1 + scores = scores.view(B, S, -1) + else: + B = scores.shape[0] + S = 1 + scores = scores.view(B, S, -1) - for warper in self.warpers: - scores = warper(input_ids, scores) + all_next_ids = [] + all_scores = [] + for j in range(S): + _scores = scores[:, j] + if self.watermark_processor is not None: + _scores = self.watermark_processor(input_ids, _scores) + if self.repetition_processor is not None: + _scores = self.repetition_processor(input_ids, _scores) + + for warper in self.warpers: + _scores = warper(input_ids, _scores) - next_ids = self.choice(scores) + next_ids = self.choice(_scores) + scores[:, j] = _scores + all_next_ids.append(next_ids.unsqueeze(1)) + next_ids = torch.cat(all_next_ids, dim=1).reshape(B*S) + scores = scores.view( B* S, -1) + if speculated_ids is not None: accepted_ids = [] B = next_ids.shape[0] // (speculated_ids.shape[1] + 1) @@ -289,6 +307,9 @@ class HeterogeneousNextTokenChooser: else: break accepted_ids.append(accepted) + + from loguru import logger + logger.info(f"ACCEPTED IDS {accepted_ids}") accepted_ids = torch.tensor(accepted_ids, device=input_ids.device, dtype=input_ids.dtype) next_ids = next_ids[indices] scores = scores[indices] @@ -298,7 +319,6 @@ class HeterogeneousNextTokenChooser: else: accepted_ids = torch.ones_like(next_ids) - logprobs = torch.log_softmax(scores, -1) next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)