mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Fixing medusa off by ones.
This commit is contained in:
parent
abc8d48d96
commit
ba16994e8a
@ -1,22 +1,25 @@
|
|||||||
build-vllm-cuda: REPOSITORY=https://github.com/vllm-project/vllm.git
|
vllm-cuda:
|
||||||
build-vllm-cuda: VLLM_COMMIT=f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
|
|
||||||
build-vllm-cuda: BRANCH=main
|
|
||||||
build-vllm-cuda: build-vllm
|
|
||||||
|
|
||||||
build-vllm-rocm: REPOSITORY=https://github.com/fxmarty/vllm-public.git
|
|
||||||
build-vllm-rocm: VLLM_COMMIT=ad9b7c4095ef54419a0533d254f2ad84bd2dfcae
|
|
||||||
build-vllm-rocm: BRANCH=rotary-no-positions-split-cos-sin
|
|
||||||
build-vllm-rocm: build-vllm
|
|
||||||
|
|
||||||
vllm:
|
|
||||||
# Clone vllm
|
# Clone vllm
|
||||||
pip install -U ninja packaging --no-cache-dir
|
pip install -U ninja packaging --no-cache-dir
|
||||||
git clone --single-branch --branch $(BRANCH) $(REPOSITORY) vllm
|
git clone https://github.com/vllm-project/vllm.git vllm
|
||||||
|
|
||||||
build-vllm: vllm
|
build-vllm-cuda: vllm-cuda
|
||||||
cd vllm && git fetch && git checkout $(VLLM_COMMIT)
|
cd vllm && git fetch && git checkout f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
|
||||||
cd vllm && python setup.py build
|
cd vllm && python setup.py build
|
||||||
|
|
||||||
install-vllm: build-vllm
|
install-vllm-cuda: build-vllm-cuda
|
||||||
|
pip uninstall vllm -y || true
|
||||||
|
cd vllm && python setup.py install
|
||||||
|
|
||||||
|
vllm-rocm:
|
||||||
|
# Clone vllm
|
||||||
|
pip install -U ninja packaging --no-cache-dir
|
||||||
|
git clone https://github.com/fxmarty/vllm-public.git vllm
|
||||||
|
|
||||||
|
build-vllm-rocm: vllm-rocm
|
||||||
|
cd vllm && git fetch && git checkout ad9b7c4095ef54419a0533d254f2ad84bd2dfcae
|
||||||
|
cd vllm && python setup.py build
|
||||||
|
|
||||||
|
install-vllm-rocm: build-vllm-rocm
|
||||||
pip uninstall vllm -y || true
|
pip uninstall vllm -y || true
|
||||||
cd vllm && python setup.py install
|
cd vllm && python setup.py install
|
||||||
|
@ -11,6 +11,8 @@ from opentelemetry import trace
|
|||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type, Union, Dict
|
from typing import Optional, Tuple, List, Type, Union, Dict
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
from text_generation_server.utils.speculate import get_speculate
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
@ -318,6 +320,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
|
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
|
||||||
|
# logger.info(f"Filter {request_ids}")
|
||||||
if len(request_ids) == 0:
|
if len(request_ids) == 0:
|
||||||
raise ValueError("Batch must have at least one request")
|
raise ValueError("Batch must have at least one request")
|
||||||
# We assume that if len(requests) == len(self) then the requests are the same
|
# We assume that if len(requests) == len(self) then the requests are the same
|
||||||
@ -422,6 +425,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
block_tables_tensor = self.block_tables_tensor[indices]
|
block_tables_tensor = self.block_tables_tensor[indices]
|
||||||
input_lengths_tensor = self.input_lengths_tensor[indices]
|
input_lengths_tensor = self.input_lengths_tensor[indices]
|
||||||
slots = self.slots[slot_filtering_indices]
|
slots = self.slots[slot_filtering_indices]
|
||||||
|
if slot_indices.max().item() > slots.shape[0]:
|
||||||
|
import ipdb;ipdb.set_trace()
|
||||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||||
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
||||||
speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None
|
speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None
|
||||||
@ -466,6 +471,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
|
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
|
||||||
|
# logger.info(f"Concatenate {[[r.id for r in batch.requests] for batch in batches]}")
|
||||||
# Batch attributes
|
# Batch attributes
|
||||||
requests = []
|
requests = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
@ -495,6 +501,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
# logger.info(f"total slots {total_slots} {[b.slots.shape for b in batches]}")
|
||||||
|
|
||||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
||||||
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
||||||
@ -781,6 +788,8 @@ class FlashCausalLM(Model):
|
|||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: FlashCausalLMBatch
|
self, batch: FlashCausalLMBatch
|
||||||
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
||||||
|
# logger.info(f"GENERATE {[r.id for r in batch.requests]}")
|
||||||
|
# logger.info(f"GENERATE {batch.position_ids} {batch.max_seqlen} {batch.input_lengths} { batch.input_lengths_tensor}")
|
||||||
prefill = batch.cu_seqlen_prefill is not None
|
prefill = batch.cu_seqlen_prefill is not None
|
||||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||||
|
|
||||||
@ -797,6 +806,8 @@ class FlashCausalLM(Model):
|
|||||||
batch.block_tables_tensor = block_tables_tensor
|
batch.block_tables_tensor = block_tables_tensor
|
||||||
batch.slots = slots
|
batch.slots = slots
|
||||||
|
|
||||||
|
# logger.info(f"GENERATE {batch.slots.shape} {batch.slot_indices}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
out = self.forward(batch)
|
out = self.forward(batch)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -820,20 +831,9 @@ class FlashCausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
next_token_logits = out
|
next_token_logits = out
|
||||||
|
|
||||||
# import datetime
|
|
||||||
# from loguru import logger
|
|
||||||
|
|
||||||
# start = datetime.datetime.now()
|
|
||||||
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
|
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
|
||||||
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits
|
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits
|
||||||
)
|
)
|
||||||
# took = datetime.datetime.now() - start
|
|
||||||
# logger.info(f"Next token chooser {batch.all_input_ids_tensor.shape} took {took}")
|
|
||||||
# if batch.all_input_ids_tensor.shape[1] < 2000 and took > datetime.timedelta(milliseconds=5):
|
|
||||||
# next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
|
|
||||||
# batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits, verbose=True
|
|
||||||
# )
|
|
||||||
# import ipdb;ipdb.set_trace()
|
|
||||||
|
|
||||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
|
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
|
||||||
@ -1077,7 +1077,9 @@ class FlashCausalLM(Model):
|
|||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
batch.input_lengths[i] = input_length + 1
|
batch.input_lengths[i] = input_length + n_accepted_ids.item()
|
||||||
|
if batch.input_lengths[i] > batch.max_seqlen:
|
||||||
|
batch.max_seqlen = batch.input_lengths[i]
|
||||||
batch.prefix_offsets[i] = prefix_offset
|
batch.prefix_offsets[i] = prefix_offset
|
||||||
batch.read_offsets[i] = read_offset
|
batch.read_offsets[i] = read_offset
|
||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
@ -1090,6 +1092,5 @@ class FlashCausalLM(Model):
|
|||||||
batch.prefill_cu_outlens = None
|
batch.prefill_cu_outlens = None
|
||||||
batch.prefill_head_indices = None
|
batch.prefill_head_indices = None
|
||||||
batch.prefill_next_token_indices = None
|
batch.prefill_next_token_indices = None
|
||||||
batch.max_seqlen = batch.max_seqlen + 1
|
|
||||||
|
|
||||||
return generations, batch
|
return generations, batch
|
||||||
|
@ -742,6 +742,9 @@ try:
|
|||||||
|
|
||||||
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
|
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
|
||||||
|
|
||||||
|
if position_ids.max().item() >= max_s:
|
||||||
|
import ipdb;ipdb.set_trace()
|
||||||
|
|
||||||
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
||||||
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
||||||
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
|
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
|
|
||||||
SPECULATE = None
|
SPECULATE = None
|
||||||
|
|
||||||
def get_speculate():
|
def get_speculate() -> int:
|
||||||
global SPECULATE
|
global SPECULATE
|
||||||
return SPECULATE
|
return SPECULATE
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user