This commit is contained in:
regisss 2025-04-28 07:50:15 +00:00
parent 790a3b5ed2
commit 3e9f5ba159
2 changed files with 12 additions and 2 deletions

View File

@ -882,7 +882,12 @@ class CausalLM(Model):
if load_to_meta: if load_to_meta:
# model loaded to meta is managed differently # model loaded to meta is managed differently
checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w") checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w")
checkpoint_files = [str(f) for f in weight_files(model_id, revision=revision, extension=".safetensors")] checkpoint_files = [
str(f)
for f in weight_files(
model_id, revision=revision, extension=".safetensors"
)
]
data = {"type": "ds_model", "checkpoints": checkpoint_files, "version": 1.0} data = {"type": "ds_model", "checkpoints": checkpoint_files, "version": 1.0}
json.dump(data, checkpoints_json) json.dump(data, checkpoints_json)
checkpoints_json.flush() checkpoints_json.flush()

View File

@ -1017,7 +1017,12 @@ class VlmCausalLM(Model):
if load_to_meta: if load_to_meta:
# model loaded to meta is managed differently # model loaded to meta is managed differently
checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w") checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w")
checkpoint_files = [str(f) for f in weight_files(model_id, revision=revision, extension=".safetensors")] checkpoint_files = [
str(f)
for f in weight_files(
model_id, revision=revision, extension=".safetensors"
)
]
data = {"type": "ds_model", "checkpoints": checkpoint_files, "version": 1.0} data = {"type": "ds_model", "checkpoints": checkpoint_files, "version": 1.0}
json.dump(data, checkpoints_json) json.dump(data, checkpoints_json)
checkpoints_json.flush() checkpoints_json.flush()