mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
Fix test_watermark (#124)
Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
parent
757c12dbac
commit
c6739526c6
@ -1,6 +1,9 @@
|
|||||||
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from text_generation_server.utils.tokens import (
|
from text_generation_server.utils.tokens import (
|
||||||
StopSequenceCriteria,
|
StopSequenceCriteria,
|
||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
@ -8,10 +11,8 @@ from text_generation_server.utils.tokens import (
|
|||||||
batch_top_tokens,
|
batch_top_tokens,
|
||||||
make_tokenizer_optional,
|
make_tokenizer_optional,
|
||||||
)
|
)
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def skip_tokenizer_env_var():
|
def skip_tokenizer_env_var():
|
||||||
import os
|
import os
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
# test_watermark_logits_processor.py
|
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||||
|
|
||||||
@ -10,35 +11,71 @@ GAMMA = os.getenv("WATERMARK_GAMMA", 0.5)
|
|||||||
DELTA = os.getenv("WATERMARK_DELTA", 2.0)
|
DELTA = os.getenv("WATERMARK_DELTA", 2.0)
|
||||||
|
|
||||||
|
|
||||||
def test_seed_rng():
|
@pytest.fixture
|
||||||
input_ids = [101, 2036, 3731, 102, 2003, 103]
|
def hpu_device():
|
||||||
processor = WatermarkLogitsProcessor()
|
return torch.device("hpu")
|
||||||
processor._seed_rng(input_ids)
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def input_ids_list():
|
||||||
|
return [101, 2036, 3731, 102, 2003, 103]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def input_ids_tensor(hpu_device):
|
||||||
|
return torch.tensor(
|
||||||
|
[[101, 2036, 3731, 102, 2003, 103]],
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=hpu_device
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def scores(hpu_device):
|
||||||
|
return torch.tensor(
|
||||||
|
[[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]],
|
||||||
|
device=hpu_device
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seed_rng(input_ids_list, hpu_device):
|
||||||
|
processor = WatermarkLogitsProcessor(device=hpu_device)
|
||||||
|
processor._seed_rng(input_ids_list)
|
||||||
assert isinstance(processor.rng, torch.Generator)
|
assert isinstance(processor.rng, torch.Generator)
|
||||||
|
|
||||||
|
|
||||||
def test_get_greenlist_ids():
|
def test_seed_rng_tensor(input_ids_tensor, hpu_device):
|
||||||
input_ids = [101, 2036, 3731, 102, 2003, 103]
|
processor = WatermarkLogitsProcessor(device=hpu_device)
|
||||||
processor = WatermarkLogitsProcessor()
|
processor._seed_rng(input_ids_tensor)
|
||||||
result = processor._get_greenlist_ids(input_ids, 10, torch.device("cpu"))
|
assert isinstance(processor.rng, torch.Generator)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_greenlist_ids(input_ids_list, hpu_device):
|
||||||
|
processor = WatermarkLogitsProcessor(device=hpu_device)
|
||||||
|
result = processor._get_greenlist_ids(input_ids_list, 10, hpu_device)
|
||||||
assert max(result) <= 10
|
assert max(result) <= 10
|
||||||
assert len(result) == int(10 * 0.5)
|
assert len(result) == int(10 * 0.5)
|
||||||
|
|
||||||
|
|
||||||
def test_calc_greenlist_mask():
|
def test_get_greenlist_ids_tensor(input_ids_tensor, hpu_device):
|
||||||
processor = WatermarkLogitsProcessor()
|
processor = WatermarkLogitsProcessor(device=hpu_device)
|
||||||
scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]])
|
result = processor._get_greenlist_ids(input_ids_tensor, 10, hpu_device)
|
||||||
greenlist_token_ids = torch.tensor([2, 3])
|
assert max(result) <= 10
|
||||||
|
assert len(result) == int(10 * 0.5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_calc_greenlist_mask(scores, hpu_device):
|
||||||
|
processor = WatermarkLogitsProcessor(device=hpu_device)
|
||||||
|
greenlist_token_ids = torch.tensor([2, 3], device=hpu_device)
|
||||||
result = processor._calc_greenlist_mask(scores, greenlist_token_ids)
|
result = processor._calc_greenlist_mask(scores, greenlist_token_ids)
|
||||||
assert result.tolist() == [[False, False, False, False], [False, False, True, True]]
|
assert result.tolist() == [[False, False, False, False], [False, False, True, True]]
|
||||||
assert result.shape == scores.shape
|
assert result.shape == scores.shape
|
||||||
|
|
||||||
|
|
||||||
def test_bias_greenlist_logits():
|
def test_bias_greenlist_logits(scores, hpu_device):
|
||||||
processor = WatermarkLogitsProcessor()
|
processor = WatermarkLogitsProcessor(device=hpu_device)
|
||||||
scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]])
|
|
||||||
green_tokens_mask = torch.tensor(
|
green_tokens_mask = torch.tensor(
|
||||||
[[False, False, True, True], [False, False, False, True]]
|
[[False, False, True, True], [False, False, False, True]], device=hpu_device
|
||||||
)
|
)
|
||||||
greenlist_bias = 2.0
|
greenlist_bias = 2.0
|
||||||
result = processor._bias_greenlist_logits(scores, green_tokens_mask, greenlist_bias)
|
result = processor._bias_greenlist_logits(scores, green_tokens_mask, greenlist_bias)
|
||||||
@ -46,9 +83,13 @@ def test_bias_greenlist_logits():
|
|||||||
assert result.shape == scores.shape
|
assert result.shape == scores.shape
|
||||||
|
|
||||||
|
|
||||||
def test_call():
|
def test_call(input_ids_list, scores, hpu_device):
|
||||||
input_ids = [101, 2036, 3731, 102, 2003, 103]
|
processor = WatermarkLogitsProcessor(device=hpu_device)
|
||||||
processor = WatermarkLogitsProcessor()
|
result = processor(input_ids_list, scores)
|
||||||
scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]])
|
assert result.shape == scores.shape
|
||||||
result = processor(input_ids, scores)
|
|
||||||
|
|
||||||
|
def test_call_tensor(input_ids_tensor, scores, hpu_device):
|
||||||
|
processor = WatermarkLogitsProcessor(device=hpu_device)
|
||||||
|
result = processor(input_ids_tensor, scores)
|
||||||
assert result.shape == scores.shape
|
assert result.shape == scores.shape
|
||||||
|
Loading…
Reference in New Issue
Block a user