diff --git a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_all_params.json b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_all_params.json new file mode 100644 index 00000000..e9b1c57a --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_all_params.json @@ -0,0 +1,58 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "stop_sequence", + "generated_tokens": 5, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -10.0625, + "text": "Test" + }, + { + "id": 2009, + "logprob": -12.28125, + "text": "request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 5229, + "logprob": -1.7587891, + "special": false, + "text": " failed" + }, + { + "id": 363, + "logprob": -0.5175781, + "special": false, + "text": " for" + }, + { + "id": 1404, + "logprob": 0.0, + "special": false, + "text": " user" + }, + { + "id": 376, + "logprob": 0.0, + "special": false, + "text": " \"" + }, + { + "id": 1688, + "logprob": -0.20422363, + "special": false, + "text": "test" + } + ] + }, + "generated_text": "Test request failed for user \"test" +} diff --git a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_load.json b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_load.json new file mode 100644 index 00000000..80d4873a --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_load.json @@ -0,0 +1,354 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -10.0625, + "text": "Test" + }, + { + "id": 2009, + "logprob": -12.28125, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 363, + "logprob": -2.0878906, + "special": false, + "text": " for" + }, + { + "id": 278, + "logprob": -3.4082031, + "special": false, + "text": " the" + }, + { + "id": 376, + "logprob": -3.8457031, + "special": false, + "text": " \"" + }, + { + "id": 2577, + "logprob": -3.5605469, + "special": false, + "text": "Get" + }, + { + "id": 599, + "logprob": -3.4707031, + "special": false, + "text": " all" + }, + { + "id": 4160, + "logprob": -3.2421875, + "special": false, + "text": " users" + }, + { + "id": 29908, + "logprob": -0.49072266, + "special": false, + "text": "\"" + }, + { + "id": 16248, + "logprob": -1.2353516, + "special": false, + "text": " endpoint" + }, + { + "id": 29889, + "logprob": -0.8833008, + "special": false, + "text": "." + }, + { + "id": 13, + "logprob": -0.42089844, + "special": false, + "text": "\n" + } + ] + }, + "generated_text": " for the \"Get all users\" endpoint.\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -10.0625, + "text": "Test" + }, + { + "id": 2009, + "logprob": -12.28125, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 363, + "logprob": -2.0878906, + "special": false, + "text": " for" + }, + { + "id": 278, + "logprob": -3.4082031, + "special": false, + "text": " the" + }, + { + "id": 376, + "logprob": -3.8457031, + "special": false, + "text": " \"" + }, + { + "id": 2577, + "logprob": -3.5625, + "special": false, + "text": "Get" + }, + { + "id": 599, + "logprob": -3.4726562, + "special": false, + "text": " all" + }, + { + "id": 4160, + "logprob": -3.2382812, + "special": false, + "text": " users" + }, + { + "id": 29908, + "logprob": -0.49047852, + "special": false, + "text": "\"" + }, + { + "id": 16248, + "logprob": -1.2412109, + "special": false, + "text": " endpoint" + }, + { + "id": 29889, + "logprob": -0.87402344, + "special": false, + "text": "." + }, + { + "id": 13, + "logprob": -0.41723633, + "special": false, + "text": "\n" + } + ] + }, + "generated_text": " for the \"Get all users\" endpoint.\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -10.0625, + "text": "Test" + }, + { + "id": 2009, + "logprob": -12.28125, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 363, + "logprob": -2.0878906, + "special": false, + "text": " for" + }, + { + "id": 278, + "logprob": -3.4082031, + "special": false, + "text": " the" + }, + { + "id": 376, + "logprob": -3.8457031, + "special": false, + "text": " \"" + }, + { + "id": 2577, + "logprob": -3.5605469, + "special": false, + "text": "Get" + }, + { + "id": 599, + "logprob": -3.4707031, + "special": false, + "text": " all" + }, + { + "id": 4160, + "logprob": -3.2421875, + "special": false, + "text": " users" + }, + { + "id": 29908, + "logprob": -0.49072266, + "special": false, + "text": "\"" + }, + { + "id": 16248, + "logprob": -1.2353516, + "special": false, + "text": " endpoint" + }, + { + "id": 29889, + "logprob": -0.8833008, + "special": false, + "text": "." + }, + { + "id": 13, + "logprob": -0.42089844, + "special": false, + "text": "\n" + } + ] + }, + "generated_text": " for the \"Get all users\" endpoint.\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -10.0625, + "text": "Test" + }, + { + "id": 2009, + "logprob": -12.28125, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 363, + "logprob": -2.0878906, + "special": false, + "text": " for" + }, + { + "id": 278, + "logprob": -3.4082031, + "special": false, + "text": " the" + }, + { + "id": 376, + "logprob": -3.8457031, + "special": false, + "text": " \"" + }, + { + "id": 2577, + "logprob": -3.5605469, + "special": false, + "text": "Get" + }, + { + "id": 599, + "logprob": -3.4707031, + "special": false, + "text": " all" + }, + { + "id": 4160, + "logprob": -3.2421875, + "special": false, + "text": " users" + }, + { + "id": 29908, + "logprob": -0.49072266, + "special": false, + "text": "\"" + }, + { + "id": 16248, + "logprob": -1.2353516, + "special": false, + "text": " endpoint" + }, + { + "id": 29889, + "logprob": -0.8833008, + "special": false, + "text": "." + }, + { + "id": 13, + "logprob": -0.42089844, + "special": false, + "text": "\n" + } + ] + }, + "generated_text": " for the \"Get all users\" endpoint.\n" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_simple.json b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_simple.json new file mode 100644 index 00000000..eb449de3 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_simple.json @@ -0,0 +1,88 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -10.0625, + "text": "Test" + }, + { + "id": 2009, + "logprob": -12.28125, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 363, + "logprob": -2.0878906, + "special": false, + "text": " for" + }, + { + "id": 278, + "logprob": -3.4121094, + "special": false, + "text": " the" + }, + { + "id": 376, + "logprob": -3.8457031, + "special": false, + "text": " \"" + }, + { + "id": 2577, + "logprob": -3.5566406, + "special": false, + "text": "Get" + }, + { + "id": 599, + "logprob": -3.4746094, + "special": false, + "text": " all" + }, + { + "id": 4160, + "logprob": -3.2363281, + "special": false, + "text": " users" + }, + { + "id": 29908, + "logprob": -0.49023438, + "special": false, + "text": "\"" + }, + { + "id": 16248, + "logprob": -1.2402344, + "special": false, + "text": " endpoint" + }, + { + "id": 29889, + "logprob": -0.88134766, + "special": false, + "text": "." + }, + { + "id": 13, + "logprob": -0.41870117, + "special": false, + "text": "\n" + } + ] + }, + "generated_text": " for the \"Get all users\" endpoint.\n" +} diff --git a/integration-tests/models/test_flash_medusa.py b/integration-tests/models/test_flash_medusa.py new file mode 100644 index 00000000..7cc797e4 --- /dev/null +++ b/integration-tests/models/test_flash_medusa.py @@ -0,0 +1,59 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_medusa_handle(launcher): + with launcher("FasterDecoding/medusa-vicuna-7b-v1.3", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_medusa(flash_medusa_handle): + await flash_medusa_handle.health(300) + return flash_medusa_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_medusa_simple(flash_medusa, response_snapshot): + response = await flash_medusa.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_medusa_all_params(flash_medusa, response_snapshot): + response = await flash_medusa.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 5 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot): + responses = await generate_load(flash_medusa, "Test request", max_new_tokens=10, n=4) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}" + assert responses[0].generated_text == ' for the "Get all users" endpoint.\n' + + assert responses == response_snapshot diff --git a/launcher/src/main.rs b/launcher/src/main.rs index b4fc86b7..b923fcea 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -155,6 +155,14 @@ struct Args { #[clap(long, env, value_enum)] quantize: Option, + /// The number of input_ids to speculate on + /// If using a medusa model, the heads will be picked up automatically + /// Other wise, it will use n-gram speculation which is relatively free + /// in terms of compute, but the speedup heavily depends on the task. + #[clap(long, env)] + speculate: Option, + + /// The dtype to be forced upon the model. This option cannot be used with `--quantize`. #[clap(long, env, value_enum)] dtype: Option, @@ -432,6 +440,11 @@ fn shard_manager( shard_args.push(quantize.to_string()) } + if let Some(speculate) = speculate { + shard_args.push("--speculate".to_string()); + shard_args.push(speculate.to_string()) + } + if let Some(dtype) = dtype { shard_args.push("--dtype".to_string()); shard_args.push(dtype.to_string()) diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 00d377f6..cb151173 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -32,6 +32,7 @@ def serve( revision: Optional[str] = None, sharded: bool = False, quantize: Optional[Quantization] = None, + speculate: Optional[int] = None, dtype: Optional[Dtype] = None, trust_remote_code: bool = False, uds_path: Path = "/tmp/text-generation-server", @@ -81,7 +82,7 @@ def serve( "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." ) server.serve( - model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path + model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code, uds_path ) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index bceec157..21d32e00 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -77,15 +77,19 @@ except ImportError as e: if MISTRAL: __all__.append(FlashMistral) +SPECULATE = None + def get_model( model_id: str, revision: Optional[str], sharded: bool, quantize: Optional[str], + speculate: Optional[int], dtype: Optional[str], trust_remote_code: bool, ) -> Model: + global SPECULATE if dtype is None: # Keep it as default for now and let # every model resolve their own default dtype. @@ -138,9 +142,18 @@ def get_model( medusa_config = config_dict model_id = config_dict["base_model_name_or_path"] revision = "main" + SPECULATE = config_dict["medusa_num_heads"] config_dict, _ = PretrainedConfig.get_config_dict( model_id, revision=revision, trust_remote_code=trust_remote_code ) + method = "medusa" + else: + if speculate is not None: + SPECULATE = speculate + else: + SPECULATE = 2 + method = "n-gram" + logger.info(f"Using speculation {method} with {SPECULATE} input ids.") model_type = config_dict["model_type"] diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 06dd3f5c..4aeb447d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -450,26 +450,7 @@ class FlashLlamaModel(torch.nn.Module): slots: torch.Tensor, input_lengths: torch.Tensor, max_s: int, - speculative_ids: Optional[torch.Tensor] ) -> torch.Tensor: - if speculative_ids is not None: - speculative_length = speculative_ids.shape[1] - new_length = speculative_length + 1 - new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).squeeze(0) - new_position_ids = (position_ids.view((1, -1)).expand(new_length, 1) + torch.arange(new_length).unsqueeze(1).to(device=position_ids.device)).squeeze(0).squeeze(-1) - - # Add an extra block just in case - block_tables = torch.cat([block_tables, block_tables[:, -1:] + 1], dim=1) - # Add Copy the block tables for all members - block_tables = block_tables.expand(new_length, -1).contiguous() - slots = slots.expand(new_length) + torch.arange(new_length, dtype=slots.dtype).to(device=slots.device) - input_lengths = input_lengths.expand(new_length) + torch.arange(new_length, dtype=input_lengths.dtype).to(device=input_lengths.device) - max_s = max_s + speculative_length - - - input_ids = new_input_ids - position_ids = new_position_ids - hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward @@ -520,7 +501,6 @@ class FlashLlamaForCausalLM(torch.nn.Module): input_lengths: torch.Tensor, max_s: int, lm_head_indices: Optional[torch.Tensor] = None, - speculative_ids: Optional[torch.Tensor] = None ) -> torch.Tensor: hidden_states = self.model( input_ids, @@ -531,7 +511,6 @@ class FlashLlamaForCausalLM(torch.nn.Module): slots, input_lengths, max_s, - speculative_ids, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 52679be8..82f38564 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -46,7 +46,6 @@ class FlashCausalLMBatch(Batch): # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill cu_seqlen_prefill: Optional[torch.Tensor] - cu_seqlen_speculative: Optional[torch.Tensor] # Paged Attention values @@ -123,7 +122,6 @@ class FlashCausalLMBatch(Batch): position_ids = [] speculative_ids = [] cu_seqlen_prefill = [0] - cu_seqlen_speculative = [0] needed_blocks_slots = [] start_slots = [] slot_indices = [] @@ -163,10 +161,6 @@ class FlashCausalLMBatch(Batch): tokenized_input = tokenized_input[-r.truncate :] - # # TODO remove this - # # Scaffolding to speculate some ids - # speculate_ids = [1, 2] - # tokenized_input.extend([1, 2]) speculate_ids = [] @@ -186,7 +180,6 @@ class FlashCausalLMBatch(Batch): # Add cumulative lengths of all previous inputs cu_seqlen_prefill.append(cumulative_length + input_length) - cu_seqlen_speculative.append(cumulative_length + input_length - len(speculate_ids)) next_token_chooser_parameters.append(r.parameters) @@ -199,7 +192,9 @@ class FlashCausalLMBatch(Batch): # Paged attention # Remove one as the first token des not have a past - total_tokens = input_length + max_new_tokens - 1 + from text_generation_server.models import SPECULATE + speculative_length = SPECULATE + total_tokens = input_length + max_new_tokens - 1 + speculative_length needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) blocks += needed_blocks needed_blocks_slots.append((needed_blocks, total_tokens)) @@ -268,10 +263,6 @@ class FlashCausalLMBatch(Batch): cu_seqlen_prefill = torch.tensor( cu_seqlen_prefill, device=device, dtype=torch.int32 ) - cu_seqlen_speculative = torch.tensor( - cu_seqlen_speculative, device=device, dtype=torch.int32 - ) - position_ids = position_ids.to(device) slot_indices = slot_indices.to(device) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) @@ -303,7 +294,6 @@ class FlashCausalLMBatch(Batch): input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, - cu_seqlen_speculative=cu_seqlen_speculative, start_slots=start_slots, slot_indices=slot_indices, needed_blocks_slots=needed_blocks_slots, @@ -437,6 +427,7 @@ class FlashCausalLMBatch(Batch): slots = self.slots[slot_filtering_indices] next_token_chooser = self.next_token_chooser.filter(indices) top_n_tokens_tensor = self.top_n_tokens_tensor[indices] + speculative_ids = self.speculative_ids[indices] start_slots = torch.tensor(start_slots, dtype=torch.int64) @@ -472,6 +463,7 @@ class FlashCausalLMBatch(Batch): top_n_tokens_tensor=top_n_tokens_tensor, blocks=blocks, max_blocks=max_blocks, + speculative_ids=speculative_ids, ) @classmethod @@ -595,6 +587,8 @@ class FlashCausalLMBatch(Batch): device=batches[0].next_token_chooser.device, ) + speculative_ids = None if batches[0].speculative_ids is None else torch.cat([b.speculative_ids for b in batches], dim=0) + # Needed to avoid dropping blocks when the batches will go out of scope for b in batches: b.block_tables = None @@ -629,6 +623,7 @@ class FlashCausalLMBatch(Batch): top_n_tokens_tensor=top_n_tokens_tensor, blocks=blocks, max_blocks=max_blocks, + speculative_ids=speculative_ids ) def __del__(self): @@ -732,17 +727,55 @@ class FlashCausalLM(Model): def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]: # Model Forward + if batch.speculative_ids is not None: + input_ids=batch.input_ids + position_ids=batch.position_ids + cu_seqlen_prefill=batch.cu_seqlen_prefill + kv_cache=get_cache_manager().kv_cache + block_tables=batch.block_tables_tensor + slots=batch.slots[batch.slot_indices] + input_lengths=batch.input_lengths_tensor + max_s=batch.max_seqlen + lm_head_indices=batch.prefill_head_indices + + speculative_ids = batch.speculative_ids + + B, speculative_length = speculative_ids.shape + new_length = speculative_length + 1 + new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1) + arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) + arange_int = arange.to(dtype=torch.int32) + new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1) + slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + + # Add Copy the block tables for all members + block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous() + max_s = max_s + speculative_length + + input_ids = new_input_ids + position_ids = new_position_ids + else: + input_ids=batch.input_ids + position_ids=batch.position_ids + cu_seqlen_prefill=batch.cu_seqlen_prefill + kv_cache=get_cache_manager().kv_cache + block_tables=batch.block_tables_tensor + slots=batch.slots[batch.slot_indices] + input_lengths=batch.input_lengths_tensor + max_s=batch.max_seqlen + lm_head_indices=batch.prefill_head_indices + return self.model.forward( - input_ids=batch.input_ids, - position_ids=batch.position_ids, - cu_seqlen_prefill=batch.cu_seqlen_prefill, - kv_cache=get_cache_manager().kv_cache, - block_tables=batch.block_tables_tensor, - slots=batch.slots[batch.slot_indices], - input_lengths=batch.input_lengths_tensor, - max_s=batch.max_seqlen, - lm_head_indices=batch.prefill_head_indices, - speculative_ids =batch.speculative_ids + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + input_lengths=input_lengths, + max_s=max_s, + lm_head_indices=lm_head_indices, ) @tracer.start_as_current_span("generate_token") @@ -792,8 +825,9 @@ class FlashCausalLM(Model): # if next_token_logits.shape[0] == 3: # import ipdb;ipdb.set_trace() + from text_generation_server.models import SPECULATE next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser( - batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, batch.speculative_ids, speculative_logits + batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, SPECULATE, batch.speculative_ids, speculative_logits ) batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( @@ -807,14 +841,8 @@ class FlashCausalLM(Model): # When batch == 1, we will just use the batch.input_ids values directly prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) - if speculative_ids is not None: - # length = len(batch) * (1 + speculative_length) - length = len(batch) - else: - length = len(batch) - # import ipdb;ipdb.set_trace() + length = len(batch) next_position_ids = batch.position_ids.new_empty(length) - # Keep only 1 slot index, TODO make sure we recover the speculated ids slots later batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] # We do not need cu_seqlen_prefill anymore batch.cu_seqlen_prefill = None @@ -885,19 +913,17 @@ class FlashCausalLM(Model): # if accepted_ids[0] > 1: # import ipdb;ipdb.set_trace() - if len(accepted_ids) > 1: - raise Exception("Implemtent the batched behavior") + # if len(accepted_ids) > 1: + # raise Exception("Implemtent the batched behavior") # Set values in batch # batch.input_ids = torch.cat([next_input_ids.unsqueeze(-1), speculative_ids], dim=1).view(-1) - for n_accepted_ids in accepted_ids: - # TODO Make this batched - batch.input_ids = next_input_ids[-1:] - batch.speculative_ids = speculative_ids - batch.position_ids = next_position_ids + n_accepted_ids - batch.input_lengths_tensor += n_accepted_ids - batch.slot_indices += n_accepted_ids + batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] + batch.speculative_ids = speculative_ids + batch.position_ids = next_position_ids + accepted_ids + batch.input_lengths_tensor += accepted_ids + batch.slot_indices += accepted_ids if prefill and prefill_logprobs: # Get prefill logprobs @@ -962,6 +988,7 @@ class FlashCausalLM(Model): read_offset, ) next_token_texts.append(next_token_text) + index += n_accepted_ids # Evaluate stopping criteria diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 919e4625..aa0b1fe3 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -272,6 +272,7 @@ class FlashMistralBatch(FlashCausalLMBatch): blocks=blocks, max_blocks=max_blocks, prefill_cache_indices=prefill_cache_indices, + speculative_ids=None ) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index fa831682..ebe066e3 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -132,6 +132,7 @@ def serve( revision: Optional[str], sharded: bool, quantize: Optional[str], + speculate: Optional[int], dtype: Optional[str], trust_remote_code: bool, uds_path: Path, @@ -141,6 +142,7 @@ def serve( revision: Optional[str], sharded: bool = False, quantize: Optional[str] = None, + speculate: Optional[int] = None, dtype: Optional[str] = None, trust_remote_code: bool = False, ): @@ -157,7 +159,7 @@ def serve( try: model = get_model( - model_id, revision, sharded, quantize, dtype, trust_remote_code + model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code ) except Exception: logger.exception("Error when initializing model") @@ -205,5 +207,5 @@ def serve( await server.stop(0) asyncio.run( - serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code) + serve_inner(model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code) ) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index c5e07cca..a9f0374a 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -147,6 +147,49 @@ class StoppingCriteria: ) +def longest_match(input_ids: List[int]) -> Optional[int]: + longest_match = 0 + seed = input_ids[-1] + final_matches = [] + current_matches = [] + for i in range(1, len(input_ids)): + index = len(input_ids) - i - 1 + + _current_matches = [] + for (_index, length) in current_matches: + if input_ids[index] == input_ids[len(input_ids) - length - 1]: + _current_matches.append((_index, length + 1)) + elif length > longest_match: + longest_match = length + final_matches.append((_index, length)) + else: + pass + current_matches = _current_matches + + if input_ids[index] == seed: + current_matches.append( (index, 1) ) + if not final_matches: + return 0 + return final_matches[-1][0] + + + +def create_n_gram_speculation(input_ids: torch.Tensor, next_ids: torch.Tensor, accepted_ids: torch.Tensor, speculate: int): + B = accepted_ids.shape[0] + device = input_ids.device + dtype = input_ids.dtype + speculative_ids = torch.zeros((B, speculate), device=device, dtype=dtype) + input_ids = input_ids.tolist() + + index = 0 + for i, (_input_ids, n_accepted_ids) in enumerate(zip(input_ids, accepted_ids.tolist())): + _input_ids.extend(next_ids[index: index + n_accepted_ids].tolist()) + index = longest_match(_input_ids) + 1 + ids = _input_ids[index:index+speculate] + speculative_ids[i, :len(ids)] = torch.tensor(ids, device=device, dtype=dtype) + index += n_accepted_ids + return speculative_ids + class HeterogeneousNextTokenChooser: def __init__( self, @@ -215,7 +258,7 @@ class HeterogeneousNextTokenChooser: self.dtype = dtype self.device = device - def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None): + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None): if self.watermark_processor is not None: scores = self.watermark_processor(input_ids, scores) if self.repetition_processor is not None: @@ -225,40 +268,51 @@ class HeterogeneousNextTokenChooser: scores = warper(input_ids, scores) - accepted_ids = [] next_ids = self.choice(scores) if speculated_ids is not None: - validate_speculative = next_ids[:-1] == speculated_ids[0] - index = 1 - for valid in validate_speculative.tolist(): - if valid: - index += 1 - # print(f"Validated {index - 1}") - next_ids = next_ids[:index] - scores = scores[:index] - speculative_scores = speculative_scores[index - 1:index] - accepted_ids.append(index) + accepted_ids = [] + B = next_ids.shape[0] // (speculated_ids.shape[1] + 1) + S = speculated_ids.shape[1] + 1 + indices = [] + for i in range(B): + _next_ids = next_ids[i*S: (i + 1)*S] + _speculated_ids = speculated_ids[i] + validate_speculative = _next_ids[:-1] == _speculated_ids + index = i * S + accepted = 1 + # First is always valid + indices.append(index) + for valid in validate_speculative.tolist(): + if valid: + index += 1 + accepted += 1 + indices.append(index) + else: + break + # if accepted > 1: + # import ipdb;ipdb.set_trace() + accepted_ids.append(accepted) + accepted_ids = torch.tensor(accepted_ids, device=input_ids.device, dtype=input_ids.dtype) + next_ids = next_ids[indices] + scores = scores[indices] + indices = torch.arange(B, device=input_ids.device) * S + if speculative_scores is not None: + speculative_scores = speculative_scores[indices + accepted_ids - 1] else: - accepted_ids.append(1) + accepted_ids = torch.ones_like(next_ids) logprobs = torch.log_softmax(scores, -1) next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1) - if speculative_scores is not None: - # length, spec_length, vocab_size = speculative_scores.shape - # speculative_scores = speculative_scores.view((-1, vocab_size)) - # if self.watermark_processor is not None: - # speculative_scores = self.watermark_processor(input_ids, speculative_scores) - # if self.repetition_processor is not None: - # speculative_scores = self.repetition_processor(input_ids, speculative_scores) - - # speculative_scores = speculative_scores.view((length, spec_length, vocab_size)) - # for warper in self.warpers: - # speculative_scores = warper(input_ids, speculative_scores) - speculative_ids = Greedy()(speculative_scores) - # # Ignore first head, it seems to be a regular head. - # speculative_ids = speculative_ids[:, 1:] + if speculate > 0: + if speculative_scores is not None: + # TODO This will only speculate the top score + # Medusa provided some scores + speculative_ids = Greedy()(speculative_scores) + else: + # n-gram + speculative_ids = create_n_gram_speculation(input_ids, next_ids, accepted_ids, speculate) else: speculative_ids = None