mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 08:22:07 +00:00
Co-authored-by: mswiniarsk <156412439+mswiniarsk@users.noreply.github.com>
This commit is contained in:
parent
3831f1bed5
commit
7dbf4bf7a4
@ -70,7 +70,7 @@ def calculate_chunks(offset):
|
||||
def biggest_single_chunk(offset):
|
||||
if offset != 0:
|
||||
idx = bisect.bisect(CHUNK_SIZES, abs(offset))
|
||||
return int(math.copysign(CHUNK_SIZES[idx-1], offset))
|
||||
return int(math.copysign(CHUNK_SIZES[idx - 1], offset))
|
||||
else:
|
||||
return 0
|
||||
|
||||
@ -145,7 +145,7 @@ def extend_tensor(tensor, padding, dim):
|
||||
|
||||
def extend_batch(tensors, target_bs, dim):
|
||||
diff = target_bs - tensors[0].size(dim)
|
||||
#TODO: add support for shrinking bs
|
||||
# TODO: add support for shrinking bs
|
||||
if diff <= 0:
|
||||
return tensors
|
||||
shape = list(tensors[0].shape)
|
||||
@ -335,11 +335,14 @@ class CausalLMBatch(Batch):
|
||||
# [[position_ids], [attention_mask], [position_ids], past_keys, past_values]
|
||||
|
||||
# move only past_keys
|
||||
dst_tensors[3:4] = grouped_move(dst_tensors[3:4], dst_dims[3:4], dst_indices, src_tensors[3:4], src_dims[3:4], src_indices)
|
||||
dst_tensors[3:4] = grouped_move(dst_tensors[3:4], dst_dims[3:4], dst_indices,
|
||||
src_tensors[3:4], src_dims[3:4], src_indices)
|
||||
# move only past_values
|
||||
dst_tensors[4:5] = grouped_move(dst_tensors[4:5], dst_dims[4:5], dst_indices, src_tensors[4:5], src_dims[4:5], src_indices)
|
||||
dst_tensors[4:5] = grouped_move(dst_tensors[4:5], dst_dims[4:5], dst_indices,
|
||||
src_tensors[4:5], src_dims[4:5], src_indices)
|
||||
# move only input_ids, attention_mask and position_ids
|
||||
dst_tensors[:3] = grouped_move(dst_tensors[:3], dst_dims[:3], dst_indices, src_tensors[:3], src_dims[:3], src_indices)
|
||||
dst_tensors[:3] = grouped_move(dst_tensors[:3], dst_dims[:3], dst_indices,
|
||||
src_tensors[:3], src_dims[:3], src_indices)
|
||||
self.set_tensor_groups(dst_tensors)
|
||||
|
||||
@classmethod
|
||||
@ -792,9 +795,11 @@ class CausalLM(Model):
|
||||
if self.is_optimized_for_gaudi:
|
||||
if prefill:
|
||||
# no right padding for prefill
|
||||
token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device)
|
||||
token_idx_scalar = batch.attention_mask.shape[-1] - 1
|
||||
token_idx = torch.tensor(token_idx_scalar).to(self.device)
|
||||
else:
|
||||
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
|
||||
token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding
|
||||
token_idx = torch.tensor(token_idx_scalar).to(self.device)
|
||||
else:
|
||||
token_idx = None
|
||||
|
||||
@ -802,11 +807,11 @@ class CausalLM(Model):
|
||||
input_length = batch.input_length
|
||||
if self.is_optimized_for_gaudi and logits.shape[-2] > 1:
|
||||
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
|
||||
batch.input_ids[:, :token_idx], logits[:, input_length - 1: input_length, :].squeeze(-2)
|
||||
batch.input_ids[:, :token_idx_scalar], logits[:, input_length - 1: input_length, :].squeeze(-2)
|
||||
)
|
||||
else:
|
||||
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
|
||||
batch.input_ids[:, :token_idx], logits.squeeze(-2)
|
||||
batch.input_ids[:, :token_idx_scalar], logits.squeeze(-2)
|
||||
)
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens,
|
||||
@ -853,7 +858,7 @@ class CausalLM(Model):
|
||||
|
||||
# Update position_ids
|
||||
if prefill:
|
||||
batch.position_ids = batch.position_ids[:, token_idx - 1: token_idx] + 1
|
||||
batch.position_ids = torch.index_select(batch.position_ids, 1, token_idx - 1) + 1
|
||||
else:
|
||||
batch.position_ids += 1
|
||||
# Update past key values
|
||||
|
Loading…
Reference in New Issue
Block a user