mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
Extract kv cache stuff
This commit is contained in:
parent
3c725314e1
commit
7c11ceba6c
@ -189,6 +189,11 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
self.position_ids = self.position_ids[keep_indices, sequence_slice]
|
self.position_ids = self.position_ids[keep_indices, sequence_slice]
|
||||||
self.attention_mask = self.attention_mask[keep_indices, sequence_slice]
|
self.attention_mask = self.attention_mask[keep_indices, sequence_slice]
|
||||||
|
|
||||||
|
self._filter_kv_caches(keep_indices, sequence_slice)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _filter_kv_caches(self, keep_indices, sequence_slice):
|
||||||
tensors_to_update = []
|
tensors_to_update = []
|
||||||
if self.past_key_values is not None:
|
if self.past_key_values is not None:
|
||||||
if not isinstance(self.past_key_values, (list, tuple)):
|
if not isinstance(self.past_key_values, (list, tuple)):
|
||||||
@ -214,8 +219,6 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
# Update tensors in-place to allow incremental garbage collection
|
# Update tensors in-place to allow incremental garbage collection
|
||||||
tensor.data = tensor[kv_cache_slice]
|
tensor.data = tensor[kv_cache_slice]
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(
|
def concatenate(
|
||||||
@ -289,6 +292,32 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
for batch in batches
|
for batch in batches
|
||||||
)
|
)
|
||||||
|
|
||||||
|
kv_cache_seq_dim = batches[0].kv_cache_seq_dim
|
||||||
|
past_key_values=cls._concatenate_key_values(batches, start_indices, end_indices, left_indices)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
batch_id=batches[0].batch_id,
|
||||||
|
requests=requests,
|
||||||
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
offsets=offsets,
|
||||||
|
token_offsets=token_offsets,
|
||||||
|
next_token_chooser=next_token_chooser,
|
||||||
|
stopping_criterias=stopping_criterias,
|
||||||
|
max_input_length=max_input_length,
|
||||||
|
kv_cache_seq_dim=kv_cache_seq_dim,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _concatenate_key_values(cls, batches, start_indices, end_indices, left_indices):
|
||||||
|
device = batches[0].input_ids.device
|
||||||
|
batch_size = sum([len(batch.requests) for batch in batches])
|
||||||
|
|
||||||
kv_formats = None
|
kv_formats = None
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
if batch.past_key_values is None:
|
if batch.past_key_values is None:
|
||||||
@ -358,28 +387,12 @@ class VectorizedCausalLMBatch(Batch):
|
|||||||
else:
|
else:
|
||||||
past_key_values[-1].append(kv_cache)
|
past_key_values[-1].append(kv_cache)
|
||||||
|
|
||||||
return cls(
|
return
|
||||||
batch_id=batches[0].batch_id,
|
|
||||||
requests=requests,
|
|
||||||
requests_idx_mapping=requests_idx_mapping,
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
input_lengths=input_lengths,
|
|
||||||
offsets=offsets,
|
|
||||||
token_offsets=token_offsets,
|
|
||||||
next_token_chooser=next_token_chooser,
|
|
||||||
stopping_criterias=stopping_criterias,
|
|
||||||
max_input_length=max_input_length,
|
|
||||||
kv_cache_seq_dim=kv_cache_seq_dim,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.requests)
|
return len(self.requests)
|
||||||
|
|
||||||
|
|
||||||
class VectorizedCausalLM(Model):
|
class VectorizedCausalLM(Model):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user