From 3e9f5ba159ec7f0ca34832c2df88e95eb785ffd5 Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Mon, 28 Apr 2025 07:50:15 +0000 Subject: [PATCH] Format --- .../server/text_generation_server/models/causal_lm.py | 7 ++++++- .../server/text_generation_server/models/vlm_causal_lm.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/causal_lm.py b/backends/gaudi/server/text_generation_server/models/causal_lm.py index 21989ba8..374b6fd6 100644 --- a/backends/gaudi/server/text_generation_server/models/causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/causal_lm.py @@ -882,7 +882,12 @@ class CausalLM(Model): if load_to_meta: # model loaded to meta is managed differently checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w") - checkpoint_files = [str(f) for f in weight_files(model_id, revision=revision, extension=".safetensors")] + checkpoint_files = [ + str(f) + for f in weight_files( + model_id, revision=revision, extension=".safetensors" + ) + ] data = {"type": "ds_model", "checkpoints": checkpoint_files, "version": 1.0} json.dump(data, checkpoints_json) checkpoints_json.flush() diff --git a/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py index b15b6d4e..6929b2ef 100644 --- a/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py @@ -1017,7 +1017,12 @@ class VlmCausalLM(Model): if load_to_meta: # model loaded to meta is managed differently checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w") - checkpoint_files = [str(f) for f in weight_files(model_id, revision=revision, extension=".safetensors")] + checkpoint_files = [ + str(f) + for f in weight_files( + model_id, revision=revision, extension=".safetensors" + ) + ] data = {"type": "ds_model", "checkpoints": checkpoint_files, "version": 1.0} json.dump(data, checkpoints_json) checkpoints_json.flush()