Add support for rope_scaling and remove is_optimized_for_gaudi (#112)

Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
Karol Damaszke 2024-03-29 15:07:32 +01:00 committed by GitHub
parent bf5263b88b
commit 7342baa2eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 140 additions and 152 deletions

View File

@ -17,14 +17,12 @@ class BloomCausalLMBatch(CausalLMBatch):
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
is_optimized_for_gaudi: bool = False,
) -> "CausalLMBatch":
batch = super().from_pb(
pb=pb,
tokenizer=tokenizer,
dtype=dtype,
device=device,
is_optimized_for_gaudi=is_optimized_for_gaudi,
)
batch.keys_head_dim_last = False
return batch

View File

@ -1,28 +1,33 @@
import bisect
from dataclasses import dataclass
from functools import wraps
import itertools
import math
import os
import tempfile
import itertools
import bisect
import math
from typing import Dict, List, Optional, Tuple, Type
import torch
from dataclasses import dataclass
from loguru import logger
from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, AutoConfig
from typing import Optional, Tuple, List, Type, Dict
import text_generation_server.habana_quantization_env as hq_env
import habana_frameworks.torch as htorch
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
from contextlib import nullcontext
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
from optimum.habana.utils import HabanaProfile
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
from optimum.habana.checkpoint_utils import (
get_repo_root,
model_on_meta,
write_checkpoints_json,
)
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
PreTrainedTokenizerBase,
AutoConfig,
)
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models import Model
@ -34,11 +39,13 @@ from text_generation_server.models.types import (
TopTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria, Sampling, make_tokenizer_optional, is_tokenizer_transparent
from text_generation_server.utils import (
HeterogeneousNextTokenChooser,
StoppingCriteria,
make_tokenizer_optional,
is_tokenizer_transparent,
)
from text_generation_server.utils.debug import dbg_trace
from loguru import logger
from functools import wraps
tracer = trace.get_tracer(__name__)
@ -384,7 +391,7 @@ class CausalLMBatch(Batch):
parameters = [r.data.parameters for r in flat_requests]
if len(flat_requests) < new_bs:
for i in range(new_bs-len(flat_requests)) :
#append the dummy parameters for dummy request
# append the dummy parameters for dummy request
parameters.append(parameters[0])
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
@ -423,7 +430,6 @@ class CausalLMBatch(Batch):
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
is_optimized_for_gaudi: bool = False,
) -> "CausalLMBatch":
dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}')
requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)]
@ -474,19 +480,16 @@ class CausalLMBatch(Batch):
input_ids = tokenized_inputs["input_ids"]
attention_mask = tokenized_inputs["attention_mask"]
if is_optimized_for_gaudi:
# Allocate space for first token
input_ids = torch.nn.functional.pad(
input_ids, (left_padding, 1), value=tokenizer.pad_token_id
)
attention_mask = torch.nn.functional.pad(
attention_mask, (left_padding, 1), value=0
)
all_input_ids = torch.nn.functional.pad(
input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id
).T.split(1, dim=1)
else:
all_input_ids = input_ids.clone().T.split(1, dim=1)
# Allocate space for first token
input_ids = torch.nn.functional.pad(
input_ids, (left_padding, 1), value=tokenizer.pad_token_id
)
attention_mask = torch.nn.functional.pad(
attention_mask, (left_padding, 1), value=0
)
all_input_ids = torch.nn.functional.pad(
input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id
).T.split(1, dim=1)
# New input length after left padding
input_len = bucket_size
@ -562,16 +565,9 @@ class CausalLM(Model):
revision: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
):
device = torch.device("hpu")
if hq_env.is_quantization_enabled:
htorch.core.hpu_set_env()
dtype = torch.bfloat16 if dtype is None else dtype
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
adapt_transformers_to_gaudi()
# Create tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
@ -580,79 +576,42 @@ class CausalLM(Model):
)
make_tokenizer_optional(tokenizer)
model_kwargs = {
"revision": revision,
}
# Create model
world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK", "0"))
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true"
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
dtype = torch.bfloat16 if dtype is None else dtype
device = torch.device("hpu")
if hq_env.is_quantization_enabled:
htorch.core.hpu_set_env()
if world_size > 1:
import habana_frameworks.torch.hpu as torch_hpu
# Get world size, rank and local rank
from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu
world_size, rank, local_rank = initialize_distributed_hpu()
import deepspeed
# Initialize process(es) for DeepSpeed
deepspeed.init_distributed(dist_backend="hccl")
logger.info(
"DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(world_size, rank, local_rank)
model = self.get_deepspeed_model(
model_id, dtype, revision
)
config = AutoConfig.from_pretrained(model_id, **model_kwargs)
load_to_meta = model_on_meta(config)
if load_to_meta:
# Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load
with deepspeed.OnDevice(dtype=dtype, device="meta"):
model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype)
else:
get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK"))
# TODO: revisit placement on CPU when auto-injection is possible
with deepspeed.OnDevice(dtype=dtype, device="cpu"):
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, **model_kwargs)
model = model.eval()
# Initialize the model
ds_inference_kwargs = {"dtype": dtype}
ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size}
ds_inference_kwargs["enable_cuda_graph"] = False
if load_to_meta:
# model loaded to meta is managed differently
checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w")
write_checkpoints_json(model_id, local_rank, checkpoints_json)
ds_inference_kwargs["checkpoint"] = checkpoints_json.name
model = deepspeed.init_inference(model, **ds_inference_kwargs)
model = model.module
model = self.prepare_model_for_quantization(model)
model = remove_kv_cache_from_output(model)
if self.enable_hpu_graph:
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
else:
get_repo_root(model_id)
rope_scaling = self.get_rope_scaling()
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
rope_scaling=rope_scaling
)
model = self.prepare_model_for_quantization(model)
model = model.eval().to(device)
# wrap in hpu_graph only if self.enable_hpu_graph is set
model = remove_kv_cache_from_output(model)
if self.enable_hpu_graph:
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true"
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
model = remove_kv_cache_from_output(model)
if self.enable_hpu_graph:
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
model = self.setup_quantization(model)
if model.config.model_type in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
self.is_optimized_for_gaudi = True
else:
self.is_optimized_for_gaudi = False
if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
raise ValueError(f"Model type {model.config.model_type} is not supported!")
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:
@ -682,24 +641,85 @@ class CausalLM(Model):
rank=rank,
kwargs=kwargs,
)
prof_ranks = [int(val) for val in os.getenv("PROF_RANKS", "0").split(',')]
self.profiling_warmup_steps = int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in prof_ranks else 0
self.profiling_steps = int(os.getenv("PROF_STEP", "0")) if rank in prof_ranks else 0
self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0"))
# Create profiler
ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(',')]
record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true"
output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile")
self.profiling_warmup_steps = int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0
self.profiling_steps = int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0
self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0"))
if self.profiling_steps > 0:
self.hb_profiler = HabanaProfile(
wait=self.profiling_wait_steps,
warmup=self.profiling_warmup_steps,
active=self.profiling_steps,
output_dir=output_dir, record_shapes=record_shapes
output_dir=output_dir,
record_shapes=record_shapes
)
self.hb_profiler.start()
else:
self.hb_profiler = None
self.step = 0
def get_deepspeed_model(
self,
model_id: str,
dtype: torch.dtype,
revision: Optional[str] = None
) -> torch.nn.Module:
import deepspeed
from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu
world_size, rank, local_rank = initialize_distributed_hpu()
model_kwargs = {
"revision": revision,
'rope_scaling': self.get_rope_scaling()
}
# Initialize process(es) for DeepSpeed
deepspeed.init_distributed(dist_backend="hccl")
logger.info(
"DeepSpeed is enabled. world_size {} rank {} local_rank {}".format(world_size, rank, local_rank)
)
config = AutoConfig.from_pretrained(model_id, **model_kwargs)
load_to_meta = model_on_meta(config)
if load_to_meta:
# Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load
with deepspeed.OnDevice(dtype=dtype, device="meta"):
model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype)
else:
get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK"))
# TODO: revisit placement on CPU when auto-injection is possible
with deepspeed.OnDevice(dtype=dtype, device="cpu"):
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, **model_kwargs)
model = model.eval()
# Initialize the model
ds_inference_kwargs = {"dtype": dtype}
ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size}
ds_inference_kwargs["enable_cuda_graph"] = False
if load_to_meta:
# model loaded to meta is managed differently
checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w")
write_checkpoints_json(model_id, local_rank, checkpoints_json)
ds_inference_kwargs["checkpoint"] = checkpoints_json.name
model = deepspeed.init_inference(model, **ds_inference_kwargs)
return model.module
def get_rope_scaling(self) -> Optional[Dict]:
rope_scaling = os.getenv("ROPE_SCALING", None)
if rope_scaling is None:
return None
rope_factor = float(os.getenv("ROPE_FACTOR", 1.0))
return {
'type': rope_scaling, 'factor': float(rope_factor)
}
def setup_quantization(self, model):
if hq_env.is_quantization_enabled:
htorch.core.quantization._mark_params_as_const(model)
@ -754,7 +774,7 @@ class CausalLM(Model):
input_ids,
attention_mask,
position_ids,
token_idx: Optional = None,
token_idx,
past_key_values: Optional = None,
bypass_hpu_graph: Optional = None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
@ -763,11 +783,9 @@ class CausalLM(Model):
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"token_idx": token_idx
}
if self.is_optimized_for_gaudi:
kwargs["token_idx"] = token_idx
if self.has_position_ids:
kwargs["position_ids"] = position_ids
@ -794,20 +812,17 @@ class CausalLM(Model):
logits = batch.logits
past = batch.past
prefill = batch.past_key_values is None
if self.is_optimized_for_gaudi:
if prefill:
# no right padding for prefill
token_idx_scalar = batch.attention_mask.shape[-1] - 1
token_idx = torch.tensor(token_idx_scalar).to(self.device)
else:
token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding
token_idx = torch.tensor(token_idx_scalar).to(self.device)
if prefill:
# no right padding for prefill
token_idx_scalar = batch.attention_mask.shape[-1] - 1
token_idx = torch.tensor(token_idx_scalar).to(self.device)
else:
token_idx = None
token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding
token_idx = torch.tensor(token_idx_scalar).to(self.device)
# Select next token
input_length = batch.input_length
if self.is_optimized_for_gaudi and logits.shape[-2] > 1:
if logits.shape[-2] > 1:
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2)
)
@ -840,20 +855,11 @@ class CausalLM(Model):
htorch.core.mark_step()
if token_idx is None:
batch.input_ids[:, 0] = next_token_ids[:, 0]
else:
batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1))
# Slice unused values from prefill, use it to store next token
if token_idx is None:
batch.input_ids = batch.input_ids[:, :1]
# Add new token into input_ids
batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1))
# Update attention_mask as we added a new token to input_ids
if self.is_optimized_for_gaudi:
batch.attention_mask.index_fill_(1, token_idx, 1)
else:
batch.attention_mask[:, -batch.padding_right_offset] = 1
batch.attention_mask.index_fill_(1, token_idx, 1)
# Adjust lengths
batch.input_length += 1
@ -886,37 +892,24 @@ class CausalLM(Model):
scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}')
assert batch.right_padding > 0, 'No more room for next token!'
if self.is_optimized_for_gaudi:
if prefill:
# no right padding for prefill
token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device)
else:
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
attention_mask = batch.attention_mask
else:
token_idx = None
# slice the attention mask to the correct shape
# TODO fix me!
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
if not prefill and token_idx is not None:
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
else:
input_ids = batch.input_ids
# Execute batch
if prefill:
# no right padding for prefill
token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device)
batch.logits, batch.past = self.forward(
input_ids,
attention_mask,
batch.input_ids,
batch.attention_mask,
batch.position_ids,
token_idx,
batch.past_key_values,
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
)
else:
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
batch.logits = self.forward(
input_ids,
attention_mask,
batch.attention_mask,
batch.position_ids,
token_idx,
batch.past_key_values,
@ -955,10 +948,7 @@ class CausalLM(Model):
top_token_logprobs = req_data['top_token_logprobs']
# Append next token to all tokens
if self.is_optimized_for_gaudi:
all_input_ids[input_length] = next_token_id
else:
all_input_ids = torch.cat([all_input_ids, next_token_id])
all_input_ids[input_length] = next_token_id
new_input_length = input_length + 1
# Generated token

View File

@ -59,7 +59,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
async def Warmup(self, request, context):
def batch_from_pb(batch):
return self.model.batch_type.from_pb(
batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi
batch, self.model.tokenizer, self.model.dtype, self.model.device
)
batches = [batch_from_pb(batch) for batch in request.batches]
@ -69,7 +69,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
async def Prefill(self, request, context):
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
)
generations, next_batch = self.model.generate_token([batch])
self.cache.set(next_batch)