mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
feat: Support sampling seeding (#37)
Co-authored-by: Yannic Kilcher <yk@users.noreply.github.com>
This commit is contained in:
parent
1539d3cbbe
commit
cd298bc5e5
@ -36,6 +36,8 @@ message NextTokenChooserParameters {
|
|||||||
float top_p = 3;
|
float top_p = 3;
|
||||||
/// apply sampling on the logits
|
/// apply sampling on the logits
|
||||||
bool do_sample = 4;
|
bool do_sample = 4;
|
||||||
|
/// random seed for sampling
|
||||||
|
optional uint64 seed = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
message StoppingCriteriaParameters {
|
message StoppingCriteriaParameters {
|
||||||
@ -82,6 +84,8 @@ message GeneratedText {
|
|||||||
repeated float logprobs = 6;
|
repeated float logprobs = 6;
|
||||||
/// Finish reason
|
/// Finish reason
|
||||||
string finish_reason = 7;
|
string finish_reason = 7;
|
||||||
|
/// Seed
|
||||||
|
optional uint64 seed = 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
message GenerateRequest {
|
message GenerateRequest {
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
use std::fs;
|
use std::fs;
|
||||||
|
|
||||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
println!("cargo:rerun-if-changed=../../proto/generate.proto");
|
||||||
fs::create_dir("src/pb").unwrap_or(());
|
fs::create_dir("src/pb").unwrap_or(());
|
||||||
tonic_build::configure()
|
tonic_build::configure()
|
||||||
.build_client(true)
|
.build_client(true)
|
||||||
|
@ -191,6 +191,7 @@ fn send_generated(finished: Vec<GeneratedText>, entries: &mut IntMap<u64, Entry>
|
|||||||
tokens: output.tokens,
|
tokens: output.tokens,
|
||||||
logprobs: output.logprobs,
|
logprobs: output.logprobs,
|
||||||
finish_reason: output.finish_reason,
|
finish_reason: output.finish_reason,
|
||||||
|
seed: output.seed,
|
||||||
queued: entry.time,
|
queued: entry.time,
|
||||||
start: entry.batch_time.unwrap(), // unwrap is always valid
|
start: entry.batch_time.unwrap(), // unwrap is always valid
|
||||||
end: Instant::now(),
|
end: Instant::now(),
|
||||||
@ -208,6 +209,7 @@ pub(crate) struct InferResponse {
|
|||||||
pub(crate) tokens: Vec<String>,
|
pub(crate) tokens: Vec<String>,
|
||||||
pub(crate) logprobs: Vec<f32>,
|
pub(crate) logprobs: Vec<f32>,
|
||||||
pub(crate) finish_reason: String,
|
pub(crate) finish_reason: String,
|
||||||
|
pub(crate) seed: Option<u64>,
|
||||||
pub(crate) queued: Instant,
|
pub(crate) queued: Instant,
|
||||||
pub(crate) start: Instant,
|
pub(crate) start: Instant,
|
||||||
pub(crate) end: Instant,
|
pub(crate) end: Instant,
|
||||||
|
@ -166,6 +166,7 @@ impl From<&GenerateParameters> for NextTokenChooserParameters {
|
|||||||
top_k: parameters.top_k as u32,
|
top_k: parameters.top_k as u32,
|
||||||
top_p: parameters.top_p,
|
top_p: parameters.top_p,
|
||||||
do_sample: parameters.do_sample,
|
do_sample: parameters.do_sample,
|
||||||
|
seed: parameters.seed,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -25,6 +25,8 @@ pub(crate) struct GenerateParameters {
|
|||||||
pub stop: Vec<String>,
|
pub stop: Vec<String>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub details: bool,
|
pub details: bool,
|
||||||
|
#[serde(default)]
|
||||||
|
pub seed: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_temperature() -> f32 {
|
fn default_temperature() -> f32 {
|
||||||
@ -56,6 +58,7 @@ fn default_parameters() -> GenerateParameters {
|
|||||||
max_new_tokens: default_max_new_tokens(),
|
max_new_tokens: default_max_new_tokens(),
|
||||||
stop: vec![],
|
stop: vec![],
|
||||||
details: false,
|
details: false,
|
||||||
|
seed: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -70,6 +73,7 @@ pub(crate) struct GenerateRequest {
|
|||||||
pub(crate) struct Details {
|
pub(crate) struct Details {
|
||||||
pub finish_reason: String,
|
pub finish_reason: String,
|
||||||
pub generated_tokens: u32,
|
pub generated_tokens: u32,
|
||||||
|
pub seed: Option<u64>,
|
||||||
pub tokens: Vec<(u32, String, f32)>,
|
pub tokens: Vec<(u32, String, f32)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,6 +55,7 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
|
|||||||
max_new_tokens: 1,
|
max_new_tokens: 1,
|
||||||
stop: vec![],
|
stop: vec![],
|
||||||
details: false,
|
details: false,
|
||||||
|
seed: None,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -70,7 +71,8 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
|
|||||||
validation_time,
|
validation_time,
|
||||||
queue_time,
|
queue_time,
|
||||||
inference_time,
|
inference_time,
|
||||||
time_per_token
|
time_per_token,
|
||||||
|
seed
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
async fn generate(
|
async fn generate(
|
||||||
@ -118,6 +120,7 @@ async fn generate(
|
|||||||
.map(|((id, text), logprob)| (id, text, logprob))
|
.map(|((id, text), logprob)| (id, text, logprob))
|
||||||
.collect();
|
.collect();
|
||||||
Some(Details {
|
Some(Details {
|
||||||
|
seed: response.seed,
|
||||||
finish_reason: response.finish_reason,
|
finish_reason: response.finish_reason,
|
||||||
generated_tokens: response.generated_tokens,
|
generated_tokens: response.generated_tokens,
|
||||||
tokens,
|
tokens,
|
||||||
@ -162,6 +165,7 @@ async fn generate(
|
|||||||
tracing::Span::current().record("queue_time", format!("{:?}", queue_time));
|
tracing::Span::current().record("queue_time", format!("{:?}", queue_time));
|
||||||
tracing::Span::current().record("inference_time", format!("{:?}", inference_time));
|
tracing::Span::current().record("inference_time", format!("{:?}", inference_time));
|
||||||
tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token));
|
tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token));
|
||||||
|
tracing::Span::current().record("seed", format!("{:?}", response.seed));
|
||||||
tracing::info!("Output: {}", response.output_text);
|
tracing::info!("Output: {}", response.output_text);
|
||||||
|
|
||||||
// Send response
|
// Send response
|
||||||
|
@ -234,7 +234,9 @@ class BLOOMSharded(BLOOM):
|
|||||||
if name == "word_embeddings.weight":
|
if name == "word_embeddings.weight":
|
||||||
model.lm_head._parameters["weight"] = tensor
|
model.lm_head._parameters["weight"] = tensor
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None):
|
def forward(
|
||||||
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
|
):
|
||||||
outputs = self.model.forward(
|
outputs = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
@ -7,7 +7,7 @@ from typing import Optional, Tuple, List, Type
|
|||||||
from text_generation.models import Model
|
from text_generation.models import Model
|
||||||
from text_generation.models.types import GeneratedText, Batch
|
from text_generation.models.types import GeneratedText, Batch
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation.pb import generate_pb2
|
||||||
from text_generation.utils import NextTokenChooser, StoppingCriteria
|
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -296,7 +296,10 @@ class CausalLM(Model):
|
|||||||
)
|
)
|
||||||
with context_manager():
|
with context_manager():
|
||||||
logits, past = self.forward(
|
logits, past = self.forward(
|
||||||
batch.input_ids, batch.attention_mask, batch.position_ids, batch.past_key_values
|
batch.input_ids,
|
||||||
|
batch.attention_mask,
|
||||||
|
batch.position_ids,
|
||||||
|
batch.past_key_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
# List of indices to cache
|
# List of indices to cache
|
||||||
@ -373,6 +376,12 @@ class CausalLM(Model):
|
|||||||
1
|
1
|
||||||
).tolist()
|
).tolist()
|
||||||
|
|
||||||
|
# Get seed
|
||||||
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
|
seed = next_token_chooser.choice.seed
|
||||||
|
else:
|
||||||
|
seed = None
|
||||||
|
|
||||||
# Add to the list of finished generations with the original request
|
# Add to the list of finished generations with the original request
|
||||||
generated_texts.append(
|
generated_texts.append(
|
||||||
GeneratedText(
|
GeneratedText(
|
||||||
@ -383,6 +392,7 @@ class CausalLM(Model):
|
|||||||
token_ids=token_ids.squeeze(1).tolist(),
|
token_ids=token_ids.squeeze(1).tolist(),
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
reason=reason,
|
reason=reason,
|
||||||
|
seed=seed,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# add to the next batch
|
# add to the next batch
|
||||||
|
@ -333,7 +333,9 @@ class GalacticaSharded(Galactica):
|
|||||||
if name == "model.decoder.embed_tokens.weight":
|
if name == "model.decoder.embed_tokens.weight":
|
||||||
model.lm_head._parameters["weight"] = tensor
|
model.lm_head._parameters["weight"] = tensor
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None):
|
def forward(
|
||||||
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
|
):
|
||||||
outputs = self.model.forward(
|
outputs = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
@ -39,12 +39,16 @@ class SantaCoder(CausalLM):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = (
|
||||||
model_name,
|
AutoModelForCausalLM.from_pretrained(
|
||||||
torch_dtype=dtype,
|
model_name,
|
||||||
load_in_8bit=quantize,
|
torch_dtype=dtype,
|
||||||
trust_remote_code=True, # required
|
load_in_8bit=quantize,
|
||||||
).to(device).eval()
|
trust_remote_code=True, # required
|
||||||
|
)
|
||||||
|
.to(device)
|
||||||
|
.eval()
|
||||||
|
)
|
||||||
|
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
@ -7,7 +7,7 @@ from typing import Optional, Tuple, List, Type
|
|||||||
from text_generation.models import Model
|
from text_generation.models import Model
|
||||||
from text_generation.models.types import GeneratedText, Batch
|
from text_generation.models.types import GeneratedText, Batch
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation.pb import generate_pb2
|
||||||
from text_generation.utils import NextTokenChooser, StoppingCriteria
|
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -451,6 +451,13 @@ class Seq2SeqLM(Model):
|
|||||||
logprobs = [float("nan")] + decoder_logprobs[
|
logprobs = [float("nan")] + decoder_logprobs[
|
||||||
-decoder_input_length:
|
-decoder_input_length:
|
||||||
].tolist()
|
].tolist()
|
||||||
|
|
||||||
|
# Get seed
|
||||||
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
|
seed = next_token_chooser.choice.seed
|
||||||
|
else:
|
||||||
|
seed = None
|
||||||
|
|
||||||
# Add to the list of finished generations with the original request
|
# Add to the list of finished generations with the original request
|
||||||
generated_texts.append(
|
generated_texts.append(
|
||||||
GeneratedText(
|
GeneratedText(
|
||||||
@ -461,6 +468,7 @@ class Seq2SeqLM(Model):
|
|||||||
token_ids=token_ids.tolist(),
|
token_ids=token_ids.tolist(),
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
reason=reason,
|
reason=reason,
|
||||||
|
seed=seed,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# add to the next batch
|
# add to the next batch
|
||||||
|
@ -2,7 +2,7 @@ import torch
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
@ -39,6 +39,7 @@ class GeneratedText:
|
|||||||
token_ids: List[int]
|
token_ids: List[int]
|
||||||
logprobs: List[float]
|
logprobs: List[float]
|
||||||
reason: str
|
reason: str
|
||||||
|
seed: Optional[int]
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.GeneratedText:
|
def to_pb(self) -> generate_pb2.GeneratedText:
|
||||||
return generate_pb2.GeneratedText(
|
return generate_pb2.GeneratedText(
|
||||||
@ -49,4 +50,5 @@ class GeneratedText:
|
|||||||
token_ids=self.token_ids,
|
token_ids=self.token_ids,
|
||||||
logprobs=self.logprobs,
|
logprobs=self.logprobs,
|
||||||
finish_reason=self.reason,
|
finish_reason=self.reason,
|
||||||
|
seed=self.seed,
|
||||||
)
|
)
|
||||||
|
@ -24,11 +24,24 @@ from text_generation.pb import generate_pb2
|
|||||||
|
|
||||||
|
|
||||||
class Sampling:
|
class Sampling:
|
||||||
|
def __init__(self, seed: Optional[int] = None):
|
||||||
|
self.generator = torch.Generator()
|
||||||
|
if seed is not None:
|
||||||
|
self.generator.manual_seed(seed)
|
||||||
|
else:
|
||||||
|
self.generator.seed()
|
||||||
|
|
||||||
def __call__(self, logits):
|
def __call__(self, logits):
|
||||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
next_tokens = torch.multinomial(
|
||||||
|
probs, num_samples=1, generator=self.generator
|
||||||
|
).squeeze(1)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def seed(self) -> int:
|
||||||
|
return self.generator.initial_seed()
|
||||||
|
|
||||||
|
|
||||||
class Greedy:
|
class Greedy:
|
||||||
def __call__(self, logits):
|
def __call__(self, logits):
|
||||||
@ -36,7 +49,9 @@ class Greedy:
|
|||||||
|
|
||||||
|
|
||||||
class NextTokenChooser:
|
class NextTokenChooser:
|
||||||
def __init__(self, temperature=1.0, top_k=None, top_p=None, do_sample=False):
|
def __init__(
|
||||||
|
self, temperature=1.0, top_k=None, top_p=None, do_sample=False, seed=None
|
||||||
|
):
|
||||||
warpers = LogitsProcessorList()
|
warpers = LogitsProcessorList()
|
||||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||||
# all samplers can be found in `generation_utils_samplers.py`
|
# all samplers can be found in `generation_utils_samplers.py`
|
||||||
@ -53,7 +68,7 @@ class NextTokenChooser:
|
|||||||
sampling = True
|
sampling = True
|
||||||
|
|
||||||
self.warpers = warpers
|
self.warpers = warpers
|
||||||
self.choice = Sampling() if sampling else Greedy()
|
self.choice = Sampling(seed) if sampling else Greedy()
|
||||||
|
|
||||||
def __call__(self, input_ids, scores):
|
def __call__(self, input_ids, scores):
|
||||||
# Warp logits
|
# Warp logits
|
||||||
@ -66,11 +81,14 @@ class NextTokenChooser:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(cls, pb: generate_pb2.NextTokenChooserParameters) -> "NextTokenChooser":
|
def from_pb(cls, pb: generate_pb2.NextTokenChooserParameters) -> "NextTokenChooser":
|
||||||
|
# handle protobuf making default values 0
|
||||||
|
seed = pb.seed if pb.HasField("seed") else None
|
||||||
return NextTokenChooser(
|
return NextTokenChooser(
|
||||||
temperature=pb.temperature,
|
temperature=pb.temperature,
|
||||||
top_k=pb.top_k,
|
top_k=pb.top_k,
|
||||||
top_p=pb.top_p,
|
top_p=pb.top_p,
|
||||||
do_sample=pb.do_sample,
|
do_sample=pb.do_sample,
|
||||||
|
seed=seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user