This commit is contained in:
Joel Lamy-Poirier 2023-05-25 15:52:14 -04:00
parent a515fbde4c
commit 72eefa3612
No known key found for this signature in database
GPG Key ID: 82EE2141E842DFCF

View File

@ -449,7 +449,7 @@ class FlashCausalLM(Model):
pre_allocate_past_size=pre_allocate_past_size,
)
def fast_forward(self, batch: FlashCausalLMBatch, max_input_length: int, use_cache: Optional[torch.dtype]):
def fast_forward(self, batch: FlashCausalLMBatch, max_input_length: int, cache_dtype: Optional[torch.dtype]):
diff = max_input_length - max(batch.input_lengths)
for i in range(len(batch)):
batch.input_lengths[i] += diff
@ -459,7 +459,19 @@ class FlashCausalLM(Model):
# TODO: Bug!?!
batch.stopping_criterias[i].current_tokens += diff
if use_cache:
if cache_dtype is None:
batch.max_seqlen = max(batch.input_lengths)
batch.all_input_ids_tensor=[]
batch.input_ids = torch.tensor(
np.concatenate([np.arange(0, input_length) for input_length in batch.input_lengths]), dtype=torch.int64, device=batch.input_ids.device
)
batch.position_ids = torch.tensor(
np.concatenate([np.arange(0, input_length) for input_length in batch.input_lengths]), dtype=torch.int32, device=batch.input_ids.device
)
batch.cu_seqlens = torch.tensor(np.pad(np.cumsum(batch.input_lengths),(1,0)), device=batch.input_ids.device, dtype=torch.int32)
batch.past_key_values=None
else:
assert len(batch.all_input_ids_tensor)>0, "Must run prefill first"
batch.input_ids.fill_(self.tokenizer.pad_token_id)
batch.position_ids += diff
@ -475,20 +487,8 @@ class FlashCausalLM(Model):
batch.past_key_values.shape[0],
batch.past_key_values.shape[1] + len(batch.requests),
*batch.past_key_values.shape[2:],
), device=batch.past_key_values.device, dtype= batch.past_key_values.dtype
), device=batch.past_key_values.device, dtype= cache_dtype
)
else:
batch.max_seqlen = max(batch.input_lengths)
batch.all_input_ids_tensor=[]
batch.input_ids = torch.tensor(
np.concatenate([np.arange(0, input_length) for input_length in batch.input_lengths]), dtype=torch.int64, device=batch.input_ids.device
)
batch.position_ids = torch.tensor(
np.concatenate([np.arange(0, input_length) for input_length in batch.input_lengths]), dtype=torch.int32, device=batch.input_ids.device
)
batch.cu_seqlens = torch.tensor(np.pad(np.cumsum(batch.input_lengths),(1,0)), device=batch.input_ids.device, dtype=torch.int32)
batch.past_key_values=None