mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Updating Phi3 (long context).
This commit is contained in:
parent
de079d607a
commit
7fac2978b3
8
Cargo.lock
generated
8
Cargo.lock
generated
@ -3393,7 +3393,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-benchmark"
|
name = "text-generation-benchmark"
|
||||||
version = "2.0.1"
|
version = "2.0.2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"average",
|
"average",
|
||||||
"clap",
|
"clap",
|
||||||
@ -3414,7 +3414,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-client"
|
name = "text-generation-client"
|
||||||
version = "2.0.1"
|
version = "2.0.2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"futures",
|
"futures",
|
||||||
"grpc-metadata",
|
"grpc-metadata",
|
||||||
@ -3430,7 +3430,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-launcher"
|
name = "text-generation-launcher"
|
||||||
version = "2.0.1"
|
version = "2.0.2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"clap",
|
"clap",
|
||||||
"ctrlc",
|
"ctrlc",
|
||||||
@ -3448,7 +3448,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-router"
|
name = "text-generation-router"
|
||||||
version = "2.0.1"
|
version = "2.0.2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"axum",
|
"axum",
|
||||||
|
@ -136,6 +136,7 @@ pub enum Config {
|
|||||||
Phi,
|
Phi,
|
||||||
#[serde(rename = "phi-msft")]
|
#[serde(rename = "phi-msft")]
|
||||||
PhiMsft,
|
PhiMsft,
|
||||||
|
Phi3,
|
||||||
Llama,
|
Llama,
|
||||||
Baichuan,
|
Baichuan,
|
||||||
Gemma,
|
Gemma,
|
||||||
|
@ -1029,10 +1029,10 @@ try:
|
|||||||
scaling_factor = None
|
scaling_factor = None
|
||||||
rope_scaling = _get_rope_config(config)
|
rope_scaling = _get_rope_config(config)
|
||||||
if rope_scaling is not None:
|
if rope_scaling is not None:
|
||||||
scaling_factor = rope_scaling["factor"]
|
|
||||||
if rope_scaling["type"] == "linear":
|
if rope_scaling["type"] == "linear":
|
||||||
pass
|
pass
|
||||||
elif rope_scaling["type"] == "dynamic":
|
elif rope_scaling["type"] == "dynamic":
|
||||||
|
scaling_factor = rope_scaling["factor"]
|
||||||
return DynamicPositionRotaryEmbedding(
|
return DynamicPositionRotaryEmbedding(
|
||||||
dim=dim,
|
dim=dim,
|
||||||
max_position_embeddings=config.max_position_embeddings,
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
@ -1041,6 +1041,7 @@ try:
|
|||||||
scaling_factor=scaling_factor,
|
scaling_factor=scaling_factor,
|
||||||
)
|
)
|
||||||
elif rope_scaling["type"] == "yarn":
|
elif rope_scaling["type"] == "yarn":
|
||||||
|
scaling_factor = rope_scaling["factor"]
|
||||||
return YarnPositionRotaryEmbedding(
|
return YarnPositionRotaryEmbedding(
|
||||||
dim=2 * inv_freq.shape[0],
|
dim=2 * inv_freq.shape[0],
|
||||||
max_position_embeddings=rope_scaling[
|
max_position_embeddings=rope_scaling[
|
||||||
@ -1054,6 +1055,52 @@ try:
|
|||||||
beta_fast=32,
|
beta_fast=32,
|
||||||
beta_slow=1,
|
beta_slow=1,
|
||||||
)
|
)
|
||||||
|
elif rope_scaling["type"] == "su":
|
||||||
|
short_factor = torch.tensor(
|
||||||
|
rope_scaling["short_factor"], dtype=torch.float32, device=device
|
||||||
|
)
|
||||||
|
short_inv_freq = 1.0 / (
|
||||||
|
short_factor
|
||||||
|
* base
|
||||||
|
** (
|
||||||
|
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
|
||||||
|
/ dim
|
||||||
|
)
|
||||||
|
)
|
||||||
|
long_factor = torch.tensor(
|
||||||
|
rope_scaling["long_factor"], dtype=torch.float32, device=device
|
||||||
|
)
|
||||||
|
long_inv_freq = 1.0 / (
|
||||||
|
long_factor
|
||||||
|
* base
|
||||||
|
** (
|
||||||
|
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
|
||||||
|
/ dim
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
original_max_position_embeddings = (
|
||||||
|
config.original_max_position_embeddings
|
||||||
|
)
|
||||||
|
max_position_embeddings = config.max_position_embeddings
|
||||||
|
if max_position_embeddings <= original_max_position_embeddings:
|
||||||
|
scaling_factor = 1.0
|
||||||
|
else:
|
||||||
|
scale = (
|
||||||
|
max_position_embeddings / original_max_position_embeddings
|
||||||
|
)
|
||||||
|
scaling_factor = math.sqrt(
|
||||||
|
1
|
||||||
|
+ math.log(scale)
|
||||||
|
/ math.log(original_max_position_embeddings)
|
||||||
|
)
|
||||||
|
|
||||||
|
return SuRotaryEmbedding(
|
||||||
|
short_inv_freq=short_inv_freq,
|
||||||
|
long_inv_freq=long_inv_freq,
|
||||||
|
scaling_factor=scaling_factor,
|
||||||
|
original_max_position_embeddings=original_max_position_embeddings,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
|
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
|
||||||
@ -1141,6 +1188,49 @@ try:
|
|||||||
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
|
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
|
||||||
return cos.unsqueeze(1), sin.unsqueeze(1)
|
return cos.unsqueeze(1), sin.unsqueeze(1)
|
||||||
|
|
||||||
|
class SuRotaryEmbedding(PositionRotaryEmbedding):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
short_inv_freq,
|
||||||
|
long_inv_freq,
|
||||||
|
scaling_factor,
|
||||||
|
original_max_position_embeddings,
|
||||||
|
):
|
||||||
|
super(PositionRotaryEmbedding, self).__init__()
|
||||||
|
self.short_inv_freq = short_inv_freq
|
||||||
|
self.long_inv_freq = long_inv_freq
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
self.original_max_position_embeddings = original_max_position_embeddings
|
||||||
|
self._seq_len_cached = 0
|
||||||
|
self._cos_cached = None
|
||||||
|
self._sin_cached = None
|
||||||
|
self._cos_k_cached = None
|
||||||
|
self._sin_k_cached = None
|
||||||
|
self.dynamic_args = None
|
||||||
|
|
||||||
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
|
# Reset the tables if the sequence length has changed,
|
||||||
|
# or if we're on a new device (possibly due to tracing for instance)
|
||||||
|
if (
|
||||||
|
seqlen > self._seq_len_cached
|
||||||
|
or self._cos_cached.device != device
|
||||||
|
or self._cos_cached.dtype != dtype
|
||||||
|
):
|
||||||
|
self._seq_len_cached = seqlen
|
||||||
|
if seqlen > self.original_max_position_embeddings:
|
||||||
|
inv_freq = self.long_inv_freq
|
||||||
|
else:
|
||||||
|
inv_freq = self.short_inv_freq
|
||||||
|
t = torch.arange(seqlen, device=device, dtype=inv_freq.dtype)
|
||||||
|
if self.scaling_factor is not None:
|
||||||
|
t /= self.scaling_factor
|
||||||
|
# Don't do einsum, it converts fp32 to fp16
|
||||||
|
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
|
|
||||||
|
freqs = torch.outer(t, inv_freq.to(device=t.device))
|
||||||
|
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||||
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||||
|
|
||||||
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||||
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
||||||
inv_freq = _create_inv_freq(dim, base, device)
|
inv_freq = _create_inv_freq(dim, base, device)
|
||||||
|
Loading…
Reference in New Issue
Block a user