mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Revert prefill optimization and fix accuracy issue in shift operation (#29)
Co-authored-by: Karol Damaszke <kdamaszke@habana.ai> Co-authored-by: madamczykhabana <110973826+madamczykhabana@users.noreply.github.com> Co-authored-by: jkaniecki <153085639+jkaniecki@users.noreply.github.com>
This commit is contained in:
parent
ac3bc0e95e
commit
2a7a967de3
@ -101,9 +101,12 @@ def shift(tensor, dim, offset):
|
|||||||
if offset == 0 or abs(offset) > elements:
|
if offset == 0 or abs(offset) > elements:
|
||||||
return tensor
|
return tensor
|
||||||
htorch.core.mark_step()
|
htorch.core.mark_step()
|
||||||
|
# We generate indices from (0 - offset + elements) to (elements - offset + elements)
|
||||||
|
# so that next modulo operation operates on positive values
|
||||||
indices = torch.arange(0, elements, dtype=torch.int32, device=tensor.device)
|
indices = torch.arange(0, elements, dtype=torch.int32, device=tensor.device)
|
||||||
offset = torch.tensor(offset, dtype=torch.int32, device=tensor.device)
|
offset = torch.tensor(-offset + elements, dtype=torch.int32, device=tensor.device)
|
||||||
indices = torch.clamp(indices - offset, 0, elements - 1)
|
indices.add_(offset)
|
||||||
|
indices.remainder_(elements)
|
||||||
target_shape = [1,] * len(tensor.shape)
|
target_shape = [1,] * len(tensor.shape)
|
||||||
target_shape[dim] = elements
|
target_shape[dim] = elements
|
||||||
indices = indices.view(target_shape).expand(shape)
|
indices = indices.view(target_shape).expand(shape)
|
||||||
@ -137,15 +140,6 @@ def remove_kv_cache_from_output(module):
|
|||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
def pad_tensors(tensors, paddings, dim, value):
|
|
||||||
for i, (tensor, padding) in enumerate(zip(tensors, paddings)):
|
|
||||||
if padding > 0:
|
|
||||||
pad_shape = (0, 0, 0, padding) if dim == -2 else (0, padding)
|
|
||||||
tensors[i] = torch.nn.functional.pad(tensor, pad_shape, value=value)
|
|
||||||
htorch.core.mark_step()
|
|
||||||
return tensors
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CausalLMRequest:
|
class CausalLMRequest:
|
||||||
idx: int
|
idx: int
|
||||||
@ -202,7 +196,7 @@ class CausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def recombine(cls, batches: List["CausalLMBatch"], pad_token_id: int) -> "CausalLMBatch":
|
def recombine(cls, batches: List["CausalLMBatch"], is_optimized_for_gaudi: bool = False) -> "CausalLMBatch":
|
||||||
total_requests = sum(len(b) for b in batches)
|
total_requests = sum(len(b) for b in batches)
|
||||||
new_bs = round_up(total_requests, BATCH_BUCKET_SIZE)
|
new_bs = round_up(total_requests, BATCH_BUCKET_SIZE)
|
||||||
batch_id = batches[0].batch_id
|
batch_id = batches[0].batch_id
|
||||||
@ -221,7 +215,7 @@ class CausalLMBatch(Batch):
|
|||||||
scenario = 'CONCAT'
|
scenario = 'CONCAT'
|
||||||
elif batches[0].batch_size != new_bs:
|
elif batches[0].batch_size != new_bs:
|
||||||
scenario = 'RESHAPE'
|
scenario = 'RESHAPE'
|
||||||
elif padding[0] <= 1:
|
elif padding[0] <= 0:
|
||||||
scenario = 'SHIFT'
|
scenario = 'SHIFT'
|
||||||
offsets = [b.max_input_length - max_input_length for b in batches]
|
offsets = [b.max_input_length - max_input_length for b in batches]
|
||||||
max_input_length = max(b.max_input_length for b in batches)
|
max_input_length = max(b.max_input_length for b in batches)
|
||||||
@ -234,7 +228,7 @@ class CausalLMBatch(Batch):
|
|||||||
|
|
||||||
grouped_requests = [[req for req in batch.requests] for batch in batches]
|
grouped_requests = [[req for req in batch.requests] for batch in batches]
|
||||||
flat_requests = list(itertools.chain(*grouped_requests))
|
flat_requests = list(itertools.chain(*grouped_requests))
|
||||||
if inplace and scenario != 'SHIFT':
|
if inplace:
|
||||||
# The data is already present in the batch. No need to move it
|
# The data is already present in the batch. No need to move it
|
||||||
grouped_requests[target_batch_idx] = []
|
grouped_requests[target_batch_idx] = []
|
||||||
free_indices = batches[target_batch_idx].free_indices()
|
free_indices = batches[target_batch_idx].free_indices()
|
||||||
@ -244,6 +238,10 @@ class CausalLMBatch(Batch):
|
|||||||
to_tensors = lambda ind: (torch.tensor(ind[0], device=device), torch.tensor(ind[1], device=device))
|
to_tensors = lambda ind: (torch.tensor(ind[0], device=device), torch.tensor(ind[1], device=device))
|
||||||
indices = [[to_tensors(req.update_idx(next(free_indices))) for req in batch_reqs] for batch_reqs in grouped_requests]
|
indices = [[to_tensors(req.update_idx(next(free_indices))) for req in batch_reqs] for batch_reqs in grouped_requests]
|
||||||
|
|
||||||
|
max_seq_len = batches[0].attention_mask.size(1)
|
||||||
|
input_length = max_input_length
|
||||||
|
right_padding = max_seq_len - input_length
|
||||||
|
|
||||||
chunk_size = batches[0].past_key_values[0][0].size(0) // batches[0].batch_size
|
chunk_size = batches[0].past_key_values[0][0].size(0) // batches[0].batch_size
|
||||||
num_layers = len(batches[0].past_key_values)
|
num_layers = len(batches[0].past_key_values)
|
||||||
past_key_values_type = type(batches[0].past_key_values)
|
past_key_values_type = type(batches[0].past_key_values)
|
||||||
@ -259,14 +257,9 @@ class CausalLMBatch(Batch):
|
|||||||
for b in batches:
|
for b in batches:
|
||||||
b.past_key_values = list(b.past_key_values)
|
b.past_key_values = list(b.past_key_values)
|
||||||
|
|
||||||
# For prefill there is a space allocated only for first token
|
|
||||||
# Need to add padding to the max total tokens before first decode
|
|
||||||
paddings = [(batch.input_length + batch.right_padding) - batch.seq_length for batch in batches]
|
|
||||||
|
|
||||||
src = [b.input_ids for b in batches]
|
src = [b.input_ids for b in batches]
|
||||||
for b in batches:
|
for b in batches:
|
||||||
del b.input_ids
|
del b.input_ids
|
||||||
src = pad_tensors(src, paddings, seq_dim, pad_token_id)
|
|
||||||
src = shift_all(src, seq_dim, offsets)
|
src = shift_all(src, seq_dim, offsets)
|
||||||
input_ids = prepare_memory(new_bs, src[target_batch_idx], inplace)
|
input_ids = prepare_memory(new_bs, src[target_batch_idx], inplace)
|
||||||
input_ids = move_data(input_ids, 1, indices, src)
|
input_ids = move_data(input_ids, 1, indices, src)
|
||||||
@ -274,7 +267,6 @@ class CausalLMBatch(Batch):
|
|||||||
src = [b.attention_mask for b in batches]
|
src = [b.attention_mask for b in batches]
|
||||||
for b in batches:
|
for b in batches:
|
||||||
del b.attention_mask
|
del b.attention_mask
|
||||||
src = pad_tensors(src, paddings, seq_dim, 0)
|
|
||||||
src = shift_all(src, seq_dim, offsets)
|
src = shift_all(src, seq_dim, offsets)
|
||||||
attention_mask = prepare_memory(new_bs, src[target_batch_idx], inplace)
|
attention_mask = prepare_memory(new_bs, src[target_batch_idx], inplace)
|
||||||
attention_mask = move_data(attention_mask, 1, indices, src)
|
attention_mask = move_data(attention_mask, 1, indices, src)
|
||||||
@ -289,13 +281,11 @@ class CausalLMBatch(Batch):
|
|||||||
past_key_values = []
|
past_key_values = []
|
||||||
for layer_num in range(num_layers):
|
for layer_num in range(num_layers):
|
||||||
src = [b.past_key_values[layer_num][0] for b in batches]
|
src = [b.past_key_values[layer_num][0] for b in batches]
|
||||||
src = pad_tensors(src, paddings, key_dim, 0)
|
|
||||||
src = shift_all(src, key_dim, offsets)
|
src = shift_all(src, key_dim, offsets)
|
||||||
updated_key = prepare_memory(new_bs * chunk_size, src[target_batch_idx], inplace)
|
updated_key = prepare_memory(new_bs * chunk_size, src[target_batch_idx], inplace)
|
||||||
updated_key = move_data(updated_key, chunk_size, indices, src)
|
updated_key = move_data(updated_key, chunk_size, indices, src)
|
||||||
|
|
||||||
src = [b.past_key_values[layer_num][1] for b in batches]
|
src = [b.past_key_values[layer_num][1] for b in batches]
|
||||||
src = pad_tensors(src, paddings, value_dim, 0)
|
|
||||||
src = shift_all(src, value_dim, offsets)
|
src = shift_all(src, value_dim, offsets)
|
||||||
updated_value = prepare_memory(new_bs * chunk_size, src[target_batch_idx], inplace)
|
updated_value = prepare_memory(new_bs * chunk_size, src[target_batch_idx], inplace)
|
||||||
updated_value = move_data(updated_value, chunk_size, indices, src)
|
updated_value = move_data(updated_value, chunk_size, indices, src)
|
||||||
@ -310,14 +300,10 @@ class CausalLMBatch(Batch):
|
|||||||
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64)
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
[r.data.parameters for r in flat_requests],
|
[r.data.parameters for r in flat_requests],
|
||||||
batches[0].next_token_chooser.device,
|
batches[0].next_token_chooser.dtype,
|
||||||
batches[0].next_token_chooser.dtype
|
batches[0].next_token_chooser.device
|
||||||
)
|
)
|
||||||
|
|
||||||
max_seq_len = attention_mask.size(1)
|
|
||||||
input_length = max_input_length
|
|
||||||
right_padding = max_seq_len - input_length
|
|
||||||
|
|
||||||
htorch.core.mark_step()
|
htorch.core.mark_step()
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
@ -392,16 +378,12 @@ class CausalLMBatch(Batch):
|
|||||||
attention_mask = tokenized_inputs["attention_mask"]
|
attention_mask = tokenized_inputs["attention_mask"]
|
||||||
|
|
||||||
if is_optimized_for_gaudi:
|
if is_optimized_for_gaudi:
|
||||||
# Allocate space for first token
|
|
||||||
input_ids = torch.nn.functional.pad(
|
input_ids = torch.nn.functional.pad(
|
||||||
input_ids, (0, 1), value=tokenizer.pad_token_id
|
input_ids, (0, max_new_tokens + extra_padding), value=tokenizer.pad_token_id
|
||||||
)
|
)
|
||||||
attention_mask = torch.nn.functional.pad(
|
attention_mask = torch.nn.functional.pad(
|
||||||
attention_mask, (0, 1), value=0
|
attention_mask, (0, max_new_tokens + extra_padding), value=0)
|
||||||
)
|
all_input_ids = input_ids.T.split(1, dim=1)
|
||||||
all_input_ids = torch.nn.functional.pad(
|
|
||||||
input_ids, (0, max_new_tokens + extra_padding - 1), value=tokenizer.pad_token_id
|
|
||||||
).T.split(1, dim=1)
|
|
||||||
else:
|
else:
|
||||||
all_input_ids = input_ids.clone().T.split(1, dim=1)
|
all_input_ids = input_ids.clone().T.split(1, dim=1)
|
||||||
|
|
||||||
@ -430,16 +412,16 @@ class CausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(self, request_ids: List[int], pad_token_id: int = 0) -> Optional["CausalLMBatch"]:
|
def filter(self, request_ids: List[int], is_optimized_for_gaudi: bool = False) -> Optional["CausalLMBatch"]:
|
||||||
dbg_trace('FILTER', f'num_reqs:{len(self.requests)} -> {len(request_ids)}')
|
dbg_trace('FILTER', f'num_reqs:{len(self.requests)} -> {len(request_ids)}')
|
||||||
request_ids = set(request_ids)
|
request_ids = set(request_ids)
|
||||||
self.requests = [req for req in self.requests if req.data.id in request_ids]
|
self.requests = [req for req in self.requests if req.data.id in request_ids]
|
||||||
return self.__class__.recombine([self], pad_token_id)
|
return self
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches: List["CausalLMBatch"], pad_token_id: int = 0) -> "CausalLMBatch":
|
def concatenate(cls, batches: List["CausalLMBatch"], is_optimized_for_gaudi: bool = False) -> "CausalLMBatch":
|
||||||
return cls.recombine(batches, pad_token_id)
|
return cls.recombine(batches, is_optimized_for_gaudi)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.requests)
|
return len(self.requests)
|
||||||
@ -664,30 +646,27 @@ class CausalLM(Model):
|
|||||||
prefill = batch.past_key_values is None
|
prefill = batch.past_key_values is None
|
||||||
# Check if we need to do any bookkeeping first
|
# Check if we need to do any bookkeeping first
|
||||||
if not prefill:
|
if not prefill:
|
||||||
batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id)
|
batch = batch.__class__.recombine([batch], self.is_optimized_for_gaudi)
|
||||||
|
|
||||||
scenario = 'PREFILL' if prefill else 'GENERATE'
|
scenario = 'PREFILL' if prefill else 'GENERATE'
|
||||||
dbg_trace(scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length}')
|
dbg_trace(scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}')
|
||||||
|
assert batch.right_padding > 0, 'No more room for next token!'
|
||||||
self.step = self.step + 1
|
self.step = self.step + 1
|
||||||
if self.hb_profer_started == True and self.step > self.profiling_warmup_steps + self.profiling_steps:
|
if self.hb_profer_started == True and self.step > self.profiling_warmup_steps + self.profiling_steps:
|
||||||
self.hb_profer.stop()
|
self.hb_profer.stop()
|
||||||
self.hb_profer_started = False
|
self.hb_profer_started = False
|
||||||
|
|
||||||
if self.is_optimized_for_gaudi:
|
if self.is_optimized_for_gaudi:
|
||||||
if prefill:
|
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
|
||||||
# no right padding for prefill
|
|
||||||
token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device)
|
|
||||||
else:
|
|
||||||
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
|
|
||||||
attention_mask = batch.attention_mask
|
attention_mask = batch.attention_mask
|
||||||
else:
|
else:
|
||||||
token_idx = None
|
token_idx = None
|
||||||
# slice the attention mask to the correct shape
|
# slice the attention mask to the correct shape
|
||||||
# TODO fix me!
|
# TODO fix me!
|
||||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||||
|
if batch.past_key_values:
|
||||||
if not prefill and token_idx is not None:
|
if token_idx is not None:
|
||||||
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
|
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
|
||||||
else:
|
else:
|
||||||
input_ids = batch.input_ids
|
input_ids = batch.input_ids
|
||||||
|
|
||||||
|
@ -59,7 +59,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
{"util": len(batch.requests)}):
|
{"util": len(batch.requests)}):
|
||||||
if batch is None:
|
if batch is None:
|
||||||
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
|
raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
|
||||||
filtered_batch = batch.filter(request.request_ids, self.model.tokenizer.pad_token_id)
|
filtered_batch = batch.filter(request.request_ids, self.model.is_optimized_for_gaudi)
|
||||||
self.cache.set(filtered_batch)
|
self.cache.set(filtered_batch)
|
||||||
|
|
||||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||||
@ -113,7 +113,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
|
|
||||||
if len(batches) > 1:
|
if len(batches) > 1:
|
||||||
with self.profiler.record_event("internal", "concatenate"):
|
with self.profiler.record_event("internal", "concatenate"):
|
||||||
batch = self.model.batch_type.concatenate(batches, self.model.tokenizer.pad_token_id)
|
batch = self.model.batch_type.concatenate(batches, self.model.is_optimized_for_gaudi)
|
||||||
else:
|
else:
|
||||||
batch = batches[0]
|
batch = batches[0]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user