remove vocab size

This commit is contained in:
OlivierDehaene 2023-03-08 18:17:34 +01:00
parent 7a6a7ed27b
commit 4abf27ce81
6 changed files with 15 additions and 15 deletions

View File

@ -73,7 +73,7 @@ class CausalLMBatch(Batch):
inputs.append(r.inputs)
input_lengths.append(r.input_length)
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, len(tokenizer), device)
NextTokenChooser.from_pb(r.parameters, device)
)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer

View File

@ -103,7 +103,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
inputs.append(escape_custom_split_sequence(r.inputs))
input_lengths.append(r.input_length)
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, len(tokenizer), device)
NextTokenChooser.from_pb(r.parameters, device)
)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer

View File

@ -16,8 +16,13 @@ class Model(ABC):
self.device = device
# see `decode_token` method
self.special_decode_token_id = self.tokenizer.pad_token_id
self.special_decode_token_length = len(self.tokenizer.pad_token)
self.tokenizer.add_special_tokens(
{"additional_special_tokens": ["<decode-token>"]}
)
self.special_decode_token_id = self.tokenizer.convert_tokens_to_ids(
"<decode-token>"
)
self.special_decode_token_length = len("<decode-token>")
@property
@abstractmethod

View File

@ -83,7 +83,7 @@ class Seq2SeqLMBatch(Batch):
decoder_input_ids.append(tokenizer.bos_token_id)
decoder_input_lengths.append(1)
next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, len(tokenizer), device)
NextTokenChooser.from_pb(r.parameters, device)
)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer

View File

@ -36,7 +36,6 @@ class Greedy:
class NextTokenChooser:
def __init__(
self,
vocab_size,
watermark=False,
temperature=1.0,
repetition_penalty=1.0,
@ -52,7 +51,7 @@ class NextTokenChooser:
sampling = do_sample
if watermark:
warpers.append(WatermarkLogitsProcessor(vocab_size, device=device))
warpers.append(WatermarkLogitsProcessor(device=device))
if repetition_penalty is not None and repetition_penalty != 1.0:
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
if temperature is not None and temperature != 1.0:
@ -85,11 +84,9 @@ class NextTokenChooser:
def from_pb(
cls,
pb: generate_pb2.NextTokenChooserParameters,
vocab_size: int,
device: torch.device,
) -> "NextTokenChooser":
return NextTokenChooser(
vocab_size=vocab_size,
watermark=pb.watermark,
temperature=pb.temperature,
repetition_penalty=pb.repetition_penalty,

View File

@ -25,14 +25,12 @@ DELTA = os.getenv("WATERMARK_DELTA", 2.0)
class WatermarkLogitsProcessor(LogitsProcessor):
def __init__(
self,
vocab_size: int,
gamma: float = GAMMA,
delta: float = DELTA,
hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
device: str = "cpu",
):
# watermarking parameters
self.vocab_size = vocab_size
self.gamma = gamma
self.delta = delta
self.rng = torch.Generator(device=device)
@ -45,13 +43,13 @@ class WatermarkLogitsProcessor(LogitsProcessor):
prev_token = input_ids[-1].item()
self.rng.manual_seed(self.hash_key * prev_token)
def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> list[int]:
def _get_greenlist_ids(self, input_ids: torch.LongTensor, max_value: int) -> list[int]:
# seed the rng using the previous tokens/prefix
self._seed_rng(input_ids)
greenlist_size = int(self.vocab_size * self.gamma)
greenlist_size = int(max_value * self.gamma)
vocab_permutation = torch.randperm(
self.vocab_size, device=input_ids.device, generator=self.rng
max_value, device=input_ids.device, generator=self.rng
)
greenlist_ids = vocab_permutation[:greenlist_size]
return greenlist_ids
@ -76,7 +74,7 @@ class WatermarkLogitsProcessor(LogitsProcessor):
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
assert len(input_ids) == 1
greenlist_ids = self._get_greenlist_ids(input_ids[0])
greenlist_ids = self._get_greenlist_ids(input_ids[0], scores.shape[-1])
green_tokens_mask = self._calc_greenlist_mask(
scores=scores, greenlist_token_ids=greenlist_ids
)