Use the generation config.

This commit is contained in:
Nicolas Patry 2024-04-25 14:57:20 +02:00
parent 4c698fa6c2
commit 80fda35249
3 changed files with 14 additions and 60 deletions

View File

@ -38,58 +38,6 @@ from text_generation_server.utils.layers import (
) )
class LlamaConfig(PretrainedConfig):
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_scaling=None,
rope_theta=10000.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights) return _load_gqa(config, prefix, weights)

View File

@ -2,14 +2,13 @@ import torch
import torch.distributed import torch.distributed
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer from transformers import AutoConfig, AutoTokenizer, GenerationConfig
from transformers.models.llama import LlamaTokenizer from transformers.models.llama import LlamaTokenizer
from typing import Optional from typing import Optional
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_llama_modeling import ( from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM, FlashLlamaForCausalLM,
LlamaConfig,
) )
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
@ -53,8 +52,13 @@ class FlashLlama(FlashCausalLM):
truncation_side="left", truncation_side="left",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
generation_config = GenerationConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
if isinstance(generation_config.eos_token_id, (list, set)):
tokenizer.eos_token_id = set(tokenizer.eos_token_id)
config = LlamaConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize

View File

@ -1,5 +1,5 @@
import re import re
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Set
import math import math
import torch import torch
@ -143,12 +143,12 @@ class StopSequenceCriteria:
class StoppingCriteria: class StoppingCriteria:
def __init__( def __init__(
self, self,
eos_token_id: int, eos_token_ids: Set[int],
stop_sequence_criterias: List[StopSequenceCriteria], stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens: int = 20, max_new_tokens: int = 20,
ignore_eos_token: bool = False, ignore_eos_token: bool = False,
): ):
self.eos_token_id = eos_token_id self.eos_token_ids = eos_token_ids
self.stop_sequence_criterias = stop_sequence_criterias self.stop_sequence_criterias = stop_sequence_criterias
self.max_new_tokens = max_new_tokens self.max_new_tokens = max_new_tokens
self.current_tokens = 0 self.current_tokens = 0
@ -160,7 +160,7 @@ class StoppingCriteria:
if self.current_tokens >= self.max_new_tokens: if self.current_tokens >= self.max_new_tokens:
return True, FinishReason.FINISH_REASON_LENGTH return True, FinishReason.FINISH_REASON_LENGTH
if not self.ignore_eos_token and last_token == self.eos_token_id: if not self.ignore_eos_token and last_token in self.eos_token_ids:
return True, FinishReason.FINISH_REASON_EOS_TOKEN return True, FinishReason.FINISH_REASON_EOS_TOKEN
if self.stop_sequence_criterias: if self.stop_sequence_criterias:
@ -184,8 +184,10 @@ class StoppingCriteria:
stop_sequence_criterias = [ stop_sequence_criterias = [
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
] ]
eos_token_id = tokenizer.eos_token_id
eos_token_ids: Set[int] = eos_token_id if isinstance(eos_token_id, set) else {eos_token_id}
return StoppingCriteria( return StoppingCriteria(
tokenizer.eos_token_id, eos_token_ids,
stop_sequence_criterias, stop_sequence_criterias,
pb.max_new_tokens, pb.max_new_tokens,
pb.ignore_eos_token, pb.ignore_eos_token,