mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
fix warping
This commit is contained in:
parent
a86e4bf713
commit
a794c677ae
7
.github/workflows/build.yaml
vendored
7
.github/workflows/build.yaml
vendored
@ -80,8 +80,8 @@ jobs:
|
|||||||
latest=auto
|
latest=auto
|
||||||
images: |
|
images: |
|
||||||
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
|
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
|
||||||
ghcr.io/huggingface/text-generation-inference
|
# ghcr.io/huggingface/text-generation-inference
|
||||||
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
|
# db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference
|
||||||
tags: |
|
tags: |
|
||||||
type=semver,pattern={{version}}
|
type=semver,pattern={{version}}
|
||||||
type=semver,pattern={{major}}.{{minor}}
|
type=semver,pattern={{major}}.{{minor}}
|
||||||
@ -93,7 +93,8 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
file: Dockerfile
|
file: Dockerfile
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
# push: ${{ github.event_name != 'pull_request' }}
|
||||||
|
push: true
|
||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
build-args: |
|
build-args: |
|
||||||
GIT_SHA=${{ env.GITHUB_SHA }}
|
GIT_SHA=${{ env.GITHUB_SHA }}
|
||||||
|
@ -67,9 +67,11 @@ class StaticWarper:
|
|||||||
self.cuda_graph = torch.cuda.CUDAGraph()
|
self.cuda_graph = torch.cuda.CUDAGraph()
|
||||||
|
|
||||||
with torch.cuda.graph(self.cuda_graph):
|
with torch.cuda.graph(self.cuda_graph):
|
||||||
|
local_scores = self.static_scores
|
||||||
for warper in self.warpers:
|
for warper in self.warpers:
|
||||||
self.static_warped_scores = warper(None, self.static_scores)
|
local_scores = warper(None, local_scores)
|
||||||
|
|
||||||
|
self.static_warped_scores = local_scores
|
||||||
# Compute logprobs
|
# Compute logprobs
|
||||||
self.static_next_logprob = torch.log_softmax(
|
self.static_next_logprob = torch.log_softmax(
|
||||||
self.static_warped_scores, -1
|
self.static_warped_scores, -1
|
||||||
|
Loading…
Reference in New Issue
Block a user