Extract kv cache stuff

This commit is contained in:
Joel Lamy-Poirier 2023-05-18 17:01:20 -04:00
parent 3c725314e1
commit 7c11ceba6c
No known key found for this signature in database
GPG Key ID: 82EE2141E842DFCF

View File

@ -189,6 +189,11 @@ class VectorizedCausalLMBatch(Batch):
self.position_ids = self.position_ids[keep_indices, sequence_slice]
self.attention_mask = self.attention_mask[keep_indices, sequence_slice]
self._filter_kv_caches(keep_indices, sequence_slice)
return self
def _filter_kv_caches(self, keep_indices, sequence_slice):
tensors_to_update = []
if self.past_key_values is not None:
if not isinstance(self.past_key_values, (list, tuple)):
@ -214,8 +219,6 @@ class VectorizedCausalLMBatch(Batch):
# Update tensors in-place to allow incremental garbage collection
tensor.data = tensor[kv_cache_slice]
return self
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(
@ -289,6 +292,32 @@ class VectorizedCausalLMBatch(Batch):
for batch in batches
)
kv_cache_seq_dim = batches[0].kv_cache_seq_dim
past_key_values=cls._concatenate_key_values(batches, start_indices, end_indices, left_indices)
return cls(
batch_id=batches[0].batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
input_lengths=input_lengths,
offsets=offsets,
token_offsets=token_offsets,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
max_input_length=max_input_length,
kv_cache_seq_dim=kv_cache_seq_dim,
max_tokens=max_tokens,
)
@classmethod
def _concatenate_key_values(cls, batches, start_indices, end_indices, left_indices):
device = batches[0].input_ids.device
batch_size = sum([len(batch.requests) for batch in batches])
kv_formats = None
for batch in batches:
if batch.past_key_values is None:
@ -358,28 +387,12 @@ class VectorizedCausalLMBatch(Batch):
else:
past_key_values[-1].append(kv_cache)
return cls(
batch_id=batches[0].batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
input_lengths=input_lengths,
offsets=offsets,
token_offsets=token_offsets,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
max_input_length=max_input_length,
kv_cache_seq_dim=kv_cache_seq_dim,
max_tokens=max_tokens,
)
return
def __len__(self):
return len(self.requests)
class VectorizedCausalLM(Model):
def __init__(
self,