From 3f2542bb6a6df97b617be82398833dcb3d66eca5 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Wed, 5 Apr 2023 19:37:41 +0200 Subject: [PATCH] fix(server): fix escape characters in stop sequence (#155) --- server/tests/utils/test_tokens.py | 9 ++ .../flash_santacoder_modeling.py | 116 +++++++++--------- .../models/flash_santacoder.py | 33 +++-- server/text_generation_server/utils/tokens.py | 1 + 4 files changed, 90 insertions(+), 69 deletions(-) diff --git a/server/tests/utils/test_tokens.py b/server/tests/utils/test_tokens.py index 3883ad97..da0006e4 100644 --- a/server/tests/utils/test_tokens.py +++ b/server/tests/utils/test_tokens.py @@ -14,6 +14,15 @@ def test_stop_sequence_criteria(): assert not criteria("/test; ") +def test_stop_sequence_criteria_escape(): + criteria = StopSequenceCriteria("<|stop|>") + + assert not criteria("<") + assert not criteria("<|stop") + assert criteria("<|stop|>") + assert not criteria("<|stop|> ") + + def test_stopping_criteria(): criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5) assert criteria(65827, "/test") == (False, None) diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index ef073636..799e7054 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -47,12 +47,12 @@ class FastLayerNorm(nn.LayerNorm): class FastLinear(nn.Linear): def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device=None, - dtype=None, + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, ) -> None: super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) @@ -67,10 +67,10 @@ class FastLinear(nn.Linear): class FlashMQAttention(torch.nn.Module): def __init__( - self, - num_heads, - hidden_size, - process_group=None, + self, + num_heads, + hidden_size, + process_group=None, ): super().__init__() self.num_heads = num_heads @@ -86,13 +86,13 @@ class FlashMQAttention(torch.nn.Module): raise NotImplementedError def forward( - self, - hidden_states, - cu_seqlens, - max_s, - layer_past, - layer_past_present_indices, - cu_seqlens_q, + self, + hidden_states, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, ): qkv = self.attn(hidden_states) @@ -162,15 +162,17 @@ class FlashMQAttention(torch.nn.Module): class MLP(nn.Module): - def __init__( - self, act, hidden_size, intermediate_size, process_group=None - ): + def __init__(self, act, hidden_size, intermediate_size, process_group=None): super().__init__() self.act = ( ACT2FN[act] if "gelu" not in act - else lambda x: torch.nn.functional.gelu(x, approximate="tanh" if act in ["gelu_fast", - "gelu_pytorch_tanh"] else None) + else lambda x: torch.nn.functional.gelu( + x, + approximate="tanh" + if act in ["gelu_fast", "gelu_pytorch_tanh"] + else None, + ) ) if process_group is None: @@ -188,13 +190,13 @@ class MLP(nn.Module): class Block(nn.Module): def __init__( - self, - num_heads, - act, - hidden_size, - intermediate_size, - layer_norm_eps, - process_group=None, + self, + num_heads, + act, + hidden_size, + intermediate_size, + layer_norm_eps, + process_group=None, ): super().__init__() self.ln_1 = FastLayerNorm(hidden_size, eps=layer_norm_eps) @@ -212,14 +214,14 @@ class Block(nn.Module): ) def forward( - self, - hidden_states, - residual, - cu_seqlens, - max_s, - layer_past, - layer_past_present_indices, - cu_seqlens_q, + self, + hidden_states, + residual, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, ): hidden_states, residual = self.ln_1(hidden_states, residual) @@ -232,9 +234,7 @@ class Block(nn.Module): cu_seqlens_q, ) - hidden_states, residual = self.ln_2( - hidden_states, residual - ) + hidden_states, residual = self.ln_2(hidden_states, residual) mlp_output = self.mlp(hidden_states) @@ -258,16 +258,16 @@ class FlashSantacoderModel(nn.Module): config.num_attention_heads, config.activation_function, config.hidden_size, - config.n_inner if config.n_inner is not None else 4 * config.hidden_size, + config.n_inner + if config.n_inner is not None + else 4 * config.hidden_size, config.layer_norm_epsilon, process_group, ) for _ in range(config.num_hidden_layers) ] ) - self.ln_f = FastLayerNorm( - config.hidden_size, eps=config.layer_norm_epsilon - ) + self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.head_size = self.h[0].attn.head_size self.num_heads = self.h[0].attn.num_heads @@ -281,12 +281,12 @@ class FlashSantacoderModel(nn.Module): layer.mlp.c_proj.transpose_weight() def forward( - self, - input_ids, - position_ids, - cu_seqlens, - max_s, - past_key_values=None, + self, + input_ids, + position_ids, + cu_seqlens, + max_s, + past_key_values=None, ): hidden_states = self.wte(input_ids) + self.wpe(position_ids) @@ -335,21 +335,19 @@ class FlashSantacoderForCausalLM(nn.Module): self.transformer = FlashSantacoderModel(config, process_group) - self.lm_head = FastLinear( - config.hidden_size, config.vocab_size, bias=False - ) + self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False) def post_load_weights(self): self.transformer.post_load_weights() self.lm_head.transpose_weight() def forward( - self, - input_ids, - position_ids, - cu_seqlens, - max_s, - past_key_values=None, + self, + input_ids, + position_ids, + cu_seqlens, + max_s, + past_key_values=None, ): hidden_states, present = self.transformer( input_ids, position_ids, cu_seqlens, max_s, past_key_values diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index b33d0477..f0207d55 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -9,7 +9,7 @@ from typing import Optional, List from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_santacoder_modeling import ( - FlashSantacoderForCausalLM + FlashSantacoderForCausalLM, ) from text_generation_server.utils import ( weight_files, @@ -37,8 +37,9 @@ class FlashSantacoder(FlashCausalLM): ) config = AutoConfig.from_pretrained( - model_id, revision=revision, - trust_remote_code=True # Needed as the config is not part of Transformers + model_id, + revision=revision, + trust_remote_code=True, # Needed as the config is not part of Transformers ) # We do not use from_pretrained as we modified the model internal module layout @@ -65,8 +66,8 @@ class FlashSantacoder(FlashCausalLM): @staticmethod def load_weights( - model: FlashSantacoderForCausalLM, - filenames: List[Path], + model: FlashSantacoderForCausalLM, + filenames: List[Path], ): for filename in filenames: state_dict = torch.load(filename, map_location="cpu") @@ -91,7 +92,12 @@ class FlashSantacoder(FlashCausalLM): current_parameter_tensor = None if current_parameter_tensor is not None: - if "c_fc.weight" in key or "c_proj.weight" in key or "q_attn.weight" in key or "kv_attn.weight" in key: + if ( + "c_fc.weight" in key + or "c_proj.weight" in key + or "q_attn.weight" in key + or "kv_attn.weight" in key + ): # Tranpose as we use nn.Linear instead of Conv1D value = value.T @@ -99,11 +105,18 @@ class FlashSantacoder(FlashCausalLM): # Init qkv if "attn.weight" in final_key: module._parameters[param_name] = value.new_empty( - (model.transformer.head_size * (model.transformer.num_heads + 2), value.shape[1]) + ( + model.transformer.head_size + * (model.transformer.num_heads + 2), + value.shape[1], + ) ) elif "attn.bias" in final_key: module._parameters[param_name] = value.new_empty( - (model.transformer.head_size * (model.transformer.num_heads + 2)) + ( + model.transformer.head_size + * (model.transformer.num_heads + 2) + ) ) # Copy to correct slice @@ -113,11 +126,11 @@ class FlashSantacoder(FlashCausalLM): module._parameters[param_name][: value.shape[0]] = value elif "kv_attn.weight" in key: module._parameters[param_name][ - model.transformer.head_size * model.transformer.num_heads: + model.transformer.head_size * model.transformer.num_heads : ] = value elif "kv_attn.bias" in key: module._parameters[param_name][ - model.transformer.head_size * model.transformer.num_heads: + model.transformer.head_size * model.transformer.num_heads : ] = value else: if current_parameter_tensor.shape != value.shape: diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index b923857d..23f504c6 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -110,6 +110,7 @@ class NextTokenChooser: class StopSequenceCriteria: def __init__(self, stop_sequence: str): + stop_sequence = re.escape(stop_sequence) self.regex = re.compile(f".*{stop_sequence}$") def __call__(self, output: str) -> bool: