mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix: cleanup config, remove unused values and fix non flash init
This commit is contained in:
parent
9bcd21a0b0
commit
6134f0108d
@ -242,7 +242,13 @@ def get_model(
|
||||
use_medusa=use_medusa,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Phi"))
|
||||
return CausalLM(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
elif model_type == "phi-msft":
|
||||
if FLASH_ATTENTION:
|
||||
|
@ -26,33 +26,23 @@ class PhiConfig(PretrainedConfig):
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=32,
|
||||
hidden_act="gelu_fast", # llama uses silu
|
||||
max_position_embeddings=2048,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-05, # rms in llama
|
||||
use_cache=True,
|
||||
layer_norm_eps=1e-05, # rms in llama,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
pretraining_tp=1,
|
||||
tie_word_embeddings=False,
|
||||
rope_scaling=None,
|
||||
rope_theta=10000.0,
|
||||
resid_pdrop=0.1, # llama doesn't have this
|
||||
partial_rotary_factor=0.5,
|
||||
partial_rotary_factor=0.5, # important difference between llama and phi
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.pretraining_tp = pretraining_tp
|
||||
self.use_cache = use_cache
|
||||
self.rope_scaling = rope_scaling
|
||||
self.rope_theta = rope_theta
|
||||
self.resid_pdrop = resid_pdrop
|
||||
self.partial_rotary_factor = partial_rotary_factor
|
||||
@ -116,16 +106,16 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
self.rotary_dim = int(config.partial_rotary_factor * self.head_size)
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.num_heads,
|
||||
dim=self.rotary_dim,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
self.rotary_dim = int(config.partial_rotary_factor * self.head_size)
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
|
Loading…
Reference in New Issue
Block a user