Fix test_watermark (#124)

Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
Karol Damaszke 2024-04-09 11:29:21 +02:00 committed by GitHub
parent 757c12dbac
commit c6739526c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 66 additions and 24 deletions

View File

@ -1,6 +1,9 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. # Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
import pytest
import torch import torch
from transformers import AutoTokenizer
from text_generation_server.utils.tokens import ( from text_generation_server.utils.tokens import (
StopSequenceCriteria, StopSequenceCriteria,
StoppingCriteria, StoppingCriteria,
@ -8,10 +11,8 @@ from text_generation_server.utils.tokens import (
batch_top_tokens, batch_top_tokens,
make_tokenizer_optional, make_tokenizer_optional,
) )
from transformers import AutoTokenizer
import pytest
@pytest.fixture @pytest.fixture
def skip_tokenizer_env_var(): def skip_tokenizer_env_var():
import os import os

View File

@ -1,7 +1,8 @@
# test_watermark_logits_processor.py # Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
import os import os
import numpy as np import numpy as np
import pytest
import torch import torch
from text_generation_server.utils.watermark import WatermarkLogitsProcessor from text_generation_server.utils.watermark import WatermarkLogitsProcessor
@ -10,35 +11,71 @@ GAMMA = os.getenv("WATERMARK_GAMMA", 0.5)
DELTA = os.getenv("WATERMARK_DELTA", 2.0) DELTA = os.getenv("WATERMARK_DELTA", 2.0)
def test_seed_rng(): @pytest.fixture
input_ids = [101, 2036, 3731, 102, 2003, 103] def hpu_device():
processor = WatermarkLogitsProcessor() return torch.device("hpu")
processor._seed_rng(input_ids)
@pytest.fixture
def input_ids_list():
return [101, 2036, 3731, 102, 2003, 103]
@pytest.fixture
def input_ids_tensor(hpu_device):
return torch.tensor(
[[101, 2036, 3731, 102, 2003, 103]],
dtype=torch.int64,
device=hpu_device
)
@pytest.fixture
def scores(hpu_device):
return torch.tensor(
[[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]],
device=hpu_device
)
def test_seed_rng(input_ids_list, hpu_device):
processor = WatermarkLogitsProcessor(device=hpu_device)
processor._seed_rng(input_ids_list)
assert isinstance(processor.rng, torch.Generator) assert isinstance(processor.rng, torch.Generator)
def test_get_greenlist_ids(): def test_seed_rng_tensor(input_ids_tensor, hpu_device):
input_ids = [101, 2036, 3731, 102, 2003, 103] processor = WatermarkLogitsProcessor(device=hpu_device)
processor = WatermarkLogitsProcessor() processor._seed_rng(input_ids_tensor)
result = processor._get_greenlist_ids(input_ids, 10, torch.device("cpu")) assert isinstance(processor.rng, torch.Generator)
def test_get_greenlist_ids(input_ids_list, hpu_device):
processor = WatermarkLogitsProcessor(device=hpu_device)
result = processor._get_greenlist_ids(input_ids_list, 10, hpu_device)
assert max(result) <= 10 assert max(result) <= 10
assert len(result) == int(10 * 0.5) assert len(result) == int(10 * 0.5)
def test_calc_greenlist_mask(): def test_get_greenlist_ids_tensor(input_ids_tensor, hpu_device):
processor = WatermarkLogitsProcessor() processor = WatermarkLogitsProcessor(device=hpu_device)
scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]]) result = processor._get_greenlist_ids(input_ids_tensor, 10, hpu_device)
greenlist_token_ids = torch.tensor([2, 3]) assert max(result) <= 10
assert len(result) == int(10 * 0.5)
def test_calc_greenlist_mask(scores, hpu_device):
processor = WatermarkLogitsProcessor(device=hpu_device)
greenlist_token_ids = torch.tensor([2, 3], device=hpu_device)
result = processor._calc_greenlist_mask(scores, greenlist_token_ids) result = processor._calc_greenlist_mask(scores, greenlist_token_ids)
assert result.tolist() == [[False, False, False, False], [False, False, True, True]] assert result.tolist() == [[False, False, False, False], [False, False, True, True]]
assert result.shape == scores.shape assert result.shape == scores.shape
def test_bias_greenlist_logits(): def test_bias_greenlist_logits(scores, hpu_device):
processor = WatermarkLogitsProcessor() processor = WatermarkLogitsProcessor(device=hpu_device)
scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]])
green_tokens_mask = torch.tensor( green_tokens_mask = torch.tensor(
[[False, False, True, True], [False, False, False, True]] [[False, False, True, True], [False, False, False, True]], device=hpu_device
) )
greenlist_bias = 2.0 greenlist_bias = 2.0
result = processor._bias_greenlist_logits(scores, green_tokens_mask, greenlist_bias) result = processor._bias_greenlist_logits(scores, green_tokens_mask, greenlist_bias)
@ -46,9 +83,13 @@ def test_bias_greenlist_logits():
assert result.shape == scores.shape assert result.shape == scores.shape
def test_call(): def test_call(input_ids_list, scores, hpu_device):
input_ids = [101, 2036, 3731, 102, 2003, 103] processor = WatermarkLogitsProcessor(device=hpu_device)
processor = WatermarkLogitsProcessor() result = processor(input_ids_list, scores)
scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]]) assert result.shape == scores.shape
result = processor(input_ids, scores)
def test_call_tensor(input_ids_tensor, scores, hpu_device):
processor = WatermarkLogitsProcessor(device=hpu_device)
result = processor(input_ids_tensor, scores)
assert result.shape == scores.shape assert result.shape == scores.shape