From ba16994e8a9678dd0c02fc58a116a710df5f8671 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 8 Dec 2023 16:28:04 +0000 Subject: [PATCH] Fixing medusa off by ones. --- server/Makefile-vllm | 33 ++++++++++--------- .../models/flash_causal_lm.py | 27 +++++++-------- server/text_generation_server/utils/layers.py | 3 ++ .../text_generation_server/utils/speculate.py | 2 +- 4 files changed, 36 insertions(+), 29 deletions(-) diff --git a/server/Makefile-vllm b/server/Makefile-vllm index ddb648ea..c9c1d520 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,22 +1,25 @@ -build-vllm-cuda: REPOSITORY=https://github.com/vllm-project/vllm.git -build-vllm-cuda: VLLM_COMMIT=f8a1e39fae05ca610be8d5a78be9d40f5274e5fc -build-vllm-cuda: BRANCH=main -build-vllm-cuda: build-vllm - -build-vllm-rocm: REPOSITORY=https://github.com/fxmarty/vllm-public.git -build-vllm-rocm: VLLM_COMMIT=ad9b7c4095ef54419a0533d254f2ad84bd2dfcae -build-vllm-rocm: BRANCH=rotary-no-positions-split-cos-sin -build-vllm-rocm: build-vllm - -vllm: +vllm-cuda: # Clone vllm pip install -U ninja packaging --no-cache-dir - git clone --single-branch --branch $(BRANCH) $(REPOSITORY) vllm + git clone https://github.com/vllm-project/vllm.git vllm -build-vllm: vllm - cd vllm && git fetch && git checkout $(VLLM_COMMIT) +build-vllm-cuda: vllm-cuda + cd vllm && git fetch && git checkout f8a1e39fae05ca610be8d5a78be9d40f5274e5fc cd vllm && python setup.py build -install-vllm: build-vllm +install-vllm-cuda: build-vllm-cuda + pip uninstall vllm -y || true + cd vllm && python setup.py install + +vllm-rocm: + # Clone vllm + pip install -U ninja packaging --no-cache-dir + git clone https://github.com/fxmarty/vllm-public.git vllm + +build-vllm-rocm: vllm-rocm + cd vllm && git fetch && git checkout ad9b7c4095ef54419a0533d254f2ad84bd2dfcae + cd vllm && python setup.py build + +install-vllm-rocm: build-vllm-rocm pip uninstall vllm -y || true cd vllm && python setup.py install diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 260f5e68..b5ac1a16 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -11,6 +11,8 @@ from opentelemetry import trace from transformers import PreTrainedTokenizerBase from typing import Optional, Tuple, List, Type, Union, Dict +from loguru import logger + from text_generation_server.models import Model from text_generation_server.utils.speculate import get_speculate from text_generation_server.models.types import ( @@ -318,6 +320,7 @@ class FlashCausalLMBatch(Batch): @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": + # logger.info(f"Filter {request_ids}") if len(request_ids) == 0: raise ValueError("Batch must have at least one request") # We assume that if len(requests) == len(self) then the requests are the same @@ -422,6 +425,8 @@ class FlashCausalLMBatch(Batch): block_tables_tensor = self.block_tables_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices] slots = self.slots[slot_filtering_indices] + if slot_indices.max().item() > slots.shape[0]: + import ipdb;ipdb.set_trace() next_token_chooser = self.next_token_chooser.filter(indices) top_n_tokens_tensor = self.top_n_tokens_tensor[indices] speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None @@ -466,6 +471,7 @@ class FlashCausalLMBatch(Batch): @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch": + # logger.info(f"Concatenate {[[r.id for r in batch.requests] for batch in batches]}") # Batch attributes requests = [] requests_idx_mapping = {} @@ -495,6 +501,7 @@ class FlashCausalLMBatch(Batch): ) ), ) + # logger.info(f"total slots {total_slots} {[b.slots.shape for b in batches]}") input_ids = batches[0].input_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size) @@ -781,6 +788,8 @@ class FlashCausalLM(Model): def generate_token( self, batch: FlashCausalLMBatch ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: + # logger.info(f"GENERATE {[r.id for r in batch.requests]}") + # logger.info(f"GENERATE {batch.position_ids} {batch.max_seqlen} {batch.input_lengths} { batch.input_lengths_tensor}") prefill = batch.cu_seqlen_prefill is not None prefill_logprobs = batch.prefill_next_token_indices is not None @@ -797,6 +806,8 @@ class FlashCausalLM(Model): batch.block_tables_tensor = block_tables_tensor batch.slots = slots + # logger.info(f"GENERATE {batch.slots.shape} {batch.slot_indices}") + try: out = self.forward(batch) except Exception as e: @@ -820,20 +831,9 @@ class FlashCausalLM(Model): else: next_token_logits = out - # import datetime - # from loguru import logger - - # start = datetime.datetime.now() next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser( batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits ) - # took = datetime.datetime.now() - start - # logger.info(f"Next token chooser {batch.all_input_ids_tensor.shape} took {took}") - # if batch.all_input_ids_tensor.shape[1] < 2000 and took > datetime.timedelta(milliseconds=5): - # next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser( - # batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits, verbose=True - # ) - # import ipdb;ipdb.set_trace() batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs @@ -1077,7 +1077,9 @@ class FlashCausalLM(Model): generations.append(generation) # Update values - batch.input_lengths[i] = input_length + 1 + batch.input_lengths[i] = input_length + n_accepted_ids.item() + if batch.input_lengths[i] > batch.max_seqlen: + batch.max_seqlen = batch.input_lengths[i] batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset batch.all_input_ids[i] = all_input_ids @@ -1090,6 +1092,5 @@ class FlashCausalLM(Model): batch.prefill_cu_outlens = None batch.prefill_head_indices = None batch.prefill_next_token_indices = None - batch.max_seqlen = batch.max_seqlen + 1 return generations, batch diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index a93ccd0e..71b7f48d 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -742,6 +742,9 @@ try: self._update_cos_sin_cache(dtype, position_ids.device, max_s) + if position_ids.max().item() >= max_s: + import ipdb;ipdb.set_trace() + cos = torch.index_select(self._cos_cached, 0, position_ids) sin = torch.index_select(self._sin_cached, 0, position_ids) # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow. diff --git a/server/text_generation_server/utils/speculate.py b/server/text_generation_server/utils/speculate.py index 229f5b8f..38a91972 100644 --- a/server/text_generation_server/utils/speculate.py +++ b/server/text_generation_server/utils/speculate.py @@ -1,7 +1,7 @@ SPECULATE = None -def get_speculate(): +def get_speculate() -> int: global SPECULATE return SPECULATE