add model id

This commit is contained in:
fxmarty 2024-05-02 15:33:00 +00:00
parent d2b4b02c0e
commit 51b0c25f37
25 changed files with 60 additions and 10 deletions

View File

@ -46,6 +46,8 @@ class BLOOMSharded(CausalLM):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -486,6 +486,7 @@ class CausalLM(Model):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id
if use_medusa: if use_medusa:
raise RuntimeError("Medusa decoding is not enabled for AutoModel") raise RuntimeError("Medusa decoding is not enabled for AutoModel")

View File

@ -834,15 +834,21 @@ class FlashCausalLM(Model):
if os.environ.get("PYTORCH_TUNABLEOP_TUNING", "1"): if os.environ.get("PYTORCH_TUNABLEOP_TUNING", "1"):
torch.cuda.tunable.tuning_enable(True) torch.cuda.tunable.tuning_enable(True)
tuning_sequences = range(1, 8) tuning_sequences = list(range(1, 3))
tunableop_filename = f"tunableop_tp{self.world_size}_rank{self.rank}.csv" tunableop_filepath = os.path.join("/data", f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv")
logger.info(f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join(tuning_sequences)}.") logger.info(f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])} (typical decoding lengths). The picked GEMMs are saved in the file {tunableop_filepath}.")
torch.cuda.tunable.read_file(tunableop_filename)
if os.path.isfile(tunableop_filepath):
logger.info(f"The file {tunableop_filepath} already exists and will be reused.")
torch.cuda.tunable.read_file(tunableop_filepath)
for seqlen in range(1, 8): os.makedirs("/data", exist_ok=True)
for seqlen in tuning_sequences:
logger.info(f"Warming up TunableOp for seqlen={seqlen}")
self.tunableop_warmup(seqlen) self.tunableop_warmup(seqlen)
torch.cuda.tunable.write_file(tunableop_filename) torch.cuda.tunable.write_file(tunableop_filepath)
torch.cuda.tunable.tuning_enable(False) torch.cuda.tunable.tuning_enable(False)
if CUDA_GRAPHS: if CUDA_GRAPHS:

View File

@ -28,6 +28,8 @@ class FlashCohere(FlashCausalLM):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -30,6 +30,8 @@ class FlashDbrx(FlashCausalLM):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -29,6 +29,8 @@ class FlashGemma(FlashCausalLM):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -31,6 +31,8 @@ class FlashLlama(FlashCausalLM):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -318,6 +318,8 @@ class BaseFlashMistral(FlashCausalLM):
trust_remote_code: bool = False, trust_remote_code: bool = False,
tokenizer_class=AutoTokenizer, tokenizer_class=AutoTokenizer,
): ):
self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -29,6 +29,8 @@ class FlashNeoXSharded(FlashCausalLM):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -29,6 +29,8 @@ class FlashPhi(FlashCausalLM):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -34,6 +34,8 @@ class FlashQwen2(BaseFlashMistral):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -30,6 +30,8 @@ class FlashRWSharded(FlashCausalLM):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -33,6 +33,8 @@ class FlashSantacoderSharded(FlashCausalLM):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -33,6 +33,8 @@ class FlashStarcoder2(BaseFlashMistral):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -171,6 +171,8 @@ class GalacticaSharded(CausalLM):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -28,6 +28,8 @@ class GPTNeoxSharded(CausalLM):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -35,6 +35,8 @@ class IDEFICSSharded(IdeficsCausalLM):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -412,6 +412,8 @@ class Mamba(Model):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id
self.process_group, _rank, world_size = initialize_torch_distributed() self.process_group, _rank, world_size = initialize_torch_distributed()
if world_size > 1: if world_size > 1:
raise RuntimeError("Mamba does not support Tensor Parallelism (TP)") raise RuntimeError("Mamba does not support Tensor Parallelism (TP)")

View File

@ -47,6 +47,8 @@ class MPTSharded(CausalLM):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -26,6 +26,8 @@ class OPTSharded(CausalLM):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -26,6 +26,8 @@ class Phi(CausalLM):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id
self.process_group, _rank, _world_size = initialize_torch_distributed() self.process_group, _rank, _world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")

View File

@ -23,6 +23,8 @@ class SantaCoder(CausalLM):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype

View File

@ -536,6 +536,8 @@ class Seq2SeqLM(Model):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id
if use_medusa: if use_medusa:
raise RuntimeError("Medusa decoding is not enabled for AutoModel") raise RuntimeError("Medusa decoding is not enabled for AutoModel")

View File

@ -29,6 +29,8 @@ class T5Sharded(Seq2SeqLM):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.model_id = model_id
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")

View File

@ -1018,7 +1018,7 @@ try:
from flash_attn.layers.rotary import RotaryEmbedding from flash_attn.layers.rotary import RotaryEmbedding
import rotary_emb import rotary_emb
elif IS_ROCM_SYSTEM: elif IS_ROCM_SYSTEM:
from vllm import pos_encoding_ops from vllm._C import ops
def _create_inv_freq(dim, base, device): def _create_inv_freq(dim, base, device):
inv_freq = 1.0 / ( inv_freq = 1.0 / (
@ -1339,6 +1339,5 @@ try:
freqs = torch.outer(t, self.inv_freq.to(device=t.device)) freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype) self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype) self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)
except ImportError as e:
except ImportError: logger.warning(f"ImportError in layers.py, beware that this may cause issues later on: {e}")
pass