fix: cleanup typos

This commit is contained in:
drbh 2025-04-28 13:59:23 +00:00
parent 9eeccbf9a5
commit b3ead6e959

View File

@ -627,7 +627,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
class LogitBiasProcessor(LogitsProcessor):
"""
`LogitsProcessor` creates a bias tensor from a dictionary of token IDs and their
`LogitBiasProcessor` creates a bias tensor from a dictionary of token IDs and their
corresponding bias values. Bias are applied to the logits during each forward pass.
Supports token IDs provided as strings (e.g., {"9707": -100}).
@ -656,7 +656,7 @@ class LogitBiasProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
# Apply bias tensor as a broadcasted addition
if self.bias_tensor.shape[0] != scores.shape[1]:
# Fix if the bias tensor is smaller than the scores
# Pad the bias matrix to match the scores if it's smaller
self.bias_tensor = torch.nn.functional.pad(
self.bias_tensor, (0, scores.shape[1] - self.bias_tensor.shape[0])
)
@ -699,7 +699,7 @@ class HeterogeneousLogitBiasProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
# Apply bias matrix as a broadcasted addition
if self.bias_matrix.shape[1] != scores.shape[1]:
# Fix if the bias matrix is smaller than the scores
# Pad the bias matrix to match the scores if it's smaller
self.bias_matrix = torch.nn.functional.pad(
self.bias_matrix, (0, scores.shape[1] - self.bias_matrix.shape[1])
)