mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Add the option to force another dtype than f16
.
Adds a new flag propagated everywhere. Disjoint from `--quantize` which also changes the actual dtype of layers. Fixes #490
This commit is contained in:
parent
70f485bf9f
commit
da7e104241
@ -36,6 +36,26 @@ impl std::fmt::Display for Quantization {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
|
enum Dtype {
|
||||||
|
Float16,
|
||||||
|
BFloat16,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for Dtype {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
// To keep in track with `server`.
|
||||||
|
match self {
|
||||||
|
Dtype::Float16 => {
|
||||||
|
write!(f, "float16")
|
||||||
|
}
|
||||||
|
Dtype::BFloat16 => {
|
||||||
|
write!(f, "bfloat16")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// App Configuration
|
/// App Configuration
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[clap(author, version, about, long_about = None)]
|
#[clap(author, version, about, long_about = None)]
|
||||||
@ -71,6 +91,10 @@ struct Args {
|
|||||||
#[clap(long, env, value_enum)]
|
#[clap(long, env, value_enum)]
|
||||||
quantize: Option<Quantization>,
|
quantize: Option<Quantization>,
|
||||||
|
|
||||||
|
/// The dtype to be forced upon the model. This option cannot be used with `--quantize`.
|
||||||
|
#[clap(long, env, value_enum)]
|
||||||
|
quantize: Option<Dtype>,
|
||||||
|
|
||||||
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
|
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
|
||||||
/// encouraged when loading a model with custom code to ensure no malicious code has been
|
/// encouraged when loading a model with custom code to ensure no malicious code has been
|
||||||
/// contributed in a newer revision.
|
/// contributed in a newer revision.
|
||||||
|
@ -16,12 +16,18 @@ class Quantization(str, Enum):
|
|||||||
gptq = "gptq"
|
gptq = "gptq"
|
||||||
|
|
||||||
|
|
||||||
|
class Dtype(str, Enum):
|
||||||
|
float16 = "float16"
|
||||||
|
bloat16 = "bfloat16"
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def serve(
|
def serve(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
sharded: bool = False,
|
sharded: bool = False,
|
||||||
quantize: Optional[Quantization] = None,
|
quantize: Optional[Quantization] = None,
|
||||||
|
dtype: Optional[Dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
uds_path: Path = "/tmp/text-generation-server",
|
uds_path: Path = "/tmp/text-generation-server",
|
||||||
logger_level: str = "INFO",
|
logger_level: str = "INFO",
|
||||||
@ -64,7 +70,14 @@ def serve(
|
|||||||
|
|
||||||
# Downgrade enum into str for easier management later on
|
# Downgrade enum into str for easier management later on
|
||||||
quantize = None if quantize is None else quantize.value
|
quantize = None if quantize is None else quantize.value
|
||||||
server.serve(model_id, revision, sharded, quantize, trust_remote_code, uds_path)
|
dtype = None if dtype is None else dtype.value
|
||||||
|
if dtype is not None and quantize is not None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
||||||
|
)
|
||||||
|
server.serve(
|
||||||
|
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
|
@ -100,11 +100,25 @@ def get_model(
|
|||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool,
|
sharded: bool,
|
||||||
quantize: Optional[str],
|
quantize: Optional[str],
|
||||||
|
dtype: Optional[str],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
|
if dtype is None:
|
||||||
|
dtype = torch.float16
|
||||||
|
elif dtype == "float16":
|
||||||
|
dtype = torch.float16
|
||||||
|
elif dtype == "bfloat16":
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unknown dtype {dtype}")
|
||||||
|
|
||||||
if "facebook/galactica" in model_id:
|
if "facebook/galactica" in model_id:
|
||||||
return GalacticaSharded(
|
return GalacticaSharded(
|
||||||
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
|
dtypetrust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_id.startswith("bigcode/"):
|
if model_id.startswith("bigcode/"):
|
||||||
@ -113,6 +127,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
@ -124,6 +139,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -138,6 +154,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
@ -149,12 +166,17 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == "bloom":
|
if model_type == "bloom":
|
||||||
return BLOOMSharded(
|
return BLOOMSharded(
|
||||||
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == "gpt_neox":
|
elif model_type == "gpt_neox":
|
||||||
@ -163,6 +185,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
@ -170,6 +193,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -177,6 +201,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -186,6 +211,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
@ -195,6 +221,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -210,6 +237,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -221,6 +249,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -228,12 +257,17 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == "opt":
|
elif model_type == "opt":
|
||||||
return OPTSharded(
|
return OPTSharded(
|
||||||
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == "t5":
|
elif model_type == "t5":
|
||||||
@ -241,6 +275,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -253,11 +288,19 @@ def get_model(
|
|||||||
|
|
||||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
return CausalLM(
|
return CausalLM(
|
||||||
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
||||||
return Seq2SeqLM(
|
return Seq2SeqLM(
|
||||||
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
auto_map = config_dict.get("auto_map", None)
|
auto_map = config_dict.get("auto_map", None)
|
||||||
@ -267,6 +310,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if "AutoModelForSeq2SeqLM" in auto_map.keys():
|
if "AutoModelForSeq2SeqLM" in auto_map.keys():
|
||||||
@ -274,6 +318,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -42,12 +42,13 @@ class BLOOMSharded(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
@ -454,11 +454,12 @@ class CausalLM(Model):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.float16
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
if quantize:
|
if quantize:
|
||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
@ -0,0 +1,361 @@
|
|||||||
|
"""A simple, flexible implementation of a GPT model.
|
||||||
|
|
||||||
|
Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
|
||||||
|
"""
|
||||||
|
# import math
|
||||||
|
# import warnings
|
||||||
|
# from typing import List, Optional, Tuple, Union
|
||||||
|
# import torch
|
||||||
|
# import torch.nn as nn
|
||||||
|
# import torch.nn.functional as F
|
||||||
|
# from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||||
|
# from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
|
# from .attention import attn_bias_shape, build_attn_bias
|
||||||
|
# from .blocks import MPTBlock
|
||||||
|
# from .custom_embedding import SharedEmbedding
|
||||||
|
# from .norm import NORM_CLASS_REGISTRY
|
||||||
|
# from .configuration_mpt import MPTConfig
|
||||||
|
# from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
|
||||||
|
# from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
|
||||||
|
# from .meta_init_context import init_empty_weights
|
||||||
|
# from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
|
||||||
|
# try:
|
||||||
|
# from .flash_attn_triton import flash_attn_func
|
||||||
|
# except:
|
||||||
|
# pass
|
||||||
|
|
||||||
|
"""GPT Blocks used for the GPT Model."""
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import math
|
||||||
|
|
||||||
|
from text_generation_server.utils.layers import (
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
PositionRotaryEmbedding,
|
||||||
|
TensorParallelHead,
|
||||||
|
FastLayerNorm,
|
||||||
|
)
|
||||||
|
|
||||||
|
EPS = 1e-5
|
||||||
|
|
||||||
|
def _gen_slopes(n_heads, alibi_bias_max=8, device=None):
|
||||||
|
_n_heads = 2 ** math.ceil(math.log2(n_heads))
|
||||||
|
m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)
|
||||||
|
m = m.mul(alibi_bias_max / _n_heads)
|
||||||
|
slopes = 1.0 / torch.pow(2, m)
|
||||||
|
if _n_heads != n_heads:
|
||||||
|
slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
|
||||||
|
return slopes.view(1, n_heads, 1, 1)
|
||||||
|
|
||||||
|
def _build_alibi_bias(n_heads, seq_len, device, dtype, alibi_bias_max):
|
||||||
|
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, 1, seq_len)
|
||||||
|
slopes = _gen_slopes(n_heads, alibi_bias_max, device=device)
|
||||||
|
alibi_bias = alibi_bias * slopes
|
||||||
|
return alibi_bias.to(dtype=dtype)
|
||||||
|
|
||||||
|
ALIBI = None
|
||||||
|
|
||||||
|
def build_alibi_bias(n_heads, seq_len, device, dtype, alibi_bias_max=8):
|
||||||
|
global ALIBI
|
||||||
|
if ALIBI is None or seq_len > ALIBI.shape[-1]:
|
||||||
|
ALIBI = _build_alibi_bias(n_heads, seq_len, device, dtype, alibi_bias_max=alibi_bias_max)
|
||||||
|
return ALIBI[:, :, :, :seq_len]
|
||||||
|
|
||||||
|
|
||||||
|
class MPTAttention(nn.Module):
|
||||||
|
def __init__(self, config, prefix, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_heads = config.n_heads
|
||||||
|
self.hidden_size = config.d_model
|
||||||
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
self.Wqkv = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.Wqkv",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.out_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.out_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
hidden_states,
|
||||||
|
alibi,
|
||||||
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
|
max_s,
|
||||||
|
past_key_values,
|
||||||
|
past_present_indices,
|
||||||
|
prefill,
|
||||||
|
):
|
||||||
|
qkv = self.Wqkv(hidden_states)
|
||||||
|
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||||
|
|
||||||
|
# Todo
|
||||||
|
raise Exception("Apply alibi ?");
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if prefill:
|
||||||
|
# Copy to layer past
|
||||||
|
layer_past[...] = qkv[:, 1:]
|
||||||
|
|
||||||
|
# output
|
||||||
|
attn_output = torch.empty_like(qkv[:, 0])
|
||||||
|
# flash attention
|
||||||
|
flash_attn_cuda.fwd(
|
||||||
|
qkv[:, 0],
|
||||||
|
qkv[:, 1],
|
||||||
|
qkv[:, 2],
|
||||||
|
attn_output,
|
||||||
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
max_s,
|
||||||
|
max_s,
|
||||||
|
0.0,
|
||||||
|
self.softmax_scale,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
# Decode
|
||||||
|
else:
|
||||||
|
query = qkv[:, 0]
|
||||||
|
# Add present to the layer_past tensor at the correct indices
|
||||||
|
layer_past[past_present_indices] = qkv[:, 1:]
|
||||||
|
|
||||||
|
# output
|
||||||
|
attn_output = torch.empty_like(query)
|
||||||
|
# flash attention
|
||||||
|
flash_attn_cuda.fwd(
|
||||||
|
query,
|
||||||
|
layer_past[:, 0],
|
||||||
|
layer_past[:, 1],
|
||||||
|
attn_output,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
1,
|
||||||
|
max_s,
|
||||||
|
0.0,
|
||||||
|
self.softmax_scale,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||||
|
|
||||||
|
class MPTMLP(nn.Module):
|
||||||
|
def __init__(self, config, prefix, weights):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.up_proj = TensorParallelColumnLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.up_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.act = nn.GELU(approximate='none')
|
||||||
|
self.down_proj = TensorParallelRowLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.down_proj(self.act(self.up_proj(x)))
|
||||||
|
|
||||||
|
class MPTBlock(nn.Module):
|
||||||
|
def __init__(self, config, prefix, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_1 = FastLayerNorm.load_no_bias(prefix=f"{prefix}.norm_1", weights=weights, eps=EPS)
|
||||||
|
self.attn = MPTAttention(config, prefix=f"{prefix}.attn", weights=weights)
|
||||||
|
self.norm_2 = FastLayerNorm.load_no_bias(prefix=f"{prefix}.norm_2", weights=weights, eps=EPS)
|
||||||
|
self.ffn = MPTMLP(config, prefix=f"{prefix}.ffn", weights=weights)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
alibi,
|
||||||
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
|
max_s,
|
||||||
|
past_key_values,
|
||||||
|
past_present_indices,
|
||||||
|
prefill,
|
||||||
|
):
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states, _ = self.norm_1(hidden_states)
|
||||||
|
# (hidden_states, attn_weights) = self.attn(
|
||||||
|
hidden_states = self.attn(
|
||||||
|
hidden_states,
|
||||||
|
alibi,
|
||||||
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
|
max_s,
|
||||||
|
past_key_values,
|
||||||
|
past_present_indices,
|
||||||
|
prefill,
|
||||||
|
)
|
||||||
|
hidden_states += residual
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states, _ = self.norm_2(hidden_states)
|
||||||
|
hidden_states = self.ffn(hidden_states)
|
||||||
|
hidden_states += residual
|
||||||
|
return (x, attn_weights)
|
||||||
|
|
||||||
|
class MPTModel(nn.Module):
|
||||||
|
def __init__(self, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.wte = TensorParallelEmbedding(
|
||||||
|
prefix="transformer.wte", weights=weights
|
||||||
|
)
|
||||||
|
self.num_heads = config.n_heads
|
||||||
|
self.hidden_size = config.d_model
|
||||||
|
self.head_size = self.hidden_size // self.num_heads
|
||||||
|
self.blocks = nn.ModuleList([MPTBlock(config, prefix=f"transformer.blocks.{i}", weights=weights) for i in range(config.n_layers)])
|
||||||
|
self.norm_f = FastLayerNorm.load_no_bias(
|
||||||
|
prefix="transformer.norm_f", weights=weights, eps=EPS
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a default sizeable global alibi
|
||||||
|
build_alibi_bias(n_heads=self.num_heads, seq_len=1024,device=weights.device, dtype = weights.dtype)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
|
max_s,
|
||||||
|
past_present_indices,
|
||||||
|
past_key_values=None,
|
||||||
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
|
):
|
||||||
|
hidden_states = self.wte(input_ids)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
if past_key_values is None:
|
||||||
|
assert pre_allocate_past_size is not None
|
||||||
|
|
||||||
|
prefill = True
|
||||||
|
|
||||||
|
# Create past tensor
|
||||||
|
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
|
||||||
|
past_key_values = hidden_states.new_empty(
|
||||||
|
(
|
||||||
|
len(input_ids),
|
||||||
|
len(self.blocks),
|
||||||
|
2,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Decode
|
||||||
|
else:
|
||||||
|
prefill = False
|
||||||
|
|
||||||
|
alibi = build_alibi_bias(n_heads=self.num_heads, seq_len=max_s,device=hidden_states.device, dtype = hidden_states.dtype)
|
||||||
|
# Cast alibi into correct shape
|
||||||
|
alibi = alibi[:, :, :, position_ids]
|
||||||
|
|
||||||
|
residual = None
|
||||||
|
for i, layer in enumerate(self.blocks):
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
alibi,
|
||||||
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
|
max_s,
|
||||||
|
past_key_values[:, i],
|
||||||
|
past_present_indices,
|
||||||
|
prefill,
|
||||||
|
)
|
||||||
|
|
||||||
|
if prefill:
|
||||||
|
present = past_key_values
|
||||||
|
# Create padded past tensor
|
||||||
|
past_key_values = hidden_states.new_empty(
|
||||||
|
(
|
||||||
|
pre_allocate_past_size,
|
||||||
|
len(self.blocks),
|
||||||
|
2,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# We slice only once instead of at every layer
|
||||||
|
past_key_values[past_present_indices] = present
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm_f(hidden_states, residual)
|
||||||
|
|
||||||
|
return hidden_states, past_key_values
|
||||||
|
|
||||||
|
class MPTForCausalLM(nn.Module):
|
||||||
|
def __init__(self, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.transformer = MPTModel(config, weights)
|
||||||
|
self.lm_head = TensorParallelHead.load(
|
||||||
|
config,
|
||||||
|
prefix="transformer.wte",
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
|
max_s,
|
||||||
|
past_present_indices,
|
||||||
|
past_key_values: Optional[torch.Tensor] = None,
|
||||||
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
hidden_states, present = self.transformer(
|
||||||
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
|
max_s,
|
||||||
|
past_present_indices,
|
||||||
|
past_key_values,
|
||||||
|
pre_allocate_past_size,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
return logits, present
|
@ -25,12 +25,13 @@ class FlashLlama(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashLlama is only available on GPU")
|
raise NotImplementedError("FlashLlama is only available on GPU")
|
||||||
|
|
||||||
|
73
server/text_generation_server/models/flash_mpt.py
Normal file
73
server/text_generation_server/models/flash_mpt.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from opentelemetry import trace
|
||||||
|
from transformers import AutoConfig, AutoTokenizer, PretrainedConfig
|
||||||
|
from typing import Optional
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
import json
|
||||||
|
|
||||||
|
from text_generation_server.models import FlashCausalLM
|
||||||
|
from text_generation_server.models.custom_modeling.flash_mpt_modeling import (
|
||||||
|
MPTForCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils import (
|
||||||
|
initialize_torch_distributed,
|
||||||
|
weight_files,
|
||||||
|
Weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MPTSharded(FlashCausalLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
):
|
||||||
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
dtype = torch.float16
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("FlashMPT is only available on GPU")
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
filename = hf_hub_download(model_id, revision=revision, filename="config.json")
|
||||||
|
with open(filename, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
config = PretrainedConfig(**config)
|
||||||
|
config.quantize = quantize
|
||||||
|
# config = AutoConfig.from_pretrained(
|
||||||
|
# # model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
# model_id, revision=revision, trust_remote_code=False
|
||||||
|
# )
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||||
|
|
||||||
|
config.quantize = quantize
|
||||||
|
model = MPTForCausalLM(config, weights)
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
super(FlashCausalLM, self).__init__(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=False,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
)
|
@ -24,12 +24,13 @@ class FlashNeoXSharded(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashNeoX is only available on GPU")
|
raise NotImplementedError("FlashNeoX is only available on GPU")
|
||||||
|
|
||||||
|
@ -25,12 +25,13 @@ class FlashRWSharded(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashRW is only available on GPU")
|
raise NotImplementedError("FlashRW is only available on GPU")
|
||||||
|
|
||||||
|
@ -24,12 +24,13 @@ class FlashSantacoderSharded(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
||||||
|
|
||||||
@ -52,8 +53,11 @@ class FlashSantacoderSharded(FlashCausalLM):
|
|||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
weights = Weights(
|
weights = Weights(
|
||||||
filenames, device=device, dtype=dtype, process_group=self.process_group,
|
filenames,
|
||||||
aliases = {"transformer.wte.weight": ["lm_head.weight"]}
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
process_group=self.process_group,
|
||||||
|
aliases={"transformer.wte.weight": ["lm_head.weight"]},
|
||||||
)
|
)
|
||||||
|
|
||||||
model = FlashSantacoderForCausalLM(config, weights)
|
model = FlashSantacoderForCausalLM(config, weights)
|
||||||
|
@ -158,12 +158,13 @@ class GalacticaSharded(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
@ -24,12 +24,13 @@ class GPTNeoxSharded(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
@ -22,12 +22,13 @@ class OPTSharded(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
@ -12,11 +12,12 @@ class RW(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.float16
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
if quantize:
|
if quantize:
|
||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
@ -19,11 +19,12 @@ class SantaCoder(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.float16
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
if quantize:
|
if quantize:
|
||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
@ -504,11 +504,12 @@ class Seq2SeqLM(Model):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.float16
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
if quantize:
|
if quantize:
|
||||||
raise ValueError("quantization is not available on CPU")
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
@ -25,12 +25,13 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
@ -99,6 +99,7 @@ def serve(
|
|||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool,
|
sharded: bool,
|
||||||
quantize: Optional[str],
|
quantize: Optional[str],
|
||||||
|
dtype: Optional[str],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
uds_path: Path,
|
uds_path: Path,
|
||||||
):
|
):
|
||||||
@ -107,6 +108,7 @@ def serve(
|
|||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool = False,
|
sharded: bool = False,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[str] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
unix_socket_template = "unix://{}-{}"
|
unix_socket_template = "unix://{}-{}"
|
||||||
@ -121,7 +123,9 @@ def serve(
|
|||||||
server_urls = [local_url]
|
server_urls = [local_url]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model = get_model(model_id, revision, sharded, quantize, trust_remote_code)
|
model = get_model(
|
||||||
|
model_id, revision, sharded, quantize, dtype, trust_remote_code
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error when initializing model")
|
logger.exception("Error when initializing model")
|
||||||
raise
|
raise
|
||||||
|
Loading…
Reference in New Issue
Block a user