From a4d5c3f40f31ac8d3fd07c182396cab1f278449c Mon Sep 17 00:00:00 2001 From: yuanwu2017 Date: Fri, 15 Mar 2024 17:54:56 +0800 Subject: [PATCH] Fix the generate_stream crash in concurrent query (#105) Signed-off-by: yuanwu --- .../text_generation_server/models/causal_lm.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index f65947e4..cff7686f 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -380,8 +380,15 @@ class CausalLMBatch(Batch): top_n_tokens = [r.data.top_n_tokens for r in flat_requests] top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) + + parameters = [r.data.parameters for r in flat_requests] + if len(flat_requests) < new_bs: + for i in range(new_bs-len(flat_requests)) : + #append the dummy parameters for dummy request + parameters.append(parameters[0]) + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( - [r.data.parameters for r in flat_requests], + parameters, batches[dst_batch_idx].next_token_chooser.dtype, batches[dst_batch_idx].next_token_chooser.device ) @@ -426,13 +433,19 @@ class CausalLMBatch(Batch): # TODO: Add support for sparse batches top_n_tokens = [r.top_n_tokens for r in pb.requests] top_n_tokens_tensor = torch.tensor(top_n_tokens, device=device, dtype=torch.int64) - next_token_chooser = HeterogeneousNextTokenChooser.from_pb([r.parameters for r in pb.requests], dtype, device) # TODO: by tokenizing all inputs at once we loose information on actual input lengths # this means that we cannot shift inputs to the left after a long input sequence # was filtered out new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE) dummy_inputs = ["?"] * (new_bs - len(requests)) + parameters = [r.parameters for r in pb.requests] + if len(pb.requests) < new_bs: + for i in range(new_bs-len(pb.requests)) : + #append the dummy parameters for dummy request + parameters.append(parameters[0]) + + next_token_chooser = HeterogeneousNextTokenChooser.from_pb(parameters, dtype, device) tokenized_inputs = tokenizer( [r.data.inputs for r in requests] + dummy_inputs, return_tensors="pt",