mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
Co-authored-by: Krzysztof Laskowski <klaskowski@habana.ai>
This commit is contained in:
parent
6248c5610e
commit
941d36f3fd
@ -207,6 +207,9 @@ class CausalLMBatch(Batch):
|
||||
|
||||
input_length: int
|
||||
|
||||
logits = None
|
||||
past = None
|
||||
|
||||
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||
return generate_pb2.CachedBatch(
|
||||
id=self.batch_id,
|
||||
@ -719,8 +722,100 @@ class CausalLM(Model):
|
||||
return outputs.logits, outputs.past_key_values
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
def generate_token(self, batch: CausalLMBatch) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
|
||||
def generate_token(self, batches: List[CausalLMBatch]) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
|
||||
# Results
|
||||
generations: List[Generation] = []
|
||||
prev_batches = []
|
||||
requests_to_generate = []
|
||||
|
||||
# In order to pipeline any actions on CPU we perform the operation in 3 main stages:
|
||||
# Stage 1. Collect next token ids of any previously started generations
|
||||
for batch_id, batch in enumerate(batches):
|
||||
if batch.logits is not None:
|
||||
logits = batch.logits
|
||||
past = batch.past
|
||||
prefill = batch.past_key_values is None
|
||||
if self.is_optimized_for_gaudi:
|
||||
if prefill:
|
||||
# no right padding for prefill
|
||||
token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device)
|
||||
else:
|
||||
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
|
||||
else:
|
||||
token_idx = None
|
||||
|
||||
# Select next token
|
||||
input_length = batch.input_length
|
||||
if self.is_optimized_for_gaudi and logits.shape[-2] > 1:
|
||||
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
|
||||
batch.input_ids[:, :token_idx], logits[:, input_length - 1: input_length, :].squeeze(-2)
|
||||
)
|
||||
else:
|
||||
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
|
||||
batch.input_ids[:, :token_idx], logits.squeeze(-2)
|
||||
)
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens,
|
||||
batch.top_n_tokens_tensor,
|
||||
logprobs,
|
||||
)
|
||||
|
||||
prev_batches.append({
|
||||
'next_token_ids': next_token_ids,
|
||||
'next_token_logprobs': next_token_logprobs,
|
||||
})
|
||||
|
||||
for req_idx, req in enumerate(batch.requests):
|
||||
requests_to_generate.append({
|
||||
'req': req,
|
||||
'prev_req_idx': req.idx,
|
||||
'batch_id': batch_id,
|
||||
'seed': batch.next_token_chooser.seeds[req_idx],
|
||||
'do_sample': batch.next_token_chooser.do_sample[req_idx],
|
||||
'top_n_tokens': batch.top_n_tokens[req_idx],
|
||||
'top_token_ids': batch_top_token_ids[req_idx],
|
||||
'top_token_logprobs': batch_top_token_logprobs[req_idx],
|
||||
})
|
||||
|
||||
htorch.core.mark_step()
|
||||
|
||||
if token_idx is None:
|
||||
batch.input_ids[:, 0] = next_token_ids[:, 0]
|
||||
else:
|
||||
batch.input_ids.index_copy_(1, token_idx.cpu(), next_token_ids.unsqueeze(1))
|
||||
|
||||
# Slice unused values from prefill, use it to store next token
|
||||
if token_idx is None:
|
||||
batch.input_ids = batch.input_ids[:, :1]
|
||||
|
||||
# Update attention_mask as we added a new token to input_ids
|
||||
if self.is_optimized_for_gaudi:
|
||||
batch.attention_mask.index_fill_(1, token_idx, 1)
|
||||
else:
|
||||
batch.attention_mask[:, -batch.padding_right_offset] = 1
|
||||
|
||||
# Adjust lengths
|
||||
batch.input_length += 1
|
||||
|
||||
# Update position_ids
|
||||
if prefill:
|
||||
batch.position_ids = batch.position_ids[:, token_idx - 1: token_idx] + 1
|
||||
else:
|
||||
batch.position_ids += 1
|
||||
# Update past key values
|
||||
if prefill:
|
||||
batch.past_key_values = past
|
||||
|
||||
htorch.core.mark_step()
|
||||
|
||||
# Stage 2. Prepare new batch for speculative scheduling
|
||||
if len(batches) > 1:
|
||||
batch = self.batch_type.concatenate(batches, self.tokenizer.pad_token_id)
|
||||
else:
|
||||
batch = batches[0]
|
||||
|
||||
prefill = batch.past_key_values is None
|
||||
|
||||
# Check if we need to do any bookkeeping first
|
||||
if not prefill:
|
||||
batch = batch.__class__.recombine([batch], self.tokenizer.pad_token_id)
|
||||
@ -729,10 +824,6 @@ class CausalLM(Model):
|
||||
dbg_trace(
|
||||
scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}')
|
||||
assert batch.right_padding > 0, 'No more room for next token!'
|
||||
self.step = self.step + 1
|
||||
if self.hb_profer_started == True and self.step > self.profiling_warmup_steps + self.profiling_steps:
|
||||
self.hb_profer.stop()
|
||||
self.hb_profer_started = False
|
||||
|
||||
if self.is_optimized_for_gaudi:
|
||||
if prefill:
|
||||
@ -753,7 +844,7 @@ class CausalLM(Model):
|
||||
input_ids = batch.input_ids
|
||||
|
||||
if prefill:
|
||||
logits, past = self.forward(
|
||||
batch.logits, batch.past = self.forward(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
batch.position_ids,
|
||||
@ -762,7 +853,7 @@ class CausalLM(Model):
|
||||
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
|
||||
)
|
||||
else:
|
||||
logits = self.forward(
|
||||
batch.logits = self.forward(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
batch.position_ids,
|
||||
@ -771,46 +862,36 @@ class CausalLM(Model):
|
||||
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
|
||||
)
|
||||
|
||||
# Results
|
||||
generations: List[Generation] = []
|
||||
stopped = True
|
||||
|
||||
# Select next token
|
||||
input_length = batch.input_length
|
||||
if self.is_optimized_for_gaudi and logits.shape[-2] > 1:
|
||||
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
|
||||
batch.input_ids[:, :token_idx], logits[:, input_length - 1: input_length, :].squeeze(-2)
|
||||
)
|
||||
else:
|
||||
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
|
||||
batch.input_ids[:, :token_idx], logits.squeeze(-2)
|
||||
)
|
||||
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens,
|
||||
batch.top_n_tokens_tensor,
|
||||
logprobs,
|
||||
)
|
||||
|
||||
next_token_logprobs = next_token_logprobs.tolist()
|
||||
next_token_ids_cpu = next_token_ids.cpu()
|
||||
htorch.core.mark_step()
|
||||
|
||||
for req_idx, req in enumerate(batch.requests):
|
||||
i = req.idx
|
||||
# Stage 3. Finish and return previous generations
|
||||
stopped = len(requests_to_generate) > 0
|
||||
for prev_batch in prev_batches:
|
||||
prev_batch['next_token_logprobs'] = prev_batch['next_token_logprobs'].tolist()
|
||||
prev_batch['next_token_ids_cpu'] = prev_batch['next_token_ids'].cpu()
|
||||
htorch.core.mark_step()
|
||||
|
||||
for req_data in requests_to_generate:
|
||||
req = req_data['req']
|
||||
i = req_data['prev_req_idx']
|
||||
prev_batch_id = req_data['batch_id']
|
||||
assert len(prev_batches) > prev_batch_id
|
||||
next_token_ids_cpu = prev_batches[prev_batch_id]['next_token_ids_cpu']
|
||||
next_token_logprobs = prev_batches[prev_batch_id]['next_token_logprobs']
|
||||
|
||||
request = req.data
|
||||
input_length = req.input_length
|
||||
prefix_offset = req.prefix_offset
|
||||
read_offset = req.read_offset
|
||||
do_sample = batch.next_token_chooser.do_sample[req_idx]
|
||||
seed = batch.next_token_chooser.seeds[req_idx]
|
||||
do_sample = req_data['do_sample']
|
||||
seed = req_data['seed']
|
||||
stopping_criteria = req.stopping_criteria
|
||||
all_input_ids = req.all_input_ids
|
||||
top_n_tokens = batch.top_n_tokens[req_idx]
|
||||
next_token_id = next_token_ids_cpu[i]
|
||||
next_token_logprob = next_token_logprobs[i]
|
||||
top_token_ids = batch_top_token_ids[req_idx]
|
||||
top_token_logprobs = batch_top_token_logprobs[req_idx]
|
||||
top_n_tokens = req_data['top_n_tokens']
|
||||
top_token_ids = req_data['top_token_ids']
|
||||
top_token_logprobs = req_data['top_token_logprobs']
|
||||
|
||||
# Append next token to all tokens
|
||||
if self.is_optimized_for_gaudi:
|
||||
@ -899,42 +980,9 @@ class CausalLM(Model):
|
||||
req.read_offset = read_offset
|
||||
htorch.core.mark_step()
|
||||
|
||||
if token_idx is None:
|
||||
batch.input_ids[:, 0] = next_token_ids[:, 0]
|
||||
else:
|
||||
batch.input_ids.index_copy_(1, token_idx.cpu(), next_token_ids.unsqueeze(1))
|
||||
self.step = self.step + 1
|
||||
if self.hb_profer_started == True and self.step > self.profiling_warmup_steps + self.profiling_steps:
|
||||
self.hb_profer.stop()
|
||||
self.hb_profer_started = False
|
||||
|
||||
# We finished all generations in the batch; there is no next batch
|
||||
if stopped:
|
||||
if self.hb_profer_started == True:
|
||||
self.hb_profer.step()
|
||||
htorch.core.mark_step()
|
||||
return generations, None
|
||||
|
||||
# Slice unused values from prefill, use it to store next token
|
||||
if token_idx is None:
|
||||
batch.input_ids = batch.input_ids[:, :1]
|
||||
|
||||
# Update attention_mask as we added a new token to input_ids
|
||||
if self.is_optimized_for_gaudi:
|
||||
batch.attention_mask.index_fill_(1, token_idx, 1)
|
||||
else:
|
||||
batch.attention_mask[:, -batch.padding_right_offset] = 1
|
||||
|
||||
# Adjust lengths
|
||||
batch.input_length += 1
|
||||
|
||||
# Update position_ids
|
||||
if prefill:
|
||||
batch.position_ids = batch.position_ids[:, token_idx - 1: token_idx] + 1
|
||||
else:
|
||||
batch.position_ids += 1
|
||||
# Update past key values
|
||||
if prefill:
|
||||
batch.past_key_values = past
|
||||
|
||||
if self.hb_profer_started == True:
|
||||
self.hb_profer.step()
|
||||
htorch.core.mark_step()
|
||||
|
||||
return generations, batch
|
||||
return generations, batch if not stopped else None
|
@ -84,7 +84,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
with self.profiler.record_event("external", "prefill", {"batch_size": batch.input_ids.size(0)}):
|
||||
|
||||
with self.profiler.record_event("internal", "generate_token"):
|
||||
generations, next_batch = self.model.generate_token(batch)
|
||||
generations, next_batch = self.model.generate_token([batch])
|
||||
self.cache.set(next_batch)
|
||||
|
||||
return generate_pb2.PrefillResponse(
|
||||
@ -111,14 +111,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
if len(batches) == 0:
|
||||
raise ValueError("All batches are empty")
|
||||
|
||||
if len(batches) > 1:
|
||||
with self.profiler.record_event("internal", "concatenate"):
|
||||
batch = self.model.batch_type.concatenate(batches, self.model.tokenizer.pad_token_id)
|
||||
else:
|
||||
batch = batches[0]
|
||||
|
||||
with self.profiler.record_event("internal", "generate_token"):
|
||||
generations, next_batch = self.model.generate_token(batch)
|
||||
generations, next_batch = self.model.generate_token(batches)
|
||||
self.cache.set(next_batch)
|
||||
|
||||
return generate_pb2.DecodeResponse(
|
||||
|
Loading…
Reference in New Issue
Block a user