diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 101da207..64db573c 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -30,7 +30,7 @@ class BloomCausalLMBatch(CausalLMBatch): tokenizer: PreTrainedTokenizerBase, dtype: torch.dtype, device: torch.device, - ) -> "CausalLMBatch": + ) -> "BloomCausalLMBatch": batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device) batch.keys_head_dim_last = False return batch diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 94b14f85..d0fa1e1f 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -646,7 +646,7 @@ class FlashCausalLMBatch(Batch): for b in batches: b.block_tables = None - return FlashCausalLMBatch( + return cls( batch_id=batches[0].batch_id, requests=requests, requests_idx_mapping=requests_idx_mapping, diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index b83af591..cf7ece38 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -80,7 +80,7 @@ class NextTokenChooser: pb: generate_pb2.NextTokenChooserParameters, device: torch.device, ) -> "NextTokenChooser": - return NextTokenChooser( + return cls( watermark=pb.watermark, temperature=pb.temperature, repetition_penalty=pb.repetition_penalty, @@ -143,7 +143,7 @@ class StoppingCriteria: stop_sequence_criterias = [ StopSequenceCriteria(sequence) for sequence in pb.stop_sequences ] - return StoppingCriteria( + return cls( tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens, @@ -266,7 +266,7 @@ class HeterogeneousNextTokenChooser: dtype: torch.dtype, device: torch.device, ) -> "HeterogeneousNextTokenChooser": - return HeterogeneousNextTokenChooser( + return cls( watermark=[pb_.watermark for pb_ in pb], temperature=[pb_.temperature for pb_ in pb], repetition_penalty=[pb_.repetition_penalty for pb_ in pb],