feat(server): support trust_remote_code

This commit is contained in:
OlivierDehaene 2023-05-23 19:23:01 +02:00
parent e9669a4085
commit de3491854b
17 changed files with 285 additions and 57 deletions

View File

@ -53,7 +53,7 @@ struct Args {
#[clap(long, env)]
revision: Option<String>,
/// Wether to shard or not the model across multiple GPUs
/// Whether to shard the model across multiple GPUs
/// By default text-generation-inference will use all available GPUs to run
/// the model. Setting it to `false` deactivates `num_shard`.
#[clap(long, env)]
@ -66,11 +66,17 @@ struct Args {
#[clap(long, env)]
num_shard: Option<usize>,
/// Wether you want the model to be quantized or not. This will use `bitsandbytes` for
/// Whether you want the model to be quantized. This will use `bitsandbytes` for
/// quantization on the fly, or `gptq`.
#[clap(long, env, value_enum)]
quantize: Option<Quantization>,
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
/// encouraged when loading a model with custom code to ensure no malicious code has been
/// contributed in a newer revision.
#[clap(long, env, value_enum)]
trust_remote_code: bool,
/// The maximum amount of concurrent requests for this particular deployment.
/// Having a low limit will refuse clients requests instead of having them
/// wait for too long and is usually good to handle backpressure correctly.
@ -239,6 +245,7 @@ fn shard_manager(
model_id: String,
revision: Option<String>,
quantize: Option<Quantization>,
trust_remote_code: bool,
uds_path: String,
rank: usize,
world_size: usize,
@ -272,6 +279,11 @@ fn shard_manager(
"--json-output".to_string(),
];
// Activate trust remote code
if trust_remote_code {
shard_argv.push("--trust-remote-code".to_string());
}
// Activate tensor parallelism
if world_size > 1 {
shard_argv.push("--sharded".to_string());
@ -692,6 +704,16 @@ fn spawn_shards(
status_sender: mpsc::Sender<ShardStatus>,
running: Arc<AtomicBool>,
) -> Result<(), LauncherError> {
if args.trust_remote_code {
tracing::warn!(
"`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.",
args.model_id
);
if args.revision.is_none() {
tracing::warn!("Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.");
}
}
// Start shard processes
for rank in 0..num_shard {
let model_id = args.model_id.clone();
@ -705,6 +727,7 @@ fn spawn_shards(
let shutdown_sender = shutdown_sender.clone();
let otlp_endpoint = args.otlp_endpoint.clone();
let quantize = args.quantize;
let trust_remote_code = args.trust_remote_code;
let master_port = args.master_port;
let disable_custom_kernels = args.disable_custom_kernels;
let watermark_gamma = args.watermark_gamma;
@ -714,6 +737,7 @@ fn spawn_shards(
model_id,
revision,
quantize,
trust_remote_code,
uds_path,
rank,
num_shard,

View File

@ -1,4 +1,4 @@
transformers_commit := 69009822aa7897ffab97afb814e38126b83f639e
transformers_commit := 63d5605212c5d88c0cc29996b3bf76840bdd1489
transformers:
# Clone fork of transformers with custom CUDA kernels and sharding logic

View File

@ -22,6 +22,7 @@ def serve(
revision: Optional[str] = None,
sharded: bool = False,
quantize: Optional[Quantization] = None,
trust_remote_code: bool = False,
uds_path: Path = "/tmp/text-generation-server",
logger_level: str = "INFO",
json_output: bool = False,
@ -63,7 +64,7 @@ def serve(
# Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value
server.serve(model_id, revision, sharded, quantize, uds_path)
server.serve(model_id, revision, sharded, quantize, trust_remote_code, uds_path)
@app.command()

View File

@ -91,13 +91,27 @@ torch.set_grad_enabled(False)
def get_model(
model_id: str, revision: Optional[str], sharded: bool, quantize: Optional[str]
model_id: str,
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
trust_remote_code: bool,
) -> Model:
if "facebook/galactica" in model_id:
if sharded:
return GalacticaSharded(model_id, revision, quantize=quantize)
return GalacticaSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
else:
return Galactica(model_id, revision, quantize=quantize)
return Galactica(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
if model_id.startswith("bigcode/"):
if sharded:
@ -105,12 +119,24 @@ def get_model(
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
)
return FlashSantacoderSharded(model_id, revision, quantize=quantize)
return FlashSantacoderSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
else:
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
return santacoder_cls(model_id, revision, quantize=quantize)
return santacoder_cls(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(model_id, revision=revision)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
model_type = config.model_type
if model_type == "gpt_bigcode":
@ -119,52 +145,128 @@ def get_model(
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
)
return FlashSantacoderSharded(model_id, revision, quantize=quantize)
return FlashSantacoderSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
else:
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
return santacoder_cls(model_id, revision, quantize=quantize)
return santacoder_cls(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
if model_type == "bloom":
if sharded:
return BLOOMSharded(model_id, revision, quantize=quantize)
return BLOOMSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
else:
return BLOOM(model_id, revision, quantize=quantize)
return BLOOM(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
if model_type == "gpt_neox":
if sharded:
neox_cls = FlashNeoXSharded if FLASH_ATTENTION else GPTNeoxSharded
return neox_cls(model_id, revision, quantize=quantize)
return neox_cls(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
else:
neox_cls = FlashNeoX if FLASH_ATTENTION else CausalLM
return neox_cls(model_id, revision, quantize=quantize)
return neox_cls(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
if model_type == "llama":
if sharded:
if FLASH_ATTENTION:
return FlashLlamaSharded(model_id, revision, quantize=quantize)
return FlashLlamaSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama"))
else:
llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM
return llama_cls(model_id, revision, quantize=quantize)
return llama_cls(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
if config.model_type == "opt":
if sharded:
return OPTSharded(model_id, revision, quantize=quantize)
return OPTSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
else:
return OPT(model_id, revision, quantize=quantize)
return OPT(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
if model_type == "t5":
if sharded:
return T5Sharded(model_id, revision, quantize=quantize)
return T5Sharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
else:
return Seq2SeqLM(model_id, revision, quantize=quantize)
return Seq2SeqLM(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
if sharded:
raise ValueError("sharded is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(model_id, revision, quantize=quantize)
return CausalLM(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
)
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
return Seq2SeqLM(model_id, revision, quantize=quantize)
return Seq2SeqLM(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
)
auto_map = getattr(config, "auto_map", None)
if trust_remote_code and auto_map is not None:
if "AutoModelForCausalLM" in auto_map.keys():
return CausalLM(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
)
if "AutoModelForSeq2SeqLM" in auto_map.keys:
return Seq2SeqLM(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
)
raise ValueError(f"Unsupported model type {model_type}")

View File

@ -54,9 +54,13 @@ class BLOOM(CausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
super(BLOOM, self).__init__(
model_id=model_id, revision=revision, quantize=quantize
model_id=model_id,
revision=revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
@property
@ -70,6 +74,7 @@ class BLOOMSharded(BLOOM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
@ -80,11 +85,19 @@ class BLOOMSharded(BLOOM):
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, slow_but_exact=False, tp_parallel=True
model_id,
revision=revision,
slow_but_exact=False,
tp_parallel=True,
trust_remote_code=trust_remote_code,
)
config.pad_token_id = 3
@ -92,7 +105,9 @@ class BLOOMSharded(BLOOM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=trust_remote_code
)
torch.distributed.barrier(group=self.process_group)
self.load_weights(

View File

@ -450,6 +450,7 @@ class CausalLM(Model):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
@ -462,7 +463,11 @@ class CausalLM(Model):
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
@ -470,6 +475,7 @@ class CausalLM(Model):
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda()

View File

@ -394,6 +394,7 @@ class FlashCausalLM(Model):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
@ -402,13 +403,18 @@ class FlashCausalLM(Model):
raise NotImplementedError("FlashCausalLM is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
model = model_cls.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
).to(device)
super(FlashCausalLM, self).__init__(

View File

@ -33,6 +33,7 @@ class FlashLlama(FlashCausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
@ -45,11 +46,11 @@ class FlashLlama(FlashCausalLM):
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
model_id, revision=revision, trust_remote_code=trust_remote_code
)
# We do not use from_pretrained as we modified the model internal module layout
@ -153,6 +154,7 @@ class FlashLlamaSharded(FlashLlama):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
@ -166,11 +168,11 @@ class FlashLlamaSharded(FlashLlama):
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
model_id, revision=revision, trust_remote_code=trust_remote_code
)
torch.distributed.barrier(group=self.process_group)

View File

@ -28,9 +28,14 @@ class FlashNeoX(FlashCausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
super(FlashNeoX, self).__init__(
FlashGPTNeoXForCausalLM, model_id, revision, quantize
FlashGPTNeoXForCausalLM,
model_id,
revision,
quantize,
trust_remote_code=trust_remote_code,
)
@ -40,6 +45,7 @@ class FlashNeoXSharded(FlashNeoX):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
@ -49,12 +55,15 @@ class FlashNeoXSharded(FlashNeoX):
raise NotImplementedError("FlashNeoX is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
model_id, revision=revision, trust_remote_code=trust_remote_code
)
torch.distributed.barrier(group=self.process_group)

View File

@ -32,6 +32,7 @@ class FlashSantacoder(FlashCausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
@ -40,7 +41,11 @@ class FlashSantacoder(FlashCausalLM):
raise NotImplementedError("FlashSantacoder is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = GPT2Config.from_pretrained(
@ -178,6 +183,7 @@ class FlashSantacoderSharded(FlashSantacoder):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
@ -187,7 +193,11 @@ class FlashSantacoderSharded(FlashSantacoder):
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = GPT2Config.from_pretrained(

View File

@ -199,6 +199,7 @@ class GalacticaSharded(Galactica):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
@ -209,11 +210,18 @@ class GalacticaSharded(Galactica):
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, tp_parallel=True
model_id,
revision=revision,
tp_parallel=True,
trust_remote_code=trust_remote_code,
)
tokenizer.pad_token_id = config.pad_token_id
@ -221,7 +229,9 @@ class GalacticaSharded(Galactica):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=trust_remote_code
)
torch.distributed.barrier(group=self.process_group)
self.load_weights(

View File

@ -36,6 +36,7 @@ class GPTNeoxSharded(CausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
@ -46,19 +47,28 @@ class GPTNeoxSharded(CausalLM):
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.pad_token = tokenizer.eos_token
config = AutoConfig.from_pretrained(
model_id, revision=revision, tp_parallel=True
model_id,
revision=revision,
tp_parallel=True,
trust_remote_code=trust_remote_code,
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=trust_remote_code
)
torch.distributed.barrier(group=self.process_group)
self.load_weights(

View File

@ -52,6 +52,7 @@ class OPTSharded(OPT):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
@ -62,11 +63,18 @@ class OPTSharded(OPT):
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, tp_parallel=True
model_id,
revision=revision,
tp_parallel=True,
trust_remote_code=trust_remote_code,
)
tokenizer.pad_token_id = config.pad_token_id
@ -74,7 +82,9 @@ class OPTSharded(OPT):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=trust_remote_code
)
torch.distributed.barrier(group=self.process_group)
self.load_weights(

View File

@ -19,6 +19,7 @@ class SantaCoder(CausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
@ -31,7 +32,11 @@ class SantaCoder(CausalLM):
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.add_special_tokens(
{
@ -51,7 +56,7 @@ class SantaCoder(CausalLM):
revision=revision,
torch_dtype=dtype,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=True, # required
trust_remote_code=trust_remote_code,
).to(device)
super(CausalLM, self).__init__(

View File

@ -503,6 +503,7 @@ class Seq2SeqLM(Model):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
@ -520,12 +521,17 @@ class Seq2SeqLM(Model):
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda()
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.bos_token_id = model.config.decoder_start_token_id

View File

@ -36,6 +36,7 @@ class T5Sharded(Seq2SeqLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
@ -46,11 +47,18 @@ class T5Sharded(Seq2SeqLM):
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, tp_parallel=True
model_id,
revision=revision,
tp_parallel=True,
trust_remote_code=trust_remote_code,
)
tokenizer.bos_token_id = config.decoder_start_token_id
@ -58,7 +66,9 @@ class T5Sharded(Seq2SeqLM):
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = AutoModelForSeq2SeqLM.from_config(config)
model = AutoModelForSeq2SeqLM.from_config(
config, trust_remote_code=trust_remote_code
)
torch.distributed.barrier(group=self.process_group)
self.load_weights(

View File

@ -101,6 +101,7 @@ def serve(
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
trust_remote_code: bool,
uds_path: Path,
):
async def serve_inner(
@ -108,6 +109,7 @@ def serve(
revision: Optional[str],
sharded: bool = False,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
unix_socket_template = "unix://{}-{}"
if sharded:
@ -121,7 +123,7 @@ def serve(
server_urls = [local_url]
try:
model = get_model(model_id, revision, sharded, quantize)
model = get_model(model_id, revision, sharded, quantize, trust_remote_code)
except Exception:
logger.exception("Error when initializing model")
raise
@ -152,4 +154,4 @@ def serve(
logger.info("Signal received. Shutting down")
await server.stop(0)
asyncio.run(serve_inner(model_id, revision, sharded, quantize))
asyncio.run(serve_inner(model_id, revision, sharded, quantize, trust_remote_code))