diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 67d1c730..65db9bfb 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -80,8 +80,8 @@ jobs: latest=auto images: | registry.internal.huggingface.tech/api-inference/community/text-generation-inference - ghcr.io/huggingface/text-generation-inference - db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference +# ghcr.io/huggingface/text-generation-inference +# db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference tags: | type=semver,pattern={{version}} type=semver,pattern={{major}}.{{minor}} @@ -93,7 +93,8 @@ jobs: with: context: . file: Dockerfile - push: ${{ github.event_name != 'pull_request' }} +# push: ${{ github.event_name != 'pull_request' }} + push: true platforms: 'linux/amd64' build-args: | GIT_SHA=${{ env.GITHUB_SHA }} diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 1c248907..c9650b39 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -337,6 +337,10 @@ class CausalLMBatch(Batch): layer[k] = t.view(len(batch), -1, *t.shape[-2:]) start_index = end_index + # Add eventual padding tokens that were added while concatenating + max_tokens += batch.max_tokens + ( + max_input_length - batch.max_input_length + ) * len(batch) first_past_kvs = batches[0].past_key_values _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape @@ -404,10 +408,6 @@ class CausalLMBatch(Batch): # Update values start_index = end_index - # Add eventual padding tokens that were added while concatenating - max_tokens += batch.max_tokens + ( - max_input_length - batch.max_input_length - ) * len(batch) past_key_values.append([padded_past_keys, padded_past_values]) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index a61aeccf..e2becb6f 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -390,6 +390,13 @@ class Seq2SeqLMBatch(Batch): batch.past_key_values = [[t for t in layer] for layer in batch.past_key_values] start_index = end_index + # Add eventual padding tokens that were added while concatenating + max_tokens += batch.max_tokens + ( + max_input_length + - batch.max_input_length + + max_decoder_input_length + - batch.max_decoder_input_length + ) * len(batch) # Determine shapes for new past kv tensors first_past_kvs = batches[0].past_key_values @@ -455,13 +462,7 @@ class Seq2SeqLMBatch(Batch): del t start_index = end_index - # Add eventual padding tokens that were added while concatenating - max_tokens += batch.max_tokens + ( - max_input_length - - batch.max_input_length - + max_decoder_input_length - - batch.max_decoder_input_length - ) * len(batch) + return cls( batch_id=batches[0].batch_id,