mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Ensure classmethods use cls
instead of the class directly
This commit is contained in:
parent
ecf6dc3a5a
commit
51f2735f6c
@ -30,7 +30,7 @@ class BloomCausalLMBatch(CausalLMBatch):
|
|||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "CausalLMBatch":
|
) -> "BloomCausalLMBatch":
|
||||||
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
|
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
|
||||||
batch.keys_head_dim_last = False
|
batch.keys_head_dim_last = False
|
||||||
return batch
|
return batch
|
||||||
|
@ -646,7 +646,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
for b in batches:
|
for b in batches:
|
||||||
b.block_tables = None
|
b.block_tables = None
|
||||||
|
|
||||||
return FlashCausalLMBatch(
|
return cls(
|
||||||
batch_id=batches[0].batch_id,
|
batch_id=batches[0].batch_id,
|
||||||
requests=requests,
|
requests=requests,
|
||||||
requests_idx_mapping=requests_idx_mapping,
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
|
@ -80,7 +80,7 @@ class NextTokenChooser:
|
|||||||
pb: generate_pb2.NextTokenChooserParameters,
|
pb: generate_pb2.NextTokenChooserParameters,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "NextTokenChooser":
|
) -> "NextTokenChooser":
|
||||||
return NextTokenChooser(
|
return cls(
|
||||||
watermark=pb.watermark,
|
watermark=pb.watermark,
|
||||||
temperature=pb.temperature,
|
temperature=pb.temperature,
|
||||||
repetition_penalty=pb.repetition_penalty,
|
repetition_penalty=pb.repetition_penalty,
|
||||||
@ -143,7 +143,7 @@ class StoppingCriteria:
|
|||||||
stop_sequence_criterias = [
|
stop_sequence_criterias = [
|
||||||
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
||||||
]
|
]
|
||||||
return StoppingCriteria(
|
return cls(
|
||||||
tokenizer.eos_token_id,
|
tokenizer.eos_token_id,
|
||||||
stop_sequence_criterias,
|
stop_sequence_criterias,
|
||||||
pb.max_new_tokens,
|
pb.max_new_tokens,
|
||||||
@ -266,7 +266,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "HeterogeneousNextTokenChooser":
|
) -> "HeterogeneousNextTokenChooser":
|
||||||
return HeterogeneousNextTokenChooser(
|
return cls(
|
||||||
watermark=[pb_.watermark for pb_ in pb],
|
watermark=[pb_.watermark for pb_ in pb],
|
||||||
temperature=[pb_.temperature for pb_ in pb],
|
temperature=[pb_.temperature for pb_ in pb],
|
||||||
repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
|
repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
|
||||||
|
Loading…
Reference in New Issue
Block a user