mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Making it work ?
This commit is contained in:
parent
a55917fb43
commit
3539ea37e2
1
.github/workflows/build.yaml
vendored
1
.github/workflows/build.yaml
vendored
@ -39,4 +39,5 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:2.0.4-rocm
|
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:2.0.4-rocm
|
||||||
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||||
|
export DEVICES=/dev/kfd,/dev/dri
|
||||||
python -m pytest -s -vv integration-tests/models/test_flash_gpt2.py
|
python -m pytest -s -vv integration-tests/models/test_flash_gpt2.py
|
||||||
|
@ -34,6 +34,7 @@ from text_generation.types import (
|
|||||||
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
||||||
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
|
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
|
||||||
DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
|
DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
|
||||||
|
DOCKER_DEVICES = os.getenv("DOCKER_DEVICES")
|
||||||
|
|
||||||
|
|
||||||
class ResponseComparator(JSONSnapshotExtension):
|
class ResponseComparator(JSONSnapshotExtension):
|
||||||
@ -453,16 +454,27 @@ def launcher(event_loop):
|
|||||||
if DOCKER_VOLUME:
|
if DOCKER_VOLUME:
|
||||||
volumes = [f"{DOCKER_VOLUME}:/data"]
|
volumes = [f"{DOCKER_VOLUME}:/data"]
|
||||||
|
|
||||||
|
if DOCKER_DEVICES:
|
||||||
|
devices = DOCKER_DEVICES.split(",")
|
||||||
|
visible = os.getenv("ROCR_VISIBLE_DEVICES")
|
||||||
|
if visible:
|
||||||
|
env["ROCR_VISIBLE_DEVICES"] = visible
|
||||||
|
device_requests = []
|
||||||
|
else:
|
||||||
|
devices = []
|
||||||
|
device_requests = [
|
||||||
|
docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]])
|
||||||
|
]
|
||||||
|
|
||||||
container = client.containers.run(
|
container = client.containers.run(
|
||||||
DOCKER_IMAGE,
|
DOCKER_IMAGE,
|
||||||
command=args,
|
command=args,
|
||||||
name=container_name,
|
name=container_name,
|
||||||
environment=env,
|
environment=env,
|
||||||
auto_remove=False,
|
auto_remove=False,
|
||||||
|
devices=devices,
|
||||||
|
device_requests=device_requests,
|
||||||
detach=True,
|
detach=True,
|
||||||
device_requests=[
|
|
||||||
docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]])
|
|
||||||
],
|
|
||||||
volumes=volumes,
|
volumes=volumes,
|
||||||
ports={"80/tcp": port},
|
ports={"80/tcp": port},
|
||||||
shm_size="1G",
|
shm_size="1G",
|
||||||
|
Loading…
Reference in New Issue
Block a user