diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index d106ce43..d9feb953 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -654,12 +654,6 @@ class LogitBiasProcessor(LogitsProcessor): self.bias_tensor.index_put_((token_ids,), bias_values, accumulate=True) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: - # Apply bias tensor as a broadcasted addition - if self.bias_tensor.shape[0] != scores.shape[1]: - # Pad the bias matrix to match the scores if it's smaller - self.bias_tensor = torch.nn.functional.pad( - self.bias_tensor, (0, scores.shape[1] - self.bias_tensor.shape[0]) - ) scores.add_(self.bias_tensor.to(device=scores.device, dtype=scores.dtype)) return scores @@ -697,13 +691,6 @@ class HeterogeneousLogitBiasProcessor(LogitsProcessor): self.bias_matrix[i].index_put_((token_ids,), bias_values, accumulate=True) def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: - # Apply bias matrix as a broadcasted addition - if self.bias_matrix.shape[1] != scores.shape[1]: - # Pad the bias matrix to match the scores if it's smaller - self.bias_matrix = torch.nn.functional.pad( - self.bias_matrix, (0, scores.shape[1] - self.bias_matrix.shape[1]) - ) - scores.add_(self.bias_matrix.to(device=scores.device, dtype=scores.dtype)) return scores