This commit is contained in:
OlivierDehaene 2023-04-18 17:51:41 +02:00
parent 9476170dda
commit 2ad7a63761
3 changed files with 288 additions and 247 deletions

118
Cargo.lock generated
View File

@ -42,42 +42,51 @@ dependencies = [
[[package]] [[package]]
name = "anstream" name = "anstream"
version = "0.2.6" version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "342258dd14006105c2b75ab1bd7543a03bdf0cfc94383303ac212a04939dff6f" checksum = "9e579a7752471abc2a8268df8b20005e3eadd975f585398f17efcfd8d4927371"
dependencies = [ dependencies = [
"anstyle", "anstyle",
"anstyle-parse", "anstyle-parse",
"anstyle-query",
"anstyle-wincon", "anstyle-wincon",
"concolor-override", "colorchoice",
"concolor-query",
"is-terminal", "is-terminal",
"utf8parse", "utf8parse",
] ]
[[package]] [[package]]
name = "anstyle" name = "anstyle"
version = "0.3.5" version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23ea9e81bd02e310c216d080f6223c179012256e5151c41db88d12c88a1684d2" checksum = "41ed9a86bf92ae6580e0a31281f65a1b1d867c0cc68d5346e2ae128dddfa6a7d"
[[package]] [[package]]
name = "anstyle-parse" name = "anstyle-parse"
version = "0.1.1" version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7d1bb534e9efed14f3e5f44e7dd1a4f709384023a4165199a4241e18dff0116" checksum = "e765fd216e48e067936442276d1d57399e37bce53c264d6fefbe298080cb57ee"
dependencies = [ dependencies = [
"utf8parse", "utf8parse",
] ]
[[package]] [[package]]
name = "anstyle-wincon" name = "anstyle-query"
version = "0.2.0" version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3127af6145b149f3287bb9a0d10ad9c5692dba8c53ad48285e5bec4063834fa" checksum = "5ca11d4be1bab0c8bc8734a9aa7bf4ee8316d462a08c6ac5052f888fef5b494b"
dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "anstyle-wincon"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4bcd8291a340dd8ac70e18878bc4501dd7b4ff970cfa21c207d36ece51ea88fd"
dependencies = [ dependencies = [
"anstyle", "anstyle",
"windows-sys 0.45.0", "windows-sys 0.48.0",
] ]
[[package]] [[package]]
@ -105,7 +114,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.14", "syn 2.0.15",
] ]
[[package]] [[package]]
@ -116,7 +125,7 @@ checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.14", "syn 2.0.15",
] ]
[[package]] [[package]]
@ -127,9 +136,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]] [[package]]
name = "axum" name = "axum"
version = "0.6.13" version = "0.6.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6539e4565c365448d483967c6dee3eaecb8e87679a17806a831e82b05b903c18" checksum = "3b32c5ea3aabaf4deb5f5ced2d688ec0844c881c9e6c696a8b769a05fc691e62"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"axum-core", "axum-core",
@ -310,9 +319,9 @@ dependencies = [
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.2.1" version = "4.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "046ae530c528f252094e4a77886ee1374437744b2bff1497aa898bbddbbb29b3" checksum = "9b802d85aaf3a1cdb02b224ba472ebdea62014fccfcb269b95a4d76443b5ee5a"
dependencies = [ dependencies = [
"clap_builder", "clap_builder",
"clap_derive", "clap_derive",
@ -321,9 +330,9 @@ dependencies = [
[[package]] [[package]]
name = "clap_builder" name = "clap_builder"
version = "4.2.1" version = "4.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "223163f58c9a40c3b0a43e1c4b50a9ce09f007ea2cb1ec258a687945b4b7929f" checksum = "14a1a858f532119338887a4b8e1af9c60de8249cd7bafd68036a489e261e37b6"
dependencies = [ dependencies = [
"anstream", "anstream",
"anstyle", "anstyle",
@ -341,7 +350,7 @@ dependencies = [
"heck", "heck",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.14", "syn 2.0.15",
] ]
[[package]] [[package]]
@ -351,19 +360,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a2dd5a6fe8c6e3502f568a6353e5273bbb15193ad9a89e457b9970798efbea1" checksum = "8a2dd5a6fe8c6e3502f568a6353e5273bbb15193ad9a89e457b9970798efbea1"
[[package]] [[package]]
name = "concolor-override" name = "colorchoice"
version = "1.0.0" version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a855d4a1978dc52fb0536a04d384c2c0c1aa273597f08b77c8c4d3b2eec6037f" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7"
[[package]]
name = "concolor-query"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88d11d52c3d7ca2e6d0040212be9e4dbbcd78b6447f535b6b561f449427944cf"
dependencies = [
"windows-sys 0.45.0",
]
[[package]] [[package]]
name = "console" name = "console"
@ -794,7 +794,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.14", "syn 2.0.15",
] ]
[[package]] [[package]]
@ -868,9 +868,9 @@ dependencies = [
[[package]] [[package]]
name = "h2" name = "h2"
version = "0.3.16" version = "0.3.18"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5be7b54589b581f624f566bf5d8eb2bab1db736c51528720b6bd36b96b55924d" checksum = "17f8a914c2987b688368b5138aa05321db91f4090cf26118185672ad588bce21"
dependencies = [ dependencies = [
"bytes", "bytes",
"fnv", "fnv",
@ -966,9 +966,9 @@ checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421"
[[package]] [[package]]
name = "hyper" name = "hyper"
version = "0.14.25" version = "0.14.26"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc5e554ff619822309ffd57d8734d77cd5ce6238bc956f037ea06c58238c9899" checksum = "ab302d72a6f11a3b910431ff93aae7e773078c769f0a3ef15fb9ec692ed147d4"
dependencies = [ dependencies = [
"bytes", "bytes",
"futures-channel", "futures-channel",
@ -1364,7 +1364,7 @@ checksum = "8795add3e14028f11f8e848bd3294898a8294767b3776b6f733560d33bd2530b"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.14", "syn 2.0.15",
] ]
[[package]] [[package]]
@ -1517,7 +1517,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.14", "syn 2.0.15",
] ]
[[package]] [[package]]
@ -1787,9 +1787,9 @@ dependencies = [
[[package]] [[package]]
name = "prost" name = "prost"
version = "0.11.8" version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e48e50df39172a3e7eb17e14642445da64996989bc212b583015435d39a58537" checksum = "0b82eaa1d779e9a4bc1c3217db8ffbeabaae1dca241bf70183242128d48681cd"
dependencies = [ dependencies = [
"bytes", "bytes",
"prost-derive", "prost-derive",
@ -1797,9 +1797,9 @@ dependencies = [
[[package]] [[package]]
name = "prost-build" name = "prost-build"
version = "0.11.8" version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c828f93f5ca4826f97fedcbd3f9a536c16b12cff3dbbb4a007f932bbad95b12" checksum = "119533552c9a7ffacc21e099c24a0ac8bb19c2a2a3f363de84cd9b844feab270"
dependencies = [ dependencies = [
"bytes", "bytes",
"heck", "heck",
@ -1819,9 +1819,9 @@ dependencies = [
[[package]] [[package]]
name = "prost-derive" name = "prost-derive"
version = "0.11.8" version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ea9b0f8cbe5e15a8a042d030bd96668db28ecb567ec37d691971ff5731d2b1b" checksum = "e5d2d8d10f3c6ded6da8b05b5fb3b8a5082514344d56c9f871412d29b4e075b4"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"itertools 0.10.5", "itertools 0.10.5",
@ -1832,9 +1832,9 @@ dependencies = [
[[package]] [[package]]
name = "prost-types" name = "prost-types"
version = "0.11.8" version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "379119666929a1afd7a043aa6cf96fa67a6dce9af60c88095a4686dbce4c9c88" checksum = "213622a1460818959ac1181aaeb2dc9c7f63df720db7d788b3e24eacd1983e13"
dependencies = [ dependencies = [
"prost", "prost",
] ]
@ -2153,14 +2153,14 @@ checksum = "291a097c63d8497e00160b166a967a4a79c64f3facdd01cbd7502231688d77df"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.14", "syn 2.0.15",
] ]
[[package]] [[package]]
name = "serde_json" name = "serde_json"
version = "1.0.95" version = "1.0.96"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d721eca97ac802aa7777b701877c8004d950fc142651367300d21c1cc0194744" checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1"
dependencies = [ dependencies = [
"itoa", "itoa",
"ryu", "ryu",
@ -2330,9 +2330,9 @@ dependencies = [
[[package]] [[package]]
name = "syn" name = "syn"
version = "2.0.14" version = "2.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcf316d5356ed6847742d036f8a39c3b8435cac10bd528a4bd461928a6ab34d5" checksum = "a34fcf3e8b60f57e6a14301a2e916d323af98b0ea63c599441eec8558660c822"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -2450,7 +2450,7 @@ checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.14", "syn 2.0.15",
] ]
[[package]] [[package]]
@ -2578,7 +2578,7 @@ checksum = "61a573bdc87985e9d6ddeed1b3d864e8a302c847e40d647746df2f1de209d1ce"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.14", "syn 2.0.15",
] ]
[[package]] [[package]]
@ -2928,9 +2928,9 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a"
[[package]] [[package]]
name = "utoipa" name = "utoipa"
version = "3.2.1" version = "3.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24e7ee17c9ef094b86e1e04170d90765bd76cb381921dacb4d3e175a267bdae6" checksum = "68ae74ef183fae36d650f063ae7bde1cacbe1cd7e72b617cbe1e985551878b98"
dependencies = [ dependencies = [
"indexmap", "indexmap",
"serde", "serde",
@ -2940,14 +2940,14 @@ dependencies = [
[[package]] [[package]]
name = "utoipa-gen" name = "utoipa-gen"
version = "3.2.1" version = "3.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df6f458e5abc811d44aca28455efc4163fb7565a7af2aa32d17611f3d1d9794d" checksum = "7ea8ac818da7e746a63285594cce8a96f5e00ee31994e655bd827569cb8b137b"
dependencies = [ dependencies = [
"proc-macro-error", "proc-macro-error",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.14", "syn 2.0.15",
] ]
[[package]] [[package]]

View File

@ -158,9 +158,7 @@ class CausalLMBatch(Batch):
request_input_length = self.input_lengths[idx] request_input_length = self.input_lengths[idx]
input_lengths.append(request_input_length) input_lengths.append(request_input_length)
max_input_length = max( max_input_length = max(max_input_length, request_input_length)
max_input_length, request_input_length
)
# Replace metadata # Replace metadata
self.requests_idx_mapping = requests_idx_mapping self.requests_idx_mapping = requests_idx_mapping
@ -176,19 +174,12 @@ class CausalLMBatch(Batch):
self.position_ids = self.position_ids[keep_indices] self.position_ids = self.position_ids[keep_indices]
# Force past to be of dim [self_size, num_heads, ...] for easy indexing # Force past to be of dim [self_size, num_heads, ...] for easy indexing
self.past_key_values = [ self.past_key_values = [
[ [t.view(len(self), -1, *t.shape[-2:])[keep_indices] for t in layer]
t.view(len(self), -1, *t.shape[-2:])[keep_indices]
for t in layer
]
for layer in self.past_key_values for layer in self.past_key_values
] ]
self.requests = [self.requests[i] for i in keep_indices] self.requests = requests
self.next_token_choosers = [ self.next_token_choosers = [self.next_token_choosers[i] for i in keep_indices]
self.next_token_choosers[i] for i in keep_indices self.stopping_criterias = [self.stopping_criterias[i] for i in keep_indices]
]
self.stopping_criterias = [
self.stopping_criterias[i] for i in keep_indices
]
return self return self
@ -435,11 +426,9 @@ class CausalLM(Model):
batch.past_key_values, batch.past_key_values,
) )
# New values for next forward
next_batch_input_ids = []
# Results # Results
generations: List[Generation] = [] generations: List[Generation] = []
stopped = True
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
@ -503,6 +492,7 @@ class CausalLM(Model):
else: else:
# Keep request in the batch # Keep request in the batch
generated_text = None generated_text = None
stopped = False
# Prefill # Prefill
if stopping_criteria.current_tokens == 1: if stopping_criteria.current_tokens == 1:
@ -535,7 +525,7 @@ class CausalLM(Model):
generations.append(generation) generations.append(generation)
# Update values # Update values
next_batch_input_ids.append(next_token_id) batch.input_ids[i] = next_token_id
batch.all_input_ids[i] = all_input_ids batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length batch.input_lengths[i] = new_input_length
batch.offsets[i] = offset batch.offsets[i] = offset
@ -544,8 +534,6 @@ class CausalLM(Model):
# Decrease right offset # Decrease right offset
batch.padding_right_offset -= 1 batch.padding_right_offset -= 1
# Create input_ids tensor
batch.input_ids = torch.cat(next_batch_input_ids, dim=0)
# Update attention_mask as we added a new token to input_ids # Update attention_mask as we added a new token to input_ids
batch.attention_mask[:, -batch.padding_right_offset] = 1 batch.attention_mask[:, -batch.padding_right_offset] = 1
@ -555,4 +543,4 @@ class CausalLM(Model):
# Update past key values # Update past key values
batch.past_key_values = past batch.past_key_values = past
return generations, batch return generations, batch if not stopped else None

View File

@ -6,7 +6,7 @@ from torch.nn import functional as F
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
from typing import Optional, Tuple, List, Type, Union from typing import Optional, Tuple, List, Type, Union, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.models.types import ( from text_generation_server.models.types import (
@ -29,14 +29,16 @@ tracer = trace.get_tracer(__name__)
class FlashCausalLMBatch(Batch): class FlashCausalLMBatch(Batch):
batch_id: int batch_id: int
requests: List[generate_pb2.Request] requests: List[generate_pb2.Request]
# request id -> idx in list mapping
requests_idx_mapping: Dict[int, int]
# Decoder values # Decoder values
input_ids: torch.Tensor input_ids: List[torch.Tensor]
position_ids: torch.Tensor position_ids: List[torch.Tensor]
# cumulative sequence lengths # cumulative sequence lengths
cu_seqlens: torch.Tensor cu_seqlens: List[int]
max_seqlen: int max_seqlen: int
past_key_values: Optional[torch.Tensor] past_key_values: Optional[List[torch.Tensor]]
# All tokens # All tokens
all_input_ids: List[List[int]] all_input_ids: List[List[int]]
@ -62,7 +64,7 @@ class FlashCausalLMBatch(Batch):
pb: generate_pb2.Batch, pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
device: torch.device, device: torch.device,
) -> "CausalLMBatch": ) -> "FlashCausalLMBatch":
input_ids = [] input_ids = []
position_ids = [] position_ids = []
cu_seqlens = [0] cu_seqlens = [0]
@ -73,6 +75,7 @@ class FlashCausalLMBatch(Batch):
token_offsets = [] token_offsets = []
all_input_ids = [] all_input_ids = []
all_input_ids_tensor = [] all_input_ids_tensor = []
requests_idx_mapping = {}
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
@ -81,13 +84,18 @@ class FlashCausalLMBatch(Batch):
cumulative_length = 0 cumulative_length = 0
# Parse batch # Parse batch
for r in pb.requests: for i, r in enumerate(pb.requests):
# request id -> idx in list mapping
requests_idx_mapping[r.id] = i
tokenized_input = tokenizer( tokenized_input = tokenizer(
r.inputs, truncation=True, max_length=r.truncate r.inputs, truncation=True, max_length=r.truncate
)["input_ids"] )["input_ids"]
input_length = len(tokenized_input) input_length = len(tokenized_input)
max_seqlen = max(max_seqlen, input_length) max_seqlen = max(max_seqlen, input_length)
input_lengths.append(input_length) input_lengths.append(input_length)
offsets.append(None) offsets.append(None)
token_offsets.append(None) token_offsets.append(None)
all_input_ids.append(tokenized_input) all_input_ids.append(tokenized_input)
@ -96,7 +104,9 @@ class FlashCausalLMBatch(Batch):
input_ids.append(tokenized_input) input_ids.append(tokenized_input)
# Position ids # Position ids
position_ids.append(torch.arange(0, input_length, dtype=torch.int32)) position_ids.append(
torch.arange(0, input_length, dtype=torch.int32, device=device)
)
# Add cumulative lengths of all previous inputs # Add cumulative lengths of all previous inputs
cu_seqlens.append(cumulative_length + input_length) cu_seqlens.append(cumulative_length + input_length)
@ -113,13 +123,10 @@ class FlashCausalLMBatch(Batch):
# Update # Update
cumulative_length += input_length cumulative_length += input_length
input_ids = torch.concat(input_ids)
position_ids = torch.concat(position_ids)
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
requests=pb.requests, requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
@ -134,60 +141,138 @@ class FlashCausalLMBatch(Batch):
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
) )
@tracer.start_as_current_span("filter")
def filter(self, requests: List[generate_pb2.Request]) -> "FlashCausalLMBatch":
if len(requests) == 0:
raise ValueError("Batch must have at least one request")
# We assume that if len(requests) == len(self) then the requests are the same
if len(requests) == len(self):
return self
# Cumulative length
cumulative_length = 0
# New values after filtering
requests_idx_mapping = {}
input_ids = []
position_ids = []
cu_seqlens = [0]
max_seqlen = 0
past_key_values = []
all_input_ids = []
all_input_ids_tensor = []
input_lengths = []
offsets = []
token_offsets = []
next_token_choosers = []
stopping_criterias = []
for i, r in enumerate(requests):
idx = self.requests_idx_mapping[r.id]
requests_idx_mapping[r.id] = i
# Get length
request_input_length = self.input_lengths[idx]
input_ids.append(self.input_ids[idx])
position_ids.append(self.position_ids[idx])
cu_seqlens.append(cumulative_length + request_input_length)
max_seqlen = max(max_seqlen, request_input_length)
past_key_values.append(self.past_key_values[idx])
all_input_ids.append(self.all_input_ids[idx])
all_input_ids_tensor.append(self.all_input_ids_tensor[idx])
input_lengths.append(request_input_length)
offsets.append(self.offsets[idx])
token_offsets.append(self.token_offsets[idx])
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criterias.append(self.stopping_criterias[idx])
cumulative_length += request_input_length
return FlashCausalLMBatch(
batch_id=self.batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
past_key_values=past_key_values,
input_lengths=input_lengths,
offsets=offsets,
token_offsets=token_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
)
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch": def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
# Batch attributes # Batch attributes
requests = [] requests = []
input_lengths = [] requests_idx_mapping = {}
offsets = []
token_offsets = []
all_input_ids = []
all_input_ids_tensor = []
next_token_choosers = []
stopping_criterias = []
# Batch tensors
input_ids = [] input_ids = []
position_ids = [] position_ids = []
cu_seqlens = [torch.tensor([0], dtype=torch.int32)] cu_seqlens = [0]
max_seqlen = 0 max_seqlen = 0
past_key_values = [] past_key_values = []
all_input_ids = []
all_input_ids_tensor = []
input_lengths = []
offsets = []
token_offsets = []
next_token_choosers = []
stopping_criterias = []
# Cumulative length # Cumulative length
cumulative_length = torch.tensor(0) cumulative_batch_size = 0
cumulative_length = 0
for i, batch in enumerate(batches): for i, batch in enumerate(batches):
requests.extend(batch.requests) requests.extend(batch.requests)
# We need to offset the mapping for each batch by the cumulative batch size
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + cumulative_batch_size
input_ids.extend(batch.input_ids)
position_ids.extend(batch.position_ids)
# Add cumulative lengths of all previous inputs
cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]])
max_seqlen = max(max_seqlen, batch.max_seqlen)
past_key_values.extend(batch.past_key_values)
all_input_ids.extend(batch.all_input_ids)
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
input_lengths.extend(batch.input_lengths) input_lengths.extend(batch.input_lengths)
offsets.extend(batch.offsets) offsets.extend(batch.offsets)
token_offsets.extend(batch.token_offsets) token_offsets.extend(batch.token_offsets)
all_input_ids.extend(batch.all_input_ids)
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
next_token_choosers.extend(batch.next_token_choosers) next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
# Add cumulative lengths of all previous inputs
cu_seqlens.append(batch.cu_seqlens[1:] + cumulative_length)
input_ids.append(batch.input_ids)
position_ids.append(batch.position_ids)
past_key_values.append(batch.past_key_values)
max_seqlen = max(max_seqlen, batch.max_seqlen)
# Update # Update
cumulative_length += batch.cu_seqlens[-1] cumulative_length += batch.cu_seqlens[-1]
cumulative_batch_size += len(batch)
input_ids = torch.concat(input_ids)
position_ids = torch.concat(position_ids)
# Concat on dim=1 as first dim represents the model layers
past_key_values = torch.concat(past_key_values, dim=1)
cu_seqlens = torch.concat(cu_seqlens)
return FlashCausalLMBatch( return FlashCausalLMBatch(
batch_id=batches[0].batch_id, batch_id=batches[0].batch_id,
requests=requests, requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
@ -269,38 +354,49 @@ class FlashCausalLM(Model):
def generate_token( def generate_token(
self, batch: FlashCausalLMBatch self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
# Better to send to device here to avoid device issues in concatenate # Shortcut when batch_size == 1
position_ids = batch.position_ids.to(self.device, non_blocking=True) if len(batch) == 1:
cu_seqlens = batch.cu_seqlens.to(self.device) input_ids = batch.input_ids[0].view(-1)
past_key_values = (
batch.past_key_values[0] if batch.past_key_values is not None else None
)
else:
# Concatenate tensors
input_ids = torch.cat(batch.input_ids).view(-1)
past_key_values = (
torch.cat(batch.past_key_values, dim=1)
if batch.past_key_values is not None
else None
)
# Concatenate when prefill, torch.tensor when decode
position_ids = (
torch.tensor(batch.position_ids, device=self.device)
if batch.past_key_values is not None
else torch.cat(batch.position_ids)
)
cu_seqlens = torch.tensor(
batch.cu_seqlens, device=self.device, dtype=torch.int32
)
out, present = self.forward( out, present = self.forward(
batch.input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, cu_seqlens,
batch.max_seqlen, batch.max_seqlen,
batch.past_key_values, past_key_values,
) )
# List of indices to cache # Initialize past_key_values in prefill
next_batch_keep_indices = [] if batch.past_key_values is None:
batch.past_key_values = [None] * len(batch)
# New values for next forward
next_batch_input_ids = []
next_batch_position_ids = []
next_batch_cu_seqlens = [0]
next_batch_max_seqlen = 0
next_batch_past_key_values = []
next_batch_input_lengths = []
next_batch_offsets = []
next_batch_token_offsets = []
next_batch_all_input_ids = []
next_batch_all_input_ids_tensor = []
# Cumulative length # Cumulative length
cumulative_length = 0 cumulative_length = 0
# Results # Results
generations: List[Generation] = [] generations: List[Generation] = []
stopped = True
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
@ -329,7 +425,8 @@ class FlashCausalLM(Model):
start_index = cumulative_length start_index = cumulative_length
end_index = cumulative_length + input_length end_index = cumulative_length + input_length
if batch.past_key_values is None: prefill = stopping_criteria.current_tokens == 0
if prefill:
# Prefill mode # Prefill mode
# out is of shape [cumulative_sequence_lengths, vocab_size] # out is of shape [cumulative_sequence_lengths, vocab_size]
logits = out[start_index:end_index] logits = out[start_index:end_index]
@ -348,7 +445,6 @@ class FlashCausalLM(Model):
# Append next token to all tokens # Append next token to all tokens
all_input_ids.append(next_token_id_item) all_input_ids.append(next_token_id_item)
all_input_ids_tensor[input_length] = next_token_id_item all_input_ids_tensor[input_length] = next_token_id_item
new_input_length = input_length + 1
# Generated token # Generated token
next_token_logprob = logprobs[-1, next_token_id_item] next_token_logprob = logprobs[-1, next_token_id_item]
@ -378,32 +474,23 @@ class FlashCausalLM(Model):
generated_text = GeneratedText( generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed output_text, stopping_criteria.current_tokens, reason, seed
) )
# CAUTION: generation will be stopped so no need to pad
# This will make the next forward crash if the request does not get filtered
new_input_length = input_length
past = present[:, start_index:end_index]
else: else:
# Keep request in the batch stopped = False
next_batch_keep_indices.append(i)
generated_text = None generated_text = None
# Get sequence present # Pad present for next iter attention
seq_present = present[:, start_index:end_index] new_input_length = input_length + 1
# Pad it for next iter attention past = torch.nn.functional.pad(
past = torch.nn.functional.pad(seq_present, (0, 0, 0, 0, 0, 0, 0, 1)) present[:, start_index:end_index], (0, 0, 0, 0, 0, 0, 0, 1)
next_batch_past_key_values.append(past)
next_batch_input_ids.append(next_token_id)
next_batch_position_ids.append(input_length)
# Cumulative sum
next_batch_cu_seqlens.append(
next_batch_cu_seqlens[-1] + new_input_length
) )
next_batch_input_lengths.append(new_input_length)
next_batch_offsets.append(offset)
next_batch_token_offsets.append(token_offset)
next_batch_all_input_ids.append(all_input_ids)
next_batch_all_input_ids_tensor.append(all_input_ids_tensor)
next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length)
# Prefill # Prefill
if stopping_criteria.current_tokens == 1: if prefill:
# Remove generated token to only have prefill and add nan for first prompt token # Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + logprobs.gather( prefill_logprobs = [float("nan")] + logprobs.gather(
1, all_input_ids_tensor[1:input_length].unsqueeze(1) 1, all_input_ids_tensor[1:input_length].unsqueeze(1)
@ -433,52 +520,18 @@ class FlashCausalLM(Model):
generations.append(generation) generations.append(generation)
cumulative_length += input_length cumulative_length += input_length
# We finished all generations in the batch; there is no next batch # Update values
if not next_batch_keep_indices: batch.input_ids[i] = next_token_id
return generations, None batch.position_ids[i] = input_length
batch.input_lengths[i] = new_input_length
batch.offsets[i] = offset
batch.token_offsets[i] = token_offset
batch.all_input_ids[i] = all_input_ids
batch.all_input_ids_tensor[i] = all_input_ids_tensor
batch.max_seqlen = max(batch.max_seqlen, new_input_length)
batch.past_key_values[i] = past
# Cumulative sum
batch.cu_seqlens[(i + 1)] = batch.cu_seqlens[i] + new_input_length
# If we finished at least one generation, we need to evict the indices of the generations that finished # No need to return a batch if we know that all requests stopped
# from the values of the next batch return generations, batch if not stopped else None
if len(next_batch_keep_indices) != len(batch):
# Apply indices to requests, token_choosers and stopping_criterias that need to be cached
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
next_batch_next_token_choosers = [
batch.next_token_choosers[i] for i in next_batch_keep_indices
]
next_batch_stopping_criterias = [
batch.stopping_criterias[i] for i in next_batch_keep_indices
]
else:
next_batch_requests = batch.requests
next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias
# Create final next batch tensors
next_batch_position_ids = torch.tensor(
next_batch_position_ids, dtype=torch.int32
)
next_batch_cu_seqlens = torch.tensor(next_batch_cu_seqlens, dtype=torch.int32)
if len(next_batch_keep_indices) > 1:
next_batch_input_ids = torch.concat(next_batch_input_ids).squeeze(1)
next_batch_past_key_values = torch.concat(next_batch_past_key_values, dim=1)
else:
next_batch_input_ids = next_batch_input_ids[0].view(1)
next_batch_past_key_values = next_batch_past_key_values[0]
next_batch = FlashCausalLMBatch(
batch_id=batch.batch_id,
requests=next_batch_requests,
input_ids=next_batch_input_ids,
position_ids=next_batch_position_ids,
cu_seqlens=next_batch_cu_seqlens,
max_seqlen=next_batch_max_seqlen,
past_key_values=next_batch_past_key_values,
input_lengths=next_batch_input_lengths,
offsets=next_batch_offsets,
token_offsets=next_batch_token_offsets,
all_input_ids=next_batch_all_input_ids,
all_input_ids_tensor=next_batch_all_input_ids_tensor,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
)
return generations, next_batch