feat: Support sampling seeding (#37)

Co-authored-by: Yannic Kilcher <yk@users.noreply.github.com>
This commit is contained in:
OlivierDehaene 2023-01-30 15:36:16 +01:00 committed by GitHub
parent 1539d3cbbe
commit cd298bc5e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 78 additions and 16 deletions

View File

@ -36,6 +36,8 @@ message NextTokenChooserParameters {
float top_p = 3; float top_p = 3;
/// apply sampling on the logits /// apply sampling on the logits
bool do_sample = 4; bool do_sample = 4;
/// random seed for sampling
optional uint64 seed = 5;
} }
message StoppingCriteriaParameters { message StoppingCriteriaParameters {
@ -82,6 +84,8 @@ message GeneratedText {
repeated float logprobs = 6; repeated float logprobs = 6;
/// Finish reason /// Finish reason
string finish_reason = 7; string finish_reason = 7;
/// Seed
optional uint64 seed = 8;
} }
message GenerateRequest { message GenerateRequest {

View File

@ -1,6 +1,7 @@
use std::fs; use std::fs;
fn main() -> Result<(), Box<dyn std::error::Error>> { fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("cargo:rerun-if-changed=../../proto/generate.proto");
fs::create_dir("src/pb").unwrap_or(()); fs::create_dir("src/pb").unwrap_or(());
tonic_build::configure() tonic_build::configure()
.build_client(true) .build_client(true)

View File

@ -191,6 +191,7 @@ fn send_generated(finished: Vec<GeneratedText>, entries: &mut IntMap<u64, Entry>
tokens: output.tokens, tokens: output.tokens,
logprobs: output.logprobs, logprobs: output.logprobs,
finish_reason: output.finish_reason, finish_reason: output.finish_reason,
seed: output.seed,
queued: entry.time, queued: entry.time,
start: entry.batch_time.unwrap(), // unwrap is always valid start: entry.batch_time.unwrap(), // unwrap is always valid
end: Instant::now(), end: Instant::now(),
@ -208,6 +209,7 @@ pub(crate) struct InferResponse {
pub(crate) tokens: Vec<String>, pub(crate) tokens: Vec<String>,
pub(crate) logprobs: Vec<f32>, pub(crate) logprobs: Vec<f32>,
pub(crate) finish_reason: String, pub(crate) finish_reason: String,
pub(crate) seed: Option<u64>,
pub(crate) queued: Instant, pub(crate) queued: Instant,
pub(crate) start: Instant, pub(crate) start: Instant,
pub(crate) end: Instant, pub(crate) end: Instant,

View File

@ -166,6 +166,7 @@ impl From<&GenerateParameters> for NextTokenChooserParameters {
top_k: parameters.top_k as u32, top_k: parameters.top_k as u32,
top_p: parameters.top_p, top_p: parameters.top_p,
do_sample: parameters.do_sample, do_sample: parameters.do_sample,
seed: parameters.seed,
} }
} }
} }

View File

@ -25,6 +25,8 @@ pub(crate) struct GenerateParameters {
pub stop: Vec<String>, pub stop: Vec<String>,
#[serde(default)] #[serde(default)]
pub details: bool, pub details: bool,
#[serde(default)]
pub seed: Option<u64>,
} }
fn default_temperature() -> f32 { fn default_temperature() -> f32 {
@ -56,6 +58,7 @@ fn default_parameters() -> GenerateParameters {
max_new_tokens: default_max_new_tokens(), max_new_tokens: default_max_new_tokens(),
stop: vec![], stop: vec![],
details: false, details: false,
seed: None,
} }
} }
@ -70,6 +73,7 @@ pub(crate) struct GenerateRequest {
pub(crate) struct Details { pub(crate) struct Details {
pub finish_reason: String, pub finish_reason: String,
pub generated_tokens: u32, pub generated_tokens: u32,
pub seed: Option<u64>,
pub tokens: Vec<(u32, String, f32)>, pub tokens: Vec<(u32, String, f32)>,
} }

View File

@ -55,6 +55,7 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
max_new_tokens: 1, max_new_tokens: 1,
stop: vec![], stop: vec![],
details: false, details: false,
seed: None,
}, },
}, },
) )
@ -70,7 +71,8 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
validation_time, validation_time,
queue_time, queue_time,
inference_time, inference_time,
time_per_token time_per_token,
seed
) )
)] )]
async fn generate( async fn generate(
@ -118,6 +120,7 @@ async fn generate(
.map(|((id, text), logprob)| (id, text, logprob)) .map(|((id, text), logprob)| (id, text, logprob))
.collect(); .collect();
Some(Details { Some(Details {
seed: response.seed,
finish_reason: response.finish_reason, finish_reason: response.finish_reason,
generated_tokens: response.generated_tokens, generated_tokens: response.generated_tokens,
tokens, tokens,
@ -162,6 +165,7 @@ async fn generate(
tracing::Span::current().record("queue_time", format!("{:?}", queue_time)); tracing::Span::current().record("queue_time", format!("{:?}", queue_time));
tracing::Span::current().record("inference_time", format!("{:?}", inference_time)); tracing::Span::current().record("inference_time", format!("{:?}", inference_time));
tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token)); tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token));
tracing::Span::current().record("seed", format!("{:?}", response.seed));
tracing::info!("Output: {}", response.output_text); tracing::info!("Output: {}", response.output_text);
// Send response // Send response

View File

@ -234,7 +234,9 @@ class BLOOMSharded(BLOOM):
if name == "word_embeddings.weight": if name == "word_embeddings.weight":
model.lm_head._parameters["weight"] = tensor model.lm_head._parameters["weight"] = tensor
def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None): def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs = self.model.forward( outputs = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,

View File

@ -7,7 +7,7 @@ from typing import Optional, Tuple, List, Type
from text_generation.models import Model from text_generation.models import Model
from text_generation.models.types import GeneratedText, Batch from text_generation.models.types import GeneratedText, Batch
from text_generation.pb import generate_pb2 from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
@dataclass @dataclass
@ -296,7 +296,10 @@ class CausalLM(Model):
) )
with context_manager(): with context_manager():
logits, past = self.forward( logits, past = self.forward(
batch.input_ids, batch.attention_mask, batch.position_ids, batch.past_key_values batch.input_ids,
batch.attention_mask,
batch.position_ids,
batch.past_key_values,
) )
# List of indices to cache # List of indices to cache
@ -373,6 +376,12 @@ class CausalLM(Model):
1 1
).tolist() ).tolist()
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
# Add to the list of finished generations with the original request # Add to the list of finished generations with the original request
generated_texts.append( generated_texts.append(
GeneratedText( GeneratedText(
@ -383,6 +392,7 @@ class CausalLM(Model):
token_ids=token_ids.squeeze(1).tolist(), token_ids=token_ids.squeeze(1).tolist(),
logprobs=logprobs, logprobs=logprobs,
reason=reason, reason=reason,
seed=seed,
) )
) )
# add to the next batch # add to the next batch

View File

@ -333,7 +333,9 @@ class GalacticaSharded(Galactica):
if name == "model.decoder.embed_tokens.weight": if name == "model.decoder.embed_tokens.weight":
model.lm_head._parameters["weight"] = tensor model.lm_head._parameters["weight"] = tensor
def forward(self, input_ids, attention_mask, position_ids, past_key_values: Optional = None): def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs = self.model.forward( outputs = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,

View File

@ -39,12 +39,16 @@ class SantaCoder(CausalLM):
} }
) )
self.model = AutoModelForCausalLM.from_pretrained( self.model = (
model_name, AutoModelForCausalLM.from_pretrained(
torch_dtype=dtype, model_name,
load_in_8bit=quantize, torch_dtype=dtype,
trust_remote_code=True, # required load_in_8bit=quantize,
).to(device).eval() trust_remote_code=True, # required
)
.to(device)
.eval()
)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer,

View File

@ -7,7 +7,7 @@ from typing import Optional, Tuple, List, Type
from text_generation.models import Model from text_generation.models import Model
from text_generation.models.types import GeneratedText, Batch from text_generation.models.types import GeneratedText, Batch
from text_generation.pb import generate_pb2 from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
@dataclass @dataclass
@ -451,6 +451,13 @@ class Seq2SeqLM(Model):
logprobs = [float("nan")] + decoder_logprobs[ logprobs = [float("nan")] + decoder_logprobs[
-decoder_input_length: -decoder_input_length:
].tolist() ].tolist()
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
# Add to the list of finished generations with the original request # Add to the list of finished generations with the original request
generated_texts.append( generated_texts.append(
GeneratedText( GeneratedText(
@ -461,6 +468,7 @@ class Seq2SeqLM(Model):
token_ids=token_ids.tolist(), token_ids=token_ids.tolist(),
logprobs=logprobs, logprobs=logprobs,
reason=reason, reason=reason,
seed=seed,
) )
) )
# add to the next batch # add to the next batch

View File

@ -2,7 +2,7 @@ import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import List from typing import List, Optional
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
@ -39,6 +39,7 @@ class GeneratedText:
token_ids: List[int] token_ids: List[int]
logprobs: List[float] logprobs: List[float]
reason: str reason: str
seed: Optional[int]
def to_pb(self) -> generate_pb2.GeneratedText: def to_pb(self) -> generate_pb2.GeneratedText:
return generate_pb2.GeneratedText( return generate_pb2.GeneratedText(
@ -49,4 +50,5 @@ class GeneratedText:
token_ids=self.token_ids, token_ids=self.token_ids,
logprobs=self.logprobs, logprobs=self.logprobs,
finish_reason=self.reason, finish_reason=self.reason,
seed=self.seed,
) )

View File

@ -24,11 +24,24 @@ from text_generation.pb import generate_pb2
class Sampling: class Sampling:
def __init__(self, seed: Optional[int] = None):
self.generator = torch.Generator()
if seed is not None:
self.generator.manual_seed(seed)
else:
self.generator.seed()
def __call__(self, logits): def __call__(self, logits):
probs = torch.nn.functional.softmax(logits, dim=-1) probs = torch.nn.functional.softmax(logits, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) next_tokens = torch.multinomial(
probs, num_samples=1, generator=self.generator
).squeeze(1)
return next_tokens return next_tokens
@property
def seed(self) -> int:
return self.generator.initial_seed()
class Greedy: class Greedy:
def __call__(self, logits): def __call__(self, logits):
@ -36,7 +49,9 @@ class Greedy:
class NextTokenChooser: class NextTokenChooser:
def __init__(self, temperature=1.0, top_k=None, top_p=None, do_sample=False): def __init__(
self, temperature=1.0, top_k=None, top_p=None, do_sample=False, seed=None
):
warpers = LogitsProcessorList() warpers = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py` # all samplers can be found in `generation_utils_samplers.py`
@ -53,7 +68,7 @@ class NextTokenChooser:
sampling = True sampling = True
self.warpers = warpers self.warpers = warpers
self.choice = Sampling() if sampling else Greedy() self.choice = Sampling(seed) if sampling else Greedy()
def __call__(self, input_ids, scores): def __call__(self, input_ids, scores):
# Warp logits # Warp logits
@ -66,11 +81,14 @@ class NextTokenChooser:
@classmethod @classmethod
def from_pb(cls, pb: generate_pb2.NextTokenChooserParameters) -> "NextTokenChooser": def from_pb(cls, pb: generate_pb2.NextTokenChooserParameters) -> "NextTokenChooser":
# handle protobuf making default values 0
seed = pb.seed if pb.HasField("seed") else None
return NextTokenChooser( return NextTokenChooser(
temperature=pb.temperature, temperature=pb.temperature,
top_k=pb.top_k, top_k=pb.top_k,
top_p=pb.top_p, top_p=pb.top_p,
do_sample=pb.do_sample, do_sample=pb.do_sample,
seed=seed,
) )