push image to test

This commit is contained in:
OlivierDehaene 2023-04-24 16:08:27 +02:00
parent c69f24d16b
commit 885411e747
3 changed files with 16 additions and 14 deletions

View File

@ -80,8 +80,8 @@ jobs:
latest=auto
images: |
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
ghcr.io/huggingface/text-generation-inference
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
# ghcr.io/huggingface/text-generation-inference
# db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
tags: |
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
@ -93,7 +93,8 @@ jobs:
with:
context: .
file: Dockerfile
push: ${{ github.event_name != 'pull_request' }}
# push: ${{ github.event_name != 'pull_request' }}
push: true
platforms: 'linux/amd64'
build-args: |
GIT_SHA=${{ env.GITHUB_SHA }}

View File

@ -337,6 +337,10 @@ class CausalLMBatch(Batch):
layer[k] = t.view(len(batch), -1, *t.shape[-2:])
start_index = end_index
# Add eventual padding tokens that were added while concatenating
max_tokens += batch.max_tokens + (
max_input_length - batch.max_input_length
) * len(batch)
first_past_kvs = batches[0].past_key_values
_, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape
@ -404,10 +408,6 @@ class CausalLMBatch(Batch):
# Update values
start_index = end_index
# Add eventual padding tokens that were added while concatenating
max_tokens += batch.max_tokens + (
max_input_length - batch.max_input_length
) * len(batch)
past_key_values.append([padded_past_keys, padded_past_values])

View File

@ -390,6 +390,13 @@ class Seq2SeqLMBatch(Batch):
batch.past_key_values = [[t for t in layer] for layer in batch.past_key_values]
start_index = end_index
# Add eventual padding tokens that were added while concatenating
max_tokens += batch.max_tokens + (
max_input_length
- batch.max_input_length
+ max_decoder_input_length
- batch.max_decoder_input_length
) * len(batch)
# Determine shapes for new past kv tensors
first_past_kvs = batches[0].past_key_values
@ -455,13 +462,7 @@ class Seq2SeqLMBatch(Batch):
del t
start_index = end_index
# Add eventual padding tokens that were added while concatenating
max_tokens += batch.max_tokens + (
max_input_length
- batch.max_input_length
+ max_decoder_input_length
- batch.max_decoder_input_length
) * len(batch)
return cls(
batch_id=batches[0].batch_id,