From a794c677ae3237cca652bb9072b88c47cc297214 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 24 May 2023 19:37:42 +0200 Subject: [PATCH] fix warping --- .github/workflows/build.yaml | 7 ++++--- server/text_generation_server/utils/tokens.py | 4 +++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 289e4f67f..b2e02f0b1 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -80,8 +80,8 @@ jobs: latest=auto images: | registry.internal.huggingface.tech/api-inference/community/text-generation-inference - ghcr.io/huggingface/text-generation-inference - db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference +# ghcr.io/huggingface/text-generation-inference +# db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference tags: | type=semver,pattern={{version}} type=semver,pattern={{major}}.{{minor}} @@ -93,7 +93,8 @@ jobs: with: context: . file: Dockerfile - push: ${{ github.event_name != 'pull_request' }} +# push: ${{ github.event_name != 'pull_request' }} + push: true platforms: 'linux/amd64' build-args: | GIT_SHA=${{ env.GITHUB_SHA }} diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index e9fb96b06..ab6b8d474 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -67,9 +67,11 @@ class StaticWarper: self.cuda_graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.cuda_graph): + local_scores = self.static_scores for warper in self.warpers: - self.static_warped_scores = warper(None, self.static_scores) + local_scores = warper(None, local_scores) + self.static_warped_scores = local_scores # Compute logprobs self.static_next_logprob = torch.log_softmax( self.static_warped_scores, -1