mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 03:14:53 +00:00
Format
This commit is contained in:
parent
790a3b5ed2
commit
3e9f5ba159
@ -882,7 +882,12 @@ class CausalLM(Model):
|
||||
if load_to_meta:
|
||||
# model loaded to meta is managed differently
|
||||
checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w")
|
||||
checkpoint_files = [str(f) for f in weight_files(model_id, revision=revision, extension=".safetensors")]
|
||||
checkpoint_files = [
|
||||
str(f)
|
||||
for f in weight_files(
|
||||
model_id, revision=revision, extension=".safetensors"
|
||||
)
|
||||
]
|
||||
data = {"type": "ds_model", "checkpoints": checkpoint_files, "version": 1.0}
|
||||
json.dump(data, checkpoints_json)
|
||||
checkpoints_json.flush()
|
||||
|
@ -1017,7 +1017,12 @@ class VlmCausalLM(Model):
|
||||
if load_to_meta:
|
||||
# model loaded to meta is managed differently
|
||||
checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w")
|
||||
checkpoint_files = [str(f) for f in weight_files(model_id, revision=revision, extension=".safetensors")]
|
||||
checkpoint_files = [
|
||||
str(f)
|
||||
for f in weight_files(
|
||||
model_id, revision=revision, extension=".safetensors"
|
||||
)
|
||||
]
|
||||
data = {"type": "ds_model", "checkpoints": checkpoint_files, "version": 1.0}
|
||||
json.dump(data, checkpoints_json)
|
||||
checkpoints_json.flush()
|
||||
|
Loading…
Reference in New Issue
Block a user