mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 20:42:06 +00:00
Temporary implem of torch.compile on our stuff.
This commit is contained in:
parent
6f15ac60b2
commit
78f87d5a0c
@ -798,11 +798,13 @@ class FlashCausalLM(Model):
|
|||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.compiled_model = torch.compile(self.model, mode="reduce-overhead")
|
||||||
|
|
||||||
if ENABLE_CUDA_GRAPHS:
|
if ENABLE_CUDA_GRAPHS:
|
||||||
try:
|
try:
|
||||||
logger.info("Experimental support for Cuda Graphs is enabled")
|
logger.info("Experimental support for Cuda Graphs is enabled")
|
||||||
# Warmup cuda graphs
|
# Warmup cuda graphs
|
||||||
for bs in [1, 2, 4] + [8 * i for i in range(8)]:
|
for bs in [1]:
|
||||||
if self.speculate is None or self.speculate + 1 <= bs:
|
if self.speculate is None or self.speculate + 1 <= bs:
|
||||||
self.cuda_graph_warmup(bs, max_s, max_bt)
|
self.cuda_graph_warmup(bs, max_s, max_bt)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -881,7 +883,19 @@ class FlashCausalLM(Model):
|
|||||||
or cuda_graph is None
|
or cuda_graph is None
|
||||||
or batch.speculative_ids is not None
|
or batch.speculative_ids is not None
|
||||||
):
|
):
|
||||||
return self.model.forward(
|
if cu_seqlen_prefill is None:
|
||||||
|
return self.compiled_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
max_s=max_s,
|
||||||
|
lm_head_indices=lm_head_indices,
|
||||||
|
)
|
||||||
|
return self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
@ -68,6 +68,7 @@ class FlashLlama(FlashCausalLM):
|
|||||||
weights._set_gptq_params(model_id, revision)
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
model = FlashLlamaForCausalLM(config, weights)
|
model = FlashLlamaForCausalLM(config, weights)
|
||||||
|
# model = torch.compile(model, mode="reduce-overhead")
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashLlama, self).__init__(
|
super(FlashLlama, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -495,18 +495,33 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
logits, speculative_logits = self.model.forward(
|
|
||||||
input_ids=input_ids,
|
if cu_seqlen_prefill is None:
|
||||||
position_ids=position_ids,
|
logits, speculative_logits = self.compiled_model(
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
input_ids=input_ids,
|
||||||
kv_cache=kv_cache,
|
position_ids=position_ids,
|
||||||
block_tables=block_tables,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
slots=slots,
|
kv_cache=kv_cache,
|
||||||
input_lengths=input_lengths,
|
block_tables=block_tables,
|
||||||
max_s=max_s,
|
slots=slots,
|
||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
input_lengths=input_lengths,
|
||||||
lm_head_indices=lm_head_indices,
|
max_s=max_s,
|
||||||
)
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
|
lm_head_indices=lm_head_indices,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits, speculative_logits = self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
max_s=max_s,
|
||||||
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
|
lm_head_indices=lm_head_indices,
|
||||||
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
batch.prefill_cache_indices = None
|
batch.prefill_cache_indices = None
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
@ -149,7 +149,19 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
batch = batches[0]
|
batch = batches[0]
|
||||||
concat_ns = None
|
concat_ns = None
|
||||||
|
|
||||||
generations, next_batch, timings = self.model.generate_token(batch)
|
torch.profiler._utils._init_for_cuda_graphs()
|
||||||
|
# prof = torch.profiler.profile()
|
||||||
|
# if self.model.rank != 0:
|
||||||
|
if True:
|
||||||
|
import contextlib
|
||||||
|
|
||||||
|
prof = contextlib.nullcontext()
|
||||||
|
else:
|
||||||
|
prof = torch.profiler.profile()
|
||||||
|
with prof:
|
||||||
|
generations, next_batch, timings = self.model.generate_token(batch)
|
||||||
|
# if self.model.rank == 0:
|
||||||
|
# prof.export_chrome_trace(f"out_rank_0.json")
|
||||||
self.cache.set(next_batch)
|
self.cache.set(next_batch)
|
||||||
|
|
||||||
return generate_pb2.DecodeResponse(
|
return generate_pb2.DecodeResponse(
|
||||||
|
@ -507,27 +507,27 @@ class TensorParallelHead(SuperLayer):
|
|||||||
return super().forward(input)
|
return super().forward(input)
|
||||||
|
|
||||||
world_size = self.process_group.size()
|
world_size = self.process_group.size()
|
||||||
if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
|
# if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
|
||||||
out_dim = self.linear.weight.shape[0]
|
# out_dim = self.linear.weight.shape[0]
|
||||||
|
|
||||||
if input.shape[0] == 1:
|
# if input.shape[0] == 1:
|
||||||
world_out = input.new_empty(1, out_dim * world_size)
|
# world_out = input.new_empty(1, out_dim * world_size)
|
||||||
local_out = input.new_empty(1, out_dim)
|
# local_out = input.new_empty(1, out_dim)
|
||||||
gather_input = local_out
|
# gather_input = local_out
|
||||||
else:
|
# else:
|
||||||
world_out = input.new_empty(out_dim * world_size, input.shape[0])
|
# world_out = input.new_empty(out_dim * world_size, input.shape[0])
|
||||||
gather_input = input.new_empty(out_dim, input.shape[0])
|
# gather_input = input.new_empty(out_dim, input.shape[0])
|
||||||
local_out = gather_input.T
|
# local_out = gather_input.T
|
||||||
|
|
||||||
torch.mm(input, self.linear.weight.T, out=local_out)
|
# torch.mm(input, self.linear.weight.T, out=local_out)
|
||||||
|
|
||||||
torch.distributed.all_gather_into_tensor(
|
# torch.distributed.all_gather_into_tensor(
|
||||||
world_out, gather_input, group=self.process_group
|
# world_out, gather_input, group=self.process_group
|
||||||
)
|
# )
|
||||||
|
|
||||||
if input.shape[0] == 1:
|
# if input.shape[0] == 1:
|
||||||
return world_out
|
# return world_out
|
||||||
return world_out.T
|
# return world_out.T
|
||||||
|
|
||||||
output = super().forward(input)
|
output = super().forward(input)
|
||||||
world_output = [
|
world_output = [
|
||||||
@ -786,6 +786,7 @@ try:
|
|||||||
self._sin_k_cached = None
|
self._sin_k_cached = None
|
||||||
self.scaling_factor = scaling_factor
|
self.scaling_factor = scaling_factor
|
||||||
self.dynamic_args = None
|
self.dynamic_args = None
|
||||||
|
self._update_cos_sin_cache(torch.float16, inv_freq.device, seqlen=4096)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -929,8 +930,6 @@ try:
|
|||||||
# But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal.
|
# But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal.
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
|
|
||||||
|
|
||||||
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
||||||
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
||||||
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
|
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
|
||||||
|
Loading…
Reference in New Issue
Block a user