mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
Ensure classmethods use cls
instead of the class directly
This commit is contained in:
parent
ecf6dc3a5a
commit
51f2735f6c
@ -30,7 +30,7 @@ class BloomCausalLMBatch(CausalLMBatch):
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "CausalLMBatch":
|
||||
) -> "BloomCausalLMBatch":
|
||||
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
|
||||
batch.keys_head_dim_last = False
|
||||
return batch
|
||||
|
@ -646,7 +646,7 @@ class FlashCausalLMBatch(Batch):
|
||||
for b in batches:
|
||||
b.block_tables = None
|
||||
|
||||
return FlashCausalLMBatch(
|
||||
return cls(
|
||||
batch_id=batches[0].batch_id,
|
||||
requests=requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
|
@ -80,7 +80,7 @@ class NextTokenChooser:
|
||||
pb: generate_pb2.NextTokenChooserParameters,
|
||||
device: torch.device,
|
||||
) -> "NextTokenChooser":
|
||||
return NextTokenChooser(
|
||||
return cls(
|
||||
watermark=pb.watermark,
|
||||
temperature=pb.temperature,
|
||||
repetition_penalty=pb.repetition_penalty,
|
||||
@ -143,7 +143,7 @@ class StoppingCriteria:
|
||||
stop_sequence_criterias = [
|
||||
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
||||
]
|
||||
return StoppingCriteria(
|
||||
return cls(
|
||||
tokenizer.eos_token_id,
|
||||
stop_sequence_criterias,
|
||||
pb.max_new_tokens,
|
||||
@ -266,7 +266,7 @@ class HeterogeneousNextTokenChooser:
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "HeterogeneousNextTokenChooser":
|
||||
return HeterogeneousNextTokenChooser(
|
||||
return cls(
|
||||
watermark=[pb_.watermark for pb_ in pb],
|
||||
temperature=[pb_.temperature for pb_ in pb],
|
||||
repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
|
||||
|
Loading…
Reference in New Issue
Block a user