This commit is contained in:
OlivierDehaene 2023-04-18 17:51:41 +02:00
parent 9476170dda
commit 2ad7a63761
3 changed files with 288 additions and 247 deletions

118
Cargo.lock generated
View File

@ -42,42 +42,51 @@ dependencies = [
[[package]]
name = "anstream"
version = "0.2.6"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "342258dd14006105c2b75ab1bd7543a03bdf0cfc94383303ac212a04939dff6f"
checksum = "9e579a7752471abc2a8268df8b20005e3eadd975f585398f17efcfd8d4927371"
dependencies = [
"anstyle",
"anstyle-parse",
"anstyle-query",
"anstyle-wincon",
"concolor-override",
"concolor-query",
"colorchoice",
"is-terminal",
"utf8parse",
]
[[package]]
name = "anstyle"
version = "0.3.5"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23ea9e81bd02e310c216d080f6223c179012256e5151c41db88d12c88a1684d2"
checksum = "41ed9a86bf92ae6580e0a31281f65a1b1d867c0cc68d5346e2ae128dddfa6a7d"
[[package]]
name = "anstyle-parse"
version = "0.1.1"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7d1bb534e9efed14f3e5f44e7dd1a4f709384023a4165199a4241e18dff0116"
checksum = "e765fd216e48e067936442276d1d57399e37bce53c264d6fefbe298080cb57ee"
dependencies = [
"utf8parse",
]
[[package]]
name = "anstyle-wincon"
version = "0.2.0"
name = "anstyle-query"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3127af6145b149f3287bb9a0d10ad9c5692dba8c53ad48285e5bec4063834fa"
checksum = "5ca11d4be1bab0c8bc8734a9aa7bf4ee8316d462a08c6ac5052f888fef5b494b"
dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "anstyle-wincon"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4bcd8291a340dd8ac70e18878bc4501dd7b4ff970cfa21c207d36ece51ea88fd"
dependencies = [
"anstyle",
"windows-sys 0.45.0",
"windows-sys 0.48.0",
]
[[package]]
@ -105,7 +114,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.14",
"syn 2.0.15",
]
[[package]]
@ -116,7 +125,7 @@ checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.14",
"syn 2.0.15",
]
[[package]]
@ -127,9 +136,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "axum"
version = "0.6.13"
version = "0.6.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6539e4565c365448d483967c6dee3eaecb8e87679a17806a831e82b05b903c18"
checksum = "3b32c5ea3aabaf4deb5f5ced2d688ec0844c881c9e6c696a8b769a05fc691e62"
dependencies = [
"async-trait",
"axum-core",
@ -310,9 +319,9 @@ dependencies = [
[[package]]
name = "clap"
version = "4.2.1"
version = "4.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "046ae530c528f252094e4a77886ee1374437744b2bff1497aa898bbddbbb29b3"
checksum = "9b802d85aaf3a1cdb02b224ba472ebdea62014fccfcb269b95a4d76443b5ee5a"
dependencies = [
"clap_builder",
"clap_derive",
@ -321,9 +330,9 @@ dependencies = [
[[package]]
name = "clap_builder"
version = "4.2.1"
version = "4.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "223163f58c9a40c3b0a43e1c4b50a9ce09f007ea2cb1ec258a687945b4b7929f"
checksum = "14a1a858f532119338887a4b8e1af9c60de8249cd7bafd68036a489e261e37b6"
dependencies = [
"anstream",
"anstyle",
@ -341,7 +350,7 @@ dependencies = [
"heck",
"proc-macro2",
"quote",
"syn 2.0.14",
"syn 2.0.15",
]
[[package]]
@ -351,19 +360,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a2dd5a6fe8c6e3502f568a6353e5273bbb15193ad9a89e457b9970798efbea1"
[[package]]
name = "concolor-override"
name = "colorchoice"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a855d4a1978dc52fb0536a04d384c2c0c1aa273597f08b77c8c4d3b2eec6037f"
[[package]]
name = "concolor-query"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88d11d52c3d7ca2e6d0040212be9e4dbbcd78b6447f535b6b561f449427944cf"
dependencies = [
"windows-sys 0.45.0",
]
checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7"
[[package]]
name = "console"
@ -794,7 +794,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.14",
"syn 2.0.15",
]
[[package]]
@ -868,9 +868,9 @@ dependencies = [
[[package]]
name = "h2"
version = "0.3.16"
version = "0.3.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5be7b54589b581f624f566bf5d8eb2bab1db736c51528720b6bd36b96b55924d"
checksum = "17f8a914c2987b688368b5138aa05321db91f4090cf26118185672ad588bce21"
dependencies = [
"bytes",
"fnv",
@ -966,9 +966,9 @@ checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421"
[[package]]
name = "hyper"
version = "0.14.25"
version = "0.14.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc5e554ff619822309ffd57d8734d77cd5ce6238bc956f037ea06c58238c9899"
checksum = "ab302d72a6f11a3b910431ff93aae7e773078c769f0a3ef15fb9ec692ed147d4"
dependencies = [
"bytes",
"futures-channel",
@ -1364,7 +1364,7 @@ checksum = "8795add3e14028f11f8e848bd3294898a8294767b3776b6f733560d33bd2530b"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.14",
"syn 2.0.15",
]
[[package]]
@ -1517,7 +1517,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.14",
"syn 2.0.15",
]
[[package]]
@ -1787,9 +1787,9 @@ dependencies = [
[[package]]
name = "prost"
version = "0.11.8"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e48e50df39172a3e7eb17e14642445da64996989bc212b583015435d39a58537"
checksum = "0b82eaa1d779e9a4bc1c3217db8ffbeabaae1dca241bf70183242128d48681cd"
dependencies = [
"bytes",
"prost-derive",
@ -1797,9 +1797,9 @@ dependencies = [
[[package]]
name = "prost-build"
version = "0.11.8"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c828f93f5ca4826f97fedcbd3f9a536c16b12cff3dbbb4a007f932bbad95b12"
checksum = "119533552c9a7ffacc21e099c24a0ac8bb19c2a2a3f363de84cd9b844feab270"
dependencies = [
"bytes",
"heck",
@ -1819,9 +1819,9 @@ dependencies = [
[[package]]
name = "prost-derive"
version = "0.11.8"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ea9b0f8cbe5e15a8a042d030bd96668db28ecb567ec37d691971ff5731d2b1b"
checksum = "e5d2d8d10f3c6ded6da8b05b5fb3b8a5082514344d56c9f871412d29b4e075b4"
dependencies = [
"anyhow",
"itertools 0.10.5",
@ -1832,9 +1832,9 @@ dependencies = [
[[package]]
name = "prost-types"
version = "0.11.8"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "379119666929a1afd7a043aa6cf96fa67a6dce9af60c88095a4686dbce4c9c88"
checksum = "213622a1460818959ac1181aaeb2dc9c7f63df720db7d788b3e24eacd1983e13"
dependencies = [
"prost",
]
@ -2153,14 +2153,14 @@ checksum = "291a097c63d8497e00160b166a967a4a79c64f3facdd01cbd7502231688d77df"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.14",
"syn 2.0.15",
]
[[package]]
name = "serde_json"
version = "1.0.95"
version = "1.0.96"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d721eca97ac802aa7777b701877c8004d950fc142651367300d21c1cc0194744"
checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1"
dependencies = [
"itoa",
"ryu",
@ -2330,9 +2330,9 @@ dependencies = [
[[package]]
name = "syn"
version = "2.0.14"
version = "2.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcf316d5356ed6847742d036f8a39c3b8435cac10bd528a4bd461928a6ab34d5"
checksum = "a34fcf3e8b60f57e6a14301a2e916d323af98b0ea63c599441eec8558660c822"
dependencies = [
"proc-macro2",
"quote",
@ -2450,7 +2450,7 @@ checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.14",
"syn 2.0.15",
]
[[package]]
@ -2578,7 +2578,7 @@ checksum = "61a573bdc87985e9d6ddeed1b3d864e8a302c847e40d647746df2f1de209d1ce"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.14",
"syn 2.0.15",
]
[[package]]
@ -2928,9 +2928,9 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a"
[[package]]
name = "utoipa"
version = "3.2.1"
version = "3.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24e7ee17c9ef094b86e1e04170d90765bd76cb381921dacb4d3e175a267bdae6"
checksum = "68ae74ef183fae36d650f063ae7bde1cacbe1cd7e72b617cbe1e985551878b98"
dependencies = [
"indexmap",
"serde",
@ -2940,14 +2940,14 @@ dependencies = [
[[package]]
name = "utoipa-gen"
version = "3.2.1"
version = "3.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df6f458e5abc811d44aca28455efc4163fb7565a7af2aa32d17611f3d1d9794d"
checksum = "7ea8ac818da7e746a63285594cce8a96f5e00ee31994e655bd827569cb8b137b"
dependencies = [
"proc-macro-error",
"proc-macro2",
"quote",
"syn 2.0.14",
"syn 2.0.15",
]
[[package]]

View File

@ -58,10 +58,10 @@ class CausalLMBatch(Batch):
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
) -> "CausalLMBatch":
inputs = []
next_token_choosers = []
@ -158,9 +158,7 @@ class CausalLMBatch(Batch):
request_input_length = self.input_lengths[idx]
input_lengths.append(request_input_length)
max_input_length = max(
max_input_length, request_input_length
)
max_input_length = max(max_input_length, request_input_length)
# Replace metadata
self.requests_idx_mapping = requests_idx_mapping
@ -176,19 +174,12 @@ class CausalLMBatch(Batch):
self.position_ids = self.position_ids[keep_indices]
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
self.past_key_values = [
[
t.view(len(self), -1, *t.shape[-2:])[keep_indices]
for t in layer
]
[t.view(len(self), -1, *t.shape[-2:])[keep_indices] for t in layer]
for layer in self.past_key_values
]
self.requests = [self.requests[i] for i in keep_indices]
self.next_token_choosers = [
self.next_token_choosers[i] for i in keep_indices
]
self.stopping_criterias = [
self.stopping_criterias[i] for i in keep_indices
]
self.requests = requests
self.next_token_choosers = [self.next_token_choosers[i] for i in keep_indices]
self.stopping_criterias = [self.stopping_criterias[i] for i in keep_indices]
return self
@ -263,17 +254,17 @@ class CausalLMBatch(Batch):
# and to remove unused allocated space
left_offset = max_input_length - batch.max_input_length
batch_left_offset = (
batch.attention_mask.shape[1]
- batch.max_input_length
- batch.padding_right_offset
batch.attention_mask.shape[1]
- batch.max_input_length
- batch.padding_right_offset
)
attention_mask[
start_index:end_index,
left_offset:-padding_right_offset,
start_index:end_index,
left_offset:-padding_right_offset,
] = batch.attention_mask[
:,
batch_left_offset: -batch.padding_right_offset,
]
batch_left_offset : -batch.padding_right_offset,
]
# Create empty tensor
# position_ids is always of shape [batch_size, 1]
@ -319,22 +310,22 @@ class CausalLMBatch(Batch):
# We slice the past keys and values to remove the padding from previous batches
if batch.keys_head_dim_last:
past_key_values[j][0][
start_index:end_index,
:,
-(batch.max_input_length - 1):,
:,
] = past_keys[:, :, -(batch.max_input_length - 1):, :]
start_index:end_index,
:,
-(batch.max_input_length - 1) :,
:,
] = past_keys[:, :, -(batch.max_input_length - 1) :, :]
else:
past_key_values[j][0][
start_index:end_index,
:,
:,
-(batch.max_input_length - 1):,
] = past_keys[:, :, :, -(batch.max_input_length - 1):]
start_index:end_index,
:,
:,
-(batch.max_input_length - 1) :,
] = past_keys[:, :, :, -(batch.max_input_length - 1) :]
past_key_values[j][1][
start_index:end_index, :, -(batch.max_input_length - 1):, :
] = past_values[:, :, -(batch.max_input_length - 1):, :]
start_index:end_index, :, -(batch.max_input_length - 1) :, :
] = past_values[:, :, -(batch.max_input_length - 1) :, :]
start_index += len(batch)
@ -363,11 +354,11 @@ class CausalLMBatch(Batch):
class CausalLM(Model):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: bool = False,
decode_buffer: int = 3,
self,
model_id: str,
revision: Optional[str] = None,
quantize: bool = False,
decode_buffer: int = 3,
):
if torch.cuda.is_available():
device = torch.device("cuda")
@ -409,7 +400,7 @@ class CausalLM(Model):
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward
outputs = self.model.forward(
@ -423,7 +414,7 @@ class CausalLM(Model):
@tracer.start_as_current_span("generate_token")
def generate_token(
self, batch: CausalLMBatch
self, batch: CausalLMBatch
) -> Tuple[List[Generation], CausalLMBatch]:
# slice the attention mask to the correct shape
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
@ -435,11 +426,9 @@ class CausalLM(Model):
batch.past_key_values,
)
# New values for next forward
next_batch_input_ids = []
# Results
generations: List[Generation] = []
stopped = True
# Zipped iterator
iterator = zip(
@ -455,14 +444,14 @@ class CausalLM(Model):
# For each member of the batch
for i, (
request,
input_length,
offset,
token_offset,
logits,
next_token_chooser,
stopping_criteria,
all_input_ids,
request,
input_length,
offset,
token_offset,
logits,
next_token_chooser,
stopping_criteria,
all_input_ids,
) in enumerate(iterator):
# Select next token
next_token_id, logprobs = next_token_chooser(
@ -489,7 +478,7 @@ class CausalLM(Model):
if stop:
# Decode generated tokens
output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens:, 0]
all_input_ids[-stopping_criteria.current_tokens :, 0]
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
@ -503,6 +492,7 @@ class CausalLM(Model):
else:
# Keep request in the batch
generated_text = None
stopped = False
# Prefill
if stopping_criteria.current_tokens == 1:
@ -535,7 +525,7 @@ class CausalLM(Model):
generations.append(generation)
# Update values
next_batch_input_ids.append(next_token_id)
batch.input_ids[i] = next_token_id
batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length
batch.offsets[i] = offset
@ -544,8 +534,6 @@ class CausalLM(Model):
# Decrease right offset
batch.padding_right_offset -= 1
# Create input_ids tensor
batch.input_ids = torch.cat(next_batch_input_ids, dim=0)
# Update attention_mask as we added a new token to input_ids
batch.attention_mask[:, -batch.padding_right_offset] = 1
@ -555,4 +543,4 @@ class CausalLM(Model):
# Update past key values
batch.past_key_values = past
return generations, batch
return generations, batch if not stopped else None

View File

@ -6,7 +6,7 @@ from torch.nn import functional as F
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
from typing import Optional, Tuple, List, Type, Union
from typing import Optional, Tuple, List, Type, Union, Dict
from text_generation_server.models import Model
from text_generation_server.models.types import (
@ -29,14 +29,16 @@ tracer = trace.get_tracer(__name__)
class FlashCausalLMBatch(Batch):
batch_id: int
requests: List[generate_pb2.Request]
# request id -> idx in list mapping
requests_idx_mapping: Dict[int, int]
# Decoder values
input_ids: torch.Tensor
position_ids: torch.Tensor
input_ids: List[torch.Tensor]
position_ids: List[torch.Tensor]
# cumulative sequence lengths
cu_seqlens: torch.Tensor
cu_seqlens: List[int]
max_seqlen: int
past_key_values: Optional[torch.Tensor]
past_key_values: Optional[List[torch.Tensor]]
# All tokens
all_input_ids: List[List[int]]
@ -62,7 +64,7 @@ class FlashCausalLMBatch(Batch):
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
) -> "CausalLMBatch":
) -> "FlashCausalLMBatch":
input_ids = []
position_ids = []
cu_seqlens = [0]
@ -73,6 +75,7 @@ class FlashCausalLMBatch(Batch):
token_offsets = []
all_input_ids = []
all_input_ids_tensor = []
requests_idx_mapping = {}
next_token_choosers = []
stopping_criterias = []
@ -81,13 +84,18 @@ class FlashCausalLMBatch(Batch):
cumulative_length = 0
# Parse batch
for r in pb.requests:
for i, r in enumerate(pb.requests):
# request id -> idx in list mapping
requests_idx_mapping[r.id] = i
tokenized_input = tokenizer(
r.inputs, truncation=True, max_length=r.truncate
)["input_ids"]
input_length = len(tokenized_input)
max_seqlen = max(max_seqlen, input_length)
input_lengths.append(input_length)
offsets.append(None)
token_offsets.append(None)
all_input_ids.append(tokenized_input)
@ -96,7 +104,9 @@ class FlashCausalLMBatch(Batch):
input_ids.append(tokenized_input)
# Position ids
position_ids.append(torch.arange(0, input_length, dtype=torch.int32))
position_ids.append(
torch.arange(0, input_length, dtype=torch.int32, device=device)
)
# Add cumulative lengths of all previous inputs
cu_seqlens.append(cumulative_length + input_length)
@ -113,13 +123,10 @@ class FlashCausalLMBatch(Batch):
# Update
cumulative_length += input_length
input_ids = torch.concat(input_ids)
position_ids = torch.concat(position_ids)
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32)
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
@ -134,60 +141,138 @@ class FlashCausalLMBatch(Batch):
stopping_criterias=stopping_criterias,
)
@tracer.start_as_current_span("filter")
def filter(self, requests: List[generate_pb2.Request]) -> "FlashCausalLMBatch":
if len(requests) == 0:
raise ValueError("Batch must have at least one request")
# We assume that if len(requests) == len(self) then the requests are the same
if len(requests) == len(self):
return self
# Cumulative length
cumulative_length = 0
# New values after filtering
requests_idx_mapping = {}
input_ids = []
position_ids = []
cu_seqlens = [0]
max_seqlen = 0
past_key_values = []
all_input_ids = []
all_input_ids_tensor = []
input_lengths = []
offsets = []
token_offsets = []
next_token_choosers = []
stopping_criterias = []
for i, r in enumerate(requests):
idx = self.requests_idx_mapping[r.id]
requests_idx_mapping[r.id] = i
# Get length
request_input_length = self.input_lengths[idx]
input_ids.append(self.input_ids[idx])
position_ids.append(self.position_ids[idx])
cu_seqlens.append(cumulative_length + request_input_length)
max_seqlen = max(max_seqlen, request_input_length)
past_key_values.append(self.past_key_values[idx])
all_input_ids.append(self.all_input_ids[idx])
all_input_ids_tensor.append(self.all_input_ids_tensor[idx])
input_lengths.append(request_input_length)
offsets.append(self.offsets[idx])
token_offsets.append(self.token_offsets[idx])
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criterias.append(self.stopping_criterias[idx])
cumulative_length += request_input_length
return FlashCausalLMBatch(
batch_id=self.batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
past_key_values=past_key_values,
input_lengths=input_lengths,
offsets=offsets,
token_offsets=token_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
)
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
# Batch attributes
requests = []
input_lengths = []
offsets = []
token_offsets = []
all_input_ids = []
all_input_ids_tensor = []
next_token_choosers = []
stopping_criterias = []
requests_idx_mapping = {}
# Batch tensors
input_ids = []
position_ids = []
cu_seqlens = [torch.tensor([0], dtype=torch.int32)]
cu_seqlens = [0]
max_seqlen = 0
past_key_values = []
all_input_ids = []
all_input_ids_tensor = []
input_lengths = []
offsets = []
token_offsets = []
next_token_choosers = []
stopping_criterias = []
# Cumulative length
cumulative_length = torch.tensor(0)
cumulative_batch_size = 0
cumulative_length = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
# We need to offset the mapping for each batch by the cumulative batch size
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + cumulative_batch_size
input_ids.extend(batch.input_ids)
position_ids.extend(batch.position_ids)
# Add cumulative lengths of all previous inputs
cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]])
max_seqlen = max(max_seqlen, batch.max_seqlen)
past_key_values.extend(batch.past_key_values)
all_input_ids.extend(batch.all_input_ids)
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
input_lengths.extend(batch.input_lengths)
offsets.extend(batch.offsets)
token_offsets.extend(batch.token_offsets)
all_input_ids.extend(batch.all_input_ids)
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
# Add cumulative lengths of all previous inputs
cu_seqlens.append(batch.cu_seqlens[1:] + cumulative_length)
input_ids.append(batch.input_ids)
position_ids.append(batch.position_ids)
past_key_values.append(batch.past_key_values)
max_seqlen = max(max_seqlen, batch.max_seqlen)
# Update
cumulative_length += batch.cu_seqlens[-1]
input_ids = torch.concat(input_ids)
position_ids = torch.concat(position_ids)
# Concat on dim=1 as first dim represents the model layers
past_key_values = torch.concat(past_key_values, dim=1)
cu_seqlens = torch.concat(cu_seqlens)
cumulative_batch_size += len(batch)
return FlashCausalLMBatch(
batch_id=batches[0].batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
@ -269,38 +354,49 @@ class FlashCausalLM(Model):
def generate_token(
self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
# Better to send to device here to avoid device issues in concatenate
position_ids = batch.position_ids.to(self.device, non_blocking=True)
cu_seqlens = batch.cu_seqlens.to(self.device)
# Shortcut when batch_size == 1
if len(batch) == 1:
input_ids = batch.input_ids[0].view(-1)
past_key_values = (
batch.past_key_values[0] if batch.past_key_values is not None else None
)
else:
# Concatenate tensors
input_ids = torch.cat(batch.input_ids).view(-1)
past_key_values = (
torch.cat(batch.past_key_values, dim=1)
if batch.past_key_values is not None
else None
)
# Concatenate when prefill, torch.tensor when decode
position_ids = (
torch.tensor(batch.position_ids, device=self.device)
if batch.past_key_values is not None
else torch.cat(batch.position_ids)
)
cu_seqlens = torch.tensor(
batch.cu_seqlens, device=self.device, dtype=torch.int32
)
out, present = self.forward(
batch.input_ids,
input_ids,
position_ids,
cu_seqlens,
batch.max_seqlen,
batch.past_key_values,
past_key_values,
)
# List of indices to cache
next_batch_keep_indices = []
# New values for next forward
next_batch_input_ids = []
next_batch_position_ids = []
next_batch_cu_seqlens = [0]
next_batch_max_seqlen = 0
next_batch_past_key_values = []
next_batch_input_lengths = []
next_batch_offsets = []
next_batch_token_offsets = []
next_batch_all_input_ids = []
next_batch_all_input_ids_tensor = []
# Initialize past_key_values in prefill
if batch.past_key_values is None:
batch.past_key_values = [None] * len(batch)
# Cumulative length
cumulative_length = 0
# Results
generations: List[Generation] = []
stopped = True
# Zipped iterator
iterator = zip(
@ -329,7 +425,8 @@ class FlashCausalLM(Model):
start_index = cumulative_length
end_index = cumulative_length + input_length
if batch.past_key_values is None:
prefill = stopping_criteria.current_tokens == 0
if prefill:
# Prefill mode
# out is of shape [cumulative_sequence_lengths, vocab_size]
logits = out[start_index:end_index]
@ -348,7 +445,6 @@ class FlashCausalLM(Model):
# Append next token to all tokens
all_input_ids.append(next_token_id_item)
all_input_ids_tensor[input_length] = next_token_id_item
new_input_length = input_length + 1
# Generated token
next_token_logprob = logprobs[-1, next_token_id_item]
@ -378,32 +474,23 @@ class FlashCausalLM(Model):
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
# CAUTION: generation will be stopped so no need to pad
# This will make the next forward crash if the request does not get filtered
new_input_length = input_length
past = present[:, start_index:end_index]
else:
# Keep request in the batch
next_batch_keep_indices.append(i)
stopped = False
generated_text = None
# Get sequence present
seq_present = present[:, start_index:end_index]
# Pad it for next iter attention
past = torch.nn.functional.pad(seq_present, (0, 0, 0, 0, 0, 0, 0, 1))
next_batch_past_key_values.append(past)
next_batch_input_ids.append(next_token_id)
next_batch_position_ids.append(input_length)
# Cumulative sum
next_batch_cu_seqlens.append(
next_batch_cu_seqlens[-1] + new_input_length
# Pad present for next iter attention
new_input_length = input_length + 1
past = torch.nn.functional.pad(
present[:, start_index:end_index], (0, 0, 0, 0, 0, 0, 0, 1)
)
next_batch_input_lengths.append(new_input_length)
next_batch_offsets.append(offset)
next_batch_token_offsets.append(token_offset)
next_batch_all_input_ids.append(all_input_ids)
next_batch_all_input_ids_tensor.append(all_input_ids_tensor)
next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length)
# Prefill
if stopping_criteria.current_tokens == 1:
if prefill:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + logprobs.gather(
1, all_input_ids_tensor[1:input_length].unsqueeze(1)
@ -433,52 +520,18 @@ class FlashCausalLM(Model):
generations.append(generation)
cumulative_length += input_length
# We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices:
return generations, None
# Update values
batch.input_ids[i] = next_token_id
batch.position_ids[i] = input_length
batch.input_lengths[i] = new_input_length
batch.offsets[i] = offset
batch.token_offsets[i] = token_offset
batch.all_input_ids[i] = all_input_ids
batch.all_input_ids_tensor[i] = all_input_ids_tensor
batch.max_seqlen = max(batch.max_seqlen, new_input_length)
batch.past_key_values[i] = past
# Cumulative sum
batch.cu_seqlens[(i + 1)] = batch.cu_seqlens[i] + new_input_length
# If we finished at least one generation, we need to evict the indices of the generations that finished
# from the values of the next batch
if len(next_batch_keep_indices) != len(batch):
# Apply indices to requests, token_choosers and stopping_criterias that need to be cached
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
next_batch_next_token_choosers = [
batch.next_token_choosers[i] for i in next_batch_keep_indices
]
next_batch_stopping_criterias = [
batch.stopping_criterias[i] for i in next_batch_keep_indices
]
else:
next_batch_requests = batch.requests
next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias
# Create final next batch tensors
next_batch_position_ids = torch.tensor(
next_batch_position_ids, dtype=torch.int32
)
next_batch_cu_seqlens = torch.tensor(next_batch_cu_seqlens, dtype=torch.int32)
if len(next_batch_keep_indices) > 1:
next_batch_input_ids = torch.concat(next_batch_input_ids).squeeze(1)
next_batch_past_key_values = torch.concat(next_batch_past_key_values, dim=1)
else:
next_batch_input_ids = next_batch_input_ids[0].view(1)
next_batch_past_key_values = next_batch_past_key_values[0]
next_batch = FlashCausalLMBatch(
batch_id=batch.batch_id,
requests=next_batch_requests,
input_ids=next_batch_input_ids,
position_ids=next_batch_position_ids,
cu_seqlens=next_batch_cu_seqlens,
max_seqlen=next_batch_max_seqlen,
past_key_values=next_batch_past_key_values,
input_lengths=next_batch_input_lengths,
offsets=next_batch_offsets,
token_offsets=next_batch_token_offsets,
all_input_ids=next_batch_all_input_ids,
all_input_ids_tensor=next_batch_all_input_ids_tensor,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
)
return generations, next_batch
# No need to return a batch if we know that all requests stopped
return generations, batch if not stopped else None