From 7c11ceba6c7ed52305490a165a7331824db6c4d2 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 18 May 2023 17:01:20 -0400 Subject: [PATCH] Extract kv cache stuff --- .../models/vectorized_causal_lm.py | 53 ++++++++++++------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/server/text_generation_server/models/vectorized_causal_lm.py b/server/text_generation_server/models/vectorized_causal_lm.py index a3f1b634..b2194b14 100644 --- a/server/text_generation_server/models/vectorized_causal_lm.py +++ b/server/text_generation_server/models/vectorized_causal_lm.py @@ -189,6 +189,11 @@ class VectorizedCausalLMBatch(Batch): self.position_ids = self.position_ids[keep_indices, sequence_slice] self.attention_mask = self.attention_mask[keep_indices, sequence_slice] + self._filter_kv_caches(keep_indices, sequence_slice) + + return self + + def _filter_kv_caches(self, keep_indices, sequence_slice): tensors_to_update = [] if self.past_key_values is not None: if not isinstance(self.past_key_values, (list, tuple)): @@ -214,8 +219,6 @@ class VectorizedCausalLMBatch(Batch): # Update tensors in-place to allow incremental garbage collection tensor.data = tensor[kv_cache_slice] - return self - @classmethod @tracer.start_as_current_span("concatenate") def concatenate( @@ -289,6 +292,32 @@ class VectorizedCausalLMBatch(Batch): for batch in batches ) + kv_cache_seq_dim = batches[0].kv_cache_seq_dim + past_key_values=cls._concatenate_key_values(batches, start_indices, end_indices, left_indices) + + return cls( + batch_id=batches[0].batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + input_lengths=input_lengths, + offsets=offsets, + token_offsets=token_offsets, + next_token_chooser=next_token_chooser, + stopping_criterias=stopping_criterias, + max_input_length=max_input_length, + kv_cache_seq_dim=kv_cache_seq_dim, + max_tokens=max_tokens, + ) + + @classmethod + def _concatenate_key_values(cls, batches, start_indices, end_indices, left_indices): + device = batches[0].input_ids.device + batch_size = sum([len(batch.requests) for batch in batches]) + kv_formats = None for batch in batches: if batch.past_key_values is None: @@ -358,28 +387,12 @@ class VectorizedCausalLMBatch(Batch): else: past_key_values[-1].append(kv_cache) - return cls( - batch_id=batches[0].batch_id, - requests=requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - input_lengths=input_lengths, - offsets=offsets, - token_offsets=token_offsets, - next_token_chooser=next_token_chooser, - stopping_criterias=stopping_criterias, - max_input_length=max_input_length, - kv_cache_seq_dim=kv_cache_seq_dim, - max_tokens=max_tokens, - ) + return + def __len__(self): return len(self.requests) - class VectorizedCausalLM(Model): def __init__( self,