mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 07:42:06 +00:00
[gaudi] Refine logging for Gaudi warmup (#3222)
* Refine logging for Gaudi warmup * Make style * Make style 2 * Flash causal LM case * Add log_master & VLM cases * Black
This commit is contained in:
parent
b4d17f18ff
commit
f13e28c98d
@ -12,7 +12,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" PyTorch Llava-NeXT model."""
|
"""PyTorch Llava-NeXT model."""
|
||||||
|
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" PyTorch Idefics2 model."""
|
"""PyTorch Idefics2 model."""
|
||||||
|
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" PyTorch Idefics3 model."""
|
"""PyTorch Idefics3 model."""
|
||||||
|
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Idefics model configuration"""
|
"""Idefics model configuration"""
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" PyTorch Idefics model."""
|
"""PyTorch Idefics model."""
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" PyTorch IdeficsVision model: a copy of CLIPVisionModel using a simpler config object"""
|
"""PyTorch IdeficsVision model: a copy of CLIPVisionModel using a simpler config object"""
|
||||||
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
@ -1721,6 +1721,7 @@ class FlashCausalLM(Model):
|
|||||||
f"{dim}:{seq_len} "
|
f"{dim}:{seq_len} "
|
||||||
f"bypass:{bypass} "
|
f"bypass:{bypass} "
|
||||||
f"free_mem:{free_mem}"
|
f"free_mem:{free_mem}"
|
||||||
|
", this may take a while..."
|
||||||
)
|
)
|
||||||
log_master(logger.info, msg)
|
log_master(logger.info, msg)
|
||||||
|
|
||||||
@ -1772,6 +1773,11 @@ class FlashCausalLM(Model):
|
|||||||
total_batch_seq = 0.001
|
total_batch_seq = 0.001
|
||||||
total_mem = 0
|
total_mem = 0
|
||||||
available_mem = prompt_available_memory
|
available_mem = prompt_available_memory
|
||||||
|
msg = (
|
||||||
|
f"Prefill batch size list:{[bsz[0] for bsz in buckets]}\n"
|
||||||
|
f"Prefill sequence length list:{[seq[1] for seq in buckets]}\n"
|
||||||
|
)
|
||||||
|
log_master(logger.info, msg)
|
||||||
for i, (batch_size, seq_len) in enumerate(buckets):
|
for i, (batch_size, seq_len) in enumerate(buckets):
|
||||||
if batch_size * seq_len > self.max_batch_prefill_tokens:
|
if batch_size * seq_len > self.max_batch_prefill_tokens:
|
||||||
continue
|
continue
|
||||||
@ -1798,6 +1804,8 @@ class FlashCausalLM(Model):
|
|||||||
total_mem += used_mem
|
total_mem += used_mem
|
||||||
total_batch_seq += batch_seq
|
total_batch_seq += batch_seq
|
||||||
|
|
||||||
|
log_master(logger.info, "Prefill warmup successful.\n")
|
||||||
|
|
||||||
def ordering_function_max_bs(b):
|
def ordering_function_max_bs(b):
|
||||||
return (-b[0], b[1])
|
return (-b[0], b[1])
|
||||||
|
|
||||||
@ -1809,6 +1817,9 @@ class FlashCausalLM(Model):
|
|||||||
total_batch_seq = 0.001
|
total_batch_seq = 0.001
|
||||||
total_mem = 0
|
total_mem = 0
|
||||||
available_mem = free_mem - self.mem_reserved
|
available_mem = free_mem - self.mem_reserved
|
||||||
|
log_master(
|
||||||
|
logger.info, f"Decode batch size list:{[bsz[0] for bsz in buckets]}\n"
|
||||||
|
)
|
||||||
for i, (batch_size, block_num) in enumerate(buckets):
|
for i, (batch_size, block_num) in enumerate(buckets):
|
||||||
if batch_size > block_num:
|
if batch_size > block_num:
|
||||||
continue
|
continue
|
||||||
@ -1833,6 +1844,8 @@ class FlashCausalLM(Model):
|
|||||||
total_mem += used_mem
|
total_mem += used_mem
|
||||||
total_batch_seq += batch_seq
|
total_batch_seq += batch_seq
|
||||||
|
|
||||||
|
log_master(logger.info, "Decode warmup successful.\n")
|
||||||
|
|
||||||
log_master(
|
log_master(
|
||||||
logger.info,
|
logger.info,
|
||||||
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
|
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
|
||||||
|
@ -822,6 +822,9 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
total_batch_seq = 0.001
|
total_batch_seq = 0.001
|
||||||
total_mem = 0
|
total_mem = 0
|
||||||
available_mem = decode_available_memory
|
available_mem = decode_available_memory
|
||||||
|
log_master(
|
||||||
|
logger.info, f"Decode batch size list:{[bsz[0] for bsz in buckets]}\n"
|
||||||
|
)
|
||||||
for i, (batch_size, block_num) in enumerate(buckets):
|
for i, (batch_size, block_num) in enumerate(buckets):
|
||||||
if batch_size > block_num:
|
if batch_size > block_num:
|
||||||
continue
|
continue
|
||||||
@ -847,6 +850,8 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
total_mem += used_mem
|
total_mem += used_mem
|
||||||
total_batch_seq += batch_seq
|
total_batch_seq += batch_seq
|
||||||
|
|
||||||
|
log_master(logger.info, "Decode warmup successful.\n")
|
||||||
|
|
||||||
log_master(
|
log_master(
|
||||||
logger.info,
|
logger.info,
|
||||||
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
|
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
|
||||||
|
@ -398,6 +398,11 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
total_batch_seq = 0.001
|
total_batch_seq = 0.001
|
||||||
total_mem = 0
|
total_mem = 0
|
||||||
available_mem = prompt_available_memory
|
available_mem = prompt_available_memory
|
||||||
|
msg = (
|
||||||
|
f"Prefill batch size list:{[bsz[0] for bsz in buckets]}\n"
|
||||||
|
f"Prefill sequence length list:{[seq[1] for seq in buckets]}\n"
|
||||||
|
)
|
||||||
|
log_master(logger.info, msg)
|
||||||
for i, (batch_size, seq_len) in enumerate(buckets):
|
for i, (batch_size, seq_len) in enumerate(buckets):
|
||||||
if batch_size * seq_len > self.max_batch_prefill_tokens:
|
if batch_size * seq_len > self.max_batch_prefill_tokens:
|
||||||
continue
|
continue
|
||||||
@ -424,6 +429,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
total_mem += used_mem
|
total_mem += used_mem
|
||||||
total_batch_seq += batch_seq
|
total_batch_seq += batch_seq
|
||||||
|
|
||||||
|
log_master(logger.info, "Prefill warmup successful.\n")
|
||||||
|
|
||||||
def ordering_function_max_bs(b):
|
def ordering_function_max_bs(b):
|
||||||
return (-b[0], b[1])
|
return (-b[0], b[1])
|
||||||
|
|
||||||
@ -435,6 +442,9 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
total_batch_seq = 0.001
|
total_batch_seq = 0.001
|
||||||
total_mem = 0
|
total_mem = 0
|
||||||
available_mem = free_mem - self.mem_reserved
|
available_mem = free_mem - self.mem_reserved
|
||||||
|
log_master(
|
||||||
|
logger.info, f"Decode batch size list:{[bsz[0] for bsz in buckets]}\n"
|
||||||
|
)
|
||||||
for i, (batch_size, block_num) in enumerate(buckets):
|
for i, (batch_size, block_num) in enumerate(buckets):
|
||||||
if batch_size > block_num:
|
if batch_size > block_num:
|
||||||
continue
|
continue
|
||||||
@ -459,6 +469,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
total_mem += used_mem
|
total_mem += used_mem
|
||||||
total_batch_seq += batch_seq
|
total_batch_seq += batch_seq
|
||||||
|
|
||||||
|
log_master(logger.info, "Decode warmup successful.\n")
|
||||||
|
|
||||||
log_master(
|
log_master(
|
||||||
logger.info,
|
logger.info,
|
||||||
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
|
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
|
||||||
|
@ -8,7 +8,7 @@ import torch
|
|||||||
|
|
||||||
|
|
||||||
def find_segments(
|
def find_segments(
|
||||||
adapter_indices: Union[torch.Tensor, List[int]]
|
adapter_indices: Union[torch.Tensor, List[int]],
|
||||||
) -> Tuple[List[int], List[int]]:
|
) -> Tuple[List[int], List[int]]:
|
||||||
segments = [0]
|
segments = [0]
|
||||||
segment_indices = []
|
segment_indices = []
|
||||||
|
Loading…
Reference in New Issue
Block a user