diff --git a/.github/workflows/server-tests.yaml b/.github/workflows/server-tests.yaml new file mode 100644 index 00000000..69da5130 --- /dev/null +++ b/.github/workflows/server-tests.yaml @@ -0,0 +1,29 @@ +name: Server Tests + +on: + pull_request: + paths: + - "server/**" + - "proto/**" + +jobs: + run_tests: + runs-on: ubuntu-20.04 + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: 3.9 + - name: Loading cache. + uses: actions/cache@v2 + id: model_cache + with: + path: ~/.cache/huggingface/ + key: models + - name: Install server dependencies + run: | + make install-server + - name: Run tests + run: | + pytest -sv server/tests diff --git a/proto/generate.proto b/proto/generate.proto index 0c67de03..16539f8b 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -27,7 +27,7 @@ message ClearCacheRequest {} /// Empty response message ClearCacheResponse {} -message LogitsWarperParameters { +message NextTokenChooserParameters { /// exponential scaling output probability distribution float temperature = 1; /// restricting to the k highest probability elements @@ -52,8 +52,8 @@ message Request { string inputs = 2; /// The number of tokens inside inputs uint32 input_length = 3; - /// Logits Warper Parameters - LogitsWarperParameters parameters = 4; + /// Next Token Chooser Parameters + NextTokenChooserParameters parameters = 4; /// Stopping Criteria Parameters StoppingCriteriaParameters stopping_parameters = 5; } @@ -71,11 +71,17 @@ message GeneratedText { /// Request Request request = 1; /// Output - string output = 2; + string output_text = 2; /// Number of generated tokens - uint32 tokens = 3; + uint32 generated_tokens = 3; + /// Tokens + repeated string tokens = 4; + /// Token IDs + repeated uint32 token_ids = 5; + /// Logprobs + repeated float logprobs = 6; /// Finish reason - string finish_reason = 4; + string finish_reason = 7; } message GenerateRequest { diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index ae337dd6..295b009b 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -7,7 +7,7 @@ mod sharded_client; pub use client::Client; pub use pb::generate::v1::{ - Batch, GeneratedText, LogitsWarperParameters, Request, StoppingCriteriaParameters, + Batch, GeneratedText, NextTokenChooserParameters, Request, StoppingCriteriaParameters, }; pub use sharded_client::ShardedClient; use thiserror::Error; diff --git a/router/src/batcher.rs b/router/src/batcher.rs index a72a6e44..1484434c 100644 --- a/router/src/batcher.rs +++ b/router/src/batcher.rs @@ -187,9 +187,13 @@ fn send_generated(finished: Vec, db: &Db) { let entry = db .remove(&output.request.unwrap().id) .expect("ID not found in db. This is a bug."); + let response = InferResponse { - output: output.output, + output_text: output.output_text, + generated_tokens: output.generated_tokens, + token_ids: output.token_ids, tokens: output.tokens, + logprobs: output.logprobs, finish_reason: output.finish_reason, queued: entry.time, start: entry.batch_time.unwrap(), // unwrap is always valid @@ -202,8 +206,11 @@ fn send_generated(finished: Vec, db: &Db) { #[derive(Debug)] pub(crate) struct InferResponse { - pub(crate) output: String, - pub(crate) tokens: u32, + pub(crate) output_text: String, + pub(crate) generated_tokens: u32, + pub(crate) token_ids: Vec, + pub(crate) tokens: Vec, + pub(crate) logprobs: Vec, pub(crate) finish_reason: String, pub(crate) queued: Instant, pub(crate) start: Instant, diff --git a/router/src/db.rs b/router/src/db.rs index 24fb7a09..df9f2b8e 100644 --- a/router/src/db.rs +++ b/router/src/db.rs @@ -5,7 +5,7 @@ use parking_lot::Mutex; use std::collections::BTreeMap; use std::sync::Arc; use text_generation_client::{ - Batch, ClientError, LogitsWarperParameters, Request, StoppingCriteriaParameters, + Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters, }; use tokio::sync::oneshot::Sender; use tokio::time::Instant; @@ -71,7 +71,7 @@ impl State { id: *id, inputs: entry.request.inputs.clone(), input_length: entry.input_length as u32, - parameters: Some(LogitsWarperParameters::from( + parameters: Some(NextTokenChooserParameters::from( entry.request.parameters.clone(), )), stopping_parameters: Some(StoppingCriteriaParameters::from( @@ -162,7 +162,7 @@ impl Db { } } -impl From for LogitsWarperParameters { +impl From for NextTokenChooserParameters { fn from(parameters: GenerateParameters) -> Self { Self { temperature: parameters.temperature, diff --git a/router/src/lib.rs b/router/src/lib.rs index b6c694ee..03711580 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -21,7 +21,10 @@ pub(crate) struct GenerateParameters { pub do_sample: bool, #[serde(default = "default_max_new_tokens")] pub max_new_tokens: u32, + #[serde(default)] pub stop: Vec, + #[serde(default)] + pub details: bool, } fn default_temperature() -> f32 { @@ -52,6 +55,7 @@ fn default_parameters() -> GenerateParameters { do_sample: default_do_sample(), max_new_tokens: default_max_new_tokens(), stop: vec![], + details: false, } } @@ -62,10 +66,18 @@ pub(crate) struct GenerateRequest { pub parameters: GenerateParameters, } +#[derive(Serialize)] +pub(crate) struct Details { + pub finish_reason: String, + pub generated_tokens: u32, + pub tokens: Vec<(u32, String, f32)>, +} + #[derive(Serialize)] pub(crate) struct GeneratedText { pub generated_text: String, - pub finish_reason: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub details: Option
, } #[derive(Serialize)] diff --git a/router/src/server.rs b/router/src/server.rs index 59296269..0ded187d 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,5 +1,5 @@ use crate::{ - Batcher, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Validation, + Batcher, Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, StatusCode}; @@ -54,6 +54,7 @@ async fn health(state: Extension) -> Result<(), (StatusCode, Json) -> Result<(), (StatusCode, Json, @@ -89,6 +90,7 @@ async fn generate( })?; // Validate request + let details = req.0.parameters.details; let (input_length, validated_request) = state.validation.validate(req.0).await.map_err(|err| { tracing::error!("{}", err.to_string()); @@ -105,12 +107,31 @@ async fn generate( err })?; + // Token details + let details = match details { + true => { + let tokens = response + .token_ids + .into_iter() + .zip(response.tokens.into_iter()) + .zip(response.logprobs.into_iter()) + .map(|((id, text), logprob)| (id, text, logprob)) + .collect(); + Some(Details { + finish_reason: response.finish_reason, + generated_tokens: response.generated_tokens, + tokens, + }) + } + false => None + }; + // Timings let total_time = start_time.elapsed(); let validation_time = response.queued - start_time; let queue_time = response.start - response.queued; let inference_time = response.end - response.start; - let time_per_token = inference_time / response.tokens; + let time_per_token = inference_time / response.generated_tokens; // Headers let mut headers = HeaderMap::new(); @@ -141,12 +162,13 @@ async fn generate( tracing::Span::current().record("queue_time", format!("{:?}", queue_time)); tracing::Span::current().record("inference_time", format!("{:?}", inference_time)); tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token)); - tracing::info!("Output: {}", response.output); + tracing::info!("Output: {}", response.output_text); + // Send response let response = vec![GeneratedText { - generated_text: response.output, - finish_reason: response.finish_reason, + generated_text: response.output_text, + details, }]; Ok((headers, Json(response))) } @@ -197,7 +219,7 @@ async fn shutdown_signal() { }; #[cfg(unix)] - let terminate = async { + let terminate = async { signal::unix::signal(signal::unix::SignalKind::terminate()) .expect("failed to install signal handler") .recv() @@ -205,7 +227,7 @@ async fn shutdown_signal() { }; #[cfg(not(unix))] - let terminate = std::future::pending::<()>(); + let terminate = std::future::pending::<()>(); tokio::select! { _ = ctrl_c => {}, diff --git a/server/Makefile b/server/Makefile index ac9cb785..38881af1 100644 --- a/server/Makefile +++ b/server/Makefile @@ -1,6 +1,6 @@ gen-server: # Compile protos - pip install grpcio-tools==1.49.1 --no-cache-dir + #pip install grpcio-tools==1.49.1 --no-cache-dir mkdir text_generation/pb || true python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 24cdafac..eb72b8a2 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -7,7 +7,7 @@ from text_generation.pb import generate_pb2 @pytest.fixture def default_pb_parameters(): - return generate_pb2.LogitsWarperParameters( + return generate_pb2.NextTokenChooserParameters( temperature=1.0, top_k=0, top_p=1.0, diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index c5dbaa3e..ee16b95f 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -128,10 +128,10 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch) assert next_batch is None assert len(generated_texts) == 1 - assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest" + assert generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" assert generated_texts[0].request == default_bloom_batch.requests[0] assert ( - generated_texts[0].tokens + generated_texts[0].generated_tokens == default_bloom_batch.stopping_criterias[0].max_new_tokens ) @@ -151,10 +151,10 @@ def test_causal_lm_generate_token_completion_multi( assert next_batch is not None assert len(generated_texts) == 1 - assert generated_texts[0].output == "TestTestTestTestTestTest" + assert generated_texts[0].output_text == "TestTestTestTestTestTest" assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1] assert ( - generated_texts[0].tokens + generated_texts[0].generated_tokens == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens ) @@ -170,10 +170,10 @@ def test_causal_lm_generate_token_completion_multi( assert next_batch is None assert len(generated_texts) == 1 - assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest" + assert generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0] assert ( - generated_texts[0].tokens + generated_texts[0].generated_tokens == default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens ) @@ -240,10 +240,10 @@ def test_batch_concatenate( assert next_batch is not None assert len(generated_texts) == 1 - assert generated_texts[0].output == "TestTestTestTestTestTest" + assert generated_texts[0].output_text == "TestTestTestTestTestTest" assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1] assert ( - generated_texts[0].tokens + generated_texts[0].generated_tokens == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens ) @@ -259,10 +259,10 @@ def test_batch_concatenate( assert next_batch is not None assert len(generated_texts) == 1 - assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest" + assert generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" assert generated_texts[0].request == default_bloom_batch.requests[0] assert ( - generated_texts[0].tokens + generated_texts[0].generated_tokens == default_bloom_batch.stopping_criterias[0].max_new_tokens ) @@ -279,9 +279,9 @@ def test_batch_concatenate( assert next_batch is None assert len(generated_texts) == 1 - assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest" + assert generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest" assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0] assert ( - generated_texts[0].tokens + generated_texts[0].generated_tokens == default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens ) diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index f38776cc..683d9fdd 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -127,10 +127,11 @@ def test_causal_lm_generate_token_completion( assert next_batch is None assert len(generated_texts) == 1 - assert generated_texts[0].output == "Test.java:784) at net.minecraft." + assert generated_texts[0].output_text == "Test.java:784) at net.minecraft." assert generated_texts[0].request == default_causal_lm_batch.requests[0] + assert len(generated_texts[0].tokens) == len(generated_texts[0].logprobs) assert ( - generated_texts[0].tokens + generated_texts[0].generated_tokens == default_causal_lm_batch.stopping_criterias[0].max_new_tokens ) @@ -150,12 +151,12 @@ def test_causal_lm_generate_token_completion_multi( assert next_batch is not None assert len(generated_texts) == 1 - assert generated_texts[0].output == "Test.java:784)" + assert generated_texts[0].output_text == "Test.java:784)" assert ( generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1] ) assert ( - generated_texts[0].tokens + generated_texts[0].generated_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens ) @@ -171,12 +172,12 @@ def test_causal_lm_generate_token_completion_multi( assert next_batch is None assert len(generated_texts) == 1 - assert generated_texts[0].output == "Test.java:784) at net.minecraft." + assert generated_texts[0].output_text == "Test.java:784) at net.minecraft." assert ( generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0] ) assert ( - generated_texts[0].tokens + generated_texts[0].generated_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens ) @@ -242,12 +243,12 @@ def test_batch_concatenate( assert next_batch is not None assert len(generated_texts) == 1 - assert generated_texts[0].output == "Test.java:784)" + assert generated_texts[0].output_text == "Test.java:784)" assert ( generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[1] ) assert ( - generated_texts[0].tokens + generated_texts[0].generated_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens ) @@ -263,10 +264,10 @@ def test_batch_concatenate( assert next_batch is not None assert len(generated_texts) == 1 - assert generated_texts[0].output == "Test.java:784) at net.minecraft." + assert generated_texts[0].output_text == "Test.java:784) at net.minecraft." assert generated_texts[0].request == default_causal_lm_batch.requests[0] assert ( - generated_texts[0].tokens + generated_texts[0].generated_tokens == default_causal_lm_batch.stopping_criterias[0].max_new_tokens ) @@ -283,11 +284,11 @@ def test_batch_concatenate( assert next_batch is None assert len(generated_texts) == 1 - assert generated_texts[0].output == "Test.java:784) at net.minecraft." + assert generated_texts[0].output_text == "Test.java:784) at net.minecraft." assert ( generated_texts[0].request == default_multi_requests_causal_lm_batch.requests[0] ) assert ( - generated_texts[0].tokens + generated_texts[0].generated_tokens == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens ) diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 94ec70d5..f1b11bc2 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -148,9 +148,9 @@ def test_seq2seq_lm_generate_token_completion( assert next_batch is None assert len(generated_texts) == 1 - assert generated_texts[0].output == "a few weeks" + assert generated_texts[0].output_text == "a few weeks" assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0] - assert generated_texts[0].tokens == 7 + assert generated_texts[0].generated_tokens == 7 def test_seq2seq_lm_generate_token_completion_multi( @@ -166,12 +166,12 @@ def test_seq2seq_lm_generate_token_completion_multi( assert next_batch is not None assert len(generated_texts) == 1 - assert generated_texts[0].output == "a few " + assert generated_texts[0].output_text == "a few " assert ( generated_texts[0].request == default_multi_requests_seq2seq_lm_batch.requests[1] ) - assert generated_texts[0].tokens == 5 + assert generated_texts[0].generated_tokens == 5 generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) assert generated_texts == [] @@ -180,12 +180,12 @@ def test_seq2seq_lm_generate_token_completion_multi( assert next_batch is None assert len(generated_texts) == 1 - assert generated_texts[0].output == "a few weeks" + assert generated_texts[0].output_text == "a few weeks" assert ( generated_texts[0].request == default_multi_requests_seq2seq_lm_batch.requests[0] ) - assert generated_texts[0].tokens == 7 + assert generated_texts[0].generated_tokens == 7 def test_batch_concatenate( @@ -287,28 +287,28 @@ def test_batch_concatenate( assert next_batch is not None assert len(generated_texts) == 1 - assert generated_texts[0].output == "a few " + assert generated_texts[0].output_text == "a few " assert ( generated_texts[0].request == default_multi_requests_seq2seq_lm_batch.requests[1] ) - assert generated_texts[0].tokens == 5 + assert generated_texts[0].generated_tokens == 5 generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is not None assert len(generated_texts) == 1 - assert generated_texts[0].output == "a few weeks" + assert generated_texts[0].output_text == "a few weeks" assert generated_texts[0].request == default_seq2seq_lm_batch.requests[0] - assert generated_texts[0].tokens == 7 + assert generated_texts[0].generated_tokens == 7 generated_texts, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is None assert len(generated_texts) == 1 - assert generated_texts[0].output == "a few weeks" + assert generated_texts[0].output_text == "a few weeks" assert ( generated_texts[0].request == default_multi_requests_seq2seq_lm_batch.requests[0] ) - assert generated_texts[0].tokens == 7 + assert generated_texts[0].generated_tokens == 7 diff --git a/server/text_generation/models/bloom.py b/server/text_generation/models/bloom.py index 20e26419..3561a8ea 100644 --- a/server/text_generation/models/bloom.py +++ b/server/text_generation/models/bloom.py @@ -246,12 +246,8 @@ class BLOOMSharded(BLOOM): ) # Logits are sharded, so we need to gather them - logits_shard = outputs.logits[:, -1, :].contiguous() - - batch_size, vocab_shard_size = logits_shard.shape - vocab_size = self.world_size * vocab_shard_size - logits = [torch.empty_like(logits_shard) for _ in range(self.world_size)] - torch.distributed.all_gather(logits, logits_shard, group=self.process_group) - logits = torch.cat(logits, dim=1).view(batch_size, 1, vocab_size) + logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)] + torch.distributed.all_gather(logits, outputs.logits, group=self.process_group) + logits = torch.cat(logits, dim=2) return logits, outputs.past_key_values diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index 72095858..7fa027a5 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -22,6 +22,7 @@ class CausalLMBatch: # All tokens all_input_ids: List[torch.Tensor] + all_logprobs: List[Optional[torch.Tensor]] # Lengths of all generations present in the batch input_lengths: List[int] @@ -46,12 +47,13 @@ class CausalLMBatch: @classmethod def from_pb( - cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device + cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device ) -> "CausalLMBatch": inputs = [] next_token_choosers = [] stopping_criterias = [] input_lengths = [] + all_logprobs = [] # Parse batch for r in pb.requests: @@ -61,6 +63,7 @@ class CausalLMBatch: stopping_criterias.append( StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) ) + all_logprobs.append(None) pad_to_multiple_of = 8 if "gpu" in str(device) else None tokenized_inputs = tokenizer( @@ -78,6 +81,7 @@ class CausalLMBatch: attention_mask=tokenized_inputs["attention_mask"], past_key_values=None, all_input_ids=all_input_ids, + all_logprobs=all_logprobs, input_lengths=input_lengths, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, @@ -95,6 +99,7 @@ class CausalLMBatch: requests = [] input_lengths = [] all_input_ids = [] + all_logprobs = [] next_token_choosers = [] stopping_criterias = [] @@ -110,6 +115,7 @@ class CausalLMBatch: requests.extend(batch.requests) input_lengths.extend(batch.input_lengths) all_input_ids.extend(batch.all_input_ids) + all_logprobs.extend(batch.all_logprobs) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) @@ -142,8 +148,8 @@ class CausalLMBatch: # We need to slice the attention mask to remove padding from previous steps attention_mask[ - start_index:end_index, -batch.max_sequence_length : - ] = batch.attention_mask[:, -batch.max_sequence_length :] + start_index:end_index, -batch.max_sequence_length: + ] = batch.attention_mask[:, -batch.max_sequence_length:] for j, past in enumerate(batch.past_key_values): past_keys, past_values = past @@ -191,22 +197,22 @@ class CausalLMBatch: # We slice the past keys and values to remove the padding from previous batches if batch.keys_head_dim_last: past_key_values[j][0][ - start_index:end_index, - :, - -(batch.max_sequence_length - 1) :, - :, - ] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :] + start_index:end_index, + :, + -(batch.max_sequence_length - 1):, + :, + ] = past_keys[:, :, -(batch.max_sequence_length - 1):, :] else: past_key_values[j][0][ - start_index:end_index, - :, - :, - -(batch.max_sequence_length - 1) :, - ] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :] + start_index:end_index, + :, + :, + -(batch.max_sequence_length - 1):, + ] = past_keys[:, :, :, -(batch.max_sequence_length - 1):] past_key_values[j][1][ - start_index:end_index, :, -(batch.max_sequence_length - 1) :, : - ] = past_values[:, :, -(batch.max_sequence_length - 1) :, :] + start_index:end_index, :, -(batch.max_sequence_length - 1):, : + ] = past_values[:, :, -(batch.max_sequence_length - 1):, :] start_index += batch.size @@ -217,6 +223,7 @@ class CausalLMBatch: attention_mask=attention_mask, past_key_values=past_key_values, all_input_ids=all_input_ids, + all_logprobs=all_logprobs, input_lengths=input_lengths, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, @@ -261,7 +268,7 @@ class CausalLM(Model): return CausalLMBatch def forward( - self, input_ids, attention_mask, past_key_values: Optional = None + self, input_ids, attention_mask, past_key_values: Optional = None ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: # Model Forward outputs = self.model.forward( @@ -273,7 +280,7 @@ class CausalLM(Model): return outputs.logits, outputs.past_key_values def generate_token( - self, batch: CausalLMBatch + self, batch: CausalLMBatch ) -> Tuple[List[GeneratedText], Optional[CausalLMBatch]]: # For some reason, inference_mode does not work well with GLOO which we use on CPU context_manager = ( @@ -291,6 +298,7 @@ class CausalLM(Model): next_batch_input_lengths = [] next_batch_input_ids = [] next_batch_all_input_ids = [] + next_batch_all_logprobs = [] # Metadata next_batch_size = 0 @@ -307,43 +315,67 @@ class CausalLM(Model): batch.next_token_choosers, batch.stopping_criterias, batch.all_input_ids, + batch.all_logprobs, ) # For each member of the batch for i, ( - request, - input_length, - logits, - next_token_chooser, - stopping_criteria, - all_tokens, + request, + input_length, + logits, + next_token_chooser, + stopping_criteria, + all_input_ids, + all_logprobs, ) in enumerate(iterator): # Select next token - next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1]) + tokens, logprobs = next_token_chooser(all_input_ids, logits) + next_token = tokens[-1].view(1, 1) # Append next token to all tokens - all_tokens = torch.cat([all_tokens, next_token]) + all_input_ids = torch.cat([all_input_ids, next_token]) + new_input_length = input_length + 1 + + if all_logprobs is None: + # logprobs of all prompt tokens (except the first one) and the generated token + all_logprobs = logprobs.gather(1, all_input_ids[1:]) + else: + # logprob of the generated token + next_token_logprob = logprobs[-1, next_token] + all_logprobs = torch.cat([all_logprobs, next_token_logprob]) # Evaluate stopping criteria - stop, reason = stopping_criteria(all_tokens) + stop, reason = stopping_criteria(all_input_ids) if stop: # Decode all tokens - output = self.tokenizer.decode( - all_tokens.squeeze(-1), skip_special_tokens=True + output_text = self.tokenizer.decode( + all_input_ids.squeeze(-1), skip_special_tokens=True ) + # Slice with input_length to remove padding + token_ids = all_input_ids[-new_input_length:] + tokens = self.tokenizer.batch_decode(token_ids) + # Add NaN for the first prompt token + logprobs = [float('nan')] + all_logprobs[-new_input_length:].squeeze(1).tolist() + # Add to the list of finished generations with the original request generated_texts.append( GeneratedText( - request, output, stopping_criteria.current_tokens, reason + request=request, + output_text=output_text, + generated_tokens=stopping_criteria.current_tokens, + tokens=tokens, + token_ids=token_ids.squeeze(1).tolist(), + logprobs=logprobs, + reason=reason ) ) # add to the next batch else: next_batch_keep_indices.append(i) next_batch_input_ids.append(next_token) - next_batch_all_input_ids.append(all_tokens) + next_batch_all_input_ids.append(all_input_ids) + next_batch_all_logprobs.append(all_logprobs) next_batch_size += 1 - new_input_length = input_length + 1 next_batch_input_lengths.append(new_input_length) next_batch_max_sequence_length = max( next_batch_max_sequence_length, new_input_length @@ -397,6 +429,7 @@ class CausalLM(Model): attention_mask=next_batch_attention_mask, past_key_values=next_batch_past_key_values, all_input_ids=next_batch_all_input_ids, + all_logprobs=next_batch_all_logprobs, input_lengths=next_batch_input_lengths, next_token_choosers=next_batch_next_token_choosers, stopping_criterias=next_batch_stopping_criterias, diff --git a/server/text_generation/models/galactica.py b/server/text_generation/models/galactica.py index 680ea43e..a713e69e 100644 --- a/server/text_generation/models/galactica.py +++ b/server/text_generation/models/galactica.py @@ -321,12 +321,8 @@ class GalacticaSharded(Galactica): ) # Logits are sharded, so we need to gather them - logits_shard = outputs.logits[:, -1, :].contiguous() - - batch_size, vocab_shard_size = logits_shard.shape - vocab_size = self.world_size * vocab_shard_size - logits = [torch.empty_like(logits_shard) for _ in range(self.world_size)] - torch.distributed.all_gather(logits, logits_shard, group=self.process_group) - logits = torch.cat(logits, dim=1).view(batch_size, 1, vocab_size) + logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)] + torch.distributed.all_gather(logits, outputs.logits, group=self.process_group) + logits = torch.cat(logits, dim=2) return logits, outputs.past_key_values diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index 93e31b4a..35c18518 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -30,6 +30,7 @@ class Seq2SeqLMBatch: # Lengths of all generations present in the batch input_lengths: List[int] decoder_input_lengths: List[int] + decoder_logprobs: List[Optional[torch.Tensor]] # Generation helpers next_token_choosers: List[NextTokenChooser] @@ -60,6 +61,7 @@ class Seq2SeqLMBatch: decoder_input_ids = [] decoder_input_lengths = [] + decoder_logprobs = [] # Parse batch for r in pb.requests: @@ -72,6 +74,7 @@ class Seq2SeqLMBatch: stopping_criterias.append( StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) ) + decoder_logprobs.append(None) # Tokenize batch pad_to_multiple_of = 8 if "gpu" in str(device) else None @@ -95,6 +98,7 @@ class Seq2SeqLMBatch: past_key_values=None, input_lengths=input_lengths, decoder_input_lengths=decoder_input_lengths, + decoder_logprobs=decoder_logprobs, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=len(pb.requests), @@ -117,6 +121,7 @@ class Seq2SeqLMBatch: requests = [] input_lengths = [] decoder_input_lengths = [] + decoder_logprobs = [] next_token_choosers = [] stopping_criterias = [] @@ -137,6 +142,7 @@ class Seq2SeqLMBatch: requests.extend(batch.requests) input_lengths.extend(batch.input_lengths) decoder_input_lengths.extend(batch.decoder_input_lengths) + decoder_logprobs.extend(batch.decoder_logprobs) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) @@ -286,6 +292,7 @@ class Seq2SeqLMBatch: past_key_values=past_key_values, input_lengths=input_lengths, decoder_input_lengths=decoder_input_lengths, + decoder_logprobs=decoder_logprobs, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, size=total_batch_size, @@ -385,6 +392,7 @@ class Seq2SeqLM(Model): next_batch_input_lengths = [] next_batch_decoder_input_ids = [] next_batch_decoder_input_lengths = [] + next_batch_decoder_logprobs = [] # Metadata next_batch_size = 0 @@ -399,6 +407,7 @@ class Seq2SeqLM(Model): batch.requests, batch.input_lengths, batch.decoder_input_lengths, + batch.decoder_logprobs, logits, batch.next_token_choosers, batch.stopping_criterias, @@ -411,38 +420,57 @@ class Seq2SeqLM(Model): request, input_length, decoder_input_length, + decoder_logprobs, logits, next_token_chooser, stopping_criteria, input_tokens, - decoder_tokens, + decoder_input_ids, ) in enumerate(iterator): - all_tokens = torch.cat([input_tokens, decoder_tokens]) # Select next token - next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1]) + next_token, logprobs = next_token_chooser(decoder_input_ids, logits) # Append next token to decoder tokens - decoder_tokens = torch.cat([decoder_tokens, next_token.squeeze(1)]) + decoder_input_ids = torch.cat([decoder_input_ids, next_token]) + new_decoder_input_length = decoder_input_length + 1 + + next_token_logprob = logprobs[-1, next_token] + if decoder_logprobs is None: + decoder_logprobs = next_token_logprob + else: + decoder_logprobs = torch.cat([decoder_logprobs, next_token_logprob]) # Evaluate stopping criteria - stop, reason = stopping_criteria(decoder_tokens) + stop, reason = stopping_criteria(decoder_input_ids) if stop: - # Decode tokens - output = self.tokenizer.decode(decoder_tokens, skip_special_tokens=True) + # Slice with decoder_input_length to remove padding + # Decode all tokens + token_ids = decoder_input_ids[-new_decoder_input_length:] + output_text = self.tokenizer.decode(token_ids, skip_special_tokens=True) + tokens = self.tokenizer.batch_decode(token_ids) + print(tokens) + # Add NaN for the bos token + logprobs = [float('nan')] + decoder_logprobs[-new_decoder_input_length:].tolist() # Add to the list of finished generations with the original request generated_texts.append( GeneratedText( - request, output, stopping_criteria.current_tokens, reason + request=request, + output_text=output_text, + generated_tokens=stopping_criteria.current_tokens, + tokens=tokens, + token_ids=token_ids.tolist(), + logprobs=logprobs, + reason=reason ) ) # add to the next batch else: next_batch_keep_indices.append(i) - next_batch_decoder_input_ids.append(decoder_tokens.unsqueeze(0)) + next_batch_decoder_input_ids.append(decoder_input_ids.unsqueeze(0)) next_batch_size += 1 - new_decoder_input_length = decoder_input_length + 1 next_batch_input_lengths.append(input_length) next_batch_decoder_input_lengths.append(new_decoder_input_length) + next_batch_decoder_logprobs.append(decoder_logprobs) next_batch_max_input_length = max( next_batch_max_input_length, input_length ) @@ -515,6 +543,7 @@ class Seq2SeqLM(Model): past_key_values=next_batch_past_key_values, input_lengths=next_batch_input_lengths, decoder_input_lengths=next_batch_decoder_input_lengths, + decoder_logprobs=next_batch_decoder_logprobs, next_token_choosers=next_batch_next_token_choosers, stopping_criterias=next_batch_stopping_criterias, size=next_batch_size, diff --git a/server/text_generation/models/types.py b/server/text_generation/models/types.py index 91ec75d0..e76cf697 100644 --- a/server/text_generation/models/types.py +++ b/server/text_generation/models/types.py @@ -30,14 +30,20 @@ class Batch(ABC): @dataclass class GeneratedText: request: generate_pb2.Request - output: str - tokens: int + output_text: str + generated_tokens: int + tokens: List[str] + token_ids: List[int] + logprobs: List[float] reason: str def to_pb(self) -> generate_pb2.GeneratedText: return generate_pb2.GeneratedText( request=self.request, - output=self.output, + output_text=self.output_text, + generated_tokens=self.generated_tokens, tokens=self.tokens, + token_ids=self.token_ids, + logprobs=self.logprobs, finish_reason=self.reason, ) diff --git a/server/text_generation/utils.py b/server/text_generation/utils.py index 661df05e..8ecf4be0 100644 --- a/server/text_generation/utils.py +++ b/server/text_generation/utils.py @@ -55,12 +55,18 @@ class NextTokenChooser: self.choice = Sampling() if sampling else Greedy() def __call__(self, input_ids, scores): + # Warp logits scores = self.warpers(input_ids, scores) + # Compute logprobs + logprobs = torch.log_softmax(scores, -1) + # Choose tokens next_ids = self.choice(scores) - return next_ids.unsqueeze(-1) + + # return next_ids, logprobs.gather(1, next_ids.unsqueeze(1)).squeeze(1) + return next_ids, logprobs @classmethod - def from_pb(cls, pb: generate_pb2.LogitsWarperParameters) -> "NextTokenChooser": + def from_pb(cls, pb: generate_pb2.NextTokenChooserParameters) -> "NextTokenChooser": return NextTokenChooser( temperature=pb.temperature, top_k=pb.top_k, @@ -93,7 +99,7 @@ class StopSequenceCriteria: class StoppingCriteria: def __init__( - self, stop_sequence_criterias: List[StopSequenceCriteria], max_new_tokens=20 + self, stop_sequence_criterias: List[StopSequenceCriteria], max_new_tokens=20 ): self.stop_sequence_criterias = stop_sequence_criterias self.max_new_tokens = max_new_tokens @@ -113,7 +119,7 @@ class StoppingCriteria: @classmethod def from_pb( - cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: AutoTokenizer + cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: AutoTokenizer ) -> "StoppingCriteria": stop_sequence_criterias = [] for stop_sequence in pb.stop_sequences: