Add support for rope_scaling and remove is_optimized_for_gaudi (#112)

Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
Karol Damaszke 2024-03-29 15:07:32 +01:00 committed by GitHub
parent bf5263b88b
commit 7342baa2eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 140 additions and 152 deletions

View File

@ -17,14 +17,12 @@ class BloomCausalLMBatch(CausalLMBatch):
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
is_optimized_for_gaudi: bool = False,
) -> "CausalLMBatch":
batch = super().from_pb(
pb=pb,
tokenizer=tokenizer,
dtype=dtype,
device=device,
is_optimized_for_gaudi=is_optimized_for_gaudi,
)
batch.keys_head_dim_last = False
return batch

View File

@ -1,28 +1,33 @@
import bisect
from dataclasses import dataclass
from functools import wraps
import itertools
import math
import os
import tempfile
import itertools
import bisect
import math
from typing import Dict, List, Optional, Tuple, Type
import torch
from dataclasses import dataclass
from loguru import logger
from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, AutoConfig
from typing import Optional, Tuple, List, Type, Dict
import text_generation_server.habana_quantization_env as hq_env
import habana_frameworks.torch as htorch
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
from contextlib import nullcontext
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
from optimum.habana.utils import HabanaProfile
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
from optimum.habana.checkpoint_utils import (
get_repo_root,
model_on_meta,
write_checkpoints_json,
)
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
PreTrainedTokenizerBase,
AutoConfig,
)
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models import Model
@ -34,11 +39,13 @@ from text_generation_server.models.types import (
TopTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria, Sampling, make_tokenizer_optional, is_tokenizer_transparent
from text_generation_server.utils import (
HeterogeneousNextTokenChooser,
StoppingCriteria,
make_tokenizer_optional,
is_tokenizer_transparent,
)
from text_generation_server.utils.debug import dbg_trace
from loguru import logger
from functools import wraps
tracer = trace.get_tracer(__name__)
@ -384,7 +391,7 @@ class CausalLMBatch(Batch):
parameters = [r.data.parameters for r in flat_requests]
if len(flat_requests) < new_bs:
for i in range(new_bs-len(flat_requests)) :
#append the dummy parameters for dummy request
# append the dummy parameters for dummy request
parameters.append(parameters[0])
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
@ -423,7 +430,6 @@ class CausalLMBatch(Batch):
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
is_optimized_for_gaudi: bool = False,
) -> "CausalLMBatch":
dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}')
requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)]
@ -474,7 +480,6 @@ class CausalLMBatch(Batch):
input_ids = tokenized_inputs["input_ids"]
attention_mask = tokenized_inputs["attention_mask"]
if is_optimized_for_gaudi:
# Allocate space for first token
input_ids = torch.nn.functional.pad(
input_ids, (left_padding, 1), value=tokenizer.pad_token_id
@ -485,8 +490,6 @@ class CausalLMBatch(Batch):
all_input_ids = torch.nn.functional.pad(
input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id
).T.split(1, dim=1)
else:
all_input_ids = input_ids.clone().T.split(1, dim=1)
# New input length after left padding
input_len = bucket_size
@ -562,16 +565,9 @@ class CausalLM(Model):
revision: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
):
device = torch.device("hpu")
if hq_env.is_quantization_enabled:
htorch.core.hpu_set_env()
dtype = torch.bfloat16 if dtype is None else dtype
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
adapt_transformers_to_gaudi()
# Create tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
@ -580,23 +576,106 @@ class CausalLM(Model):
)
make_tokenizer_optional(tokenizer)
model_kwargs = {
"revision": revision,
}
# Create model
world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK", "0"))
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true"
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
dtype = torch.bfloat16 if dtype is None else dtype
device = torch.device("hpu")
if hq_env.is_quantization_enabled:
htorch.core.hpu_set_env()
if world_size > 1:
import habana_frameworks.torch.hpu as torch_hpu
model = self.get_deepspeed_model(
model_id, dtype, revision
)
model = self.prepare_model_for_quantization(model)
else:
get_repo_root(model_id)
rope_scaling = self.get_rope_scaling()
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
rope_scaling=rope_scaling
)
model = self.prepare_model_for_quantization(model)
model = model.eval().to(device)
# Get world size, rank and local rank
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true"
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
model = remove_kv_cache_from_output(model)
if self.enable_hpu_graph:
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
model = self.setup_quantization(model)
if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
raise ValueError(f"Model type {model.config.model_type} is not supported!")
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:
tokenizer.pad_token_id = model.config.pad_token_id
elif model.config.eos_token_id is not None:
tokenizer.pad_token_id = model.config.eos_token_id
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
kwargs = {
"use_cache": True,
"return_dict": True,
}
if model.config.model_type == "llama":
kwargs["attn_softmax_bf16"] = True
kwargs["trim_logits"] = True
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
kwargs=kwargs,
)
# Create profiler
ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(',')]
record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true"
output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile")
self.profiling_warmup_steps = int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0
self.profiling_steps = int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0
self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0"))
if self.profiling_steps > 0:
self.hb_profiler = HabanaProfile(
wait=self.profiling_wait_steps,
warmup=self.profiling_warmup_steps,
active=self.profiling_steps,
output_dir=output_dir,
record_shapes=record_shapes
)
self.hb_profiler.start()
else:
self.hb_profiler = None
self.step = 0
def get_deepspeed_model(
self,
model_id: str,
dtype: torch.dtype,
revision: Optional[str] = None
) -> torch.nn.Module:
import deepspeed
from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu
world_size, rank, local_rank = initialize_distributed_hpu()
import deepspeed
model_kwargs = {
"revision": revision,
'rope_scaling': self.get_rope_scaling()
}
# Initialize process(es) for DeepSpeed
deepspeed.init_distributed(dist_backend="hccl")
@ -628,78 +707,19 @@ class CausalLM(Model):
write_checkpoints_json(model_id, local_rank, checkpoints_json)
ds_inference_kwargs["checkpoint"] = checkpoints_json.name
model = deepspeed.init_inference(model, **ds_inference_kwargs)
model = model.module
model = self.prepare_model_for_quantization(model)
model = remove_kv_cache_from_output(model)
if self.enable_hpu_graph:
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
else:
get_repo_root(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
)
model = self.prepare_model_for_quantization(model)
model = model.eval().to(device)
# wrap in hpu_graph only if self.enable_hpu_graph is set
model = remove_kv_cache_from_output(model)
if self.enable_hpu_graph:
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
model = self.setup_quantization(model)
return model.module
if model.config.model_type in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
self.is_optimized_for_gaudi = True
else:
self.is_optimized_for_gaudi = False
def get_rope_scaling(self) -> Optional[Dict]:
rope_scaling = os.getenv("ROPE_SCALING", None)
if rope_scaling is None:
return None
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:
tokenizer.pad_token_id = model.config.pad_token_id
elif model.config.eos_token_id is not None:
tokenizer.pad_token_id = model.config.eos_token_id
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
kwargs = {
"use_cache": True,
"return_dict": True,
rope_factor = float(os.getenv("ROPE_FACTOR", 1.0))
return {
'type': rope_scaling, 'factor': float(rope_factor)
}
if model.config.model_type == "llama":
kwargs["attn_softmax_bf16"] = True
kwargs["trim_logits"] = True
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
kwargs=kwargs,
)
prof_ranks = [int(val) for val in os.getenv("PROF_RANKS", "0").split(',')]
self.profiling_warmup_steps = int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in prof_ranks else 0
self.profiling_steps = int(os.getenv("PROF_STEP", "0")) if rank in prof_ranks else 0
self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0"))
record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true"
output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile")
if self.profiling_steps > 0:
self.hb_profiler = HabanaProfile(
wait=self.profiling_wait_steps,
warmup=self.profiling_warmup_steps,
active=self.profiling_steps,
output_dir=output_dir, record_shapes=record_shapes
)
self.hb_profiler.start()
else:
self.hb_profiler = None
self.step = 0
def setup_quantization(self, model):
if hq_env.is_quantization_enabled:
htorch.core.quantization._mark_params_as_const(model)
@ -754,7 +774,7 @@ class CausalLM(Model):
input_ids,
attention_mask,
position_ids,
token_idx: Optional = None,
token_idx,
past_key_values: Optional = None,
bypass_hpu_graph: Optional = None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
@ -763,11 +783,9 @@ class CausalLM(Model):
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"token_idx": token_idx
}
if self.is_optimized_for_gaudi:
kwargs["token_idx"] = token_idx
if self.has_position_ids:
kwargs["position_ids"] = position_ids
@ -794,7 +812,6 @@ class CausalLM(Model):
logits = batch.logits
past = batch.past
prefill = batch.past_key_values is None
if self.is_optimized_for_gaudi:
if prefill:
# no right padding for prefill
token_idx_scalar = batch.attention_mask.shape[-1] - 1
@ -802,12 +819,10 @@ class CausalLM(Model):
else:
token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding
token_idx = torch.tensor(token_idx_scalar).to(self.device)
else:
token_idx = None
# Select next token
input_length = batch.input_length
if self.is_optimized_for_gaudi and logits.shape[-2] > 1:
if logits.shape[-2] > 1:
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2)
)
@ -840,20 +855,11 @@ class CausalLM(Model):
htorch.core.mark_step()
if token_idx is None:
batch.input_ids[:, 0] = next_token_ids[:, 0]
else:
# Add new token into input_ids
batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1))
# Slice unused values from prefill, use it to store next token
if token_idx is None:
batch.input_ids = batch.input_ids[:, :1]
# Update attention_mask as we added a new token to input_ids
if self.is_optimized_for_gaudi:
batch.attention_mask.index_fill_(1, token_idx, 1)
else:
batch.attention_mask[:, -batch.padding_right_offset] = 1
# Adjust lengths
batch.input_length += 1
@ -886,37 +892,24 @@ class CausalLM(Model):
scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}')
assert batch.right_padding > 0, 'No more room for next token!'
if self.is_optimized_for_gaudi:
# Execute batch
if prefill:
# no right padding for prefill
token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device)
else:
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
attention_mask = batch.attention_mask
else:
token_idx = None
# slice the attention mask to the correct shape
# TODO fix me!
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
if not prefill and token_idx is not None:
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
else:
input_ids = batch.input_ids
if prefill:
batch.logits, batch.past = self.forward(
input_ids,
attention_mask,
batch.input_ids,
batch.attention_mask,
batch.position_ids,
token_idx,
batch.past_key_values,
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
)
else:
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
batch.logits = self.forward(
input_ids,
attention_mask,
batch.attention_mask,
batch.position_ids,
token_idx,
batch.past_key_values,
@ -955,10 +948,7 @@ class CausalLM(Model):
top_token_logprobs = req_data['top_token_logprobs']
# Append next token to all tokens
if self.is_optimized_for_gaudi:
all_input_ids[input_length] = next_token_id
else:
all_input_ids = torch.cat([all_input_ids, next_token_id])
new_input_length = input_length + 1
# Generated token

View File

@ -59,7 +59,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
async def Warmup(self, request, context):
def batch_from_pb(batch):
return self.model.batch_type.from_pb(
batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi
batch, self.model.tokenizer, self.model.dtype, self.model.device
)
batches = [batch_from_pb(batch) for batch in request.batches]
@ -69,7 +69,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
async def Prefill(self, request, context):
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
)
generations, next_batch = self.model.generate_token([batch])
self.cache.set(next_batch)