mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 08:22:07 +00:00
Co-authored-by: mswiniarsk <156412439+mswiniarsk@users.noreply.github.com>
This commit is contained in:
parent
3831f1bed5
commit
7dbf4bf7a4
@ -70,7 +70,7 @@ def calculate_chunks(offset):
|
|||||||
def biggest_single_chunk(offset):
|
def biggest_single_chunk(offset):
|
||||||
if offset != 0:
|
if offset != 0:
|
||||||
idx = bisect.bisect(CHUNK_SIZES, abs(offset))
|
idx = bisect.bisect(CHUNK_SIZES, abs(offset))
|
||||||
return int(math.copysign(CHUNK_SIZES[idx-1], offset))
|
return int(math.copysign(CHUNK_SIZES[idx - 1], offset))
|
||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
@ -145,7 +145,7 @@ def extend_tensor(tensor, padding, dim):
|
|||||||
|
|
||||||
def extend_batch(tensors, target_bs, dim):
|
def extend_batch(tensors, target_bs, dim):
|
||||||
diff = target_bs - tensors[0].size(dim)
|
diff = target_bs - tensors[0].size(dim)
|
||||||
#TODO: add support for shrinking bs
|
# TODO: add support for shrinking bs
|
||||||
if diff <= 0:
|
if diff <= 0:
|
||||||
return tensors
|
return tensors
|
||||||
shape = list(tensors[0].shape)
|
shape = list(tensors[0].shape)
|
||||||
@ -283,7 +283,7 @@ class CausalLMBatch(Batch):
|
|||||||
def get_tensor_groups(self):
|
def get_tensor_groups(self):
|
||||||
past_keys, past_values = self.detach_kv_cache()
|
past_keys, past_values = self.detach_kv_cache()
|
||||||
seq_dim = -1
|
seq_dim = -1
|
||||||
key_dim = -2 # TODO: Add case for Bloom and other models
|
key_dim = -2 # TODO: Add case for Bloom and other models
|
||||||
value_dim = -2
|
value_dim = -2
|
||||||
tensors = [[self.input_ids], [self.attention_mask], [self.position_ids], past_keys, past_values]
|
tensors = [[self.input_ids], [self.attention_mask], [self.position_ids], past_keys, past_values]
|
||||||
# We don't need to align position_ids
|
# We don't need to align position_ids
|
||||||
@ -335,11 +335,14 @@ class CausalLMBatch(Batch):
|
|||||||
# [[position_ids], [attention_mask], [position_ids], past_keys, past_values]
|
# [[position_ids], [attention_mask], [position_ids], past_keys, past_values]
|
||||||
|
|
||||||
# move only past_keys
|
# move only past_keys
|
||||||
dst_tensors[3:4] = grouped_move(dst_tensors[3:4], dst_dims[3:4], dst_indices, src_tensors[3:4], src_dims[3:4], src_indices)
|
dst_tensors[3:4] = grouped_move(dst_tensors[3:4], dst_dims[3:4], dst_indices,
|
||||||
|
src_tensors[3:4], src_dims[3:4], src_indices)
|
||||||
# move only past_values
|
# move only past_values
|
||||||
dst_tensors[4:5] = grouped_move(dst_tensors[4:5], dst_dims[4:5], dst_indices, src_tensors[4:5], src_dims[4:5], src_indices)
|
dst_tensors[4:5] = grouped_move(dst_tensors[4:5], dst_dims[4:5], dst_indices,
|
||||||
|
src_tensors[4:5], src_dims[4:5], src_indices)
|
||||||
# move only input_ids, attention_mask and position_ids
|
# move only input_ids, attention_mask and position_ids
|
||||||
dst_tensors[:3] = grouped_move(dst_tensors[:3], dst_dims[:3], dst_indices, src_tensors[:3], src_dims[:3], src_indices)
|
dst_tensors[:3] = grouped_move(dst_tensors[:3], dst_dims[:3], dst_indices,
|
||||||
|
src_tensors[:3], src_dims[:3], src_indices)
|
||||||
self.set_tensor_groups(dst_tensors)
|
self.set_tensor_groups(dst_tensors)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -792,9 +795,11 @@ class CausalLM(Model):
|
|||||||
if self.is_optimized_for_gaudi:
|
if self.is_optimized_for_gaudi:
|
||||||
if prefill:
|
if prefill:
|
||||||
# no right padding for prefill
|
# no right padding for prefill
|
||||||
token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device)
|
token_idx_scalar = batch.attention_mask.shape[-1] - 1
|
||||||
|
token_idx = torch.tensor(token_idx_scalar).to(self.device)
|
||||||
else:
|
else:
|
||||||
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
|
token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding
|
||||||
|
token_idx = torch.tensor(token_idx_scalar).to(self.device)
|
||||||
else:
|
else:
|
||||||
token_idx = None
|
token_idx = None
|
||||||
|
|
||||||
@ -802,11 +807,11 @@ class CausalLM(Model):
|
|||||||
input_length = batch.input_length
|
input_length = batch.input_length
|
||||||
if self.is_optimized_for_gaudi and logits.shape[-2] > 1:
|
if self.is_optimized_for_gaudi and logits.shape[-2] > 1:
|
||||||
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
|
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
|
||||||
batch.input_ids[:, :token_idx], logits[:, input_length - 1: input_length, :].squeeze(-2)
|
batch.input_ids[:, :token_idx_scalar], logits[:, input_length - 1: input_length, :].squeeze(-2)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
|
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
|
||||||
batch.input_ids[:, :token_idx], logits.squeeze(-2)
|
batch.input_ids[:, :token_idx_scalar], logits.squeeze(-2)
|
||||||
)
|
)
|
||||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
batch.top_n_tokens,
|
batch.top_n_tokens,
|
||||||
@ -853,7 +858,7 @@ class CausalLM(Model):
|
|||||||
|
|
||||||
# Update position_ids
|
# Update position_ids
|
||||||
if prefill:
|
if prefill:
|
||||||
batch.position_ids = batch.position_ids[:, token_idx - 1: token_idx] + 1
|
batch.position_ids = torch.index_select(batch.position_ids, 1, token_idx - 1) + 1
|
||||||
else:
|
else:
|
||||||
batch.position_ids += 1
|
batch.position_ids += 1
|
||||||
# Update past key values
|
# Update past key values
|
||||||
|
Loading…
Reference in New Issue
Block a user