Fixing medusa off by ones.

This commit is contained in:
Nicolas Patry 2023-12-08 16:28:04 +00:00
parent abc8d48d96
commit ba16994e8a
4 changed files with 36 additions and 29 deletions

View File

@ -1,22 +1,25 @@
build-vllm-cuda: REPOSITORY=https://github.com/vllm-project/vllm.git
build-vllm-cuda: VLLM_COMMIT=f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
build-vllm-cuda: BRANCH=main
build-vllm-cuda: build-vllm
build-vllm-rocm: REPOSITORY=https://github.com/fxmarty/vllm-public.git
build-vllm-rocm: VLLM_COMMIT=ad9b7c4095ef54419a0533d254f2ad84bd2dfcae
build-vllm-rocm: BRANCH=rotary-no-positions-split-cos-sin
build-vllm-rocm: build-vllm
vllm:
vllm-cuda:
# Clone vllm
pip install -U ninja packaging --no-cache-dir
git clone --single-branch --branch $(BRANCH) $(REPOSITORY) vllm
git clone https://github.com/vllm-project/vllm.git vllm
build-vllm: vllm
cd vllm && git fetch && git checkout $(VLLM_COMMIT)
build-vllm-cuda: vllm-cuda
cd vllm && git fetch && git checkout f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
cd vllm && python setup.py build
install-vllm: build-vllm
install-vllm-cuda: build-vllm-cuda
pip uninstall vllm -y || true
cd vllm && python setup.py install
vllm-rocm:
# Clone vllm
pip install -U ninja packaging --no-cache-dir
git clone https://github.com/fxmarty/vllm-public.git vllm
build-vllm-rocm: vllm-rocm
cd vllm && git fetch && git checkout ad9b7c4095ef54419a0533d254f2ad84bd2dfcae
cd vllm && python setup.py build
install-vllm-rocm: build-vllm-rocm
pip uninstall vllm -y || true
cd vllm && python setup.py install

View File

@ -11,6 +11,8 @@ from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Dict
from loguru import logger
from text_generation_server.models import Model
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.models.types import (
@ -318,6 +320,7 @@ class FlashCausalLMBatch(Batch):
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
# logger.info(f"Filter {request_ids}")
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
# We assume that if len(requests) == len(self) then the requests are the same
@ -422,6 +425,8 @@ class FlashCausalLMBatch(Batch):
block_tables_tensor = self.block_tables_tensor[indices]
input_lengths_tensor = self.input_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices]
if slot_indices.max().item() > slots.shape[0]:
import ipdb;ipdb.set_trace()
next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None
@ -466,6 +471,7 @@ class FlashCausalLMBatch(Batch):
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
# logger.info(f"Concatenate {[[r.id for r in batch.requests] for batch in batches]}")
# Batch attributes
requests = []
requests_idx_mapping = {}
@ -495,6 +501,7 @@ class FlashCausalLMBatch(Batch):
)
),
)
# logger.info(f"total slots {total_slots} {[b.slots.shape for b in batches]}")
input_ids = batches[0].input_ids.new_empty(total_batch_size)
position_ids = batches[0].position_ids.new_empty(total_batch_size)
@ -781,6 +788,8 @@ class FlashCausalLM(Model):
def generate_token(
self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
# logger.info(f"GENERATE {[r.id for r in batch.requests]}")
# logger.info(f"GENERATE {batch.position_ids} {batch.max_seqlen} {batch.input_lengths} { batch.input_lengths_tensor}")
prefill = batch.cu_seqlen_prefill is not None
prefill_logprobs = batch.prefill_next_token_indices is not None
@ -797,6 +806,8 @@ class FlashCausalLM(Model):
batch.block_tables_tensor = block_tables_tensor
batch.slots = slots
# logger.info(f"GENERATE {batch.slots.shape} {batch.slot_indices}")
try:
out = self.forward(batch)
except Exception as e:
@ -820,20 +831,9 @@ class FlashCausalLM(Model):
else:
next_token_logits = out
# import datetime
# from loguru import logger
# start = datetime.datetime.now()
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits
)
# took = datetime.datetime.now() - start
# logger.info(f"Next token chooser {batch.all_input_ids_tensor.shape} took {took}")
# if batch.all_input_ids_tensor.shape[1] < 2000 and took > datetime.timedelta(milliseconds=5):
# next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
# batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits, verbose=True
# )
# import ipdb;ipdb.set_trace()
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
@ -1077,7 +1077,9 @@ class FlashCausalLM(Model):
generations.append(generation)
# Update values
batch.input_lengths[i] = input_length + 1
batch.input_lengths[i] = input_length + n_accepted_ids.item()
if batch.input_lengths[i] > batch.max_seqlen:
batch.max_seqlen = batch.input_lengths[i]
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids
@ -1090,6 +1092,5 @@ class FlashCausalLM(Model):
batch.prefill_cu_outlens = None
batch.prefill_head_indices = None
batch.prefill_next_token_indices = None
batch.max_seqlen = batch.max_seqlen + 1
return generations, batch

View File

@ -742,6 +742,9 @@ try:
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
if position_ids.max().item() >= max_s:
import ipdb;ipdb.set_trace()
cos = torch.index_select(self._cos_cached, 0, position_ids)
sin = torch.index_select(self._sin_cached, 0, position_ids)
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.

View File

@ -1,7 +1,7 @@
SPECULATE = None
def get_speculate():
def get_speculate() -> int:
global SPECULATE
return SPECULATE