mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
fix
This commit is contained in:
parent
a515fbde4c
commit
72eefa3612
@ -449,7 +449,7 @@ class FlashCausalLM(Model):
|
|||||||
pre_allocate_past_size=pre_allocate_past_size,
|
pre_allocate_past_size=pre_allocate_past_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
def fast_forward(self, batch: FlashCausalLMBatch, max_input_length: int, use_cache: Optional[torch.dtype]):
|
def fast_forward(self, batch: FlashCausalLMBatch, max_input_length: int, cache_dtype: Optional[torch.dtype]):
|
||||||
diff = max_input_length - max(batch.input_lengths)
|
diff = max_input_length - max(batch.input_lengths)
|
||||||
for i in range(len(batch)):
|
for i in range(len(batch)):
|
||||||
batch.input_lengths[i] += diff
|
batch.input_lengths[i] += diff
|
||||||
@ -459,7 +459,19 @@ class FlashCausalLM(Model):
|
|||||||
# TODO: Bug!?!
|
# TODO: Bug!?!
|
||||||
batch.stopping_criterias[i].current_tokens += diff
|
batch.stopping_criterias[i].current_tokens += diff
|
||||||
|
|
||||||
if use_cache:
|
if cache_dtype is None:
|
||||||
|
batch.max_seqlen = max(batch.input_lengths)
|
||||||
|
batch.all_input_ids_tensor=[]
|
||||||
|
|
||||||
|
batch.input_ids = torch.tensor(
|
||||||
|
np.concatenate([np.arange(0, input_length) for input_length in batch.input_lengths]), dtype=torch.int64, device=batch.input_ids.device
|
||||||
|
)
|
||||||
|
batch.position_ids = torch.tensor(
|
||||||
|
np.concatenate([np.arange(0, input_length) for input_length in batch.input_lengths]), dtype=torch.int32, device=batch.input_ids.device
|
||||||
|
)
|
||||||
|
batch.cu_seqlens = torch.tensor(np.pad(np.cumsum(batch.input_lengths),(1,0)), device=batch.input_ids.device, dtype=torch.int32)
|
||||||
|
batch.past_key_values=None
|
||||||
|
else:
|
||||||
assert len(batch.all_input_ids_tensor)>0, "Must run prefill first"
|
assert len(batch.all_input_ids_tensor)>0, "Must run prefill first"
|
||||||
batch.input_ids.fill_(self.tokenizer.pad_token_id)
|
batch.input_ids.fill_(self.tokenizer.pad_token_id)
|
||||||
batch.position_ids += diff
|
batch.position_ids += diff
|
||||||
@ -475,20 +487,8 @@ class FlashCausalLM(Model):
|
|||||||
batch.past_key_values.shape[0],
|
batch.past_key_values.shape[0],
|
||||||
batch.past_key_values.shape[1] + len(batch.requests),
|
batch.past_key_values.shape[1] + len(batch.requests),
|
||||||
*batch.past_key_values.shape[2:],
|
*batch.past_key_values.shape[2:],
|
||||||
), device=batch.past_key_values.device, dtype= batch.past_key_values.dtype
|
), device=batch.past_key_values.device, dtype= cache_dtype
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
batch.max_seqlen = max(batch.input_lengths)
|
|
||||||
batch.all_input_ids_tensor=[]
|
|
||||||
|
|
||||||
batch.input_ids = torch.tensor(
|
|
||||||
np.concatenate([np.arange(0, input_length) for input_length in batch.input_lengths]), dtype=torch.int64, device=batch.input_ids.device
|
|
||||||
)
|
|
||||||
batch.position_ids = torch.tensor(
|
|
||||||
np.concatenate([np.arange(0, input_length) for input_length in batch.input_lengths]), dtype=torch.int32, device=batch.input_ids.device
|
|
||||||
)
|
|
||||||
batch.cu_seqlens = torch.tensor(np.pad(np.cumsum(batch.input_lengths),(1,0)), device=batch.input_ids.device, dtype=torch.int32)
|
|
||||||
batch.past_key_values=None
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user