feat(server): load santacoder/starcoder models with safetensors

This commit is contained in:
OlivierDehaene 2023-06-01 10:55:26 +02:00
parent db2ebe3947
commit f6438ac352
2 changed files with 76 additions and 90 deletions

View File

@ -546,11 +546,7 @@ enum LauncherError {
WebserverCannotStart, WebserverCannotStart,
} }
fn download_convert_model( fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
args: &Args,
auto_convert: bool,
running: Arc<AtomicBool>,
) -> Result<(), LauncherError> {
let mut download_argv = vec![ let mut download_argv = vec![
"text-generation-server".to_string(), "text-generation-server".to_string(),
"download-weights".to_string(), "download-weights".to_string(),
@ -562,11 +558,6 @@ fn download_convert_model(
"--json-output".to_string(), "--json-output".to_string(),
]; ];
// Auto convert weights to safetensors
if auto_convert {
download_argv.push("--auto-convert".to_string());
}
// Model optional revision // Model optional revision
if let Some(revision) = &args.revision { if let Some(revision) = &args.revision {
download_argv.push("--revision".to_string()); download_argv.push("--revision".to_string());
@ -932,11 +923,8 @@ fn main() -> Result<(), LauncherError> {
}) })
.expect("Error setting Ctrl-C handler"); .expect("Error setting Ctrl-C handler");
// auto_convert is only needed for sharded models as we do not require safetensors in
// single shard mode
let auto_convert = num_shard > 1;
// Download and convert model weights // Download and convert model weights
download_convert_model(&args, auto_convert, running.clone())?; download_convert_model(&args, running.clone())?;
// Shared shutdown bool // Shared shutdown bool
let shutdown = Arc::new(Mutex::new(false)); let shutdown = Arc::new(Mutex::new(false));

View File

@ -54,12 +54,7 @@ class FlashSantacoder(FlashCausalLM):
) )
# We do not use from_pretrained as we modified the model internal module layout # We do not use from_pretrained as we modified the model internal module layout
try: filenames = weight_files(model_id, revision, ".safetensors")
filenames = weight_files(model_id, revision, ".bin")
# Local files not found
except LocalEntryNotFoundError:
hub_files = weight_hub_files(model_id, revision, ".bin")
filenames = download_weights(hub_files, model_id, revision)
with init_empty_weights(): with init_empty_weights():
model = FlashSantacoderForCausalLM(config) model = FlashSantacoderForCausalLM(config)
@ -91,81 +86,84 @@ class FlashSantacoder(FlashCausalLM):
transpose: bool, transpose: bool,
): ):
for filename in filenames: for filename in filenames:
state_dict = torch.load(filename, map_location="cpu") with safe_open(
for key, value in state_dict.items(): filename, framework="pt", device=str(device) if quantize is None else "cpu"
value = value.to(device if quantize is None else "cpu").to(dtype) ) as f:
for key in f.keys():
value = f.get_slice(key)
value = value.to(device if quantize is None else "cpu").to(dtype)
layer_name = ".".join(key.split(".")[:4]) layer_name = ".".join(key.split(".")[:4])
# Fused qkv # Fused qkv
if "q_attn.weight" in key or "kv_attn.weight" in key: if "q_attn.weight" in key or "kv_attn.weight" in key:
final_key = layer_name + ".c_attn.weight" final_key = layer_name + ".c_attn.weight"
elif "q_attn.bias" in key or "kv_attn.bias" in key: elif "q_attn.bias" in key or "kv_attn.bias" in key:
final_key = layer_name + ".c_attn.bias" final_key = layer_name + ".c_attn.bias"
else:
final_key = key
module_name, param_name = final_key.rsplit(".", 1)
module = model.get_submodule(module_name)
try:
current_parameter_tensor = module._parameters[param_name]
except KeyError:
current_parameter_tensor = None
if current_parameter_tensor is not None:
if transpose and (
"c_fc.weight" in key
or "c_proj.weight" in key
or "q_attn.weight" in key
or "kv_attn.weight" in key
or "c_attn.weight" in key
):
# Tranpose as we use nn.Linear instead of Conv1D
value = value.T
if current_parameter_tensor.device == torch.device("meta"):
# Init qkv
if "c_attn.weight" in final_key:
module._parameters[param_name] = value.new_empty(
(
model.transformer.head_size
* (model.transformer.num_heads + 2),
value.shape[1],
)
)
elif "c_attn.bias" in final_key:
module._parameters[param_name] = value.new_empty(
(
model.transformer.head_size
* (model.transformer.num_heads + 2)
)
)
# Copy to correct slice
if "q_attn.weight" in key:
module._parameters[param_name][: value.shape[0]] = value
elif "q_attn.bias" in key:
module._parameters[param_name][: value.shape[0]] = value
elif "kv_attn.weight" in key:
module._parameters[param_name][
model.transformer.head_size * model.transformer.num_heads :
] = value
elif "kv_attn.bias" in key:
module._parameters[param_name][
model.transformer.head_size * model.transformer.num_heads :
] = value
else: else:
if current_parameter_tensor.shape != value.shape: final_key = key
raise ValueError(
f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
)
module._parameters[param_name] = value
else:
module._buffers[param_name] = value
del value module_name, param_name = final_key.rsplit(".", 1)
module = model.get_submodule(module_name)
try:
current_parameter_tensor = module._parameters[param_name]
except KeyError:
current_parameter_tensor = None
if current_parameter_tensor is not None:
if transpose and (
"c_fc.weight" in key
or "c_proj.weight" in key
or "q_attn.weight" in key
or "kv_attn.weight" in key
or "c_attn.weight" in key
):
# Tranpose as we use nn.Linear instead of Conv1D
value = value.T
if current_parameter_tensor.device == torch.device("meta"):
# Init qkv
if "c_attn.weight" in final_key:
module._parameters[param_name] = value.new_empty(
(
model.transformer.head_size
* (model.transformer.num_heads + 2),
value.shape[1],
)
)
elif "c_attn.bias" in final_key:
module._parameters[param_name] = value.new_empty(
(
model.transformer.head_size
* (model.transformer.num_heads + 2)
)
)
# Copy to correct slice
if "q_attn.weight" in key:
module._parameters[param_name][: value.shape[0]] = value
elif "q_attn.bias" in key:
module._parameters[param_name][: value.shape[0]] = value
elif "kv_attn.weight" in key:
module._parameters[param_name][
model.transformer.head_size * model.transformer.num_heads :
] = value
elif "kv_attn.bias" in key:
module._parameters[param_name][
model.transformer.head_size * model.transformer.num_heads :
] = value
else:
if current_parameter_tensor.shape != value.shape:
raise ValueError(
f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
)
module._parameters[param_name] = value
else:
module._buffers[param_name] = value
del value
torch.cuda.empty_cache() torch.cuda.empty_cache()
model.post_load_weights(quantize) model.post_load_weights(quantize)