mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Updating Phi3 (long context).
This commit is contained in:
parent
de079d607a
commit
7fac2978b3
8
Cargo.lock
generated
8
Cargo.lock
generated
@ -3393,7 +3393,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-benchmark"
|
||||
version = "2.0.1"
|
||||
version = "2.0.2"
|
||||
dependencies = [
|
||||
"average",
|
||||
"clap",
|
||||
@ -3414,7 +3414,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-client"
|
||||
version = "2.0.1"
|
||||
version = "2.0.2"
|
||||
dependencies = [
|
||||
"futures",
|
||||
"grpc-metadata",
|
||||
@ -3430,7 +3430,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-launcher"
|
||||
version = "2.0.1"
|
||||
version = "2.0.2"
|
||||
dependencies = [
|
||||
"clap",
|
||||
"ctrlc",
|
||||
@ -3448,7 +3448,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-router"
|
||||
version = "2.0.1"
|
||||
version = "2.0.2"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"axum",
|
||||
|
@ -136,6 +136,7 @@ pub enum Config {
|
||||
Phi,
|
||||
#[serde(rename = "phi-msft")]
|
||||
PhiMsft,
|
||||
Phi3,
|
||||
Llama,
|
||||
Baichuan,
|
||||
Gemma,
|
||||
|
@ -1029,10 +1029,10 @@ try:
|
||||
scaling_factor = None
|
||||
rope_scaling = _get_rope_config(config)
|
||||
if rope_scaling is not None:
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
if rope_scaling["type"] == "linear":
|
||||
pass
|
||||
elif rope_scaling["type"] == "dynamic":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
return DynamicPositionRotaryEmbedding(
|
||||
dim=dim,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
@ -1041,6 +1041,7 @@ try:
|
||||
scaling_factor=scaling_factor,
|
||||
)
|
||||
elif rope_scaling["type"] == "yarn":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
return YarnPositionRotaryEmbedding(
|
||||
dim=2 * inv_freq.shape[0],
|
||||
max_position_embeddings=rope_scaling[
|
||||
@ -1054,6 +1055,52 @@ try:
|
||||
beta_fast=32,
|
||||
beta_slow=1,
|
||||
)
|
||||
elif rope_scaling["type"] == "su":
|
||||
short_factor = torch.tensor(
|
||||
rope_scaling["short_factor"], dtype=torch.float32, device=device
|
||||
)
|
||||
short_inv_freq = 1.0 / (
|
||||
short_factor
|
||||
* base
|
||||
** (
|
||||
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
|
||||
/ dim
|
||||
)
|
||||
)
|
||||
long_factor = torch.tensor(
|
||||
rope_scaling["long_factor"], dtype=torch.float32, device=device
|
||||
)
|
||||
long_inv_freq = 1.0 / (
|
||||
long_factor
|
||||
* base
|
||||
** (
|
||||
torch.arange(0, dim, 2, device=device, dtype=torch.float32)
|
||||
/ dim
|
||||
)
|
||||
)
|
||||
|
||||
original_max_position_embeddings = (
|
||||
config.original_max_position_embeddings
|
||||
)
|
||||
max_position_embeddings = config.max_position_embeddings
|
||||
if max_position_embeddings <= original_max_position_embeddings:
|
||||
scaling_factor = 1.0
|
||||
else:
|
||||
scale = (
|
||||
max_position_embeddings / original_max_position_embeddings
|
||||
)
|
||||
scaling_factor = math.sqrt(
|
||||
1
|
||||
+ math.log(scale)
|
||||
/ math.log(original_max_position_embeddings)
|
||||
)
|
||||
|
||||
return SuRotaryEmbedding(
|
||||
short_inv_freq=short_inv_freq,
|
||||
long_inv_freq=long_inv_freq,
|
||||
scaling_factor=scaling_factor,
|
||||
original_max_position_embeddings=original_max_position_embeddings,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
|
||||
@ -1141,6 +1188,49 @@ try:
|
||||
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
|
||||
return cos.unsqueeze(1), sin.unsqueeze(1)
|
||||
|
||||
class SuRotaryEmbedding(PositionRotaryEmbedding):
|
||||
def __init__(
|
||||
self,
|
||||
short_inv_freq,
|
||||
long_inv_freq,
|
||||
scaling_factor,
|
||||
original_max_position_embeddings,
|
||||
):
|
||||
super(PositionRotaryEmbedding, self).__init__()
|
||||
self.short_inv_freq = short_inv_freq
|
||||
self.long_inv_freq = long_inv_freq
|
||||
self.scaling_factor = scaling_factor
|
||||
self.original_max_position_embeddings = original_max_position_embeddings
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
self.dynamic_args = None
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
):
|
||||
self._seq_len_cached = seqlen
|
||||
if seqlen > self.original_max_position_embeddings:
|
||||
inv_freq = self.long_inv_freq
|
||||
else:
|
||||
inv_freq = self.short_inv_freq
|
||||
t = torch.arange(seqlen, device=device, dtype=inv_freq.dtype)
|
||||
if self.scaling_factor is not None:
|
||||
t /= self.scaling_factor
|
||||
# Don't do einsum, it converts fp32 to fp16
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
|
||||
freqs = torch.outer(t, inv_freq.to(device=t.device))
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
|
||||
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
||||
inv_freq = _create_inv_freq(dim, base, device)
|
||||
|
Loading…
Reference in New Issue
Block a user