mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
feat(server): support trust_remote_code (#363)
This commit is contained in:
parent
e9669a4085
commit
e3e487dc71
3
.github/workflows/build.yaml
vendored
3
.github/workflows/build.yaml
vendored
@ -213,13 +213,12 @@ jobs:
|
|||||||
sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
|
sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
|
||||||
- name: Install
|
- name: Install
|
||||||
run: |
|
run: |
|
||||||
pip install pytest-xdist
|
|
||||||
make install-integration-tests
|
make install-integration-tests
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
|
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
|
||||||
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||||
pytest -s -vv -n 2 --dist loadfile integration-tests
|
pytest -s -vv integration-tests
|
||||||
|
|
||||||
stop-runner:
|
stop-runner:
|
||||||
name: Stop self-hosted EC2 runner
|
name: Stop self-hosted EC2 runner
|
||||||
|
@ -53,7 +53,7 @@ struct Args {
|
|||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
|
|
||||||
/// Wether to shard or not the model across multiple GPUs
|
/// Whether to shard the model across multiple GPUs
|
||||||
/// By default text-generation-inference will use all available GPUs to run
|
/// By default text-generation-inference will use all available GPUs to run
|
||||||
/// the model. Setting it to `false` deactivates `num_shard`.
|
/// the model. Setting it to `false` deactivates `num_shard`.
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
@ -66,11 +66,17 @@ struct Args {
|
|||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
num_shard: Option<usize>,
|
num_shard: Option<usize>,
|
||||||
|
|
||||||
/// Wether you want the model to be quantized or not. This will use `bitsandbytes` for
|
/// Whether you want the model to be quantized. This will use `bitsandbytes` for
|
||||||
/// quantization on the fly, or `gptq`.
|
/// quantization on the fly, or `gptq`.
|
||||||
#[clap(long, env, value_enum)]
|
#[clap(long, env, value_enum)]
|
||||||
quantize: Option<Quantization>,
|
quantize: Option<Quantization>,
|
||||||
|
|
||||||
|
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
|
||||||
|
/// encouraged when loading a model with custom code to ensure no malicious code has been
|
||||||
|
/// contributed in a newer revision.
|
||||||
|
#[clap(long, env, value_enum)]
|
||||||
|
trust_remote_code: bool,
|
||||||
|
|
||||||
/// The maximum amount of concurrent requests for this particular deployment.
|
/// The maximum amount of concurrent requests for this particular deployment.
|
||||||
/// Having a low limit will refuse clients requests instead of having them
|
/// Having a low limit will refuse clients requests instead of having them
|
||||||
/// wait for too long and is usually good to handle backpressure correctly.
|
/// wait for too long and is usually good to handle backpressure correctly.
|
||||||
@ -239,6 +245,7 @@ fn shard_manager(
|
|||||||
model_id: String,
|
model_id: String,
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
quantize: Option<Quantization>,
|
quantize: Option<Quantization>,
|
||||||
|
trust_remote_code: bool,
|
||||||
uds_path: String,
|
uds_path: String,
|
||||||
rank: usize,
|
rank: usize,
|
||||||
world_size: usize,
|
world_size: usize,
|
||||||
@ -272,6 +279,11 @@ fn shard_manager(
|
|||||||
"--json-output".to_string(),
|
"--json-output".to_string(),
|
||||||
];
|
];
|
||||||
|
|
||||||
|
// Activate trust remote code
|
||||||
|
if trust_remote_code {
|
||||||
|
shard_argv.push("--trust-remote-code".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
// Activate tensor parallelism
|
// Activate tensor parallelism
|
||||||
if world_size > 1 {
|
if world_size > 1 {
|
||||||
shard_argv.push("--sharded".to_string());
|
shard_argv.push("--sharded".to_string());
|
||||||
@ -692,6 +704,16 @@ fn spawn_shards(
|
|||||||
status_sender: mpsc::Sender<ShardStatus>,
|
status_sender: mpsc::Sender<ShardStatus>,
|
||||||
running: Arc<AtomicBool>,
|
running: Arc<AtomicBool>,
|
||||||
) -> Result<(), LauncherError> {
|
) -> Result<(), LauncherError> {
|
||||||
|
if args.trust_remote_code {
|
||||||
|
tracing::warn!(
|
||||||
|
"`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.",
|
||||||
|
args.model_id
|
||||||
|
);
|
||||||
|
if args.revision.is_none() {
|
||||||
|
tracing::warn!("Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Start shard processes
|
// Start shard processes
|
||||||
for rank in 0..num_shard {
|
for rank in 0..num_shard {
|
||||||
let model_id = args.model_id.clone();
|
let model_id = args.model_id.clone();
|
||||||
@ -705,6 +727,7 @@ fn spawn_shards(
|
|||||||
let shutdown_sender = shutdown_sender.clone();
|
let shutdown_sender = shutdown_sender.clone();
|
||||||
let otlp_endpoint = args.otlp_endpoint.clone();
|
let otlp_endpoint = args.otlp_endpoint.clone();
|
||||||
let quantize = args.quantize;
|
let quantize = args.quantize;
|
||||||
|
let trust_remote_code = args.trust_remote_code;
|
||||||
let master_port = args.master_port;
|
let master_port = args.master_port;
|
||||||
let disable_custom_kernels = args.disable_custom_kernels;
|
let disable_custom_kernels = args.disable_custom_kernels;
|
||||||
let watermark_gamma = args.watermark_gamma;
|
let watermark_gamma = args.watermark_gamma;
|
||||||
@ -714,6 +737,7 @@ fn spawn_shards(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize,
|
quantize,
|
||||||
|
trust_remote_code,
|
||||||
uds_path,
|
uds_path,
|
||||||
rank,
|
rank,
|
||||||
num_shard,
|
num_shard,
|
||||||
|
@ -22,6 +22,7 @@ def serve(
|
|||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
sharded: bool = False,
|
sharded: bool = False,
|
||||||
quantize: Optional[Quantization] = None,
|
quantize: Optional[Quantization] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
uds_path: Path = "/tmp/text-generation-server",
|
uds_path: Path = "/tmp/text-generation-server",
|
||||||
logger_level: str = "INFO",
|
logger_level: str = "INFO",
|
||||||
json_output: bool = False,
|
json_output: bool = False,
|
||||||
@ -63,7 +64,7 @@ def serve(
|
|||||||
|
|
||||||
# Downgrade enum into str for easier management later on
|
# Downgrade enum into str for easier management later on
|
||||||
quantize = None if quantize is None else quantize.value
|
quantize = None if quantize is None else quantize.value
|
||||||
server.serve(model_id, revision, sharded, quantize, uds_path)
|
server.serve(model_id, revision, sharded, quantize, trust_remote_code, uds_path)
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
|
@ -91,13 +91,27 @@ torch.set_grad_enabled(False)
|
|||||||
|
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
model_id: str, revision: Optional[str], sharded: bool, quantize: Optional[str]
|
model_id: str,
|
||||||
|
revision: Optional[str],
|
||||||
|
sharded: bool,
|
||||||
|
quantize: Optional[str],
|
||||||
|
trust_remote_code: bool,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
if "facebook/galactica" in model_id:
|
if "facebook/galactica" in model_id:
|
||||||
if sharded:
|
if sharded:
|
||||||
return GalacticaSharded(model_id, revision, quantize=quantize)
|
return GalacticaSharded(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return Galactica(model_id, revision, quantize=quantize)
|
return Galactica(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
if model_id.startswith("bigcode/"):
|
if model_id.startswith("bigcode/"):
|
||||||
if sharded:
|
if sharded:
|
||||||
@ -105,12 +119,24 @@ def get_model(
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
|
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
|
||||||
)
|
)
|
||||||
return FlashSantacoderSharded(model_id, revision, quantize=quantize)
|
return FlashSantacoderSharded(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
|
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
|
||||||
return santacoder_cls(model_id, revision, quantize=quantize)
|
return santacoder_cls(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_id, revision=revision)
|
config = AutoConfig.from_pretrained(
|
||||||
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
model_type = config.model_type
|
model_type = config.model_type
|
||||||
|
|
||||||
if model_type == "gpt_bigcode":
|
if model_type == "gpt_bigcode":
|
||||||
@ -119,52 +145,133 @@ def get_model(
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
|
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
|
||||||
)
|
)
|
||||||
return FlashSantacoderSharded(model_id, revision, quantize=quantize)
|
return FlashSantacoderSharded(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
|
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
|
||||||
return santacoder_cls(model_id, revision, quantize=quantize)
|
return santacoder_cls(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
if model_type == "bloom":
|
if model_type == "bloom":
|
||||||
if sharded:
|
if sharded:
|
||||||
return BLOOMSharded(model_id, revision, quantize=quantize)
|
return BLOOMSharded(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return BLOOM(model_id, revision, quantize=quantize)
|
return BLOOM(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
if model_type == "gpt_neox":
|
if model_type == "gpt_neox":
|
||||||
if sharded:
|
if sharded:
|
||||||
neox_cls = FlashNeoXSharded if FLASH_ATTENTION else GPTNeoxSharded
|
neox_cls = FlashNeoXSharded if FLASH_ATTENTION else GPTNeoxSharded
|
||||||
return neox_cls(model_id, revision, quantize=quantize)
|
return neox_cls(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
neox_cls = FlashNeoX if FLASH_ATTENTION else CausalLM
|
neox_cls = FlashNeoX if FLASH_ATTENTION else CausalLM
|
||||||
return neox_cls(model_id, revision, quantize=quantize)
|
return neox_cls(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
if model_type == "llama":
|
if model_type == "llama":
|
||||||
if sharded:
|
if sharded:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashLlamaSharded(model_id, revision, quantize=quantize)
|
return FlashLlamaSharded(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama"))
|
||||||
else:
|
else:
|
||||||
llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM
|
llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM
|
||||||
return llama_cls(model_id, revision, quantize=quantize)
|
return llama_cls(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
if config.model_type == "opt":
|
if config.model_type == "opt":
|
||||||
if sharded:
|
if sharded:
|
||||||
return OPTSharded(model_id, revision, quantize=quantize)
|
return OPTSharded(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return OPT(model_id, revision, quantize=quantize)
|
return OPT(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
if model_type == "t5":
|
if model_type == "t5":
|
||||||
if sharded:
|
if sharded:
|
||||||
return T5Sharded(model_id, revision, quantize=quantize)
|
return T5Sharded(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return Seq2SeqLM(model_id, revision, quantize=quantize)
|
return Seq2SeqLM(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
if sharded:
|
if sharded:
|
||||||
raise ValueError("sharded is not supported for AutoModel")
|
raise ValueError("sharded is not supported for AutoModel")
|
||||||
|
|
||||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
return CausalLM(model_id, revision, quantize=quantize)
|
return CausalLM(
|
||||||
|
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
||||||
return Seq2SeqLM(model_id, revision, quantize=quantize)
|
return Seq2SeqLM(
|
||||||
|
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
|
||||||
|
auto_map = getattr(config, "auto_map", None)
|
||||||
|
if trust_remote_code and auto_map is not None:
|
||||||
|
if "AutoModelForCausalLM" in auto_map.keys():
|
||||||
|
return CausalLM(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
if "AutoModelForSeq2SeqLM" in auto_map.keys:
|
||||||
|
return Seq2SeqLM(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
raise ValueError(f"Unsupported model type {model_type}")
|
raise ValueError(f"Unsupported model type {model_type}")
|
||||||
|
@ -54,9 +54,13 @@ class BLOOM(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
super(BLOOM, self).__init__(
|
super(BLOOM, self).__init__(
|
||||||
model_id=model_id, revision=revision, quantize=quantize
|
model_id=model_id,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -70,6 +74,7 @@ class BLOOMSharded(BLOOM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -80,11 +85,19 @@ class BLOOMSharded(BLOOM):
|
|||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_id, revision=revision, slow_but_exact=False, tp_parallel=True
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
slow_but_exact=False,
|
||||||
|
tp_parallel=True,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
config.pad_token_id = 3
|
config.pad_token_id = 3
|
||||||
|
|
||||||
@ -92,7 +105,9 @@ class BLOOMSharded(BLOOM):
|
|||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = AutoModelForCausalLM.from_config(config)
|
model = AutoModelForCausalLM.from_config(
|
||||||
|
config, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
self.load_weights(
|
self.load_weights(
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import inspect
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
@ -450,6 +451,7 @@ class CausalLM(Model):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
@ -462,22 +464,38 @@ class CausalLM(Model):
|
|||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
|
device_map="auto"
|
||||||
|
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||||
|
else None,
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
|
||||||
tokenizer.pad_token_id = (
|
if tokenizer.pad_token_id is None:
|
||||||
model.config.pad_token_id
|
if model.config.pad_token_id is not None:
|
||||||
if model.config.pad_token_id is not None
|
tokenizer.pad_token_id = model.config.pad_token_id
|
||||||
else model.config.eos_token_id
|
elif model.config.eos_token_id is not None:
|
||||||
|
tokenizer.pad_token_id = model.config.eos_token_id
|
||||||
|
elif tokenizer.eos_token_id is not None:
|
||||||
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
else:
|
||||||
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
|
|
||||||
|
self.has_position_ids = (
|
||||||
|
inspect.signature(model.forward).parameters.get("position_ids", None)
|
||||||
|
is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
@ -501,14 +519,17 @@ class CausalLM(Model):
|
|||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
outputs = self.model.forward(
|
kwargs = {
|
||||||
input_ids=input_ids,
|
"input_ids": input_ids,
|
||||||
attention_mask=attention_mask,
|
"attention_mask": attention_mask,
|
||||||
position_ids=position_ids,
|
"past_key_values": past_key_values,
|
||||||
past_key_values=past_key_values,
|
"use_cache": True,
|
||||||
use_cache=True,
|
"return_dict": True,
|
||||||
return_dict=True,
|
}
|
||||||
)
|
if self.has_position_ids:
|
||||||
|
kwargs["position_ids"] = position_ids
|
||||||
|
|
||||||
|
outputs = self.model.forward(**kwargs)
|
||||||
return outputs.logits, outputs.past_key_values
|
return outputs.logits, outputs.past_key_values
|
||||||
|
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
|
@ -394,6 +394,7 @@ class FlashCausalLM(Model):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
@ -402,13 +403,18 @@ class FlashCausalLM(Model):
|
|||||||
raise NotImplementedError("FlashCausalLM is only available on GPU")
|
raise NotImplementedError("FlashCausalLM is only available on GPU")
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
model = model_cls.from_pretrained(
|
model = model_cls.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
|
@ -33,6 +33,7 @@ class FlashLlama(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
@ -45,11 +46,11 @@ class FlashLlama(FlashCausalLM):
|
|||||||
revision=revision,
|
revision=revision,
|
||||||
padding_side="left",
|
padding_side="left",
|
||||||
truncation_side="left",
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_id,
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
revision=revision,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# We do not use from_pretrained as we modified the model internal module layout
|
# We do not use from_pretrained as we modified the model internal module layout
|
||||||
@ -153,6 +154,7 @@ class FlashLlamaSharded(FlashLlama):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -166,11 +168,11 @@ class FlashLlamaSharded(FlashLlama):
|
|||||||
revision=revision,
|
revision=revision,
|
||||||
padding_side="left",
|
padding_side="left",
|
||||||
truncation_side="left",
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_id,
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
revision=revision,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
@ -28,9 +28,14 @@ class FlashNeoX(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
super(FlashNeoX, self).__init__(
|
super(FlashNeoX, self).__init__(
|
||||||
FlashGPTNeoXForCausalLM, model_id, revision, quantize
|
FlashGPTNeoXForCausalLM,
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -40,6 +45,7 @@ class FlashNeoXSharded(FlashNeoX):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -49,12 +55,15 @@ class FlashNeoXSharded(FlashNeoX):
|
|||||||
raise NotImplementedError("FlashNeoX is only available on GPU")
|
raise NotImplementedError("FlashNeoX is only available on GPU")
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_id,
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
revision=revision,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
@ -32,6 +32,7 @@ class FlashSantacoder(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
@ -40,7 +41,11 @@ class FlashSantacoder(FlashCausalLM):
|
|||||||
raise NotImplementedError("FlashSantacoder is only available on GPU")
|
raise NotImplementedError("FlashSantacoder is only available on GPU")
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = GPT2Config.from_pretrained(
|
config = GPT2Config.from_pretrained(
|
||||||
@ -178,6 +183,7 @@ class FlashSantacoderSharded(FlashSantacoder):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -187,7 +193,11 @@ class FlashSantacoderSharded(FlashSantacoder):
|
|||||||
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = GPT2Config.from_pretrained(
|
config = GPT2Config.from_pretrained(
|
||||||
|
@ -199,6 +199,7 @@ class GalacticaSharded(Galactica):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -209,11 +210,18 @@ class GalacticaSharded(Galactica):
|
|||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_id, revision=revision, tp_parallel=True
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
tp_parallel=True,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
tokenizer.pad_token_id = config.pad_token_id
|
tokenizer.pad_token_id = config.pad_token_id
|
||||||
|
|
||||||
@ -221,7 +229,9 @@ class GalacticaSharded(Galactica):
|
|||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = AutoModelForCausalLM.from_config(config)
|
model = AutoModelForCausalLM.from_config(
|
||||||
|
config, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
self.load_weights(
|
self.load_weights(
|
||||||
|
@ -36,6 +36,7 @@ class GPTNeoxSharded(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -46,19 +47,28 @@ class GPTNeoxSharded(CausalLM):
|
|||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_id, revision=revision, tp_parallel=True
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
tp_parallel=True,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = AutoModelForCausalLM.from_config(config)
|
model = AutoModelForCausalLM.from_config(
|
||||||
|
config, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
self.load_weights(
|
self.load_weights(
|
||||||
|
@ -52,6 +52,7 @@ class OPTSharded(OPT):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -62,11 +63,18 @@ class OPTSharded(OPT):
|
|||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_id, revision=revision, tp_parallel=True
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
tp_parallel=True,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
tokenizer.pad_token_id = config.pad_token_id
|
tokenizer.pad_token_id = config.pad_token_id
|
||||||
|
|
||||||
@ -74,7 +82,9 @@ class OPTSharded(OPT):
|
|||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = AutoModelForCausalLM.from_config(config)
|
model = AutoModelForCausalLM.from_config(
|
||||||
|
config, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
self.load_weights(
|
self.load_weights(
|
||||||
|
@ -19,6 +19,7 @@ class SantaCoder(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
@ -31,7 +32,11 @@ class SantaCoder(CausalLM):
|
|||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
tokenizer.add_special_tokens(
|
tokenizer.add_special_tokens(
|
||||||
{
|
{
|
||||||
@ -51,7 +56,7 @@ class SantaCoder(CausalLM):
|
|||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=True, # required
|
trust_remote_code=trust_remote_code,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
|
@ -503,6 +503,7 @@ class Seq2SeqLM(Model):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
@ -518,14 +519,21 @@ class Seq2SeqLM(Model):
|
|||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
|
device_map="auto"
|
||||||
|
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||||
|
else None,
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
tokenizer.bos_token_id = model.config.decoder_start_token_id
|
tokenizer.bos_token_id = model.config.decoder_start_token_id
|
||||||
|
|
||||||
|
@ -36,6 +36,7 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -46,11 +47,18 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_id, revision=revision, tp_parallel=True
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
tp_parallel=True,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
tokenizer.bos_token_id = config.decoder_start_token_id
|
tokenizer.bos_token_id = config.decoder_start_token_id
|
||||||
|
|
||||||
@ -58,7 +66,9 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = AutoModelForSeq2SeqLM.from_config(config)
|
model = AutoModelForSeq2SeqLM.from_config(
|
||||||
|
config, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
self.load_weights(
|
self.load_weights(
|
||||||
|
@ -101,6 +101,7 @@ def serve(
|
|||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool,
|
sharded: bool,
|
||||||
quantize: Optional[str],
|
quantize: Optional[str],
|
||||||
|
trust_remote_code: bool,
|
||||||
uds_path: Path,
|
uds_path: Path,
|
||||||
):
|
):
|
||||||
async def serve_inner(
|
async def serve_inner(
|
||||||
@ -108,6 +109,7 @@ def serve(
|
|||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool = False,
|
sharded: bool = False,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
unix_socket_template = "unix://{}-{}"
|
unix_socket_template = "unix://{}-{}"
|
||||||
if sharded:
|
if sharded:
|
||||||
@ -121,7 +123,7 @@ def serve(
|
|||||||
server_urls = [local_url]
|
server_urls = [local_url]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model = get_model(model_id, revision, sharded, quantize)
|
model = get_model(model_id, revision, sharded, quantize, trust_remote_code)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error when initializing model")
|
logger.exception("Error when initializing model")
|
||||||
raise
|
raise
|
||||||
@ -152,4 +154,4 @@ def serve(
|
|||||||
logger.info("Signal received. Shutting down")
|
logger.info("Signal received. Shutting down")
|
||||||
await server.stop(0)
|
await server.stop(0)
|
||||||
|
|
||||||
asyncio.run(serve_inner(model_id, revision, sharded, quantize))
|
asyncio.run(serve_inner(model_id, revision, sharded, quantize, trust_remote_code))
|
||||||
|
Loading…
Reference in New Issue
Block a user