Add support for rope_scaling and remove is_optimized_for_gaudi (#112)

Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
Karol Damaszke 2024-03-29 15:07:32 +01:00 committed by GitHub
parent bf5263b88b
commit 7342baa2eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 140 additions and 152 deletions

View File

@ -17,14 +17,12 @@ class BloomCausalLMBatch(CausalLMBatch):
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
is_optimized_for_gaudi: bool = False,
) -> "CausalLMBatch": ) -> "CausalLMBatch":
batch = super().from_pb( batch = super().from_pb(
pb=pb, pb=pb,
tokenizer=tokenizer, tokenizer=tokenizer,
dtype=dtype, dtype=dtype,
device=device, device=device,
is_optimized_for_gaudi=is_optimized_for_gaudi,
) )
batch.keys_head_dim_last = False batch.keys_head_dim_last = False
return batch return batch

View File

@ -1,28 +1,33 @@
import bisect
from dataclasses import dataclass
from functools import wraps
import itertools
import math
import os import os
import tempfile import tempfile
import itertools from typing import Dict, List, Optional, Tuple, Type
import bisect
import math
import torch import torch
from loguru import logger
from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, AutoConfig
from typing import Optional, Tuple, List, Type, Dict
import text_generation_server.habana_quantization_env as hq_env import text_generation_server.habana_quantization_env as hq_env
import habana_frameworks.torch as htorch import habana_frameworks.torch as htorch
from habana_frameworks.torch.hpu import wrap_in_hpu_graph from habana_frameworks.torch.hpu import wrap_in_hpu_graph
from contextlib import nullcontext from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
from optimum.habana.utils import HabanaProfile from optimum.habana.utils import HabanaProfile
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
from optimum.habana.checkpoint_utils import ( from optimum.habana.checkpoint_utils import (
get_repo_root, get_repo_root,
model_on_meta, model_on_meta,
write_checkpoints_json, write_checkpoints_json,
) )
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
PreTrainedTokenizerBase,
AutoConfig,
)
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models import Model from text_generation_server.models import Model
@ -34,11 +39,13 @@ from text_generation_server.models.types import (
TopTokens, TopTokens,
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria, Sampling, make_tokenizer_optional, is_tokenizer_transparent from text_generation_server.utils import (
HeterogeneousNextTokenChooser,
StoppingCriteria,
make_tokenizer_optional,
is_tokenizer_transparent,
)
from text_generation_server.utils.debug import dbg_trace from text_generation_server.utils.debug import dbg_trace
from loguru import logger
from functools import wraps
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -384,7 +391,7 @@ class CausalLMBatch(Batch):
parameters = [r.data.parameters for r in flat_requests] parameters = [r.data.parameters for r in flat_requests]
if len(flat_requests) < new_bs: if len(flat_requests) < new_bs:
for i in range(new_bs-len(flat_requests)) : for i in range(new_bs-len(flat_requests)) :
#append the dummy parameters for dummy request # append the dummy parameters for dummy request
parameters.append(parameters[0]) parameters.append(parameters[0])
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
@ -423,7 +430,6 @@ class CausalLMBatch(Batch):
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
is_optimized_for_gaudi: bool = False,
) -> "CausalLMBatch": ) -> "CausalLMBatch":
dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}') dbg_trace('FROM_PB', f'num_reqs:{len(pb.requests)}')
requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)] requests = [CausalLMRequest.from_pb(idx, req, tokenizer) for idx, req in enumerate(pb.requests)]
@ -474,7 +480,6 @@ class CausalLMBatch(Batch):
input_ids = tokenized_inputs["input_ids"] input_ids = tokenized_inputs["input_ids"]
attention_mask = tokenized_inputs["attention_mask"] attention_mask = tokenized_inputs["attention_mask"]
if is_optimized_for_gaudi:
# Allocate space for first token # Allocate space for first token
input_ids = torch.nn.functional.pad( input_ids = torch.nn.functional.pad(
input_ids, (left_padding, 1), value=tokenizer.pad_token_id input_ids, (left_padding, 1), value=tokenizer.pad_token_id
@ -485,8 +490,6 @@ class CausalLMBatch(Batch):
all_input_ids = torch.nn.functional.pad( all_input_ids = torch.nn.functional.pad(
input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id input_ids, (0, max_new_tokens), value=tokenizer.pad_token_id
).T.split(1, dim=1) ).T.split(1, dim=1)
else:
all_input_ids = input_ids.clone().T.split(1, dim=1)
# New input length after left padding # New input length after left padding
input_len = bucket_size input_len = bucket_size
@ -562,16 +565,9 @@ class CausalLM(Model):
revision: Optional[str] = None, revision: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
): ):
device = torch.device("hpu")
if hq_env.is_quantization_enabled:
htorch.core.hpu_set_env()
dtype = torch.bfloat16 if dtype is None else dtype
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
adapt_transformers_to_gaudi() adapt_transformers_to_gaudi()
# Create tokenizer
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
@ -580,23 +576,106 @@ class CausalLM(Model):
) )
make_tokenizer_optional(tokenizer) make_tokenizer_optional(tokenizer)
model_kwargs = { # Create model
"revision": revision,
}
world_size = int(os.getenv("WORLD_SIZE", "1")) world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK", "0")) rank = int(os.getenv("RANK", "0"))
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" dtype = torch.bfloat16 if dtype is None else dtype
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true" device = torch.device("hpu")
if hq_env.is_quantization_enabled:
htorch.core.hpu_set_env()
if world_size > 1: if world_size > 1:
import habana_frameworks.torch.hpu as torch_hpu model = self.get_deepspeed_model(
model_id, dtype, revision
)
model = self.prepare_model_for_quantization(model)
else:
get_repo_root(model_id)
rope_scaling = self.get_rope_scaling()
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
rope_scaling=rope_scaling
)
model = self.prepare_model_for_quantization(model)
model = model.eval().to(device)
# Get world size, rank and local rank self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true"
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
model = remove_kv_cache_from_output(model)
if self.enable_hpu_graph:
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
model = self.setup_quantization(model)
if model.config.model_type not in MODELS_OPTIMIZED_WITH_STATIC_SHAPES:
raise ValueError(f"Model type {model.config.model_type} is not supported!")
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:
tokenizer.pad_token_id = model.config.pad_token_id
elif model.config.eos_token_id is not None:
tokenizer.pad_token_id = model.config.eos_token_id
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
kwargs = {
"use_cache": True,
"return_dict": True,
}
if model.config.model_type == "llama":
kwargs["attn_softmax_bf16"] = True
kwargs["trim_logits"] = True
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
kwargs=kwargs,
)
# Create profiler
ranks_to_profile = [int(val) for val in os.getenv("PROF_RANKS", "0").split(',')]
record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true"
output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile")
self.profiling_warmup_steps = int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in ranks_to_profile else 0
self.profiling_steps = int(os.getenv("PROF_STEP", "0")) if rank in ranks_to_profile else 0
self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0"))
if self.profiling_steps > 0:
self.hb_profiler = HabanaProfile(
wait=self.profiling_wait_steps,
warmup=self.profiling_warmup_steps,
active=self.profiling_steps,
output_dir=output_dir,
record_shapes=record_shapes
)
self.hb_profiler.start()
else:
self.hb_profiler = None
self.step = 0
def get_deepspeed_model(
self,
model_id: str,
dtype: torch.dtype,
revision: Optional[str] = None
) -> torch.nn.Module:
import deepspeed
from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu
world_size, rank, local_rank = initialize_distributed_hpu() world_size, rank, local_rank = initialize_distributed_hpu()
import deepspeed model_kwargs = {
"revision": revision,
'rope_scaling': self.get_rope_scaling()
}
# Initialize process(es) for DeepSpeed # Initialize process(es) for DeepSpeed
deepspeed.init_distributed(dist_backend="hccl") deepspeed.init_distributed(dist_backend="hccl")
@ -628,78 +707,19 @@ class CausalLM(Model):
write_checkpoints_json(model_id, local_rank, checkpoints_json) write_checkpoints_json(model_id, local_rank, checkpoints_json)
ds_inference_kwargs["checkpoint"] = checkpoints_json.name ds_inference_kwargs["checkpoint"] = checkpoints_json.name
model = deepspeed.init_inference(model, **ds_inference_kwargs) model = deepspeed.init_inference(model, **ds_inference_kwargs)
model = model.module
model = self.prepare_model_for_quantization(model)
model = remove_kv_cache_from_output(model)
if self.enable_hpu_graph:
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
else: return model.module
get_repo_root(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
)
model = self.prepare_model_for_quantization(model)
model = model.eval().to(device)
# wrap in hpu_graph only if self.enable_hpu_graph is set
model = remove_kv_cache_from_output(model)
if self.enable_hpu_graph:
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
model = self.setup_quantization(model)
if model.config.model_type in MODELS_OPTIMIZED_WITH_STATIC_SHAPES: def get_rope_scaling(self) -> Optional[Dict]:
self.is_optimized_for_gaudi = True rope_scaling = os.getenv("ROPE_SCALING", None)
else: if rope_scaling is None:
self.is_optimized_for_gaudi = False return None
if tokenizer.pad_token_id is None: rope_factor = float(os.getenv("ROPE_FACTOR", 1.0))
if model.config.pad_token_id is not None: return {
tokenizer.pad_token_id = model.config.pad_token_id 'type': rope_scaling, 'factor': float(rope_factor)
elif model.config.eos_token_id is not None:
tokenizer.pad_token_id = model.config.eos_token_id
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
kwargs = {
"use_cache": True,
"return_dict": True,
} }
if model.config.model_type == "llama":
kwargs["attn_softmax_bf16"] = True
kwargs["trim_logits"] = True
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
kwargs=kwargs,
)
prof_ranks = [int(val) for val in os.getenv("PROF_RANKS", "0").split(',')]
self.profiling_warmup_steps = int(os.getenv("PROF_WARMUPSTEP", "0")) if rank in prof_ranks else 0
self.profiling_steps = int(os.getenv("PROF_STEP", "0")) if rank in prof_ranks else 0
self.profiling_wait_steps = int(os.getenv("PROF_WAITSTEP", "0"))
record_shapes = os.getenv("PROF_RECORD_SHAPES", "false").lower() == "true"
output_dir = os.getenv("PROF_PATH", "/tmp/hpu_profile")
if self.profiling_steps > 0:
self.hb_profiler = HabanaProfile(
wait=self.profiling_wait_steps,
warmup=self.profiling_warmup_steps,
active=self.profiling_steps,
output_dir=output_dir, record_shapes=record_shapes
)
self.hb_profiler.start()
else:
self.hb_profiler = None
self.step = 0
def setup_quantization(self, model): def setup_quantization(self, model):
if hq_env.is_quantization_enabled: if hq_env.is_quantization_enabled:
htorch.core.quantization._mark_params_as_const(model) htorch.core.quantization._mark_params_as_const(model)
@ -754,7 +774,7 @@ class CausalLM(Model):
input_ids, input_ids,
attention_mask, attention_mask,
position_ids, position_ids,
token_idx: Optional = None, token_idx,
past_key_values: Optional = None, past_key_values: Optional = None,
bypass_hpu_graph: Optional = None, bypass_hpu_graph: Optional = None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
@ -763,11 +783,9 @@ class CausalLM(Model):
"input_ids": input_ids, "input_ids": input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"past_key_values": past_key_values, "past_key_values": past_key_values,
"token_idx": token_idx
} }
if self.is_optimized_for_gaudi:
kwargs["token_idx"] = token_idx
if self.has_position_ids: if self.has_position_ids:
kwargs["position_ids"] = position_ids kwargs["position_ids"] = position_ids
@ -794,7 +812,6 @@ class CausalLM(Model):
logits = batch.logits logits = batch.logits
past = batch.past past = batch.past
prefill = batch.past_key_values is None prefill = batch.past_key_values is None
if self.is_optimized_for_gaudi:
if prefill: if prefill:
# no right padding for prefill # no right padding for prefill
token_idx_scalar = batch.attention_mask.shape[-1] - 1 token_idx_scalar = batch.attention_mask.shape[-1] - 1
@ -802,12 +819,10 @@ class CausalLM(Model):
else: else:
token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding token_idx_scalar = batch.attention_mask.shape[-1] - batch.right_padding
token_idx = torch.tensor(token_idx_scalar).to(self.device) token_idx = torch.tensor(token_idx_scalar).to(self.device)
else:
token_idx = None
# Select next token # Select next token
input_length = batch.input_length input_length = batch.input_length
if self.is_optimized_for_gaudi and logits.shape[-2] > 1: if logits.shape[-2] > 1:
next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser( next_token_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2) batch.input_ids, logits[:, input_length - 1: input_length, :].squeeze(-2)
) )
@ -840,20 +855,11 @@ class CausalLM(Model):
htorch.core.mark_step() htorch.core.mark_step()
if token_idx is None: # Add new token into input_ids
batch.input_ids[:, 0] = next_token_ids[:, 0]
else:
batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1)) batch.input_ids.index_copy_(1, token_idx, next_token_ids.unsqueeze(1))
# Slice unused values from prefill, use it to store next token
if token_idx is None:
batch.input_ids = batch.input_ids[:, :1]
# Update attention_mask as we added a new token to input_ids # Update attention_mask as we added a new token to input_ids
if self.is_optimized_for_gaudi:
batch.attention_mask.index_fill_(1, token_idx, 1) batch.attention_mask.index_fill_(1, token_idx, 1)
else:
batch.attention_mask[:, -batch.padding_right_offset] = 1
# Adjust lengths # Adjust lengths
batch.input_length += 1 batch.input_length += 1
@ -886,37 +892,24 @@ class CausalLM(Model):
scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}') scenario, f'bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}')
assert batch.right_padding > 0, 'No more room for next token!' assert batch.right_padding > 0, 'No more room for next token!'
if self.is_optimized_for_gaudi: # Execute batch
if prefill: if prefill:
# no right padding for prefill # no right padding for prefill
token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device) token_idx = torch.tensor(batch.attention_mask.shape[-1] - 1).to(self.device)
else:
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
attention_mask = batch.attention_mask
else:
token_idx = None
# slice the attention mask to the correct shape
# TODO fix me!
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
if not prefill and token_idx is not None:
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
else:
input_ids = batch.input_ids
if prefill:
batch.logits, batch.past = self.forward( batch.logits, batch.past = self.forward(
input_ids, batch.input_ids,
attention_mask, batch.attention_mask,
batch.position_ids, batch.position_ids,
token_idx, token_idx,
batch.past_key_values, batch.past_key_values,
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None
) )
else: else:
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
batch.logits = self.forward( batch.logits = self.forward(
input_ids, input_ids,
attention_mask, batch.attention_mask,
batch.position_ids, batch.position_ids,
token_idx, token_idx,
batch.past_key_values, batch.past_key_values,
@ -955,10 +948,7 @@ class CausalLM(Model):
top_token_logprobs = req_data['top_token_logprobs'] top_token_logprobs = req_data['top_token_logprobs']
# Append next token to all tokens # Append next token to all tokens
if self.is_optimized_for_gaudi:
all_input_ids[input_length] = next_token_id all_input_ids[input_length] = next_token_id
else:
all_input_ids = torch.cat([all_input_ids, next_token_id])
new_input_length = input_length + 1 new_input_length = input_length + 1
# Generated token # Generated token

View File

@ -59,7 +59,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
async def Warmup(self, request, context): async def Warmup(self, request, context):
def batch_from_pb(batch): def batch_from_pb(batch):
return self.model.batch_type.from_pb( return self.model.batch_type.from_pb(
batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi batch, self.model.tokenizer, self.model.dtype, self.model.device
) )
batches = [batch_from_pb(batch) for batch in request.batches] batches = [batch_from_pb(batch) for batch in request.batches]
@ -69,7 +69,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
async def Prefill(self, request, context): async def Prefill(self, request, context):
batch = self.model.batch_type.from_pb( batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device, self.model.is_optimized_for_gaudi request.batch, self.model.tokenizer, self.model.dtype, self.model.device
) )
generations, next_batch = self.model.generate_token([batch]) generations, next_batch = self.model.generate_token([batch])
self.cache.set(next_batch) self.cache.set(next_batch)