This commit is contained in:
Nicolas Patry 2024-10-09 11:42:38 +02:00
parent 3e8d722733
commit b18ed0f443
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863

View File

@ -484,6 +484,7 @@ def launcher(event_loop):
try: try:
container = client.containers.get(container_name) container = client.containers.get(container_name)
container.stop() container.stop()
container.remove()
container.wait() container.wait()
except NotFound: except NotFound:
pass pass
@ -513,24 +514,22 @@ def launcher(event_loop):
device_requests = [] device_requests = []
if not devices: if not devices:
devices = None devices = None
elif devices == ["nvidia.com/gpu=all"]:
devices = None
device_requests = [
docker.types.DeviceRequest(
driver="cdi",
# count=gpu_count,
device_ids=[f"nvidia.com/gpu={i}"],
)
for i in range(gpu_count)
]
else: else:
devices = [] devices = []
device_requests = [ device_requests = [
docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]]) docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]])
] ]
# raise Exception(
# f"""
# Docoker image: {DOCKER_IMAGE}
# args: {args}
# container name: {container_name}
# env: {env}
# device_requests: {device_requests}
# devices: {devices}
# """
# )
# env.pop("LOG_LEVEL")
# env.pop("ROCR_VISIBLE_DEVICES")
container = client.containers.run( container = client.containers.run(
DOCKER_IMAGE, DOCKER_IMAGE,
command=args, command=args,
@ -546,25 +545,23 @@ def launcher(event_loop):
shm_size="1G", shm_size="1G",
) )
import time
time.sleep(600)
yield ContainerLauncherHandle(client, container.name, port)
if not use_flash_attention:
del env["USE_FLASH_ATTENTION"]
try: try:
container.stop() yield ContainerLauncherHandle(client, container.name, port)
container.wait()
except NotFound:
pass
container_output = container.logs().decode("utf-8") if not use_flash_attention:
print(container_output, file=sys.stderr) del env["USE_FLASH_ATTENTION"]
container.remove() try:
container.stop()
container.wait()
except NotFound:
pass
container_output = container.logs().decode("utf-8")
print(container_output, file=sys.stderr)
finally:
container.remove()
if DOCKER_IMAGE is not None: if DOCKER_IMAGE is not None:
return docker_launcher return docker_launcher