mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Fixing medusa off by ones.
This commit is contained in:
parent
abc8d48d96
commit
ba16994e8a
@ -1,22 +1,25 @@
|
||||
build-vllm-cuda: REPOSITORY=https://github.com/vllm-project/vllm.git
|
||||
build-vllm-cuda: VLLM_COMMIT=f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
|
||||
build-vllm-cuda: BRANCH=main
|
||||
build-vllm-cuda: build-vllm
|
||||
|
||||
build-vllm-rocm: REPOSITORY=https://github.com/fxmarty/vllm-public.git
|
||||
build-vllm-rocm: VLLM_COMMIT=ad9b7c4095ef54419a0533d254f2ad84bd2dfcae
|
||||
build-vllm-rocm: BRANCH=rotary-no-positions-split-cos-sin
|
||||
build-vllm-rocm: build-vllm
|
||||
|
||||
vllm:
|
||||
vllm-cuda:
|
||||
# Clone vllm
|
||||
pip install -U ninja packaging --no-cache-dir
|
||||
git clone --single-branch --branch $(BRANCH) $(REPOSITORY) vllm
|
||||
git clone https://github.com/vllm-project/vllm.git vllm
|
||||
|
||||
build-vllm: vllm
|
||||
cd vllm && git fetch && git checkout $(VLLM_COMMIT)
|
||||
build-vllm-cuda: vllm-cuda
|
||||
cd vllm && git fetch && git checkout f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
|
||||
cd vllm && python setup.py build
|
||||
|
||||
install-vllm: build-vllm
|
||||
install-vllm-cuda: build-vllm-cuda
|
||||
pip uninstall vllm -y || true
|
||||
cd vllm && python setup.py install
|
||||
|
||||
vllm-rocm:
|
||||
# Clone vllm
|
||||
pip install -U ninja packaging --no-cache-dir
|
||||
git clone https://github.com/fxmarty/vllm-public.git vllm
|
||||
|
||||
build-vllm-rocm: vllm-rocm
|
||||
cd vllm && git fetch && git checkout ad9b7c4095ef54419a0533d254f2ad84bd2dfcae
|
||||
cd vllm && python setup.py build
|
||||
|
||||
install-vllm-rocm: build-vllm-rocm
|
||||
pip uninstall vllm -y || true
|
||||
cd vllm && python setup.py install
|
||||
|
@ -11,6 +11,8 @@ from opentelemetry import trace
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from typing import Optional, Tuple, List, Type, Union, Dict
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.utils.speculate import get_speculate
|
||||
from text_generation_server.models.types import (
|
||||
@ -318,6 +320,7 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
|
||||
# logger.info(f"Filter {request_ids}")
|
||||
if len(request_ids) == 0:
|
||||
raise ValueError("Batch must have at least one request")
|
||||
# We assume that if len(requests) == len(self) then the requests are the same
|
||||
@ -422,6 +425,8 @@ class FlashCausalLMBatch(Batch):
|
||||
block_tables_tensor = self.block_tables_tensor[indices]
|
||||
input_lengths_tensor = self.input_lengths_tensor[indices]
|
||||
slots = self.slots[slot_filtering_indices]
|
||||
if slot_indices.max().item() > slots.shape[0]:
|
||||
import ipdb;ipdb.set_trace()
|
||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
||||
speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None
|
||||
@ -466,6 +471,7 @@ class FlashCausalLMBatch(Batch):
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
|
||||
# logger.info(f"Concatenate {[[r.id for r in batch.requests] for batch in batches]}")
|
||||
# Batch attributes
|
||||
requests = []
|
||||
requests_idx_mapping = {}
|
||||
@ -495,6 +501,7 @@ class FlashCausalLMBatch(Batch):
|
||||
)
|
||||
),
|
||||
)
|
||||
# logger.info(f"total slots {total_slots} {[b.slots.shape for b in batches]}")
|
||||
|
||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
||||
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
||||
@ -781,6 +788,8 @@ class FlashCausalLM(Model):
|
||||
def generate_token(
|
||||
self, batch: FlashCausalLMBatch
|
||||
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
||||
# logger.info(f"GENERATE {[r.id for r in batch.requests]}")
|
||||
# logger.info(f"GENERATE {batch.position_ids} {batch.max_seqlen} {batch.input_lengths} { batch.input_lengths_tensor}")
|
||||
prefill = batch.cu_seqlen_prefill is not None
|
||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||
|
||||
@ -797,6 +806,8 @@ class FlashCausalLM(Model):
|
||||
batch.block_tables_tensor = block_tables_tensor
|
||||
batch.slots = slots
|
||||
|
||||
# logger.info(f"GENERATE {batch.slots.shape} {batch.slot_indices}")
|
||||
|
||||
try:
|
||||
out = self.forward(batch)
|
||||
except Exception as e:
|
||||
@ -820,20 +831,9 @@ class FlashCausalLM(Model):
|
||||
else:
|
||||
next_token_logits = out
|
||||
|
||||
# import datetime
|
||||
# from loguru import logger
|
||||
|
||||
# start = datetime.datetime.now()
|
||||
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
|
||||
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits
|
||||
)
|
||||
# took = datetime.datetime.now() - start
|
||||
# logger.info(f"Next token chooser {batch.all_input_ids_tensor.shape} took {took}")
|
||||
# if batch.all_input_ids_tensor.shape[1] < 2000 and took > datetime.timedelta(milliseconds=5):
|
||||
# next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
|
||||
# batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits, verbose=True
|
||||
# )
|
||||
# import ipdb;ipdb.set_trace()
|
||||
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
|
||||
@ -1077,7 +1077,9 @@ class FlashCausalLM(Model):
|
||||
generations.append(generation)
|
||||
|
||||
# Update values
|
||||
batch.input_lengths[i] = input_length + 1
|
||||
batch.input_lengths[i] = input_length + n_accepted_ids.item()
|
||||
if batch.input_lengths[i] > batch.max_seqlen:
|
||||
batch.max_seqlen = batch.input_lengths[i]
|
||||
batch.prefix_offsets[i] = prefix_offset
|
||||
batch.read_offsets[i] = read_offset
|
||||
batch.all_input_ids[i] = all_input_ids
|
||||
@ -1090,6 +1092,5 @@ class FlashCausalLM(Model):
|
||||
batch.prefill_cu_outlens = None
|
||||
batch.prefill_head_indices = None
|
||||
batch.prefill_next_token_indices = None
|
||||
batch.max_seqlen = batch.max_seqlen + 1
|
||||
|
||||
return generations, batch
|
||||
|
@ -742,6 +742,9 @@ try:
|
||||
|
||||
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
|
||||
|
||||
if position_ids.max().item() >= max_s:
|
||||
import ipdb;ipdb.set_trace()
|
||||
|
||||
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
||||
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
||||
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
|
||||
|
@ -1,7 +1,7 @@
|
||||
|
||||
SPECULATE = None
|
||||
|
||||
def get_speculate():
|
||||
def get_speculate() -> int:
|
||||
global SPECULATE
|
||||
return SPECULATE
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user