mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
use only last logits
This commit is contained in:
parent
189fba28d1
commit
e9441a1ea2
@ -38,7 +38,8 @@ function sample_example(inputs, max_new_tokens, name) {
|
||||
parameters: {
|
||||
max_new_tokens: max_new_tokens,
|
||||
do_sample: true,
|
||||
top_p: 0.9
|
||||
top_p: 0.9,
|
||||
seed: 0
|
||||
}
|
||||
});
|
||||
let params = {
|
||||
|
@ -66,7 +66,7 @@ impl Client {
|
||||
///
|
||||
/// Returns Generation for each request in batch
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
|
||||
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size, max_sequence_length = batch.requests.iter().map(|request| request.input_length).max()))]
|
||||
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
|
||||
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
||||
let response = self.stub.prefill(request).await?.into_inner();
|
||||
@ -77,7 +77,7 @@ impl Client {
|
||||
///
|
||||
/// Returns Generation for each request in batches
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
||||
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>(), max_sequence_length = batches.iter().map(|batch|{batch.requests.iter().map(|request| request.input_length).max()}).max()))]
|
||||
pub async fn decode(
|
||||
&mut self,
|
||||
batches: Vec<Batch>,
|
||||
|
@ -53,7 +53,7 @@ impl ShardedClient {
|
||||
///
|
||||
/// Returns Generation for each request in batch
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
|
||||
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size, max_sequence_length = batch.requests.iter().map(|request| request.input_length).max()))]
|
||||
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
@ -69,7 +69,7 @@ impl ShardedClient {
|
||||
///
|
||||
/// Returns Generation for each request in batches
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
||||
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>(), max_sequence_length = batches.iter().map(|batch|{batch.requests.iter().map(|request| request.input_length).max()}).max()))]
|
||||
pub async fn decode(
|
||||
&mut self,
|
||||
batches: Vec<Batch>,
|
||||
|
@ -96,7 +96,7 @@ impl Infer {
|
||||
request: valid_request,
|
||||
response_tx,
|
||||
span: Span::current(),
|
||||
batch_span: None,
|
||||
temp_span: None,
|
||||
queue_time: Instant::now(),
|
||||
batch_time: None,
|
||||
_permit: permit,
|
||||
@ -228,6 +228,18 @@ async fn batching_task(
|
||||
.next_batch(min_size, max_batch_size - batch_size as usize)
|
||||
.await
|
||||
{
|
||||
let new_batch_size = new_batch.size;
|
||||
entries.iter_mut().for_each(|(_, entry)| {
|
||||
// Create a new span to add the info that this entry is waiting
|
||||
// because a new batch is being computed
|
||||
let entry_waiting_span =
|
||||
info_span!(parent: &entry.span, "waiting", batch_size = new_batch_size);
|
||||
// Add relationship
|
||||
entry_waiting_span.follows_from(&span);
|
||||
// Update entry
|
||||
entry.temp_span = Some(entry_waiting_span);
|
||||
});
|
||||
|
||||
// Generate one token for this new batch to have the attention past in cache
|
||||
let new_cached_batch =
|
||||
wrap_future(client.prefill(new_batch), &mut new_entries)
|
||||
@ -253,7 +265,7 @@ async fn batching_task(
|
||||
// Add relationship
|
||||
entry_batch_span.follows_from(&next_batch_span);
|
||||
// Update entry
|
||||
entry.batch_span = Some(entry_batch_span);
|
||||
entry.temp_span = Some(entry_batch_span);
|
||||
});
|
||||
|
||||
cached_batch = wrap_future(client.decode(batches), &mut entries)
|
||||
@ -289,7 +301,7 @@ async fn wrap_future(
|
||||
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||
entries.drain().for_each(|(_, entry)| {
|
||||
// Create and enter a span to link this function back to the entry
|
||||
let _send_error_span = info_span!(parent: entry.batch_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||
let err = InferError::GenerationError(error.to_string());
|
||||
tracing::error!("{err}");
|
||||
|
||||
@ -312,7 +324,7 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
|
||||
.expect("ID not found in entries. This is a bug.");
|
||||
|
||||
// Create and enter a span to link this function back to the entry
|
||||
let _generation_span = info_span!(parent: entry.batch_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
||||
let _generation_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
||||
|
||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||
// Send message
|
||||
|
@ -18,9 +18,8 @@ pub(crate) struct Entry {
|
||||
pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
|
||||
/// Span that will live as long as entry
|
||||
pub span: Span,
|
||||
/// Span for every inference batch
|
||||
/// This span will only live as long as one prefill/decode
|
||||
pub batch_span: Option<Span>,
|
||||
/// Temporary span used as a guard when logging inference, wait times...
|
||||
pub temp_span: Option<Span>,
|
||||
/// Instant when this entry was queued
|
||||
pub queue_time: Instant,
|
||||
/// Instant when this entry was added to a batch
|
||||
@ -125,7 +124,12 @@ impl State {
|
||||
}
|
||||
|
||||
/// Append an entry to the queue
|
||||
fn append(&mut self, entry: Entry) {
|
||||
fn append(&mut self, mut entry: Entry) {
|
||||
// Create a span that will live as long as the entry is in the queue waiting to be batched
|
||||
let queue_span = info_span!(parent: &entry.span, "queued");
|
||||
entry.temp_span = Some(queue_span);
|
||||
|
||||
// Push entry in the queue
|
||||
self.entries.push((self.next_id, entry));
|
||||
self.next_id += 1;
|
||||
}
|
||||
@ -163,7 +167,7 @@ impl State {
|
||||
// Add relationship
|
||||
entry_batch_span.follows_from(&next_batch_span);
|
||||
// Update entry
|
||||
entry.batch_span = Some(entry_batch_span);
|
||||
entry.temp_span = Some(entry_batch_span);
|
||||
|
||||
batch_requests.push(Request {
|
||||
id,
|
||||
@ -235,7 +239,7 @@ mod tests {
|
||||
},
|
||||
response_tx,
|
||||
span: info_span!("entry"),
|
||||
batch_span: None,
|
||||
temp_span: None,
|
||||
queue_time: Instant::now(),
|
||||
batch_time: None,
|
||||
_permit: permit,
|
||||
|
@ -76,7 +76,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
||||
queue_time,
|
||||
inference_time,
|
||||
time_per_token,
|
||||
seed
|
||||
seed,
|
||||
)
|
||||
)]
|
||||
async fn generate(
|
||||
@ -179,7 +179,8 @@ async fn generate(
|
||||
validation_time,
|
||||
queue_time,
|
||||
inference_time,
|
||||
time_per_token
|
||||
time_per_token,
|
||||
seed,
|
||||
)
|
||||
)]
|
||||
async fn generate_stream(
|
||||
@ -241,13 +242,11 @@ async fn generate_stream(
|
||||
|
||||
// Tracing metadata
|
||||
span.record("total_time", format!("{:?}", total_time));
|
||||
span
|
||||
.record("validation_time", format!("{:?}", validation_time));
|
||||
span.record("validation_time", format!("{:?}", validation_time));
|
||||
span.record("queue_time", format!("{:?}", queue_time));
|
||||
span
|
||||
.record("inference_time", format!("{:?}", inference_time));
|
||||
span
|
||||
.record("time_per_token", format!("{:?}", time_per_token));
|
||||
span.record("inference_time", format!("{:?}", inference_time));
|
||||
span.record("time_per_token", format!("{:?}", time_per_token));
|
||||
span.record("seed", format!("{:?}", generated_text.seed));
|
||||
tracing::info!(parent: &span, "Output: {}", generated_text.text);
|
||||
|
||||
// StreamResponse
|
||||
|
@ -277,7 +277,6 @@ class CausalLM(Model):
|
||||
generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("forward")
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
@ -327,93 +326,93 @@ class CausalLM(Model):
|
||||
batch.all_input_ids,
|
||||
)
|
||||
|
||||
with tracer.start_as_current_span("post_processing"):
|
||||
# For each member of the batch
|
||||
for i, (
|
||||
request,
|
||||
input_length,
|
||||
logits,
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
all_input_ids,
|
||||
) in enumerate(iterator):
|
||||
# Select next token
|
||||
tokens, logprobs = next_token_chooser(all_input_ids.view(1, -1), logits)
|
||||
next_token_id = tokens[-1].view(1, 1)
|
||||
# For each member of the batch
|
||||
for i, (
|
||||
request,
|
||||
input_length,
|
||||
logits,
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
all_input_ids,
|
||||
) in enumerate(iterator):
|
||||
# Select next token
|
||||
next_token_id, logprobs = next_token_chooser(
|
||||
all_input_ids.view(1, -1), logits
|
||||
)
|
||||
|
||||
# Append next token to all tokens
|
||||
all_input_ids = torch.cat([all_input_ids, next_token_id])
|
||||
new_input_length = input_length + 1
|
||||
# Append next token to all tokens
|
||||
all_input_ids = torch.cat([all_input_ids, next_token_id])
|
||||
new_input_length = input_length + 1
|
||||
|
||||
# Generated token
|
||||
next_token_logprob = logprobs[-1, next_token_id]
|
||||
next_token_id_squeezed = next_token_id.squeeze()
|
||||
next_token_text = self.tokenizer.decode(
|
||||
next_token_id_squeezed,
|
||||
# Generated token
|
||||
next_token_logprob = logprobs[-1, next_token_id]
|
||||
next_token_id_squeezed = next_token_id.squeeze()
|
||||
next_token_text = self.tokenizer.decode(
|
||||
next_token_id_squeezed,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
|
||||
# Evaluate stopping criteria
|
||||
stop, reason = stopping_criteria(
|
||||
next_token_id_squeezed,
|
||||
next_token_text,
|
||||
)
|
||||
|
||||
if stop:
|
||||
# Decode generated tokens
|
||||
output_text = self.decode(
|
||||
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
||||
)
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
seed = next_token_chooser.choice.seed
|
||||
else:
|
||||
seed = None
|
||||
|
||||
generated_text = GeneratedText(
|
||||
output_text, stopping_criteria.current_tokens, reason, seed
|
||||
)
|
||||
else:
|
||||
# Keep request in the batch
|
||||
generated_text = None
|
||||
next_batch_keep_indices.append(i)
|
||||
next_batch_input_ids.append(next_token_id)
|
||||
next_batch_all_input_ids.append(all_input_ids)
|
||||
next_batch_size += 1
|
||||
next_batch_input_lengths.append(new_input_length)
|
||||
next_batch_max_sequence_length = max(
|
||||
next_batch_max_sequence_length, new_input_length
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if stopping_criteria.current_tokens == 1:
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
prefill_logprobs = [float("nan")] + logprobs.gather(
|
||||
1, all_input_ids[1:]
|
||||
).squeeze(1)[-new_input_length:-1].tolist()
|
||||
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
|
||||
# Evaluate stopping criteria
|
||||
stop, reason = stopping_criteria(
|
||||
next_token_id_squeezed,
|
||||
next_token_text,
|
||||
prefill_tokens = PrefillTokens(
|
||||
prefill_token_ids, prefill_logprobs, prefill_texts
|
||||
)
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
if stop:
|
||||
# Decode generated tokens
|
||||
output_text = self.decode(
|
||||
all_input_ids[-stopping_criteria.current_tokens :, 0]
|
||||
)
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
seed = next_token_chooser.choice.seed
|
||||
else:
|
||||
seed = None
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
next_token_id_squeezed,
|
||||
next_token_logprob,
|
||||
next_token_text,
|
||||
generated_text,
|
||||
)
|
||||
|
||||
generated_text = GeneratedText(
|
||||
output_text, stopping_criteria.current_tokens, reason, seed
|
||||
)
|
||||
else:
|
||||
# Keep request in the batch
|
||||
generated_text = None
|
||||
next_batch_keep_indices.append(i)
|
||||
next_batch_input_ids.append(next_token_id)
|
||||
next_batch_all_input_ids.append(all_input_ids)
|
||||
next_batch_size += 1
|
||||
next_batch_input_lengths.append(new_input_length)
|
||||
next_batch_max_sequence_length = max(
|
||||
next_batch_max_sequence_length, new_input_length
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if stopping_criteria.current_tokens == 1:
|
||||
# Remove generated token to only have prefill and add nan for first prompt token
|
||||
prefill_logprobs = [float("nan")] + logprobs.gather(
|
||||
1, all_input_ids[1:]
|
||||
).squeeze(1)[-new_input_length:-1].tolist()
|
||||
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
prefill_tokens = PrefillTokens(
|
||||
prefill_token_ids, prefill_logprobs, prefill_texts
|
||||
)
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
next_token_id_squeezed,
|
||||
next_token_logprob,
|
||||
next_token_text,
|
||||
generated_text,
|
||||
)
|
||||
|
||||
generations.append(generation)
|
||||
generations.append(generation)
|
||||
|
||||
# We finished all generations in the batch; there is no next batch
|
||||
if not next_batch_keep_indices:
|
||||
|
@ -328,7 +328,6 @@ class Seq2SeqLM(Model):
|
||||
def decode(self, decoder_ids: List[int]) -> str:
|
||||
return self.tokenizer.decode(decoder_ids, skip_special_tokens=True)
|
||||
|
||||
@tracer.start_as_current_span("forward")
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
@ -407,94 +406,91 @@ class Seq2SeqLM(Model):
|
||||
batch.decoder_input_ids,
|
||||
)
|
||||
|
||||
with tracer.start_as_current_span("post_processing"):
|
||||
# For each member of the batch
|
||||
for i, (
|
||||
request,
|
||||
input_length,
|
||||
decoder_input_length,
|
||||
logits,
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
input_tokens,
|
||||
decoder_input_ids,
|
||||
) in enumerate(iterator):
|
||||
# Select next token
|
||||
next_token_id, logprobs = next_token_chooser(
|
||||
decoder_input_ids.view(1, -1), logits
|
||||
# For each member of the batch
|
||||
for i, (
|
||||
request,
|
||||
input_length,
|
||||
decoder_input_length,
|
||||
logits,
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
input_tokens,
|
||||
decoder_input_ids,
|
||||
) in enumerate(iterator):
|
||||
# Select next token
|
||||
next_token_id, logprobs = next_token_chooser(
|
||||
decoder_input_ids.view(1, -1), logits
|
||||
)
|
||||
|
||||
# Append next token to decoder tokens
|
||||
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])
|
||||
new_decoder_input_length = decoder_input_length + 1
|
||||
|
||||
# Generated token
|
||||
next_token_logprob = logprobs[-1, next_token_id]
|
||||
next_token_id_squeezed = next_token_id.squeeze()
|
||||
next_token_text = self.tokenizer.decode(
|
||||
next_token_id_squeezed,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
|
||||
# Evaluate stopping criteria
|
||||
stop, reason = stopping_criteria(next_token_id, next_token_text)
|
||||
|
||||
if stop:
|
||||
# Slice with decoder_input_length to remove padding
|
||||
# Decode all tokens
|
||||
output_text = self.decode(decoder_input_ids[-new_decoder_input_length:])
|
||||
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
seed = next_token_chooser.choice.seed
|
||||
else:
|
||||
seed = None
|
||||
|
||||
generated_text = GeneratedText(
|
||||
output_text, stopping_criteria.current_tokens, reason, seed
|
||||
)
|
||||
else:
|
||||
# Keep request in the batch
|
||||
generated_text = None
|
||||
next_batch_keep_indices.append(i)
|
||||
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
|
||||
next_batch_size += 1
|
||||
next_batch_input_lengths.append(input_length)
|
||||
next_batch_decoder_input_lengths.append(new_decoder_input_length)
|
||||
next_batch_max_input_length = max(
|
||||
next_batch_max_input_length, input_length
|
||||
)
|
||||
next_batch_max_decoder_input_length = max(
|
||||
next_batch_max_decoder_input_length, new_decoder_input_length
|
||||
)
|
||||
|
||||
# Append next token to decoder tokens
|
||||
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])
|
||||
new_decoder_input_length = decoder_input_length + 1
|
||||
|
||||
# Generated token
|
||||
next_token_logprob = logprobs[-1, next_token_id]
|
||||
next_token_id_squeezed = next_token_id.squeeze()
|
||||
next_token_text = self.tokenizer.decode(
|
||||
next_token_id_squeezed,
|
||||
# Prefill
|
||||
if stopping_criteria.current_tokens == 1:
|
||||
prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1]
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
|
||||
# Evaluate stopping criteria
|
||||
stop, reason = stopping_criteria(next_token_id, next_token_text)
|
||||
|
||||
if stop:
|
||||
# Slice with decoder_input_length to remove padding
|
||||
# Decode all tokens
|
||||
output_text = self.decode(
|
||||
decoder_input_ids[-new_decoder_input_length:]
|
||||
)
|
||||
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
seed = next_token_chooser.choice.seed
|
||||
else:
|
||||
seed = None
|
||||
|
||||
generated_text = GeneratedText(
|
||||
output_text, stopping_criteria.current_tokens, reason, seed
|
||||
)
|
||||
else:
|
||||
# Keep request in the batch
|
||||
generated_text = None
|
||||
next_batch_keep_indices.append(i)
|
||||
next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0))
|
||||
next_batch_size += 1
|
||||
next_batch_input_lengths.append(input_length)
|
||||
next_batch_decoder_input_lengths.append(new_decoder_input_length)
|
||||
next_batch_max_input_length = max(
|
||||
next_batch_max_input_length, input_length
|
||||
)
|
||||
next_batch_max_decoder_input_length = max(
|
||||
next_batch_max_decoder_input_length, new_decoder_input_length
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if stopping_criteria.current_tokens == 1:
|
||||
prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1]
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
prefill_tokens = PrefillTokens(
|
||||
prefill_token_ids, [float("nan")], prefill_texts
|
||||
)
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
next_token_id_squeezed,
|
||||
next_token_logprob,
|
||||
next_token_text,
|
||||
generated_text,
|
||||
prefill_tokens = PrefillTokens(
|
||||
prefill_token_ids, [float("nan")], prefill_texts
|
||||
)
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
||||
generations.append(generation)
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
next_token_id_squeezed,
|
||||
next_token_logprob,
|
||||
next_token_text,
|
||||
generated_text,
|
||||
)
|
||||
|
||||
generations.append(generation)
|
||||
|
||||
# We finished all generations in the batch; there is no next batch
|
||||
if not next_batch_keep_indices:
|
||||
|
@ -36,16 +36,14 @@ class Sampling:
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, logits):
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
next_tokens = torch.multinomial(
|
||||
probs, num_samples=1, generator=self.generator
|
||||
).squeeze(1)
|
||||
probs = torch.nn.functional.softmax(logits)
|
||||
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator)
|
||||
return next_tokens
|
||||
|
||||
|
||||
class Greedy:
|
||||
def __call__(self, logits):
|
||||
return logits.argmax(dim=-1)
|
||||
return logits.argmax()
|
||||
|
||||
|
||||
class NextTokenChooser:
|
||||
@ -87,8 +85,9 @@ class NextTokenChooser:
|
||||
logprobs = torch.log_softmax(scores, -1)
|
||||
|
||||
# Choose tokens
|
||||
next_ids = self.choice(scores)
|
||||
return next_ids, logprobs
|
||||
next_id = self.choice(scores[-1])
|
||||
|
||||
return next_id.view(1, 1), logprobs
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
|
Loading…
Reference in New Issue
Block a user