mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 16:32:12 +00:00
Some work to get batching working.
This commit is contained in:
parent
b0cb4fa9d0
commit
bdbccb774c
@ -0,0 +1,58 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "stop_sequence",
|
||||
"generated_tokens": 5,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -10.0625,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -12.28125,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 5229,
|
||||
"logprob": -1.7587891,
|
||||
"special": false,
|
||||
"text": " failed"
|
||||
},
|
||||
{
|
||||
"id": 363,
|
||||
"logprob": -0.5175781,
|
||||
"special": false,
|
||||
"text": " for"
|
||||
},
|
||||
{
|
||||
"id": 1404,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " user"
|
||||
},
|
||||
{
|
||||
"id": 376,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 1688,
|
||||
"logprob": -0.20422363,
|
||||
"special": false,
|
||||
"text": "test"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "Test request failed for user \"test"
|
||||
}
|
@ -0,0 +1,88 @@
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -10.0625,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -12.28125,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 363,
|
||||
"logprob": -2.0878906,
|
||||
"special": false,
|
||||
"text": " for"
|
||||
},
|
||||
{
|
||||
"id": 278,
|
||||
"logprob": -3.4121094,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 376,
|
||||
"logprob": -3.8457031,
|
||||
"special": false,
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 2577,
|
||||
"logprob": -3.5566406,
|
||||
"special": false,
|
||||
"text": "Get"
|
||||
},
|
||||
{
|
||||
"id": 599,
|
||||
"logprob": -3.4746094,
|
||||
"special": false,
|
||||
"text": " all"
|
||||
},
|
||||
{
|
||||
"id": 4160,
|
||||
"logprob": -3.2363281,
|
||||
"special": false,
|
||||
"text": " users"
|
||||
},
|
||||
{
|
||||
"id": 29908,
|
||||
"logprob": -0.49023438,
|
||||
"special": false,
|
||||
"text": "\""
|
||||
},
|
||||
{
|
||||
"id": 16248,
|
||||
"logprob": -1.2402344,
|
||||
"special": false,
|
||||
"text": " endpoint"
|
||||
},
|
||||
{
|
||||
"id": 29889,
|
||||
"logprob": -0.88134766,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.41870117,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": " for the \"Get all users\" endpoint.\n"
|
||||
}
|
58
integration-tests/models/test_flash_medusa.py
Normal file
58
integration-tests/models/test_flash_medusa.py
Normal file
@ -0,0 +1,58 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_medusa_handle(launcher):
|
||||
with launcher("FasterDecoding/medusa-vicuna-7b-v1.3", num_shard=2) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_medusa(flash_medusa_handle):
|
||||
await flash_medusa_handle.health(300)
|
||||
return flash_medusa_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_medusa_simple(flash_medusa, response_snapshot):
|
||||
response = await flash_medusa.generate(
|
||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_medusa_all_params(flash_medusa, response_snapshot):
|
||||
response = await flash_medusa.generate(
|
||||
"Test request",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
stop_sequences=["test"],
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 5
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot):
|
||||
responses = await generate_load(flash_medusa, "Test request", max_new_tokens=10, n=4)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
@ -450,26 +450,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
speculative_ids: Optional[torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
if speculative_ids is not None:
|
||||
speculative_length = speculative_ids.shape[1]
|
||||
new_length = speculative_length + 1
|
||||
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).squeeze(0)
|
||||
new_position_ids = (position_ids.view((1, -1)).expand(new_length, 1) + torch.arange(new_length).unsqueeze(1).to(device=position_ids.device)).squeeze(0).squeeze(-1)
|
||||
|
||||
# Add an extra block just in case
|
||||
block_tables = torch.cat([block_tables, block_tables[:, -1:] + 1], dim=1)
|
||||
# Add Copy the block tables for all members
|
||||
block_tables = block_tables.expand(new_length, -1).contiguous()
|
||||
slots = slots.expand(new_length) + torch.arange(new_length, dtype=slots.dtype).to(device=slots.device)
|
||||
input_lengths = input_lengths.expand(new_length) + torch.arange(new_length, dtype=input_lengths.dtype).to(device=input_lengths.device)
|
||||
max_s = max_s + speculative_length
|
||||
|
||||
|
||||
input_ids = new_input_ids
|
||||
position_ids = new_position_ids
|
||||
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
@ -520,7 +501,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
speculative_ids: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
@ -531,7 +511,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
speculative_ids,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
|
@ -3,6 +3,7 @@ import itertools
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
import torch
|
||||
import torch.distributed
|
||||
from loguru import logger
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -46,7 +47,6 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
|
||||
cu_seqlen_prefill: Optional[torch.Tensor]
|
||||
cu_seqlen_speculative: Optional[torch.Tensor]
|
||||
|
||||
# Paged Attention values
|
||||
|
||||
@ -123,7 +123,6 @@ class FlashCausalLMBatch(Batch):
|
||||
position_ids = []
|
||||
speculative_ids = []
|
||||
cu_seqlen_prefill = [0]
|
||||
cu_seqlen_speculative = [0]
|
||||
needed_blocks_slots = []
|
||||
start_slots = []
|
||||
slot_indices = []
|
||||
@ -163,18 +162,9 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
tokenized_input = tokenized_input[-r.truncate :]
|
||||
|
||||
# # TODO remove this
|
||||
# # Scaffolding to speculate some ids
|
||||
# speculate_ids = [1, 2]
|
||||
# tokenized_input.extend([1, 2])
|
||||
speculate_ids = []
|
||||
|
||||
|
||||
input_length = len(tokenized_input)
|
||||
input_lengths.append(input_length)
|
||||
|
||||
|
||||
|
||||
prefix_offsets.append(input_length - 5)
|
||||
read_offsets.append(input_length)
|
||||
|
||||
@ -186,7 +176,6 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
# Add cumulative lengths of all previous inputs
|
||||
cu_seqlen_prefill.append(cumulative_length + input_length)
|
||||
cu_seqlen_speculative.append(cumulative_length + input_length - len(speculate_ids))
|
||||
|
||||
next_token_chooser_parameters.append(r.parameters)
|
||||
|
||||
@ -199,7 +188,8 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
# Paged attention
|
||||
# Remove one as the first token des not have a past
|
||||
total_tokens = input_length + max_new_tokens - 1
|
||||
speculative_length = 2
|
||||
total_tokens = input_length + max_new_tokens - 1 + speculative_length
|
||||
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
||||
blocks += needed_blocks
|
||||
needed_blocks_slots.append((needed_blocks, total_tokens))
|
||||
@ -268,9 +258,6 @@ class FlashCausalLMBatch(Batch):
|
||||
cu_seqlen_prefill = torch.tensor(
|
||||
cu_seqlen_prefill, device=device, dtype=torch.int32
|
||||
)
|
||||
cu_seqlen_speculative = torch.tensor(
|
||||
cu_seqlen_speculative, device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
position_ids = position_ids.to(device)
|
||||
slot_indices = slot_indices.to(device)
|
||||
@ -296,6 +283,7 @@ class FlashCausalLMBatch(Batch):
|
||||
top_n_tokens, device=device, dtype=torch.int64
|
||||
)
|
||||
|
||||
logger.info("FROM PB")
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
requests=pb.requests,
|
||||
@ -303,7 +291,6 @@ class FlashCausalLMBatch(Batch):
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
cu_seqlen_speculative=cu_seqlen_speculative,
|
||||
start_slots=start_slots,
|
||||
slot_indices=slot_indices,
|
||||
needed_blocks_slots=needed_blocks_slots,
|
||||
@ -331,6 +318,7 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
|
||||
logger.info("FILTER")
|
||||
if len(request_ids) == 0:
|
||||
raise ValueError("Batch must have at least one request")
|
||||
# We assume that if len(requests) == len(self) then the requests are the same
|
||||
@ -344,6 +332,7 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
# Used to index into tensors
|
||||
indices = []
|
||||
batch_indices = []
|
||||
|
||||
# slots to keep after filtering
|
||||
slot_filtering_indices = torch.zeros(
|
||||
@ -371,14 +360,18 @@ class FlashCausalLMBatch(Batch):
|
||||
# Cumulative length
|
||||
cumulative_max_length = 0
|
||||
|
||||
logger.info(f"Request ids {request_ids} {len(self.requests)}")
|
||||
for i, request_id in enumerate(request_ids):
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
indices.append(idx)
|
||||
batch_indices.append(idx)
|
||||
S = 1 if self.speculative_ids is None else self.speculative_ids.shape[1] + 1
|
||||
indices.extend(range(idx * S, (idx + 1) * S))
|
||||
requests_idx_mapping[request_id] = i
|
||||
|
||||
requests.append(self.requests[idx])
|
||||
|
||||
# Get length
|
||||
logger.info(f"Input lengths {self.input_lengths} {idx} {S}")
|
||||
request_input_length = self.input_lengths[idx]
|
||||
max_seqlen = max(max_seqlen, request_input_length)
|
||||
|
||||
@ -429,14 +422,19 @@ class FlashCausalLMBatch(Batch):
|
||||
self.block_tables = None
|
||||
|
||||
# Index into tensors
|
||||
logger.info(f"INPUT IDS {indices} {self.input_ids}")
|
||||
input_ids = self.input_ids[indices]
|
||||
position_ids = self.position_ids[indices]
|
||||
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
||||
block_tables_tensor = self.block_tables_tensor[indices]
|
||||
input_lengths_tensor = self.input_lengths_tensor[indices]
|
||||
slots = self.slots[slot_filtering_indices]
|
||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
||||
next_token_chooser = self.next_token_chooser.filter(batch_indices)
|
||||
top_n_tokens_tensor = self.top_n_tokens_tensor[batch_indices]
|
||||
|
||||
logger.info(f"{indices} {self.speculative_ids}")
|
||||
speculative_ids = self.speculative_ids[batch_indices]
|
||||
logger.info(f"SPEC IDS {speculative_ids}")
|
||||
|
||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||
|
||||
@ -472,12 +470,14 @@ class FlashCausalLMBatch(Batch):
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
blocks=blocks,
|
||||
max_blocks=max_blocks,
|
||||
speculative_ids=speculative_ids,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
|
||||
# Batch attributes
|
||||
logger.info(f"Concatenate {len(batches)}, {[b.input_ids.shape for b in batches]}")
|
||||
requests = []
|
||||
requests_idx_mapping = {}
|
||||
|
||||
@ -488,7 +488,7 @@ class FlashCausalLMBatch(Batch):
|
||||
max_length = 0
|
||||
max_seqlen = 0
|
||||
for b in batches:
|
||||
total_batch_size += len(b)
|
||||
total_batch_size += len(b.input_ids)
|
||||
total_slots += len(b.slots)
|
||||
blocks += b.blocks
|
||||
max_blocks = max(max_blocks, b.max_blocks)
|
||||
@ -536,6 +536,7 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
# Cumulative length
|
||||
cumulative_batch_size = 0
|
||||
cumulative_1 = 0
|
||||
cumulative_slots = 0
|
||||
|
||||
for i, batch in enumerate(batches):
|
||||
@ -546,23 +547,28 @@ class FlashCausalLMBatch(Batch):
|
||||
else:
|
||||
# We need to offset the mapping for each batch by the cumulative batch size
|
||||
for k, v in batch.requests_idx_mapping.items():
|
||||
requests_idx_mapping[k] = v + cumulative_batch_size
|
||||
requests_idx_mapping[k] = v + cumulative_1
|
||||
|
||||
start_index = cumulative_batch_size
|
||||
end_index = cumulative_batch_size + len(batch)
|
||||
end_index = cumulative_batch_size + len(batch.input_ids)
|
||||
slots_start_index = cumulative_slots
|
||||
slots_end_index = cumulative_slots + len(batch.slots)
|
||||
|
||||
start_index1 = cumulative_1
|
||||
end_index1 = cumulative_1 + len(batch)
|
||||
|
||||
# Copy tensors (GPU)
|
||||
input_ids[start_index:end_index] = batch.input_ids
|
||||
position_ids[start_index:end_index] = batch.position_ids
|
||||
logger.info(f"IN concat {batch.slot_indices}")
|
||||
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
|
||||
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
|
||||
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||
|
||||
top_n_tokens_tensor[start_index1:end_index1] = batch.top_n_tokens_tensor
|
||||
slots[slots_start_index:slots_end_index] = batch.slots
|
||||
|
||||
all_input_ids_tensor[
|
||||
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
||||
start_index1:end_index1, : batch.all_input_ids_tensor.shape[1]
|
||||
] = batch.all_input_ids_tensor[:, :max_length]
|
||||
|
||||
block_tables_tensor[
|
||||
@ -584,11 +590,14 @@ class FlashCausalLMBatch(Batch):
|
||||
top_n_tokens.extend(batch.top_n_tokens)
|
||||
|
||||
# Update
|
||||
cumulative_batch_size += len(batch)
|
||||
cumulative_batch_size += len(batch.input_ids)
|
||||
cumulative_slots += len(batch.slots)
|
||||
cumulative_1 += len(batch)
|
||||
|
||||
start_slots = torch.concat(start_slots)
|
||||
|
||||
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0)
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
next_token_chooser_parameters,
|
||||
dtype=batches[0].next_token_chooser.dtype,
|
||||
@ -629,6 +638,7 @@ class FlashCausalLMBatch(Batch):
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
blocks=blocks,
|
||||
max_blocks=max_blocks,
|
||||
speculative_ids=speculative_ids,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
@ -731,18 +741,28 @@ class FlashCausalLM(Model):
|
||||
return int(num_blocks * BLOCK_SIZE)
|
||||
|
||||
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Model Forward
|
||||
|
||||
input_ids=batch.input_ids
|
||||
position_ids=batch.position_ids
|
||||
cu_seqlen_prefill=batch.cu_seqlen_prefill
|
||||
kv_cache = get_cache_manager().kv_cache
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths=batch.input_lengths_tensor
|
||||
max_s=batch.max_seqlen
|
||||
lm_head_indices=batch.prefill_head_indices
|
||||
|
||||
return self.model.forward(
|
||||
input_ids=batch.input_ids,
|
||||
position_ids=batch.position_ids,
|
||||
cu_seqlen_prefill=batch.cu_seqlen_prefill,
|
||||
kv_cache=get_cache_manager().kv_cache,
|
||||
block_tables=batch.block_tables_tensor,
|
||||
slots=batch.slots[batch.slot_indices],
|
||||
input_lengths=batch.input_lengths_tensor,
|
||||
max_s=batch.max_seqlen,
|
||||
lm_head_indices=batch.prefill_head_indices,
|
||||
speculative_ids =batch.speculative_ids
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
lm_head_indices=lm_head_indices,
|
||||
# speculative_ids =batch.speculative_ids
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
@ -790,8 +810,6 @@ class FlashCausalLM(Model):
|
||||
next_token_logits = out
|
||||
|
||||
|
||||
# if next_token_logits.shape[0] == 3:
|
||||
# import ipdb;ipdb.set_trace()
|
||||
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
|
||||
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, batch.speculative_ids, speculative_logits
|
||||
)
|
||||
@ -799,20 +817,15 @@ class FlashCausalLM(Model):
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
|
||||
)
|
||||
logger.info(f"{batch.top_n_tokens_tensor.shape}, {batch.top_n_tokens}")
|
||||
|
||||
speculative_length = 0 if speculative_ids is None else speculative_ids.shape[1]
|
||||
if prefill:
|
||||
if len(batch) > 1 and prefill_logprobs:
|
||||
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
||||
# When batch == 1, we will just use the batch.input_ids values directly
|
||||
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
||||
|
||||
if speculative_ids is not None:
|
||||
# length = len(batch) * (1 + speculative_length)
|
||||
length = len(batch)
|
||||
else:
|
||||
length = len(batch)
|
||||
# import ipdb;ipdb.set_trace()
|
||||
length = len(batch)
|
||||
next_position_ids = batch.position_ids.new_empty(length)
|
||||
# Keep only 1 slot index, TODO make sure we recover the speculated ids slots later
|
||||
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
|
||||
@ -820,7 +833,6 @@ class FlashCausalLM(Model):
|
||||
batch.cu_seqlen_prefill = None
|
||||
else:
|
||||
prefill_logprobs = None
|
||||
# import ipdb;ipdb.set_trace()
|
||||
next_position_ids = batch.position_ids
|
||||
|
||||
# Cumulative length
|
||||
@ -875,29 +887,32 @@ class FlashCausalLM(Model):
|
||||
start_index + 1 : start_index + out_length
|
||||
]
|
||||
|
||||
# logger.info(f"Request ids {request_ids} {len(self.requests)}")
|
||||
for j in range(n_accepted_ids):
|
||||
batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
|
||||
index += 1
|
||||
|
||||
cumulative_length += input_length
|
||||
|
||||
|
||||
# if accepted_ids[0] > 1:
|
||||
# import ipdb;ipdb.set_trace()
|
||||
|
||||
if len(accepted_ids) > 1:
|
||||
raise Exception("Implemtent the batched behavior")
|
||||
# if len(accepted_ids) > 1:
|
||||
# raise Exception("Implemtent the batched behavior")
|
||||
|
||||
# Set values in batch
|
||||
# batch.input_ids = torch.cat([next_input_ids.unsqueeze(-1), speculative_ids], dim=1).view(-1)
|
||||
|
||||
for n_accepted_ids in accepted_ids:
|
||||
# TODO Make this batched
|
||||
batch.input_ids = next_input_ids[-1:]
|
||||
batch.speculative_ids = speculative_ids
|
||||
batch.position_ids = next_position_ids + n_accepted_ids
|
||||
batch.input_lengths_tensor += n_accepted_ids
|
||||
batch.slot_indices += n_accepted_ids
|
||||
accepted_ids = torch.tensor(accepted_ids, device=batch.input_ids.device, dtype=batch.input_ids.dtype)
|
||||
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
|
||||
batch.speculative_ids = speculative_ids
|
||||
|
||||
logger.info(f"ACCEPTED IDS {accepted_ids} {batch.position_ids}")
|
||||
if accepted_ids.shape != batch.slot_indices:
|
||||
# This can happen after a concatenation
|
||||
# The slot indices is already modified for some speculative_ids
|
||||
B = batch.slot_indices.shape[0] // accepted_ids.shape[0]
|
||||
accepted_ids = accepted_ids.view(-1, 1).expand(-1, B).reshape(-1)
|
||||
batch.slot_indices += accepted_ids
|
||||
batch.position_ids = next_position_ids + accepted_ids
|
||||
batch.input_lengths_tensor += accepted_ids
|
||||
|
||||
if prefill and prefill_logprobs:
|
||||
# Get prefill logprobs
|
||||
@ -955,26 +970,22 @@ class FlashCausalLM(Model):
|
||||
next_token_texts = []
|
||||
for j in range(index, index + n_accepted_ids):
|
||||
# Generated token
|
||||
all_input_ids.append(next_token_ids[j])
|
||||
next_token_id = next_token_ids[j]
|
||||
all_input_ids.append(next_token_id)
|
||||
next_token_text, prefix_offset, read_offset = self.decode_token(
|
||||
all_input_ids,
|
||||
prefix_offset,
|
||||
read_offset,
|
||||
)
|
||||
next_token_texts.append(next_token_text)
|
||||
|
||||
# Evaluate stopping criteria
|
||||
|
||||
for next_token_id in _next_token_ids:
|
||||
stop, reason = stopping_criteria(
|
||||
next_token_id,
|
||||
next_token_text,
|
||||
)
|
||||
|
||||
if stop:
|
||||
stopped = True
|
||||
break
|
||||
if not stop:
|
||||
else:
|
||||
stopped = False
|
||||
|
||||
# Shard generations
|
||||
@ -1068,8 +1079,27 @@ class FlashCausalLM(Model):
|
||||
batch.prefill_cu_outlens = None
|
||||
batch.prefill_head_indices = None
|
||||
batch.prefill_next_token_indices = None
|
||||
if prefill:
|
||||
batch.max_seqlen += speculative_length
|
||||
batch.max_seqlen = batch.max_seqlen + 1
|
||||
|
||||
# Model Forward
|
||||
if batch.speculative_ids is not None:
|
||||
B, speculative_length = batch.speculative_ids.shape
|
||||
new_length = speculative_length + 1
|
||||
batch.input_ids = torch.cat([batch.input_ids.unsqueeze(-1), batch.speculative_ids], dim=1).view(-1)
|
||||
if batch.position_ids.shape[0] != B * new_length:
|
||||
arange = torch.arange(new_length).unsqueeze(0).to(device=batch.position_ids.device)
|
||||
batch.position_ids = (batch.position_ids.view((-1, 1)).expand(B,new_length) + arange).view(-1)
|
||||
batch.slot_indices = (batch.slot_indices.view((-1, 1)).expand(B,new_length) + arange.to(dtype=batch.slot_indices.dtype)).view(-1)
|
||||
batch.input_lengths_tensor = (batch.input_lengths_tensor.view((-1, 1)).expand(B,new_length) + arange.to(dtype=batch.input_lengths_tensor.dtype)).view(-1)
|
||||
batch.max_seqlen = batch.max_seqlen + speculative_length
|
||||
# Add an extra block just in case
|
||||
block_tables = torch.cat([batch.block_tables_tensor, batch.block_tables_tensor[:, -1:] + 1], dim=1)
|
||||
# Add Copy the block tables for all members
|
||||
# Contiguous because paged assumes contiguity
|
||||
batch.block_tables_tensor = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B * new_length, -1).contiguous()
|
||||
|
||||
batch.lm_head_indices=batch.prefill_head_indices
|
||||
cu_seqlen_prefill=batch.cu_seqlen_prefill
|
||||
|
||||
|
||||
return generations, batch
|
||||
|
@ -225,21 +225,35 @@ class HeterogeneousNextTokenChooser:
|
||||
scores = warper(input_ids, scores)
|
||||
|
||||
|
||||
accepted_ids = []
|
||||
next_ids = self.choice(scores)
|
||||
from loguru import logger
|
||||
if speculated_ids is not None:
|
||||
validate_speculative = next_ids[:-1] == speculated_ids[0]
|
||||
index = 1
|
||||
for valid in validate_speculative.tolist():
|
||||
if valid:
|
||||
index += 1
|
||||
# print(f"Validated {index - 1}")
|
||||
next_ids = next_ids[:index]
|
||||
scores = scores[:index]
|
||||
speculative_scores = speculative_scores[index - 1:index]
|
||||
accepted_ids.append(index)
|
||||
logger.info(f"CHOOSER {next_ids} {speculated_ids}")
|
||||
accepted_ids = []
|
||||
B = next_ids.shape[0] // (speculated_ids.shape[1] + 1)
|
||||
S = speculated_ids.shape[1] + 1
|
||||
indices = []
|
||||
for i in range(B):
|
||||
_next_ids = next_ids[i*S: (i + 1)*S]
|
||||
_speculated_ids = speculated_ids[i]
|
||||
validate_speculative = _next_ids[:-1] == _speculated_ids
|
||||
index = i * S
|
||||
accepted = 1
|
||||
# First is always valid
|
||||
indices.append(index)
|
||||
for valid in validate_speculative.tolist():
|
||||
if valid:
|
||||
index += 1
|
||||
accepted += 1
|
||||
indices.append(index)
|
||||
# print(f"Validated {accepted}")
|
||||
accepted_ids.append(accepted)
|
||||
accepted_ids = torch.tensor(accepted_ids, device=input_ids.device, dtype=input_ids.dtype)
|
||||
next_ids = next_ids[indices]
|
||||
scores = scores[indices]
|
||||
speculative_scores = speculative_scores[accepted_ids.cumsum(dim=-1) - 1]
|
||||
else:
|
||||
accepted_ids.append(1)
|
||||
accepted_ids = torch.ones_like(next_ids)
|
||||
|
||||
|
||||
logprobs = torch.log_softmax(scores, -1)
|
||||
|
Loading…
Reference in New Issue
Block a user