mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 06:12:07 +00:00
# What does this PR do? Fixes #843 <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
99 lines
3.5 KiB
Python
99 lines
3.5 KiB
Python
# coding=utf-8
|
|
# Copyright 2023 Authors of "A Watermark for Large Language Models"
|
|
# available at https://arxiv.org/abs/2301.10226
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import os
|
|
|
|
import torch
|
|
from transformers import LogitsProcessor
|
|
from typing import List, Union
|
|
|
|
GAMMA = float(os.getenv("WATERMARK_GAMMA", 0.5))
|
|
DELTA = float(os.getenv("WATERMARK_DELTA", 2.0))
|
|
|
|
|
|
class WatermarkLogitsProcessor(LogitsProcessor):
|
|
def __init__(
|
|
self,
|
|
gamma: float = GAMMA,
|
|
delta: float = DELTA,
|
|
hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
|
|
device: str = "cpu",
|
|
):
|
|
# watermarking parameters
|
|
self.gamma = gamma
|
|
self.delta = delta
|
|
self.rng = torch.Generator(device=device)
|
|
self.hash_key = hash_key
|
|
|
|
def _seed_rng(self, input_ids: Union[List[int], torch.LongTensor]):
|
|
if isinstance(input_ids, list):
|
|
assert (
|
|
len(input_ids) >= 1
|
|
), "requires at least a 1 token prefix sequence to seed rng"
|
|
prev_token = input_ids[-1]
|
|
else:
|
|
assert len(input_ids) == 1
|
|
input_ids = input_ids[0]
|
|
assert (
|
|
input_ids.shape[-1] >= 1
|
|
), "requires at least a 1 token prefix sequence to seed rng"
|
|
prev_token = input_ids[-1].item()
|
|
self.rng.manual_seed(self.hash_key * prev_token)
|
|
|
|
def _get_greenlist_ids(
|
|
self,
|
|
input_ids: Union[List[int], torch.LongTensor],
|
|
max_value: int,
|
|
device: torch.device,
|
|
) -> List[int]:
|
|
# seed the rng using the previous tokens/prefix
|
|
self._seed_rng(input_ids)
|
|
|
|
greenlist_size = int(max_value * self.gamma)
|
|
vocab_permutation = torch.randperm(max_value, device=device, generator=self.rng)
|
|
greenlist_ids = vocab_permutation[:greenlist_size]
|
|
return greenlist_ids
|
|
|
|
@staticmethod
|
|
def _calc_greenlist_mask(
|
|
scores: torch.FloatTensor, greenlist_token_ids
|
|
) -> torch.BoolTensor:
|
|
green_tokens_mask = torch.zeros_like(scores)
|
|
green_tokens_mask[-1, greenlist_token_ids] = 1
|
|
final_mask = green_tokens_mask.bool()
|
|
return final_mask
|
|
|
|
@staticmethod
|
|
def _bias_greenlist_logits(
|
|
scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float
|
|
) -> torch.Tensor:
|
|
scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
|
|
return scores
|
|
|
|
def __call__(
|
|
self, input_ids: Union[List[int], torch.LongTensor], scores: torch.FloatTensor
|
|
) -> torch.FloatTensor:
|
|
greenlist_ids = self._get_greenlist_ids(
|
|
input_ids, scores.shape[-1], scores.device
|
|
)
|
|
green_tokens_mask = self._calc_greenlist_mask(
|
|
scores=scores, greenlist_token_ids=greenlist_ids
|
|
)
|
|
|
|
scores = self._bias_greenlist_logits(
|
|
scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta
|
|
)
|
|
return scores
|