Fixing medusa off by ones.

This commit is contained in:
Nicolas Patry 2023-12-08 16:28:04 +00:00
parent abc8d48d96
commit ba16994e8a
4 changed files with 36 additions and 29 deletions

View File

@ -1,22 +1,25 @@
build-vllm-cuda: REPOSITORY=https://github.com/vllm-project/vllm.git vllm-cuda:
build-vllm-cuda: VLLM_COMMIT=f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
build-vllm-cuda: BRANCH=main
build-vllm-cuda: build-vllm
build-vllm-rocm: REPOSITORY=https://github.com/fxmarty/vllm-public.git
build-vllm-rocm: VLLM_COMMIT=ad9b7c4095ef54419a0533d254f2ad84bd2dfcae
build-vllm-rocm: BRANCH=rotary-no-positions-split-cos-sin
build-vllm-rocm: build-vllm
vllm:
# Clone vllm # Clone vllm
pip install -U ninja packaging --no-cache-dir pip install -U ninja packaging --no-cache-dir
git clone --single-branch --branch $(BRANCH) $(REPOSITORY) vllm git clone https://github.com/vllm-project/vllm.git vllm
build-vllm: vllm build-vllm-cuda: vllm-cuda
cd vllm && git fetch && git checkout $(VLLM_COMMIT) cd vllm && git fetch && git checkout f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
cd vllm && python setup.py build cd vllm && python setup.py build
install-vllm: build-vllm install-vllm-cuda: build-vllm-cuda
pip uninstall vllm -y || true
cd vllm && python setup.py install
vllm-rocm:
# Clone vllm
pip install -U ninja packaging --no-cache-dir
git clone https://github.com/fxmarty/vllm-public.git vllm
build-vllm-rocm: vllm-rocm
cd vllm && git fetch && git checkout ad9b7c4095ef54419a0533d254f2ad84bd2dfcae
cd vllm && python setup.py build
install-vllm-rocm: build-vllm-rocm
pip uninstall vllm -y || true pip uninstall vllm -y || true
cd vllm && python setup.py install cd vllm && python setup.py install

View File

@ -11,6 +11,8 @@ from opentelemetry import trace
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Dict from typing import Optional, Tuple, List, Type, Union, Dict
from loguru import logger
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils.speculate import get_speculate
from text_generation_server.models.types import ( from text_generation_server.models.types import (
@ -318,6 +320,7 @@ class FlashCausalLMBatch(Batch):
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
# logger.info(f"Filter {request_ids}")
if len(request_ids) == 0: if len(request_ids) == 0:
raise ValueError("Batch must have at least one request") raise ValueError("Batch must have at least one request")
# We assume that if len(requests) == len(self) then the requests are the same # We assume that if len(requests) == len(self) then the requests are the same
@ -422,6 +425,8 @@ class FlashCausalLMBatch(Batch):
block_tables_tensor = self.block_tables_tensor[indices] block_tables_tensor = self.block_tables_tensor[indices]
input_lengths_tensor = self.input_lengths_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices] slots = self.slots[slot_filtering_indices]
if slot_indices.max().item() > slots.shape[0]:
import ipdb;ipdb.set_trace()
next_token_chooser = self.next_token_chooser.filter(indices) next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices] top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None
@ -466,6 +471,7 @@ class FlashCausalLMBatch(Batch):
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch": def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
# logger.info(f"Concatenate {[[r.id for r in batch.requests] for batch in batches]}")
# Batch attributes # Batch attributes
requests = [] requests = []
requests_idx_mapping = {} requests_idx_mapping = {}
@ -495,6 +501,7 @@ class FlashCausalLMBatch(Batch):
) )
), ),
) )
# logger.info(f"total slots {total_slots} {[b.slots.shape for b in batches]}")
input_ids = batches[0].input_ids.new_empty(total_batch_size) input_ids = batches[0].input_ids.new_empty(total_batch_size)
position_ids = batches[0].position_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size)
@ -781,6 +788,8 @@ class FlashCausalLM(Model):
def generate_token( def generate_token(
self, batch: FlashCausalLMBatch self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
# logger.info(f"GENERATE {[r.id for r in batch.requests]}")
# logger.info(f"GENERATE {batch.position_ids} {batch.max_seqlen} {batch.input_lengths} { batch.input_lengths_tensor}")
prefill = batch.cu_seqlen_prefill is not None prefill = batch.cu_seqlen_prefill is not None
prefill_logprobs = batch.prefill_next_token_indices is not None prefill_logprobs = batch.prefill_next_token_indices is not None
@ -797,6 +806,8 @@ class FlashCausalLM(Model):
batch.block_tables_tensor = block_tables_tensor batch.block_tables_tensor = block_tables_tensor
batch.slots = slots batch.slots = slots
# logger.info(f"GENERATE {batch.slots.shape} {batch.slot_indices}")
try: try:
out = self.forward(batch) out = self.forward(batch)
except Exception as e: except Exception as e:
@ -820,20 +831,9 @@ class FlashCausalLM(Model):
else: else:
next_token_logits = out next_token_logits = out
# import datetime
# from loguru import logger
# start = datetime.datetime.now()
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser( next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits
) )
# took = datetime.datetime.now() - start
# logger.info(f"Next token chooser {batch.all_input_ids_tensor.shape} took {took}")
# if batch.all_input_ids_tensor.shape[1] < 2000 and took > datetime.timedelta(milliseconds=5):
# next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
# batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits, verbose=True
# )
# import ipdb;ipdb.set_trace()
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
@ -1077,7 +1077,9 @@ class FlashCausalLM(Model):
generations.append(generation) generations.append(generation)
# Update values # Update values
batch.input_lengths[i] = input_length + 1 batch.input_lengths[i] = input_length + n_accepted_ids.item()
if batch.input_lengths[i] > batch.max_seqlen:
batch.max_seqlen = batch.input_lengths[i]
batch.prefix_offsets[i] = prefix_offset batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids batch.all_input_ids[i] = all_input_ids
@ -1090,6 +1092,5 @@ class FlashCausalLM(Model):
batch.prefill_cu_outlens = None batch.prefill_cu_outlens = None
batch.prefill_head_indices = None batch.prefill_head_indices = None
batch.prefill_next_token_indices = None batch.prefill_next_token_indices = None
batch.max_seqlen = batch.max_seqlen + 1
return generations, batch return generations, batch

View File

@ -742,6 +742,9 @@ try:
self._update_cos_sin_cache(dtype, position_ids.device, max_s) self._update_cos_sin_cache(dtype, position_ids.device, max_s)
if position_ids.max().item() >= max_s:
import ipdb;ipdb.set_trace()
cos = torch.index_select(self._cos_cached, 0, position_ids) cos = torch.index_select(self._cos_cached, 0, position_ids)
sin = torch.index_select(self._sin_cached, 0, position_ids) sin = torch.index_select(self._sin_cached, 0, position_ids)
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow. # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.

View File

@ -1,7 +1,7 @@
SPECULATE = None SPECULATE = None
def get_speculate(): def get_speculate() -> int:
global SPECULATE global SPECULATE
return SPECULATE return SPECULATE