Ensure classmethods use cls instead of the class directly

This commit is contained in:
Antoni Baum 2023-06-30 11:47:42 -07:00 committed by GitHub
parent ecf6dc3a5a
commit 51f2735f6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 5 additions and 5 deletions

View File

@ -30,7 +30,7 @@ class BloomCausalLMBatch(CausalLMBatch):
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "CausalLMBatch": ) -> "BloomCausalLMBatch":
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device) batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
batch.keys_head_dim_last = False batch.keys_head_dim_last = False
return batch return batch

View File

@ -646,7 +646,7 @@ class FlashCausalLMBatch(Batch):
for b in batches: for b in batches:
b.block_tables = None b.block_tables = None
return FlashCausalLMBatch( return cls(
batch_id=batches[0].batch_id, batch_id=batches[0].batch_id,
requests=requests, requests=requests,
requests_idx_mapping=requests_idx_mapping, requests_idx_mapping=requests_idx_mapping,

View File

@ -80,7 +80,7 @@ class NextTokenChooser:
pb: generate_pb2.NextTokenChooserParameters, pb: generate_pb2.NextTokenChooserParameters,
device: torch.device, device: torch.device,
) -> "NextTokenChooser": ) -> "NextTokenChooser":
return NextTokenChooser( return cls(
watermark=pb.watermark, watermark=pb.watermark,
temperature=pb.temperature, temperature=pb.temperature,
repetition_penalty=pb.repetition_penalty, repetition_penalty=pb.repetition_penalty,
@ -143,7 +143,7 @@ class StoppingCriteria:
stop_sequence_criterias = [ stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
] ]
return StoppingCriteria( return cls(
tokenizer.eos_token_id, tokenizer.eos_token_id,
stop_sequence_criterias, stop_sequence_criterias,
pb.max_new_tokens, pb.max_new_tokens,
@ -266,7 +266,7 @@ class HeterogeneousNextTokenChooser:
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "HeterogeneousNextTokenChooser": ) -> "HeterogeneousNextTokenChooser":
return HeterogeneousNextTokenChooser( return cls(
watermark=[pb_.watermark for pb_ in pb], watermark=[pb_.watermark for pb_ in pb],
temperature=[pb_.temperature for pb_ in pb], temperature=[pb_.temperature for pb_ in pb],
repetition_penalty=[pb_.repetition_penalty for pb_ in pb], repetition_penalty=[pb_.repetition_penalty for pb_ in pb],