mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
fix(server): do not warp prefill logits
This commit is contained in:
parent
941cd42e0c
commit
f49786ccba
@ -70,7 +70,11 @@ class NextTokenChooser:
|
||||
|
||||
def __call__(self, input_ids, scores):
|
||||
# Warp logits
|
||||
scores = self.warpers(input_ids, scores)
|
||||
if scores.shape[0] > 1:
|
||||
# only warp the last token logits
|
||||
scores[-1:, :] = self.warpers(input_ids, scores[-1:, :])
|
||||
else:
|
||||
scores = self.warpers(input_ids, scores)
|
||||
|
||||
# Compute logprobs
|
||||
logprobs = torch.log_softmax(scores, -1)
|
||||
|
Loading…
Reference in New Issue
Block a user