Enable deferred token generation (#44) (#75)

Co-authored-by: Krzysztof Laskowski <klaskowski@habana.ai>
This commit is contained in:
Karol Damaszke 2024-02-27 15:46:40 +01:00 committed by GitHub
parent 6248c5610e
commit 941d36f3fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 125 additions and 83 deletions

View File

@ -207,6 +207,9 @@ class CausalLMBatch(Batch):
input_length: int
logits = None
past = None
def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch(
id=self.batch_id,
@ -719,8 +722,100 @@ class CausalLM(Model):
return outputs.logits, outputs.past_key_values
@tracer.start_as_current_span("generate_token")
def generate_token(self, batch: CausalLMBatch) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
def generate_token(self, batches: List[CausalLMBatch]) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
# Results
generations: List[Generation] = []
prev_batches = []
requests_to_generate = []
# In order to pipeline any actions on CPU we perform the operation in 3 main stages:
# Stage 1. Collect next token ids of any previously started generations
for batch_id, batch in enumerate(batches):
if batch.logits is not None:
logits = batch.logits
past = batch.past
prefill = batch.past_key_values is None
if self.is_optimized_for_gaudi:
if prefill:
# no right padding for prefill
token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device)
else:
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
else:
token_idx = None
# Select next token
input_length = batch.input_length
if self.is_optimized_for_gaudi and logits.shape[-2] > 1:
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
batch.input_ids[:, :token_idx], logits[:, input_length - 1: input_length, :].squeeze(-2)
)
else:
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
batch.input_ids[:, :token_idx], logits.squeeze(-2)
)
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
logprobs,
)
prev_batches.append({
'next_token_ids': next_token_ids,
'next_token_logprobs': next_token_logprobs,
})
for req_idx, req in enumerate(batch.requests):
requests_to_generate.append({
'req': req,
'prev_req_idx': req.idx,
'batch_id': batch_id,
'seed': batch.next_token_chooser.seeds[req_idx],
'do_sample': batch.next_token_chooser.do_sample[req_idx],
'top_n_tokens': batch.top_n_tokens[req_idx],
'top_token_ids': batch_top_token_ids[req_idx],
'top_token_logprobs': batch_top_token_logprobs[req_idx],
})
htorch.core.mark_step()
if token_idx is None:
batch.input_ids[:, 0] = next_token_ids[:, 0]
else:
batch.input_ids.index_copy_(1, token_idx.cpu(), next_token_ids.unsqueeze(1))
# Slice unused values from prefill, use it to store next token
if token_idx is None:
batch.input_ids = batch.input_ids[:, :1]
# Update attention_mask as we added a new token to input_ids
if self.is_optimized_for_gaudi:
batch.attention_mask.index_fill_(1, token_idx, 1)
else:
batch.attention_mask[:, -batch.padding_right_offset] = 1
# Adjust lengths
batch.input_length += 1
# Update position_ids
if prefill:
batch.position_ids = batch.position_ids[:, token_idx - 1: token_idx] + 1
else:
batch.position_ids += 1
# Update past key values
if prefill:
batch.past_key_values = past
htorch.core.mark_step()
# Stage 2. Prepare new batch for speculative scheduling
if len(batches) > 1:
batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id)
else:
batch = batches[0]
prefill = batch.past_key_values is None
# Check if we need to do any bookkeeping first
if not prefill:
batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id)
@ -729,10 +824,6 @@ class CausalLM(Model):
dbg_trace(
scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}')
assert batch.right_padding > 0, 'No more room for next token!'
self.step = self.step + 1
if self.hb_profer_started == True and self.step > self.profiling_warmup_steps + self.profiling_steps:
self.hb_profer.stop()
self.hb_profer_started = False
if self.is_optimized_for_gaudi:
if prefill:
@ -753,7 +844,7 @@ class CausalLM(Model):
input_ids = batch.input_ids
if prefill:
logits, past = self.forward(
batch.logits, batch.past = self.forward(
input_ids,
attention_mask,
batch.position_ids,
@ -762,7 +853,7 @@ class CausalLM(Model):
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
)
else:
logits = self.forward(
batch.logits = self.forward(
input_ids,
attention_mask,
batch.position_ids,
@ -771,46 +862,36 @@ class CausalLM(Model):
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
)
# Results
generations: List[Generation] = []
stopped = True
# Select next token
input_length = batch.input_length
if self.is_optimized_for_gaudi and logits.shape[-2] > 1:
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
batch.input_ids[:, :token_idx], logits[:, input_length - 1: input_length, :].squeeze(-2)
)
else:
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
batch.input_ids[:, :token_idx], logits.squeeze(-2)
)
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
logprobs,
)
next_token_logprobs = next_token_logprobs.tolist()
next_token_ids_cpu = next_token_ids.cpu()
htorch.core.mark_step()
for req_idx, req in enumerate(batch.requests):
i = req.idx
# Stage 3. Finish and return previous generations
stopped = len(requests_to_generate) > 0
for prev_batch in prev_batches:
prev_batch['next_token_logprobs'] = prev_batch['next_token_logprobs'].tolist()
prev_batch['next_token_ids_cpu'] = prev_batch['next_token_ids'].cpu()
htorch.core.mark_step()
for req_data in requests_to_generate:
req = req_data['req']
i = req_data['prev_req_idx']
prev_batch_id = req_data['batch_id']
assert len(prev_batches) > prev_batch_id
next_token_ids_cpu = prev_batches[prev_batch_id]['next_token_ids_cpu']
next_token_logprobs = prev_batches[prev_batch_id]['next_token_logprobs']
request = req.data
input_length = req.input_length
prefix_offset = req.prefix_offset
read_offset = req.read_offset
do_sample = batch.next_token_chooser.do_sample[req_idx]
seed = batch.next_token_chooser.seeds[req_idx]
do_sample = req_data['do_sample']
seed = req_data['seed']
stopping_criteria = req.stopping_criteria
all_input_ids = req.all_input_ids
top_n_tokens = batch.top_n_tokens[req_idx]
next_token_id = next_token_ids_cpu[i]
next_token_logprob = next_token_logprobs[i]
top_token_ids = batch_top_token_ids[req_idx]
top_token_logprobs = batch_top_token_logprobs[req_idx]
top_n_tokens = req_data['top_n_tokens']
top_token_ids = req_data['top_token_ids']
top_token_logprobs = req_data['top_token_logprobs']
# Append next token to all tokens
if self.is_optimized_for_gaudi:
@ -899,42 +980,9 @@ class CausalLM(Model):
req.read_offset = read_offset
htorch.core.mark_step()
if token_idx is None:
batch.input_ids[:, 0] = next_token_ids[:, 0]
else:
batch.input_ids.index_copy_(1, token_idx.cpu(), next_token_ids.unsqueeze(1))
self.step = self.step + 1
if self.hb_profer_started == True and self.step > self.profiling_warmup_steps + self.profiling_steps:
self.hb_profer.stop()
self.hb_profer_started = False
# We finished all generations in the batch; there is no next batch
if stopped:
if self.hb_profer_started == True:
self.hb_profer.step()
htorch.core.mark_step()
return generations, None
# Slice unused values from prefill, use it to store next token
if token_idx is None:
batch.input_ids = batch.input_ids[:, :1]
# Update attention_mask as we added a new token to input_ids
if self.is_optimized_for_gaudi:
batch.attention_mask.index_fill_(1, token_idx, 1)
else:
batch.attention_mask[:, -batch.padding_right_offset] = 1
# Adjust lengths
batch.input_length += 1
# Update position_ids
if prefill:
batch.position_ids = batch.position_ids[:, token_idx - 1: token_idx] + 1
else:
batch.position_ids += 1
# Update past key values
if prefill:
batch.past_key_values = past
if self.hb_profer_started == True:
self.hb_profer.step()
htorch.core.mark_step()
return generations, batch
return generations, batch if not stopped else None

View File

@ -84,7 +84,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
with self.profiler.record_event("external", "prefill", {"batch_size": batch.input_ids.size(0)}):
with self.profiler.record_event("internal", "generate_token"):
generations, next_batch = self.model.generate_token(batch)
generations, next_batch = self.model.generate_token([batch])
self.cache.set(next_batch)
return generate_pb2.PrefillResponse(
@ -111,14 +111,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if len(batches) == 0:
raise ValueError("All batches are empty")
if len(batches) > 1:
with self.profiler.record_event("internal", "concatenate"):
batch = self.model.batch_type.concatenate(batches, self.model.tokenizer.pad_token_id)
else:
batch = batches[0]
with self.profiler.record_event("internal", "generate_token"):
generations, next_batch = self.model.generate_token(batch)
generations, next_batch = self.model.generate_token(batches)
self.cache.set(next_batch)
return generate_pb2.DecodeResponse(