diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 6d796ac3..6fa85d4e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -38,58 +38,6 @@ from text_generation_server.utils.layers import ( ) -class LlamaConfig(PretrainedConfig): - def __init__( - self, - vocab_size=32000, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=None, - hidden_act="silu", - max_position_embeddings=2048, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - pad_token_id=0, - bos_token_id=1, - eos_token_id=2, - pretraining_tp=1, - tie_word_embeddings=False, - rope_scaling=None, - rope_theta=10000.0, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.pretraining_tp = pretraining_tp - self.use_cache = use_cache - self.rope_scaling = rope_scaling - self.rope_theta = rope_theta - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - def load_attention(config, prefix, weights): if config.num_attention_heads != config.num_key_value_heads: return _load_gqa(config, prefix, weights) diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 56768942..612a071d 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -2,14 +2,13 @@ import torch import torch.distributed from opentelemetry import trace -from transformers import AutoConfig, AutoTokenizer +from transformers import AutoConfig, AutoTokenizer, GenerationConfig from transformers.models.llama import LlamaTokenizer from typing import Optional from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_llama_modeling import ( FlashLlamaForCausalLM, - LlamaConfig, ) from text_generation_server.utils import ( initialize_torch_distributed, @@ -53,8 +52,13 @@ class FlashLlama(FlashCausalLM): truncation_side="left", trust_remote_code=trust_remote_code, ) + generation_config = GenerationConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + if isinstance(generation_config.eos_token_id, (list, set)): + tokenizer.eos_token_id = set(tokenizer.eos_token_id) - config = LlamaConfig.from_pretrained( + config = AutoConfig.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 7c8a18f0..e1d00aa1 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -1,5 +1,5 @@ import re -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Set import math import torch @@ -143,12 +143,12 @@ class StopSequenceCriteria: class StoppingCriteria: def __init__( self, - eos_token_id: int, + eos_token_ids: Set[int], stop_sequence_criterias: List[StopSequenceCriteria], max_new_tokens: int = 20, ignore_eos_token: bool = False, ): - self.eos_token_id = eos_token_id + self.eos_token_ids = eos_token_ids self.stop_sequence_criterias = stop_sequence_criterias self.max_new_tokens = max_new_tokens self.current_tokens = 0 @@ -160,7 +160,7 @@ class StoppingCriteria: if self.current_tokens >= self.max_new_tokens: return True, FinishReason.FINISH_REASON_LENGTH - if not self.ignore_eos_token and last_token == self.eos_token_id: + if not self.ignore_eos_token and last_token in self.eos_token_ids: return True, FinishReason.FINISH_REASON_EOS_TOKEN if self.stop_sequence_criterias: @@ -184,8 +184,10 @@ class StoppingCriteria: stop_sequence_criterias = [ StopSequenceCriteria(sequence) for sequence in pb.stop_sequences ] + eos_token_id = tokenizer.eos_token_id + eos_token_ids: Set[int] = eos_token_id if isinstance(eos_token_id, set) else {eos_token_id} return StoppingCriteria( - tokenizer.eos_token_id, + eos_token_ids, stop_sequence_criterias, pb.max_new_tokens, pb.ignore_eos_token,