mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix: cleanup config, remove unused values and fix non flash init
This commit is contained in:
parent
9bcd21a0b0
commit
6134f0108d
@ -242,7 +242,13 @@ def get_model(
|
|||||||
use_medusa=use_medusa,
|
use_medusa=use_medusa,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Phi"))
|
return CausalLM(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
elif model_type == "phi-msft":
|
elif model_type == "phi-msft":
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
|
@ -25,34 +25,24 @@ class PhiConfig(PretrainedConfig):
|
|||||||
num_hidden_layers=32,
|
num_hidden_layers=32,
|
||||||
num_attention_heads=32,
|
num_attention_heads=32,
|
||||||
num_key_value_heads=32,
|
num_key_value_heads=32,
|
||||||
hidden_act="gelu_fast", # llama uses silu
|
hidden_act="gelu_fast", # llama uses silu
|
||||||
max_position_embeddings=2048,
|
layer_norm_eps=1e-05, # rms in llama,
|
||||||
initializer_range=0.02,
|
|
||||||
layer_norm_eps=1e-05, # rms in llama
|
|
||||||
use_cache=True,
|
|
||||||
pad_token_id=0,
|
pad_token_id=0,
|
||||||
bos_token_id=1,
|
bos_token_id=1,
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
pretraining_tp=1,
|
|
||||||
tie_word_embeddings=False,
|
tie_word_embeddings=False,
|
||||||
rope_scaling=None,
|
|
||||||
rope_theta=10000.0,
|
rope_theta=10000.0,
|
||||||
resid_pdrop=0.1, # llama doesn't have this
|
resid_pdrop=0.1, # llama doesn't have this
|
||||||
partial_rotary_factor=0.5,
|
partial_rotary_factor=0.5, # important difference between llama and phi
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.num_hidden_layers = num_hidden_layers
|
self.num_hidden_layers = num_hidden_layers
|
||||||
self.num_attention_heads = num_attention_heads
|
self.num_attention_heads = num_attention_heads
|
||||||
self.num_key_value_heads = num_key_value_heads
|
self.num_key_value_heads = num_key_value_heads
|
||||||
self.hidden_act = hidden_act
|
self.hidden_act = hidden_act
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.layer_norm_eps = layer_norm_eps
|
self.layer_norm_eps = layer_norm_eps
|
||||||
self.pretraining_tp = pretraining_tp
|
|
||||||
self.use_cache = use_cache
|
|
||||||
self.rope_scaling = rope_scaling
|
|
||||||
self.rope_theta = rope_theta
|
self.rope_theta = rope_theta
|
||||||
self.resid_pdrop = resid_pdrop
|
self.resid_pdrop = resid_pdrop
|
||||||
self.partial_rotary_factor = partial_rotary_factor
|
self.partial_rotary_factor = partial_rotary_factor
|
||||||
@ -116,16 +106,16 @@ class FlashPhiAttention(torch.nn.Module):
|
|||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.head_size = self.hidden_size // self.num_heads
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
|
||||||
|
self.softmax_scale = self.head_size**-0.5
|
||||||
|
self.rotary_dim = int(config.partial_rotary_factor * self.head_size)
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
config=config,
|
config=config,
|
||||||
dim=self.num_heads,
|
dim=self.rotary_dim,
|
||||||
base=config.rope_theta,
|
base=config.rope_theta,
|
||||||
device=weights.device,
|
device=weights.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.softmax_scale = self.head_size**-0.5
|
|
||||||
self.rotary_dim = int(config.partial_rotary_factor * self.head_size)
|
|
||||||
|
|
||||||
if self.num_heads % weights.process_group.size() != 0:
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
Loading…
Reference in New Issue
Block a user