mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-09 19:02:09 +00:00
fix: cleanup typos
This commit is contained in:
parent
9eeccbf9a5
commit
b3ead6e959
@ -627,7 +627,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||
|
||||
class LogitBiasProcessor(LogitsProcessor):
|
||||
"""
|
||||
`LogitsProcessor` creates a bias tensor from a dictionary of token IDs and their
|
||||
`LogitBiasProcessor` creates a bias tensor from a dictionary of token IDs and their
|
||||
corresponding bias values. Bias are applied to the logits during each forward pass.
|
||||
|
||||
Supports token IDs provided as strings (e.g., {"9707": -100}).
|
||||
@ -656,7 +656,7 @@ class LogitBiasProcessor(LogitsProcessor):
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
# Apply bias tensor as a broadcasted addition
|
||||
if self.bias_tensor.shape[0] != scores.shape[1]:
|
||||
# Fix if the bias tensor is smaller than the scores
|
||||
# Pad the bias matrix to match the scores if it's smaller
|
||||
self.bias_tensor = torch.nn.functional.pad(
|
||||
self.bias_tensor, (0, scores.shape[1] - self.bias_tensor.shape[0])
|
||||
)
|
||||
@ -699,7 +699,7 @@ class HeterogeneousLogitBiasProcessor(LogitsProcessor):
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
# Apply bias matrix as a broadcasted addition
|
||||
if self.bias_matrix.shape[1] != scores.shape[1]:
|
||||
# Fix if the bias matrix is smaller than the scores
|
||||
# Pad the bias matrix to match the scores if it's smaller
|
||||
self.bias_matrix = torch.nn.functional.pad(
|
||||
self.bias_matrix, (0, scores.shape[1] - self.bias_matrix.shape[1])
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user