feat(server): add frequency penalty (#1541)

This commit is contained in:
OlivierDehaene 2024-02-08 18:41:25 +01:00 committed by Karol Damaszke
parent cec954e391
commit f1d8da3ba6
21 changed files with 412 additions and 97 deletions

38
Cargo.lock generated
View File

@ -2824,7 +2824,7 @@ dependencies = [
"tabled", "tabled",
"text-generation-client", "text-generation-client",
"thiserror", "thiserror",
"tokenizers", "tokenizers 0.14.1",
"tokio", "tokio",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
@ -2888,7 +2888,7 @@ dependencies = [
"serde_json", "serde_json",
"text-generation-client", "text-generation-client",
"thiserror", "thiserror",
"tokenizers", "tokenizers 0.15.1",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tower-http", "tower-http",
@ -3012,6 +3012,40 @@ dependencies = [
"unicode_categories", "unicode_categories",
] ]
[[package]]
name = "tokenizers"
version = "0.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6db445cceba5dfeb0f9702be7d6bfd91801ddcbe8fe8722defe7f2e96da75812"
dependencies = [
"aho-corasick",
"clap",
"derive_builder",
"esaxx-rs",
"getrandom",
"hf-hub",
"indicatif",
"itertools 0.11.0",
"lazy_static",
"log",
"macro_rules_attribute",
"monostate",
"onig",
"paste",
"rand",
"rayon",
"rayon-cond",
"regex",
"regex-syntax 0.7.5",
"serde",
"serde_json",
"spm_precompiled",
"thiserror",
"unicode-normalization-alignments",
"unicode-segmentation",
"unicode_categories",
]
[[package]] [[package]]
name = "tokio" name = "tokio"
version = "1.37.0" version = "1.37.0"

View File

@ -30,6 +30,7 @@ pub async fn run(
top_p: Option<f32>, top_p: Option<f32>,
typical_p: Option<f32>, typical_p: Option<f32>,
repetition_penalty: Option<f32>, repetition_penalty: Option<f32>,
frequency_penalty: Option<f32>,
watermark: bool, watermark: bool,
do_sample: bool, do_sample: bool,
client: ShardedClient, client: ShardedClient,
@ -42,6 +43,7 @@ pub async fn run(
do_sample, do_sample,
seed: 0, seed: 0,
repetition_penalty: repetition_penalty.unwrap_or(1.0), repetition_penalty: repetition_penalty.unwrap_or(1.0),
frequency_penalty: frequency_penalty.unwrap_or(0.0),
watermark, watermark,
}; };
@ -140,6 +142,7 @@ pub async fn run(
top_p, top_p,
typical_p, typical_p,
repetition_penalty, repetition_penalty,
frequency_penalty,
watermark, watermark,
do_sample, do_sample,
); );

View File

@ -84,6 +84,11 @@ struct Args {
#[clap(long, env)] #[clap(long, env)]
repetition_penalty: Option<f32>, repetition_penalty: Option<f32>,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
frequency_penalty: Option<f32>,
/// Generation parameter in case you want to specifically test/debug particular /// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server` /// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)] #[clap(long, env)]
@ -119,6 +124,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
top_p, top_p,
typical_p, typical_p,
repetition_penalty, repetition_penalty,
frequency_penalty,
watermark, watermark,
do_sample, do_sample,
master_shard_uds_path, master_shard_uds_path,
@ -187,6 +193,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
top_p, top_p,
typical_p, typical_p,
repetition_penalty, repetition_penalty,
frequency_penalty,
watermark, watermark,
do_sample, do_sample,
sharded_client, sharded_client,

View File

@ -15,6 +15,7 @@ pub(crate) fn parameters_table(
top_p: Option<f32>, top_p: Option<f32>,
typical_p: Option<f32>, typical_p: Option<f32>,
repetition_penalty: Option<f32>, repetition_penalty: Option<f32>,
frequency_penalty: Option<f32>,
watermark: bool, watermark: bool,
do_sample: bool, do_sample: bool,
) -> Table { ) -> Table {
@ -33,6 +34,7 @@ pub(crate) fn parameters_table(
builder.push_record(["Top P", &format!("{top_p:?}")]); builder.push_record(["Top P", &format!("{top_p:?}")]);
builder.push_record(["Typical P", &format!("{typical_p:?}")]); builder.push_record(["Typical P", &format!("{typical_p:?}")]);
builder.push_record(["Repetition Penalty", &format!("{repetition_penalty:?}")]); builder.push_record(["Repetition Penalty", &format!("{repetition_penalty:?}")]);
builder.push_record(["Frequency Penalty", &format!("{frequency_penalty:?}")]);
builder.push_record(["Watermark", &watermark.to_string()]); builder.push_record(["Watermark", &watermark.to_string()]);
builder.push_record(["Do Sample", &do_sample.to_string()]); builder.push_record(["Do Sample", &do_sample.to_string()]);

View File

@ -24,6 +24,7 @@ async def test_mamba(fused_kernel_mamba, response_snapshot):
assert response.generated_text == "\n\nDeep learning is a new type of machine" assert response.generated_text == "\n\nDeep learning is a new type of machine"
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_mamba_all_params(fused_kernel_mamba, response_snapshot): async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
@ -44,13 +45,19 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
) )
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
assert response.generated_text == "blue, red, yellow, \nand orange (in the order they appear in" assert (
response.generated_text
== "blue, red, yellow, \nand orange (in the order they appear in"
)
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_mamba_load(fused_kernel_mamba, generate_load, response_snapshot): async def test_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
responses = await generate_load(fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4) responses = await generate_load(
fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4
)
assert len(responses) == 4 assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses]) assert all([r.generated_text == responses[0].generated_text for r in responses])

View File

@ -66,6 +66,8 @@ message NextTokenChooserParameters {
uint64 seed = 6; uint64 seed = 6;
/// repetition penalty /// repetition penalty
float repetition_penalty = 7; float repetition_penalty = 7;
/// frequency penalty
float frequency_penalty = 9;
/// token watermarking using "A Watermark for Large Language Models" /// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8; bool watermark = 8;
} }

View File

@ -227,6 +227,7 @@ impl Client {
do_sample: false, do_sample: false,
seed: 0, seed: 0,
repetition_penalty: 1.0, repetition_penalty: 1.0,
frequency_penalty: 0.0,
watermark: false, watermark: false,
}) })
} else { } else {
@ -238,6 +239,7 @@ impl Client {
do_sample: true, do_sample: true,
seed: 0, seed: 0,
repetition_penalty: 1.2, repetition_penalty: 1.2,
frequency_penalty: 0.1,
watermark: false, watermark: false,
}) })
}; };

View File

@ -43,6 +43,7 @@ impl Health {
do_sample: false, do_sample: false,
seed: 0, seed: 0,
repetition_penalty: 1.0, repetition_penalty: 1.0,
frequency_penalty: 0.0,
watermark: false, watermark: false,
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {

View File

@ -106,6 +106,14 @@ pub(crate) struct GenerateParameters {
)] )]
pub repetition_penalty: Option<f32>, pub repetition_penalty: Option<f32>,
#[serde(default)] #[serde(default)]
#[schema(
exclusive_minimum = -2.0,
nullable = true,
default = "null",
example = 0.1
)]
pub frequency_penalty: Option<f32>,
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)]
pub top_k: Option<i32>, pub top_k: Option<i32>,
#[serde(default)] #[serde(default)]
@ -172,6 +180,7 @@ fn default_parameters() -> GenerateParameters {
best_of: None, best_of: None,
temperature: None, temperature: None,
repetition_penalty: None, repetition_penalty: None,
frequency_penalty: None,
top_k: None, top_k: None,
top_p: None, top_p: None,
typical_p: None, typical_p: None,
@ -205,10 +214,71 @@ pub(crate) struct ChatCompletion {
pub(crate) struct ChatCompletionComplete { pub(crate) struct ChatCompletionComplete {
pub index: u32, pub index: u32,
pub message: Message, pub message: Message,
pub logprobs: Option<Vec<f32>>, pub logprobs: Option<ChatCompletionLogprobs>,
pub finish_reason: String, pub finish_reason: String,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionLogprobs {
content: Vec<ChatCompletionLogprob>,
}
impl From<(Token, Vec<Token>)> for ChatCompletionLogprobs {
fn from(value: (Token, Vec<Token>)) -> Self {
let (token, top_tokens) = value;
Self {
content: vec![ChatCompletionLogprob {
token: token.text,
logprob: token.logprob,
top_logprobs: top_tokens
.into_iter()
.map(|t| ChatCompletionTopLogprob {
token: t.text,
logprob: t.logprob,
})
.collect(),
}],
}
}
}
impl From<(Vec<Token>, Vec<Vec<Token>>)> for ChatCompletionLogprobs {
fn from(value: (Vec<Token>, Vec<Vec<Token>>)) -> Self {
let (tokens, top_tokens) = value;
Self {
content: tokens
.into_iter()
.zip(top_tokens)
.map(|(t, top_t)| ChatCompletionLogprob {
token: t.text,
logprob: t.logprob,
top_logprobs: top_t
.into_iter()
.map(|t| ChatCompletionTopLogprob {
token: t.text,
logprob: t.logprob,
})
.collect(),
})
.collect(),
}
}
}
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionLogprob {
token: String,
logprob: f32,
top_logprobs: Vec<ChatCompletionTopLogprob>,
}
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionTopLogprob {
token: String,
logprob: f32,
}
#[derive(Clone, Deserialize, Serialize)] #[derive(Clone, Deserialize, Serialize)]
pub(crate) struct Usage { pub(crate) struct Usage {
pub prompt_tokens: u32, pub prompt_tokens: u32,
@ -238,7 +308,7 @@ impl ChatCompletion {
content: output, content: output,
}, },
logprobs: return_logprobs logprobs: return_logprobs
.then(|| details.tokens.iter().map(|t| t.logprob).collect()), .then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
finish_reason: details.finish_reason.to_string(), finish_reason: details.finish_reason.to_string(),
}], }],
usage: Usage { usage: Usage {
@ -266,7 +336,7 @@ pub(crate) struct ChatCompletionChunk {
pub(crate) struct ChatCompletionChoice { pub(crate) struct ChatCompletionChoice {
pub index: u32, pub index: u32,
pub delta: ChatCompletionDelta, pub delta: ChatCompletionDelta,
pub logprobs: Option<f32>, pub logprobs: Option<ChatCompletionLogprobs>,
pub finish_reason: Option<String>, pub finish_reason: Option<String>,
} }
@ -285,7 +355,7 @@ impl ChatCompletionChunk {
delta: String, delta: String,
created: u64, created: u64,
index: u32, index: u32,
logprobs: Option<f32>, logprobs: Option<ChatCompletionLogprobs>,
finish_reason: Option<String>, finish_reason: Option<String>,
) -> Self { ) -> Self {
Self { Self {
@ -319,8 +389,8 @@ pub(crate) struct ChatRequest {
/// UNUSED /// UNUSED
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. /// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub model: String, /* NOTE: UNUSED */ pub model: String,
/* NOTE: UNUSED */
/// A list of messages comprising the conversation so far. /// A list of messages comprising the conversation so far.
#[serde(default = "default_request_messages")] #[serde(default = "default_request_messages")]
pub messages: Vec<Message>, pub messages: Vec<Message>,
@ -346,7 +416,6 @@ pub(crate) struct ChatRequest {
#[schema(example = "false")] #[schema(example = "false")]
pub logprobs: Option<bool>, pub logprobs: Option<bool>,
/// UNUSED
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with /// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
/// an associated log probability. logprobs must be set to true if this parameter is used. /// an associated log probability. logprobs must be set to true if this parameter is used.
#[serde(default)] #[serde(default)]
@ -365,7 +434,6 @@ pub(crate) struct ChatRequest {
#[schema(nullable = true, example = "2")] #[schema(nullable = true, example = "2")]
pub n: Option<u32>, pub n: Option<u32>,
/// UNUSED
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
/// increasing the model's likelihood to talk about new topics /// increasing the model's likelihood to talk about new topics
#[serde(default)] #[serde(default)]
@ -447,7 +515,7 @@ pub struct PrefillToken {
logprob: f32, logprob: f32,
} }
#[derive(Debug, Serialize, ToSchema)] #[derive(Debug, Serialize, ToSchema, Clone)]
pub struct Token { pub struct Token {
#[schema(example = 0)] #[schema(example = 0)]
id: u32, id: u32,

View File

@ -496,6 +496,7 @@ mod tests {
do_sample: false, do_sample: false,
seed: 0, seed: 0,
repetition_penalty: 0.0, repetition_penalty: 0.0,
frequency_penalty: 0.0,
watermark: false, watermark: false,
}, },
stopping_parameters: StoppingCriteriaParameters { stopping_parameters: StoppingCriteriaParameters {

View File

@ -6,9 +6,10 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{ use crate::{
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta, BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
ChatRequest, CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters, ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse,
GenerateRequest, GenerateResponse, HubModelInfo, HubTokenizerConfig, Infer, Info, Message, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo,
PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Validation, HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
StreamResponse, Token, TokenizeResponse, Validation,
}; };
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
@ -572,8 +573,8 @@ async fn chat_completions(
let stream = req.stream; let stream = req.stream;
let max_new_tokens = req.max_tokens.or(Some(100)); let max_new_tokens = req.max_tokens.or(Some(100));
let repetition_penalty = req let repetition_penalty = req
.frequency_penalty .presence_penalty
// rescale frequency_penalty from (-2.0, 2.0) to (0.0, 4.0) // rescale repetition_penalty from (-2.0, 2.0) to (0.0, 4.0)
.map(|x| x + 2.0); .map(|x| x + 2.0);
let logprobs = req.logprobs.unwrap_or(false); let logprobs = req.logprobs.unwrap_or(false);
let seed = req.seed; let seed = req.seed;
@ -601,6 +602,7 @@ async fn chat_completions(
best_of: None, best_of: None,
temperature: req.temperature, temperature: req.temperature,
repetition_penalty, repetition_penalty,
frequency_penalty: req.frequency_penalty,
top_k: None, top_k: None,
top_p: req.top_p, top_p: req.top_p,
typical_p: None, typical_p: None,
@ -632,6 +634,10 @@ async fn chat_completions(
.unwrap_or_else(|_| std::time::Duration::from_secs(0)) .unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs(); .as_secs();
let logprobs = logprobs.then(|| {
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens))
});
event event
.json_data(ChatCompletionChunk::new( .json_data(ChatCompletionChunk::new(
model_id.clone(), model_id.clone(),
@ -639,7 +645,7 @@ async fn chat_completions(
stream_token.token.text, stream_token.token.text,
current_time, current_time,
stream_token.index, stream_token.index,
logprobs.then_some(stream_token.token.logprob), logprobs,
stream_token.details.map(|d| d.finish_reason.to_string()), stream_token.details.map(|d| d.finish_reason.to_string()),
)) ))
.map_or_else( .map_or_else(

View File

@ -187,6 +187,7 @@ impl Validation {
best_of, best_of,
temperature, temperature,
repetition_penalty, repetition_penalty,
frequency_penalty,
top_k, top_k,
top_p, top_p,
typical_p, typical_p,
@ -223,12 +224,17 @@ impl Validation {
return Err(ValidationError::RepetitionPenalty); return Err(ValidationError::RepetitionPenalty);
} }
let frequency_penalty = frequency_penalty.unwrap_or(0.0);
if !(-2.0..=2.0).contains(&frequency_penalty) {
return Err(ValidationError::FrequencyPenalty);
}
// TODO: enable watermark with fp8 quantization // TODO: enable watermark with fp8 quantization
let quantization_enabled = env::var("QUANT_CONFIG") let quantization_enabled = env::var("QUANT_CONFIG")
.ok() .ok()
.map_or(false, |value| !value.is_empty()); .map_or(false, |value| !value.is_empty());
if watermark && quantization_enabled { if watermark && quantization_enabled {
return Err(ValidationError::WatermarkWithQuantization) return Err(ValidationError::WatermarkWithQuantization);
} }
// Different because the proto default value is not a valid value // Different because the proto default value is not a valid value
@ -314,6 +320,7 @@ impl Validation {
let parameters = NextTokenChooserParameters { let parameters = NextTokenChooserParameters {
temperature, temperature,
repetition_penalty, repetition_penalty,
frequency_penalty,
top_k, top_k,
top_p, top_p,
typical_p, typical_p,
@ -445,6 +452,8 @@ pub enum ValidationError {
Temperature, Temperature,
#[error("`repetition_penalty` must be strictly positive")] #[error("`repetition_penalty` must be strictly positive")]
RepetitionPenalty, RepetitionPenalty,
#[error("`frequency_penalty` must be >= -2.0 and <= 2.0")]
FrequencyPenalty,
#[error("`top_p` must be > 0.0 and < 1.0")] #[error("`top_p` must be > 0.0 and < 1.0")]
TopP, TopP,
#[error("`top_k` must be strictly positive")] #[error("`top_k` must be strictly positive")]

View File

@ -798,8 +798,8 @@ class CausalLM(Model):
attention_mask, attention_mask,
position_ids, position_ids,
token_idx, token_idx,
past_key_values: Optional = None, past_key_values: Optional[List[Tuple]] = None,
bypass_hpu_graph: Optional = None, bypass_hpu_graph: Optional[bool] = None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward # Model Forward
kwargs = { kwargs = {
@ -1040,14 +1040,17 @@ class CausalLM(Model):
if top_n_tokens > 0: if top_n_tokens > 0:
all_top_tokens = [] all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs): for (top_token_ids, top_token_logprobs) in zip(
top_token_ids, top_token_logprobs
):
toptoken_texts = self.tokenizer.batch_decode( toptoken_texts = self.tokenizer.batch_decode(
top_token_ids, top_token_ids,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
skip_special_tokens=False, skip_special_tokens=False,
) )
special_toptokens = [ special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids token_id in self.all_special_ids
for token_id in top_token_ids
] ]
top_tokens = Tokens( top_tokens = Tokens(
top_token_ids, top_token_ids,

View File

@ -19,6 +19,7 @@ from einops import rearrange
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
import math import math
class MambaConfig(PretrainedConfig): class MambaConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
@ -53,6 +54,7 @@ class MambaConfig(PretrainedConfig):
**kwargs, **kwargs,
) )
class MambaBlock(nn.Module): class MambaBlock(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
@ -60,8 +62,12 @@ class MambaBlock(nn.Module):
self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False) self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False)
self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False) self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False)
self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True) self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True)
self.dt_proj_no_bias = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=False) self.dt_proj_no_bias = FastLinear.load(
self.out_proj = FastLinear.load(config, f"{prefix}.out_proj", weights, bias=False) config, f"{prefix}.dt_proj", weights, bias=False
)
self.out_proj = FastLinear.load(
config, f"{prefix}.out_proj", weights, bias=False
)
self.conv1d = FastLinear.load(config, f"{prefix}.conv1d", weights, bias=True) self.conv1d = FastLinear.load(config, f"{prefix}.conv1d", weights, bias=True)
self.negA = -torch.exp(weights.get_tensor(f"{prefix}.A_log").float()) self.negA = -torch.exp(weights.get_tensor(f"{prefix}.A_log").float())
self.D = weights.get_tensor(f"{prefix}.D") self.D = weights.get_tensor(f"{prefix}.D")
@ -85,7 +91,9 @@ class MambaBlock(nn.Module):
conv_state = F.pad(x, (self.d_conv - seqlen, 0)) conv_state = F.pad(x, (self.d_conv - seqlen, 0))
x = causal_conv1d_fn( x = causal_conv1d_fn(
x=x, x=x,
weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), weight=self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
),
bias=self.conv1d.bias, bias=self.conv1d.bias,
activation=self.activation, activation=self.activation,
) )
@ -94,7 +102,9 @@ class MambaBlock(nn.Module):
# We want dt to have d as the slowest moving dimension # We want dt to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) dt, B, C = torch.split(
x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1
)
dt = self.dt_proj.weight @ dt.t() dt = self.dt_proj.weight @ dt.t()
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
@ -121,25 +131,36 @@ class MambaBlock(nn.Module):
conv_state_new = torch.cat([conv_state, _x.transpose(1, 2)], dim=-1) conv_state_new = torch.cat([conv_state, _x.transpose(1, 2)], dim=-1)
conv_out = causal_conv1d_fn( conv_out = causal_conv1d_fn(
x=conv_state_new, x=conv_state_new,
weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), weight=self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
),
bias=self.conv1d.bias, bias=self.conv1d.bias,
activation=self.activation activation=self.activation,
) )
conv_state = conv_state_new[:, :, 1:] conv_state = conv_state_new[:, :, 1:]
bsz, seqlen, dim = hidden_states.shape bsz, seqlen, dim = hidden_states.shape
output_tensor = torch.zeros( output_tensor = torch.zeros(
(bsz, seqlen, dim), (bsz, seqlen, dim), device=hidden_states.device, dtype=hidden_states.dtype
device=hidden_states.device,
dtype=hidden_states.dtype
) )
for i in range(0, bsz): for i in range(0, bsz):
x = conv_out[i : i + 1, :, -1] x = conv_out[i : i + 1, :, -1]
z = _z[i : i + 1, -1, :] z = _z[i : i + 1, -1, :]
x_db = self.x_proj(x) x_db = self.x_proj(x)
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) dt, B, C = torch.split(
x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1
)
dt = F.linear(dt, self.dt_proj.weight) dt = F.linear(dt, self.dt_proj.weight)
y = selective_state_update( y = selective_state_update(
ssm_state[i:i+1,:,:], x, dt, self.negA, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True ssm_state[i : i + 1, :, :],
x,
dt,
self.negA,
B,
C,
self.D,
z=z,
dt_bias=self.dt_proj.bias,
dt_softplus=True,
) )
out = self.out_proj(y) out = self.out_proj(y)
output_tensor[i] = out output_tensor[i] = out
@ -147,12 +168,15 @@ class MambaBlock(nn.Module):
return output_tensor, conv_state, ssm_state return output_tensor, conv_state, ssm_state
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, layer_id, config, weights):
super().__init__() super().__init__()
self.mamba_block = MambaBlock(prefix=f"{layer_id}.mixer", config=config, weights=weights) self.mamba_block = MambaBlock(
self.layer_norm = FastRMSNorm.load(prefix=f"{layer_id}.norm", weights=weights, eps=config.layer_norm_epsilon) prefix=f"{layer_id}.mixer", config=config, weights=weights
)
self.layer_norm = FastRMSNorm.load(
prefix=f"{layer_id}.norm", weights=weights, eps=config.layer_norm_epsilon
)
def forward( def forward(
self, self,
@ -163,28 +187,47 @@ class ResidualBlock(nn.Module):
residual = (hidden_states + residual) if residual is not None else hidden_states residual = (hidden_states + residual) if residual is not None else hidden_states
shape = residual.shape shape = residual.shape
hidden_states, _ = self.layer_norm(residual.view(-1, shape[-1])) hidden_states, _ = self.layer_norm(residual.view(-1, shape[-1]))
hidden_states, conv_state, last_ssm_state = self.mamba_block(hidden_states.view(*shape), inference_params) hidden_states, conv_state, last_ssm_state = self.mamba_block(
hidden_states.view(*shape), inference_params
)
return hidden_states, residual, conv_state, last_ssm_state return hidden_states, residual, conv_state, last_ssm_state
class MambaModel(nn.Module): class MambaModel(nn.Module):
def __init__(self, config, weights): def __init__(self, config, weights):
super().__init__() super().__init__()
prefix = "backbone" prefix = "backbone"
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights) self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights)
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ResidualBlock(f"{prefix}.layers.{i}", config, weights) for i in range(config.n_layer)] [
ResidualBlock(f"{prefix}.layers.{i}", config, weights)
for i in range(config.n_layer)
]
)
self.norm_f = FastRMSNorm.load(
f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon
)
self.lm_head = FastLinear.load(
config, f"{prefix}.embedding", weights, bias=False
) )
self.norm_f = FastRMSNorm.load(f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon)
self.lm_head = FastLinear.load(config, f"{prefix}.embedding", weights, bias=False)
self.config = config self.config = config
def forward(self, input_ids: torch.Tensor, inference_params=None, residual=None) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]: def forward(
self, input_ids: torch.Tensor, inference_params=None, residual=None
) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
for block in self.blocks: for block in self.blocks:
hidden_states, residual, conv_state, ssm_state = block(hidden_states, residual, inference_params) hidden_states, residual, conv_state, ssm_state = block(
inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = (conv_state, ssm_state) hidden_states, residual, inference_params
)
inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = (
conv_state,
ssm_state,
)
hidden_states = hidden_states + residual if residual is not None else hidden_states hidden_states = (
hidden_states + residual if residual is not None else hidden_states
)
hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1))) hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1)))
hidden_states = hidden_states.view(residual.shape) hidden_states = hidden_states.view(residual.shape)
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)

View File

@ -842,7 +842,6 @@ class FlashCausalLM(Model):
else: else:
next_token_logits = out next_token_logits = out
speculate = get_speculate() speculate = get_speculate()
( (
next_input_ids, next_input_ids,
@ -1064,14 +1063,17 @@ class FlashCausalLM(Model):
if top_n_tokens > 0: if top_n_tokens > 0:
all_top_tokens = [] all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs): for (top_token_ids, top_token_logprobs) in zip(
top_token_ids, top_token_logprobs
):
toptoken_texts = self.tokenizer.batch_decode( toptoken_texts = self.tokenizer.batch_decode(
top_token_ids, top_token_ids,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
skip_special_tokens=False, skip_special_tokens=False,
) )
special_toptokens = [ special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids token_id in self.all_special_ids
for token_id in top_token_ids
] ]
top_tokens = Tokens( top_tokens = Tokens(
top_token_ids, top_token_ids,

View File

@ -26,6 +26,7 @@ from dataclasses import dataclass
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
from mamba_ssm.utils.generation import InferenceParams from mamba_ssm.utils.generation import InferenceParams
@dataclass @dataclass
class MambaBatch(Batch): class MambaBatch(Batch):
batch_id: int batch_id: int
@ -221,7 +222,10 @@ class MambaBatch(Batch):
# TODO # TODO
# Kept it simple by just updating the state, maybe updating the other CPU values is necessary. # Kept it simple by just updating the state, maybe updating the other CPU values is necessary.
key_value_memory_dict = {} key_value_memory_dict = {}
for i, (conv_state, ssm_state) in self.inference_params.key_value_memory_dict.items(): for i, (
conv_state,
ssm_state,
) in self.inference_params.key_value_memory_dict.items():
key_value_memory_dict[i] = (conv_state[indices], ssm_state[indices]) key_value_memory_dict[i] = (conv_state[indices], ssm_state[indices])
self.inference_params.key_value_memory_dict = key_value_memory_dict self.inference_params.key_value_memory_dict = key_value_memory_dict
@ -305,8 +309,9 @@ class MambaBatch(Batch):
start_index = end_index start_index = end_index
(_, d_model, d_conv) = (
(_, d_model, d_conv) = batches[0].inference_params.key_value_memory_dict[0][0].shape batches[0].inference_params.key_value_memory_dict[0][0].shape
)
(_, _, d_state) = batches[0].inference_params.key_value_memory_dict[0][1].shape (_, _, d_state) = batches[0].inference_params.key_value_memory_dict[0][1].shape
n_blocks = len(batches[0].inference_params.key_value_memory_dict) n_blocks = len(batches[0].inference_params.key_value_memory_dict)
dtype = batches[0].inference_params.key_value_memory_dict[0][0].dtype dtype = batches[0].inference_params.key_value_memory_dict[0][0].dtype
@ -344,9 +349,15 @@ class MambaBatch(Batch):
for i in range(n_blocks): for i in range(n_blocks):
conv_state, ssm_state = batch.inference_params.key_value_memory_dict[i] conv_state, ssm_state = batch.inference_params.key_value_memory_dict[i]
batch_size = batch.inference_params.max_batch_size batch_size = batch.inference_params.max_batch_size
inference_params.key_value_memory_dict[i][0][current_batch:current_batch + batch_size] = conv_state inference_params.key_value_memory_dict[i][0][
inference_params.key_value_memory_dict[i][1][current_batch:current_batch + batch_size] = ssm_state current_batch : current_batch + batch_size
inference_params.lengths_per_sample[current_batch: current_batch + batch_size] = batch.inference_params.lengths_per_sample ] = conv_state
inference_params.key_value_memory_dict[i][1][
current_batch : current_batch + batch_size
] = ssm_state
inference_params.lengths_per_sample[
current_batch : current_batch + batch_size
] = batch.inference_params.lengths_per_sample
current_batch += batch_size current_batch += batch_size
return cls( return cls(
@ -366,12 +377,13 @@ class MambaBatch(Batch):
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last, keys_head_dim_last=batches[0].keys_head_dim_last,
max_tokens=max_tokens, max_tokens=max_tokens,
inference_params=inference_params inference_params=inference_params,
) )
def __len__(self): def __len__(self):
return len(self.requests) return len(self.requests)
class Mamba(Model): class Mamba(Model):
def __init__( def __init__(
self, self,
@ -441,7 +453,9 @@ class Mamba(Model):
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]: def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
start = time.time_ns() start = time.time_ns()
input_ids = batch.input_ids # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids input_ids = (
batch.input_ids
) # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
batch_size = input_ids.shape[0] batch_size = input_ids.shape[0]
max_seqlen = input_ids.shape[1] max_seqlen = input_ids.shape[1]
@ -450,7 +464,10 @@ class Mamba(Model):
# Inference params # Inference params
seqlen_og = 0 seqlen_og = 0
inf_cache = {} inf_cache = {}
lengths_per_sample = torch.ones(batch_size, dtype=torch.int32, device=input_ids.device) * max_seqlen lengths_per_sample = (
torch.ones(batch_size, dtype=torch.int32, device=input_ids.device)
* max_seqlen
)
if batch.inference_params is None: if batch.inference_params is None:
inference_params = InferenceParams( inference_params = InferenceParams(
@ -478,11 +495,16 @@ class Mamba(Model):
device=block.dt_proj.weight.device, device=block.dt_proj.weight.device,
dtype=block.dt_proj.weight.dtype, dtype=block.dt_proj.weight.dtype,
) )
inference_params.key_value_memory_dict[block.layer_idx] = (conv_state, ssm_state) inference_params.key_value_memory_dict[block.layer_idx] = (
conv_state,
ssm_state,
)
batch.inference_params = inference_params batch.inference_params = inference_params
# Forward pass # Forward pass
logits, past_input_ids, new_inference_params = self.model(input_ids, batch.inference_params) logits, past_input_ids, new_inference_params = self.model(
input_ids, batch.inference_params
)
batch.inference_params = new_inference_params batch.inference_params = new_inference_params
# Results # Results
@ -564,7 +586,8 @@ class Mamba(Model):
prefix_offset=len(all_input_ids) prefix_offset=len(all_input_ids)
- stopping_criteria.current_tokens - stopping_criteria.current_tokens
- 1, - 1,
read_offset=len(all_input_ids) - stopping_criteria.current_tokens, read_offset=len(all_input_ids)
- stopping_criteria.current_tokens,
skip_special_tokens=True, skip_special_tokens=True,
) )
# Get seed # Get seed

View File

@ -750,14 +750,17 @@ class Seq2SeqLM(Model):
if top_n_tokens > 0: if top_n_tokens > 0:
all_top_tokens = [] all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs): for (top_token_ids, top_token_logprobs) in zip(
top_token_ids, top_token_logprobs
):
toptoken_texts = self.tokenizer.batch_decode( toptoken_texts = self.tokenizer.batch_decode(
top_token_ids, top_token_ids,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
skip_special_tokens=False, skip_special_tokens=False,
) )
special_toptokens = [ special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids token_id in self.all_special_ids
for token_id in top_token_ids
] ]
top_tokens = Tokens( top_tokens = Tokens(
top_token_ids, top_token_ids,

View File

@ -95,5 +95,7 @@ class Generation:
generated_text=self.generated_text.to_pb() generated_text=self.generated_text.to_pb()
if self.generated_text is not None if self.generated_text is not None
else None, else None,
top_tokens=[top_tokens.to_pb() for top_tokens in self.top_tokens] if self.top_tokens is not None else None, top_tokens=[top_tokens.to_pb() for top_tokens in self.top_tokens]
if self.top_tokens is not None
else None,
) )

View File

@ -107,6 +107,62 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
return None return None
class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
r"""
Frequency penalty as defined by OpenAI
Args:
penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty.
"""
def __init__(self, penalty: float):
self.penalty = penalty
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
score = torch.gather(scores, 1, input_ids)
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
score = -torch.where(
score < 0, score * self.penalty, score / self.penalty
)
return scores.scatter_add_(1, input_ids, score)
class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
r"""
Frequency penalty as defined by OpenAI
Args:
frequency_penalty (`List[float]`):
The parameter for frequency penalty. 0.0 means no penalty.
"""
def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
self.penalty = penalty
self.penalty_tensor = torch.tensor(
penalty, dtype=dtype, device=device
).unsqueeze(1)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
score = torch.gather(scores, 1, input_ids)
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
score = -torch.where(
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
)
return scores.scatter_add_(1, input_ids, score)
def filter(self, indices):
self.penalty = [self.penalty[i] for i in indices]
if any([x != 0.0 for x in self.penalty]):
self.penalty_tensor = self.penalty_tensor[indices]
return self
return None
class HeterogeneousTemperatureLogitsWarper: class HeterogeneousTemperatureLogitsWarper:
r""" r"""
[`LogitsWarper`] for temperature (exponential scaling output probability distribution). [`LogitsWarper`] for temperature (exponential scaling output probability distribution).

View File

@ -1,14 +1,16 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. # Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
import re import re
from typing import Callable, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason from text_generation_server.pb.generate_pb2 import FinishReason
from text_generation_server.utils.logits_process import ( from text_generation_server.utils.logits_process import (
FrequencyPenaltyLogitsProcessor,
HeterogeneousProcessorWrapper, HeterogeneousProcessorWrapper,
HeterogeneousRepetitionPenaltyLogitsProcessor, HeterogeneousRepetitionPenaltyLogitsProcessor,
HeterogeneousFrequencyPenaltyLogitsProcessor,
HeterogeneousTemperatureLogitsWarper, HeterogeneousTemperatureLogitsWarper,
HeterogeneousTopKLogitsWarper, HeterogeneousTopKLogitsWarper,
HeterogeneousTopPLogitsWarper, HeterogeneousTopPLogitsWarper,
@ -26,6 +28,7 @@ class NextTokenChooser:
watermark=False, watermark=False,
temperature=1.0, temperature=1.0,
repetition_penalty=1.0, repetition_penalty=1.0,
frequency_penalty=0.0,
top_k=None, top_k=None,
top_p=None, top_p=None,
typical_p=None, typical_p=None,
@ -35,7 +38,14 @@ class NextTokenChooser:
): ):
self.watermark_processor = WatermarkLogitsProcessor(device=device) if watermark else None self.watermark_processor = WatermarkLogitsProcessor(device=device) if watermark else None
self.repetition_processor = ( self.repetition_processor = (
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty) if repetition_penalty else None RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
if repetition_penalty and repetition_penalty != 1.0
else None
)
self.frequency_processor = (
FrequencyPenaltyLogitsProcessor(penalty=frequency_penalty)
if frequency_penalty and frequency_penalty != 0.0
else None
) )
has_warpers = ( has_warpers = (
@ -57,6 +67,8 @@ class NextTokenChooser:
scores = self.watermark_processor(input_ids, scores) scores = self.watermark_processor(input_ids, scores)
if self.repetition_processor is not None: if self.repetition_processor is not None:
scores = self.repetition_processor(input_ids, scores) scores = self.repetition_processor(input_ids, scores)
if self.frequency_processor is not None:
scores = self.frequency_processor(input_ids, scores)
if self.static_warper is None: if self.static_warper is None:
next_logprob = torch.log_softmax(scores, -1) next_logprob = torch.log_softmax(scores, -1)
@ -77,6 +89,7 @@ class NextTokenChooser:
watermark=pb.watermark, watermark=pb.watermark,
temperature=pb.temperature, temperature=pb.temperature,
repetition_penalty=pb.repetition_penalty, repetition_penalty=pb.repetition_penalty,
frequency_penalty=pb.frequency_penalty,
top_k=pb.top_k, top_k=pb.top_k,
top_p=pb.top_p, top_p=pb.top_p,
typical_p=pb.typical_p, typical_p=pb.typical_p,
@ -171,7 +184,6 @@ def create_n_gram_speculation(
return speculative_ids return speculative_ids
class HeterogeneousNextTokenChooser: class HeterogeneousNextTokenChooser:
def __init__( def __init__(
self, self,
@ -180,6 +192,7 @@ class HeterogeneousNextTokenChooser:
watermark: List[bool], watermark: List[bool],
temperature: List[float], temperature: List[float],
repetition_penalty: List[float], repetition_penalty: List[float],
frequency_penalty: List[float],
top_k: List[int], top_k: List[int],
top_p: List[float], top_p: List[float],
typical_p: List[float], typical_p: List[float],
@ -203,14 +216,28 @@ class HeterogeneousNextTokenChooser:
) )
self.repetition_processor = ( self.repetition_processor = (
HeterogeneousRepetitionPenaltyLogitsProcessor(repetition_penalty, dtype, device) HeterogeneousRepetitionPenaltyLogitsProcessor(
repetition_penalty, dtype, device
)
if any([x != 1.0 for x in repetition_penalty]) if any([x != 1.0 for x in repetition_penalty])
else None else None
) )
self.frequency_processor = (
HeterogeneousFrequencyPenaltyLogitsProcessor(
frequency_penalty, dtype, device
)
if any([x != 0.0 for x in frequency_penalty])
else None
)
if any([x != 1.0 for x in temperature]): if any([x != 1.0 for x in temperature]):
do_sample = [sample or x != 1.0 for x, sample in zip(temperature, do_sample)] do_sample = [
warpers.append(HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)) sample or x != 1.0 for x, sample in zip(temperature, do_sample)
]
warpers.append(
HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)
)
if any([x != 0 for x in top_k]): if any([x != 0 for x in top_k]):
do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)] do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]
@ -261,6 +288,8 @@ class HeterogeneousNextTokenChooser:
_scores = self.watermark_processor(input_ids, _scores) _scores = self.watermark_processor(input_ids, _scores)
if self.repetition_processor is not None: if self.repetition_processor is not None:
_scores = self.repetition_processor(input_ids, _scores) _scores = self.repetition_processor(input_ids, _scores)
if self.frequency_processor is not None:
_scores = self.frequency_processor(input_ids, _scores)
for warper in self.warpers: for warper in self.warpers:
_scores = warper(input_ids, _scores) _scores = warper(input_ids, _scores)
@ -329,6 +358,9 @@ class HeterogeneousNextTokenChooser:
if self.repetition_processor is not None: if self.repetition_processor is not None:
self.repetition_processor = self.repetition_processor.filter(indices) self.repetition_processor = self.repetition_processor.filter(indices)
if self.frequency_processor is not None:
self.frequency_processor = self.frequency_processor.filter(indices)
filtered_warpers = [] filtered_warpers = []
for warper in self.warpers: for warper in self.warpers:
filtered_warper = warper.filter(indices) filtered_warper = warper.filter(indices)
@ -358,6 +390,7 @@ class HeterogeneousNextTokenChooser:
watermark=[pb_.watermark for pb_ in pb], watermark=[pb_.watermark for pb_ in pb],
temperature=[pb_.temperature for pb_ in pb], temperature=[pb_.temperature for pb_ in pb],
repetition_penalty=[pb_.repetition_penalty for pb_ in pb], repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
frequency_penalty=[pb_.frequency_penalty for pb_ in pb],
top_k=[pb_.top_k for pb_ in pb], top_k=[pb_.top_k for pb_ in pb],
top_p=[pb_.top_p for pb_ in pb], top_p=[pb_.top_p for pb_ in pb],
typical_p=[pb_.typical_p for pb_ in pb], typical_p=[pb_.typical_p for pb_ in pb],
@ -431,7 +464,10 @@ class HeterogeneousSampling:
def batch_top_tokens( def batch_top_tokens(
top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor top_n_tokens: List[int],
top_n_tokens_tensor: torch.Tensor,
logprobs: torch.Tensor,
accepted_ids: torch.Tensor,
) -> Tuple[List[List[List[int]]], List[List[List[float]]]]: ) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
"""Find the top n most likely tokens for a batch of generations. """Find the top n most likely tokens for a batch of generations.
@ -443,12 +479,15 @@ def batch_top_tokens(
if max_top_n == 0: if max_top_n == 0:
return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens) return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens)
batch_size = accepted_ids.shape[0] batch_size = accepted_ids.shape[0]
speculate_size = logprobs.shape[0] // batch_size speculate_size = logprobs.shape[0] // batch_size
top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size) top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size)
# Ensure top_n doesn't exceed vocab size # Ensure top_n doesn't exceed vocab size
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens for _ in range(speculate_size)] top_n_tokens = [
min(tok, logprobs.size(-1))
for tok in top_n_tokens
for _ in range(speculate_size)
]
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2 # Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
# Sorted topk is faster than torch.sort() since we only need a small subset # Sorted topk is faster than torch.sort() since we only need a small subset
@ -489,7 +528,9 @@ def batch_top_tokens(
row_top_token_ids = [] row_top_token_ids = []
row_top_token_logprobs = [] row_top_token_logprobs = []
for idxs, vals, n, req_n in zip(_top_indices, _top_values, _top_n_ishes, _top_n_tokens): for idxs, vals, n, req_n in zip(
_top_indices, _top_values, _top_n_ishes, _top_n_tokens
):
indices = idxs[:n] if req_n > 0 else [] indices = idxs[:n] if req_n > 0 else []
values = vals[:n] if req_n > 0 else [] values = vals[:n] if req_n > 0 else []
@ -527,8 +568,8 @@ def make_tokenizer_optional(tokenizer):
for inner_text in text] for inner_text in text]
if padding == "longest": if padding == "longest":
max_length = max(len(tokens) for tokens in all_tokens) max_length = max(len(tokens) for tokens in all_tokens)
return {"input_ids": torch.tensor([[tokenizer.pad_token_id] * (max_length - len(tokens)) + tokens for tokens in all_tokens], dtype=torch.int32), return {"input_ids": torch.tensor([[tokenizer.pad_token_id] * (max_length - len(tokens)) + tokens for tokens in all_tokens]),
"attention_mask": torch.tensor([[0] * (max_length - len(tokens)) + [1] * len(tokens) for tokens in all_tokens], dtype=torch.int32)} "attention_mask": torch.tensor([[0] * (max_length - len(tokens)) + [1] * len(tokens) for tokens in all_tokens])}
def decode( def decode(
self, self,