Updating Phi3 (long context).

This commit is contained in:
Nicolas Patry 2024-05-02 14:14:24 +00:00
parent de079d607a
commit 7fac2978b3
3 changed files with 96 additions and 5 deletions

8
Cargo.lock generated
View File

@ -3393,7 +3393,7 @@ dependencies = [
[[package]]
name = "text-generation-benchmark"
version = "2.0.1"
version = "2.0.2"
dependencies = [
"average",
"clap",
@ -3414,7 +3414,7 @@ dependencies = [
[[package]]
name = "text-generation-client"
version = "2.0.1"
version = "2.0.2"
dependencies = [
"futures",
"grpc-metadata",
@ -3430,7 +3430,7 @@ dependencies = [
[[package]]
name = "text-generation-launcher"
version = "2.0.1"
version = "2.0.2"
dependencies = [
"clap",
"ctrlc",
@ -3448,7 +3448,7 @@ dependencies = [
[[package]]
name = "text-generation-router"
version = "2.0.1"
version = "2.0.2"
dependencies = [
"async-stream",
"axum",

View File

@ -136,6 +136,7 @@ pub enum Config {
Phi,
#[serde(rename = "phi-msft")]
PhiMsft,
Phi3,
Llama,
Baichuan,
Gemma,

View File

@ -1029,10 +1029,10 @@ try:
scaling_factor = None
rope_scaling = _get_rope_config(config)
if rope_scaling is not None:
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "linear":
pass
elif rope_scaling["type"] == "dynamic":
scaling_factor = rope_scaling["factor"]
return DynamicPositionRotaryEmbedding(
dim=dim,
max_position_embeddings=config.max_position_embeddings,
@ -1041,6 +1041,7 @@ try:
scaling_factor=scaling_factor,
)
elif rope_scaling["type"] == "yarn":
scaling_factor = rope_scaling["factor"]
return YarnPositionRotaryEmbedding(
dim=2 * inv_freq.shape[0],
max_position_embeddings=rope_scaling[
@ -1054,6 +1055,52 @@ try:
beta_fast=32,
beta_slow=1,
)
elif rope_scaling["type"] == "su":
short_factor = torch.tensor(
rope_scaling["short_factor"], dtype=torch.float32, device=device
)
short_inv_freq = 1.0 / (
short_factor
* base
** (
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
/ dim
)
)
long_factor = torch.tensor(
rope_scaling["long_factor"], dtype=torch.float32, device=device
)
long_inv_freq = 1.0 / (
long_factor
* base
** (
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
/ dim
)
)
original_max_position_embeddings = (
config.original_max_position_embeddings
)
max_position_embeddings = config.max_position_embeddings
if max_position_embeddings <= original_max_position_embeddings:
scaling_factor = 1.0
else:
scale = (
max_position_embeddings / original_max_position_embeddings
)
scaling_factor = math.sqrt(
1
+ math.log(scale)
/ math.log(original_max_position_embeddings)
)
return SuRotaryEmbedding(
short_inv_freq=short_inv_freq,
long_inv_freq=long_inv_freq,
scaling_factor=scaling_factor,
original_max_position_embeddings=original_max_position_embeddings,
)
else:
raise NotImplementedError(
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
@ -1141,6 +1188,49 @@ try:
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
return cos.unsqueeze(1), sin.unsqueeze(1)
class SuRotaryEmbedding(PositionRotaryEmbedding):
def __init__(
self,
short_inv_freq,
long_inv_freq,
scaling_factor,
original_max_position_embeddings,
):
super(PositionRotaryEmbedding, self).__init__()
self.short_inv_freq = short_inv_freq
self.long_inv_freq = long_inv_freq
self.scaling_factor = scaling_factor
self.original_max_position_embeddings = original_max_position_embeddings
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
self.dynamic_args = None
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
self._seq_len_cached = seqlen
if seqlen > self.original_max_position_embeddings:
inv_freq = self.long_inv_freq
else:
inv_freq = self.short_inv_freq
t = torch.arange(seqlen, device=device, dtype=inv_freq.dtype)
if self.scaling_factor is not None:
t /= self.scaling_factor
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, inv_freq.to(device=t.device))
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
inv_freq = _create_inv_freq(dim, base, device)