mirror of
				https://github.com/huggingface/text-generation-inference.git
				synced 2025-10-25 23:05:22 +00:00 
			
		
		
		
	Co-authored-by: mswiniarsk <156412439+mswiniarsk@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									3831f1bed5
								
							
						
					
					
						commit
						7dbf4bf7a4
					
				| @ -70,7 +70,7 @@ def calculate_chunks(offset): | ||||
| def biggest_single_chunk(offset): | ||||
|     if offset != 0: | ||||
|         idx = bisect.bisect(CHUNK_SIZES, abs(offset)) | ||||
|         return int(math.copysign(CHUNK_SIZES[idx-1], offset)) | ||||
|         return int(math.copysign(CHUNK_SIZES[idx - 1], offset)) | ||||
|     else: | ||||
|         return 0 | ||||
| 
 | ||||
| @ -145,7 +145,7 @@ def extend_tensor(tensor, padding, dim): | ||||
| 
 | ||||
| def extend_batch(tensors, target_bs, dim): | ||||
|     diff = target_bs - tensors[0].size(dim) | ||||
|     #TODO: add support for shrinking bs | ||||
|     # TODO: add support for shrinking bs | ||||
|     if diff <= 0: | ||||
|         return tensors | ||||
|     shape = list(tensors[0].shape) | ||||
| @ -335,11 +335,14 @@ class CausalLMBatch(Batch): | ||||
|             # [[position_ids], [attention_mask], [position_ids], past_keys, past_values] | ||||
| 
 | ||||
|             # move only past_keys | ||||
|             dst_tensors[3:4] = grouped_move(dst_tensors[3:4], dst_dims[3:4], dst_indices, src_tensors[3:4], src_dims[3:4], src_indices) | ||||
|             dst_tensors[3:4] = grouped_move(dst_tensors[3:4], dst_dims[3:4], dst_indices, | ||||
|                                             src_tensors[3:4], src_dims[3:4], src_indices) | ||||
|             # move only past_values | ||||
|             dst_tensors[4:5] = grouped_move(dst_tensors[4:5], dst_dims[4:5], dst_indices, src_tensors[4:5], src_dims[4:5], src_indices) | ||||
|             dst_tensors[4:5] = grouped_move(dst_tensors[4:5], dst_dims[4:5], dst_indices, | ||||
|                                             src_tensors[4:5], src_dims[4:5], src_indices) | ||||
|             # move only input_ids, attention_mask and position_ids | ||||
|             dst_tensors[:3] = grouped_move(dst_tensors[:3], dst_dims[:3], dst_indices, src_tensors[:3], src_dims[:3], src_indices) | ||||
|             dst_tensors[:3] = grouped_move(dst_tensors[:3], dst_dims[:3], dst_indices, | ||||
|                                            src_tensors[:3], src_dims[:3], src_indices) | ||||
|         self.set_tensor_groups(dst_tensors) | ||||
| 
 | ||||
|     @classmethod | ||||
| @ -792,9 +795,11 @@ class CausalLM(Model): | ||||
|                 if self.is_optimized_for_gaudi: | ||||
|                     if prefill: | ||||
|                         # no right padding for prefill | ||||
|                         token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device) | ||||
|                         token_idx_scalar = batch.attention_mask.shape[-1] - 1 | ||||
|                         token_idx = torch.tensor(token_idx_scalar).to(self.device) | ||||
|                     else: | ||||
|                         token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device) | ||||
|                         token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding | ||||
|                         token_idx = torch.tensor(token_idx_scalar).to(self.device) | ||||
|                 else: | ||||
|                     token_idx = None | ||||
| 
 | ||||
| @ -802,11 +807,11 @@ class CausalLM(Model): | ||||
|                 input_length = batch.input_length | ||||
|                 if self.is_optimized_for_gaudi and logits.shape[-2] > 1: | ||||
|                     next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser( | ||||
|                         batch.input_ids[:, :token_idx], logits[:, input_length - 1: input_length, :].squeeze(-2) | ||||
|                         batch.input_ids[:, :token_idx_scalar], logits[:, input_length - 1: input_length, :].squeeze(-2) | ||||
|                     ) | ||||
|                 else: | ||||
|                     next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser( | ||||
|                         batch.input_ids[:, :token_idx], logits.squeeze(-2) | ||||
|                         batch.input_ids[:, :token_idx_scalar], logits.squeeze(-2) | ||||
|                     ) | ||||
|                 batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( | ||||
|                     batch.top_n_tokens, | ||||
| @ -853,7 +858,7 @@ class CausalLM(Model): | ||||
| 
 | ||||
|                 # Update position_ids | ||||
|                 if prefill: | ||||
|                     batch.position_ids = batch.position_ids[:, token_idx - 1: token_idx] + 1 | ||||
|                     batch.position_ids = torch.index_select(batch.position_ids, 1, token_idx - 1) + 1 | ||||
|                 else: | ||||
|                     batch.position_ids += 1 | ||||
|                 # Update past key values | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user