diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 09d8b69b4..9d931faa0 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -17,14 +17,12 @@ class BloomCausalLMBatch(CausalLMBatch): tokenizer: PreTrainedTokenizerBase, dtype: torch.dtype, device: torch.device, - is_optimized_for_gaudi: bool = False, ) -> "CausalLMBatch": batch = super().from_pb( pb=pb, tokenizer=tokenizer, dtype=dtype, device=device, - is_optimized_for_gaudi=is_optimized_for_gaudi, ) batch.keys_head_dim_last = False return batch diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 7c5b93ee1..9083b786a 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -1,28 +1,33 @@ +import bisect +from dataclasses import dataclass +from functools import wraps +import itertools +import math import os import tempfile -import itertools -import bisect -import math +from typing import Dict, List, Optional, Tuple, Type import torch - -from dataclasses import dataclass +from loguru import logger from opentelemetry import trace -from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, AutoConfig -from typing import Optional, Tuple, List, Type, Dict import text_generation_server.habana_quantization_env as hq_env import habana_frameworks.torch as htorch from habana_frameworks.torch.hpu import wrap_in_hpu_graph -from contextlib import nullcontext +from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi from optimum.habana.utils import HabanaProfile - from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES from optimum.habana.checkpoint_utils import ( get_repo_root, model_on_meta, write_checkpoints_json, ) +from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + PreTrainedTokenizerBase, + AutoConfig, +) from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.models import Model @@ -34,11 +39,13 @@ from text_generation_server.models.types import ( TopTokens, ) from text_generation_server.pb import generate_pb2 -from text_generation_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria, Sampling, make_tokenizer_optional, is_tokenizer_transparent +from text_generation_server.utils import ( + HeterogeneousNextTokenChooser, + StoppingCriteria, + make_tokenizer_optional, + is_tokenizer_transparent, +) from text_generation_server.utils.debug import dbg_trace -from loguru import logger -from functools import wraps - tracer = trace.get_tracer(__name__) @@ -384,7 +391,7 @@ class CausalLMBatch(Batch): parameters = [r.data.parameters for r in flat_requests] if len(flat_requests) < new_bs: for i in range(new_bs-len(flat_requests)) : - #append the dummy parameters for dummy request + # append the dummy parameters for dummy request parameters.append(parameters[0]) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( @@ -423,7 +430,6 @@ class CausalLMBatch(Batch): tokenizer: PreTrainedTokenizerBase, dtype: torch.dtype, device: torch.device, - is_optimized_for_gaudi: bool = False, ) -> "CausalLMBatch": dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}') requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)] @@ -474,19 +480,16 @@ class CausalLMBatch(Batch): input_ids = tokenized_inputs["input_ids"] attention_mask = tokenized_inputs["attention_mask"] - if is_optimized_for_gaudi: - # Allocate space for first token - input_ids = torch.nn.functional.pad( - input_ids, (left_padding, 1), value=tokenizer.pad_token_id - ) - attention_mask = torch.nn.functional.pad( - attention_mask, (left_padding, 1), value=0 - ) - all_input_ids = torch.nn.functional.pad( - input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id - ).T.split(1, dim=1) - else: - all_input_ids = input_ids.clone().T.split(1, dim=1) + # Allocate space for first token + input_ids = torch.nn.functional.pad( + input_ids, (left_padding, 1), value=tokenizer.pad_token_id + ) + attention_mask = torch.nn.functional.pad( + attention_mask, (left_padding, 1), value=0 + ) + all_input_ids = torch.nn.functional.pad( + input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id + ).T.split(1, dim=1) # New input length after left padding input_len = bucket_size @@ -562,16 +565,9 @@ class CausalLM(Model): revision: Optional[str] = None, dtype: Optional[torch.dtype] = None, ): - device = torch.device("hpu") - if hq_env.is_quantization_enabled: - htorch.core.hpu_set_env() - - dtype = torch.bfloat16 if dtype is None else dtype - - from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi - adapt_transformers_to_gaudi() + # Create tokenizer tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, @@ -580,79 +576,42 @@ class CausalLM(Model): ) make_tokenizer_optional(tokenizer) - model_kwargs = { - "revision": revision, - } - + # Create model world_size = int(os.getenv("WORLD_SIZE", "1")) rank = int(os.getenv("RANK", "0")) - self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" - self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true" + dtype = torch.bfloat16 if dtype is None else dtype + device = torch.device("hpu") + + if hq_env.is_quantization_enabled: + htorch.core.hpu_set_env() if world_size > 1: - import habana_frameworks.torch.hpu as torch_hpu - - # Get world size, rank and local rank - from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu - - world_size, rank, local_rank = initialize_distributed_hpu() - import deepspeed - - # Initialize process(es) for DeepSpeed - deepspeed.init_distributed(dist_backend="hccl") - logger.info( - "DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(world_size, rank, local_rank) + model = self.get_deepspeed_model( + model_id, dtype, revision ) - config = AutoConfig.from_pretrained(model_id, **model_kwargs) - load_to_meta = model_on_meta(config) - - if load_to_meta: - # Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load - with deepspeed.OnDevice(dtype=dtype, device="meta"): - model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype) - else: - get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK")) - # TODO: revisit placement on CPU when auto-injection is possible - with deepspeed.OnDevice(dtype=dtype, device="cpu"): - model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, **model_kwargs) - model = model.eval() - - # Initialize the model - ds_inference_kwargs = {"dtype": dtype} - ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size} - ds_inference_kwargs["enable_cuda_graph"] = False - - if load_to_meta: - # model loaded to meta is managed differently - checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w") - write_checkpoints_json(model_id, local_rank, checkpoints_json) - ds_inference_kwargs["checkpoint"] = checkpoints_json.name - model = deepspeed.init_inference(model, **ds_inference_kwargs) - model = model.module model = self.prepare_model_for_quantization(model) - model = remove_kv_cache_from_output(model) - if self.enable_hpu_graph: - model = wrap_in_hpu_graph(model, disable_tensor_cache=True) - else: get_repo_root(model_id) + rope_scaling = self.get_rope_scaling() model = AutoModelForCausalLM.from_pretrained( model_id, revision=revision, torch_dtype=dtype, + rope_scaling=rope_scaling ) model = self.prepare_model_for_quantization(model) model = model.eval().to(device) - # wrap in hpu_graph only if self.enable_hpu_graph is set - model = remove_kv_cache_from_output(model) - if self.enable_hpu_graph: - model = wrap_in_hpu_graph(model, disable_tensor_cache=True) + + self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" + self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true" + model = remove_kv_cache_from_output(model) + if self.enable_hpu_graph: + model = wrap_in_hpu_graph(model, disable_tensor_cache=True) + model = self.setup_quantization(model) - if model.config.model_type in MODELS_OPTIMIZED_WITH_STATIC_SHAPES: - self.is_optimized_for_gaudi = True - else: - self.is_optimized_for_gaudi = False + if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES: + raise ValueError(f"Model type {model.config.model_type} is not supported!") if tokenizer.pad_token_id is None: if model.config.pad_token_id is not None: @@ -682,24 +641,85 @@ class CausalLM(Model): rank=rank, kwargs=kwargs, ) - prof_ranks = [int(val) for val in os.getenv("PROF_RANKS", "0").split(',')] - self.profiling_warmup_steps = int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in prof_ranks else 0 - self.profiling_steps = int(os.getenv("PROF_STEP", "0")) if rank in prof_ranks else 0 - self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0")) + + # Create profiler + ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(',')] record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true" output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile") + self.profiling_warmup_steps = int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0 + self.profiling_steps = int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0 + self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0")) if self.profiling_steps > 0: self.hb_profiler = HabanaProfile( wait=self.profiling_wait_steps, warmup=self.profiling_warmup_steps, active=self.profiling_steps, - output_dir=output_dir, record_shapes=record_shapes + output_dir=output_dir, + record_shapes=record_shapes ) self.hb_profiler.start() else: self.hb_profiler = None self.step = 0 + def get_deepspeed_model( + self, + model_id: str, + dtype: torch.dtype, + revision: Optional[str] = None + ) -> torch.nn.Module: + import deepspeed + from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu + + world_size, rank, local_rank = initialize_distributed_hpu() + model_kwargs = { + "revision": revision, + 'rope_scaling': self.get_rope_scaling() + } + + # Initialize process(es) for DeepSpeed + deepspeed.init_distributed(dist_backend="hccl") + logger.info( + "DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(world_size, rank, local_rank) + ) + config = AutoConfig.from_pretrained(model_id, **model_kwargs) + load_to_meta = model_on_meta(config) + + if load_to_meta: + # Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load + with deepspeed.OnDevice(dtype=dtype, device="meta"): + model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype) + else: + get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK")) + # TODO: revisit placement on CPU when auto-injection is possible + with deepspeed.OnDevice(dtype=dtype, device="cpu"): + model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, **model_kwargs) + model = model.eval() + + # Initialize the model + ds_inference_kwargs = {"dtype": dtype} + ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size} + ds_inference_kwargs["enable_cuda_graph"] = False + + if load_to_meta: + # model loaded to meta is managed differently + checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w") + write_checkpoints_json(model_id, local_rank, checkpoints_json) + ds_inference_kwargs["checkpoint"] = checkpoints_json.name + model = deepspeed.init_inference(model, **ds_inference_kwargs) + + return model.module + + def get_rope_scaling(self) -> Optional[Dict]: + rope_scaling = os.getenv("ROPE_SCALING", None) + if rope_scaling is None: + return None + + rope_factor = float(os.getenv("ROPE_FACTOR", 1.0)) + return { + 'type': rope_scaling, 'factor': float(rope_factor) + } + def setup_quantization(self, model): if hq_env.is_quantization_enabled: htorch.core.quantization._mark_params_as_const(model) @@ -754,7 +774,7 @@ class CausalLM(Model): input_ids, attention_mask, position_ids, - token_idx: Optional = None, + token_idx, past_key_values: Optional = None, bypass_hpu_graph: Optional = None, ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: @@ -763,11 +783,9 @@ class CausalLM(Model): "input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values, + "token_idx": token_idx } - if self.is_optimized_for_gaudi: - kwargs["token_idx"] = token_idx - if self.has_position_ids: kwargs["position_ids"] = position_ids @@ -794,20 +812,17 @@ class CausalLM(Model): logits = batch.logits past = batch.past prefill = batch.past_key_values is None - if self.is_optimized_for_gaudi: - if prefill: - # no right padding for prefill - token_idx_scalar = batch.attention_mask.shape[-1] - 1 - token_idx = torch.tensor(token_idx_scalar).to(self.device) - else: - token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding - token_idx = torch.tensor(token_idx_scalar).to(self.device) + if prefill: + # no right padding for prefill + token_idx_scalar = batch.attention_mask.shape[-1] - 1 + token_idx = torch.tensor(token_idx_scalar).to(self.device) else: - token_idx = None + token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding + token_idx = torch.tensor(token_idx_scalar).to(self.device) # Select next token input_length = batch.input_length - if self.is_optimized_for_gaudi and logits.shape[-2] > 1: + if logits.shape[-2] > 1: next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser( batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2) ) @@ -840,20 +855,11 @@ class CausalLM(Model): htorch.core.mark_step() - if token_idx is None: - batch.input_ids[:, 0] = next_token_ids[:, 0] - else: - batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1)) - - # Slice unused values from prefill, use it to store next token - if token_idx is None: - batch.input_ids = batch.input_ids[:, :1] + # Add new token into input_ids + batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1)) # Update attention_mask as we added a new token to input_ids - if self.is_optimized_for_gaudi: - batch.attention_mask.index_fill_(1, token_idx, 1) - else: - batch.attention_mask[:, -batch.padding_right_offset] = 1 + batch.attention_mask.index_fill_(1, token_idx, 1) # Adjust lengths batch.input_length += 1 @@ -886,37 +892,24 @@ class CausalLM(Model): scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}') assert batch.right_padding > 0, 'No more room for next token!' - if self.is_optimized_for_gaudi: - if prefill: - # no right padding for prefill - token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device) - else: - token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device) - attention_mask = batch.attention_mask - else: - token_idx = None - # slice the attention mask to the correct shape - # TODO fix me! - attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] - - if not prefill and token_idx is not None: - input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1) - else: - input_ids = batch.input_ids - + # Execute batch if prefill: + # no right padding for prefill + token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device) batch.logits, batch.past = self.forward( - input_ids, - attention_mask, + batch.input_ids, + batch.attention_mask, batch.position_ids, token_idx, batch.past_key_values, bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None ) else: + token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device) + input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1) batch.logits = self.forward( input_ids, - attention_mask, + batch.attention_mask, batch.position_ids, token_idx, batch.past_key_values, @@ -955,10 +948,7 @@ class CausalLM(Model): top_token_logprobs = req_data['top_token_logprobs'] # Append next token to all tokens - if self.is_optimized_for_gaudi: - all_input_ids[input_length] = next_token_id - else: - all_input_ids = torch.cat([all_input_ids, next_token_id]) + all_input_ids[input_length] = next_token_id new_input_length = input_length + 1 # Generated token diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 32b1b9142..99bd9517d 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -59,7 +59,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): async def Warmup(self, request, context): def batch_from_pb(batch): return self.model.batch_type.from_pb( - batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi + batch, self.model.tokenizer, self.model.dtype, self.model.device ) batches = [batch_from_pb(batch) for batch in request.batches] @@ -69,7 +69,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): async def Prefill(self, request, context): batch = self.model.batch_type.from_pb( - request.batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi + request.batch, self.model.tokenizer, self.model.dtype, self.model.device ) generations, next_batch = self.model.generate_token([batch]) self.cache.set(next_batch)