fix: cleanup config, remove unused values and fix non flash init

This commit is contained in:
drbh 2024-01-24 17:38:17 +00:00
parent 9bcd21a0b0
commit 6134f0108d
2 changed files with 15 additions and 19 deletions

View File

@ -242,7 +242,13 @@ def get_model(
use_medusa=use_medusa,
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Phi"))
return CausalLM(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == "phi-msft":
if FLASH_ATTENTION:

View File

@ -25,34 +25,24 @@ class PhiConfig(PretrainedConfig):
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
hidden_act="gelu_fast", # llama uses silu
max_position_embeddings=2048,
initializer_range=0.02,
layer_norm_eps=1e-05, # rms in llama
use_cache=True,
hidden_act="gelu_fast", # llama uses silu
layer_norm_eps=1e-05, # rms in llama,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_scaling=None,
rope_theta=10000.0,
resid_pdrop=0.1, # llama doesn't have this
partial_rotary_factor=0.5,
resid_pdrop=0.1, # llama doesn't have this
partial_rotary_factor=0.5, # important difference between llama and phi
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
self.resid_pdrop = resid_pdrop
self.partial_rotary_factor = partial_rotary_factor
@ -116,16 +106,16 @@ class FlashPhiAttention(torch.nn.Module):
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.softmax_scale = self.head_size**-0.5
self.rotary_dim = int(config.partial_rotary_factor * self.head_size)
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.num_heads,
dim=self.rotary_dim,
base=config.rope_theta,
device=weights.device,
)
self.softmax_scale = self.head_size**-0.5
self.rotary_dim = int(config.partial_rotary_factor * self.head_size)
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "