diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index ad769990..d106ce43 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -627,7 +627,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): class LogitBiasProcessor(LogitsProcessor): """ - `LogitsProcessor` creates a bias tensor from a dictionary of token IDs and their + `LogitBiasProcessor` creates a bias tensor from a dictionary of token IDs and their corresponding bias values. Bias are applied to the logits during each forward pass. Supports token IDs provided as strings (e.g., {"9707": -100}). @@ -656,7 +656,7 @@ class LogitBiasProcessor(LogitsProcessor): def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: # Apply bias tensor as a broadcasted addition if self.bias_tensor.shape[0] != scores.shape[1]: - # Fix if the bias tensor is smaller than the scores + # Pad the bias matrix to match the scores if it's smaller self.bias_tensor = torch.nn.functional.pad( self.bias_tensor, (0, scores.shape[1] - self.bias_tensor.shape[0]) ) @@ -699,7 +699,7 @@ class HeterogeneousLogitBiasProcessor(LogitsProcessor): def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: # Apply bias matrix as a broadcasted addition if self.bias_matrix.shape[1] != scores.shape[1]: - # Fix if the bias matrix is smaller than the scores + # Pad the bias matrix to match the scores if it's smaller self.bias_matrix = torch.nn.functional.pad( self.bias_matrix, (0, scores.shape[1] - self.bias_matrix.shape[1]) )