mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
fix(server): fix generate_stream by forcing tokens to be decoded correctly (#100)
This commit is contained in:
parent
1c19b0934e
commit
9b205d33cc
@ -14,7 +14,7 @@
|
|||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 259,
|
"id": 259,
|
||||||
"text": "",
|
"text": " ",
|
||||||
"logprob": -1.3656927,
|
"logprob": -1.3656927,
|
||||||
"special": false
|
"special": false
|
||||||
},
|
},
|
||||||
@ -32,13 +32,13 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 287,
|
"id": 287,
|
||||||
"text": "the",
|
"text": " the",
|
||||||
"logprob": -1.2102449,
|
"logprob": -1.2102449,
|
||||||
"special": false
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 259,
|
"id": 259,
|
||||||
"text": "",
|
"text": " ",
|
||||||
"logprob": -1.6057279,
|
"logprob": -1.6057279,
|
||||||
"special": false
|
"special": false
|
||||||
},
|
},
|
||||||
@ -50,19 +50,19 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 304,
|
"id": 304,
|
||||||
"text": "of",
|
"text": " of",
|
||||||
"logprob": -0.5270343,
|
"logprob": -0.5270343,
|
||||||
"special": false
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 287,
|
"id": 287,
|
||||||
"text": "the",
|
"text": " the",
|
||||||
"logprob": -0.62522805,
|
"logprob": -0.62522805,
|
||||||
"special": false
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 259,
|
"id": 259,
|
||||||
"text": "",
|
"text": " ",
|
||||||
"logprob": -1.4069618,
|
"logprob": -1.4069618,
|
||||||
"special": false
|
"special": false
|
||||||
},
|
},
|
||||||
@ -74,19 +74,19 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 304,
|
"id": 304,
|
||||||
"text": "of",
|
"text": " of",
|
||||||
"logprob": -1.3172221,
|
"logprob": -1.3172221,
|
||||||
"special": false
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 287,
|
"id": 287,
|
||||||
"text": "the",
|
"text": " the",
|
||||||
"logprob": -0.3501925,
|
"logprob": -0.3501925,
|
||||||
"special": false
|
"special": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 259,
|
"id": 259,
|
||||||
"text": "",
|
"text": " ",
|
||||||
"logprob": -0.7219573,
|
"logprob": -0.7219573,
|
||||||
"special": false
|
"special": false
|
||||||
},
|
},
|
||||||
@ -104,7 +104,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 259,
|
"id": 259,
|
||||||
"text": "",
|
"text": " ",
|
||||||
"logprob": -0.32933083,
|
"logprob": -0.32933083,
|
||||||
"special": false
|
"special": false
|
||||||
},
|
},
|
||||||
@ -116,7 +116,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2978,
|
"id": 2978,
|
||||||
"text": "test",
|
"text": " test",
|
||||||
"logprob": -1.5846587,
|
"logprob": -1.5846587,
|
||||||
"special": false
|
"special": false
|
||||||
},
|
},
|
||||||
|
@ -148,7 +148,7 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
|
|||||||
assert all([generation.generated_text is None for generation in generations])
|
assert all([generation.generated_text is None for generation in generations])
|
||||||
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
|
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
|
||||||
assert all([generation.token_id.item() == 259 for generation in generations])
|
assert all([generation.token_id.item() == 259 for generation in generations])
|
||||||
assert all([generation.token_text == "" for generation in generations])
|
assert all([generation.token_text == " " for generation in generations])
|
||||||
assert generations[0].request_id == 0
|
assert generations[0].request_id == 0
|
||||||
|
|
||||||
|
|
||||||
|
@ -385,10 +385,8 @@ class CausalLM(Model):
|
|||||||
# Generated token
|
# Generated token
|
||||||
next_token_logprob = logprobs[-1, next_token_id]
|
next_token_logprob = logprobs[-1, next_token_id]
|
||||||
next_token_id_squeezed = next_token_id.squeeze()
|
next_token_id_squeezed = next_token_id.squeeze()
|
||||||
next_token_text = self.tokenizer.decode(
|
next_token_text = self.decode_token(
|
||||||
next_token_id_squeezed,
|
next_token_id_squeezed,
|
||||||
clean_up_tokenization_spaces=False,
|
|
||||||
skip_special_tokens=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
|
@ -15,6 +15,15 @@ class Model(ABC):
|
|||||||
self.all_special_ids = set(tokenizer.all_special_ids)
|
self.all_special_ids = set(tokenizer.all_special_ids)
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
|
# see `decode_token` method
|
||||||
|
self.tokenizer.add_special_tokens(
|
||||||
|
{"additional_special_tokens": ["<decode-token>"]}
|
||||||
|
)
|
||||||
|
self.special_decode_token_id = self.tokenizer.convert_tokens_to_ids(
|
||||||
|
"<decode-token>"
|
||||||
|
)
|
||||||
|
self.special_decode_token_length = len("<decode-token>")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def batch_type(self) -> Type[B]:
|
def batch_type(self) -> Type[B]:
|
||||||
@ -23,3 +32,12 @@ class Model(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
|
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def decode_token(self, token_id: int) -> str:
|
||||||
|
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
||||||
|
# append token to special decode token and decode both
|
||||||
|
result = self.tokenizer.decode(
|
||||||
|
[self.special_decode_token_id, token_id], skip_special_tokens=False
|
||||||
|
)
|
||||||
|
# slice to remove special decode token
|
||||||
|
return result[self.special_decode_token_length :]
|
||||||
|
@ -342,7 +342,9 @@ class Seq2SeqLM(Model):
|
|||||||
return Seq2SeqLMBatch
|
return Seq2SeqLMBatch
|
||||||
|
|
||||||
def decode(self, decoder_ids: List[int]) -> str:
|
def decode(self, decoder_ids: List[int]) -> str:
|
||||||
return self.tokenizer.decode(decoder_ids, skip_special_tokens=True)
|
return self.tokenizer.decode(
|
||||||
|
decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -457,10 +459,8 @@ class Seq2SeqLM(Model):
|
|||||||
# Generated token
|
# Generated token
|
||||||
next_token_logprob = logprobs[-1, next_token_id]
|
next_token_logprob = logprobs[-1, next_token_id]
|
||||||
next_token_id_squeezed = next_token_id.squeeze()
|
next_token_id_squeezed = next_token_id.squeeze()
|
||||||
next_token_text = self.tokenizer.decode(
|
next_token_text = self.decode_token(
|
||||||
next_token_id_squeezed,
|
next_token_id_squeezed,
|
||||||
clean_up_tokenization_spaces=False,
|
|
||||||
skip_special_tokens=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
|
@ -24,12 +24,12 @@ DELTA = os.getenv("WATERMARK_DELTA", 2.0)
|
|||||||
|
|
||||||
class WatermarkLogitsProcessor(LogitsProcessor):
|
class WatermarkLogitsProcessor(LogitsProcessor):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
gamma: float = GAMMA,
|
gamma: float = GAMMA,
|
||||||
delta: float = DELTA,
|
delta: float = DELTA,
|
||||||
hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
|
hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
|
||||||
device: str = "cpu",
|
device: str = "cpu",
|
||||||
):
|
):
|
||||||
# watermarking parameters
|
# watermarking parameters
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
@ -40,7 +40,7 @@ class WatermarkLogitsProcessor(LogitsProcessor):
|
|||||||
|
|
||||||
def _seed_rng(self, input_ids: torch.LongTensor) -> None:
|
def _seed_rng(self, input_ids: torch.LongTensor) -> None:
|
||||||
assert (
|
assert (
|
||||||
input_ids.shape[-1] >= 1
|
input_ids.shape[-1] >= 1
|
||||||
), "requires at least a 1 token prefix sequence to seed rng"
|
), "requires at least a 1 token prefix sequence to seed rng"
|
||||||
prev_token = input_ids[-1].item()
|
prev_token = input_ids[-1].item()
|
||||||
self.rng.manual_seed(self.hash_key * prev_token)
|
self.rng.manual_seed(self.hash_key * prev_token)
|
||||||
@ -58,7 +58,7 @@ class WatermarkLogitsProcessor(LogitsProcessor):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _calc_greenlist_mask(
|
def _calc_greenlist_mask(
|
||||||
scores: torch.FloatTensor, greenlist_token_ids
|
scores: torch.FloatTensor, greenlist_token_ids
|
||||||
) -> torch.BoolTensor:
|
) -> torch.BoolTensor:
|
||||||
green_tokens_mask = torch.zeros_like(scores)
|
green_tokens_mask = torch.zeros_like(scores)
|
||||||
green_tokens_mask[-1, greenlist_token_ids] = 1
|
green_tokens_mask[-1, greenlist_token_ids] = 1
|
||||||
@ -67,13 +67,13 @@ class WatermarkLogitsProcessor(LogitsProcessor):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _bias_greenlist_logits(
|
def _bias_greenlist_logits(
|
||||||
scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float
|
scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
|
scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
assert len(input_ids) == 1
|
assert len(input_ids) == 1
|
||||||
greenlist_ids = self._get_greenlist_ids(input_ids[0])
|
greenlist_ids = self._get_greenlist_ids(input_ids[0])
|
||||||
|
Loading…
Reference in New Issue
Block a user