fp8 compressed tensors w8a8 support for Gaudi backend (#3242)

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2025-05-28 20:54:20 +08:00 committed by GitHub
parent 1883a62a94
commit f14044009a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 548 additions and 59 deletions

View File

@ -99,6 +99,7 @@ RUN cd server && \
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
pip install . --no-cache-dir
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git@bmax_fix
RUN pip install compressed-tensors==0.9.1
# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark

View File

@ -19,6 +19,7 @@ class Quantization(str, Enum):
gptq = "gptq"
awq = "awq"
fp8 = "fp8"
compressed_tensors = "compressed-tensors"
class Dtype(str, Enum):
@ -109,6 +110,7 @@ def serve(
"gptq",
"awq",
"fp8",
"compressed-tensors",
}:
raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."

View File

@ -0,0 +1,3 @@
from .loader import CompressedTensorsLoader
__all__ = ["CompressedTensorsLoader"]

View File

@ -0,0 +1,169 @@
from typing import Any, Dict, List, Union
from compressed_tensors import QuantizationConfig, QuantizationStatus
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization import (
QuantizationScheme,
QuantizationType,
find_name_or_class_matches,
)
from loguru import logger
from pydantic import ValidationError
from torch import nn
from text_generation_server.layers.compressed_tensors.w8an_fp import W8ANFpLoader
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
UnquantizedWeight,
Weights,
WeightsLoader,
)
# compressed-tensors can match modules as quantization targets. However,
# they need to be objects rather than classes or class names. Since we
# need to match `Linear` targets, make an instance that can be re-used.
_EMPTY_LINEAR: nn.Module = nn.Linear(0, 0)
class CompressedTensorsLoader(WeightsLoader):
"""Loader for checkpoints stored in the compressed-tensors format."""
def __init__(self, config: Dict[str, Any]):
quantization_config_raw = config.get("quantization_config")
if quantization_config_raw is None:
# `compression_config` was renamed to `quantization_config`; support
# retained for backward compatibility.
quantization_config_raw = config.get("compression_config")
if quantization_config_raw is None:
raise ValueError(
"Checkpoint does not have compressed-tensors configuration"
)
try:
quantization_config = QuantizationConfig.model_validate(
quantization_config_raw
)
except ValidationError as e:
raise ValueError("Cannot parse compressed-tensors configuration") from e
if quantization_config.quantization_status not in (
QuantizationStatus.COMPRESSED,
QuantizationStatus.FROZEN,
):
raise ValueError(
f"Model quantization was not finished, status was: {quantization_config.quantization_status}"
)
self.ignore = (
quantization_config.ignore if quantization_config.ignore is not None else []
)
self.loaders = self._get_target_loaders(quantization_config)
for target, loader in self.loaders.items():
log_once(
logger.info,
f"Using {loader} for compressed-tensors target '{target}'",
)
def get_weights(self, weights: Weights, prefix: str):
loader = self._lookup_loader(prefix)
return loader.get_weights(weights, prefix)
def get_weights_col_packed(
self,
weights: "Weights",
prefix: str,
block_sizes: Union[int, List[int]],
):
loader = self._lookup_loader(prefix)
return loader.get_weights_col_packed(weights, prefix, block_sizes)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
loader = self._lookup_loader(prefixes[0])
return loader.get_multi_weights_col(weights, prefixes, dim)
def get_multi_weights(self, weights: Weights, prefixes: List[str], dim: int):
loader = self._lookup_loader(prefixes[0])
return loader.get_multi_weights(weights, prefixes, dim)
def get_weights_row(self, weights: Weights, prefix: str):
loader = self._lookup_loader(prefix)
return loader.get_weights_row(weights, prefix)
def _get_target_loaders(
self, quantization_config: QuantizationConfig
) -> Dict[str, WeightsLoader]:
"""
A compressed-tensors checkpoint can use different quantizations
for different targets. This method returns a dictionary with a
loader per target.
"""
loaders: Dict[str, WeightsLoader] = {}
format = quantization_config.format
for group_name, group in quantization_config.config_groups.items():
# The group configuration can be a string, but does that ever
# happen in a serialized quantization config?
assert isinstance(group, QuantizationScheme)
loader = self._create_loader_for_group(format, group_name, group)
# A quantized parameter group can have multiple targets, add the
# loader for all the targets.
for target in group.targets:
if target in loaders:
raise ValueError(
f"Target '{target} has multiple configured loaders'"
)
loaders[target] = loader
return loaders
def _create_loader_for_group(
self, format: str, group_name: str, group: QuantizationScheme
) -> WeightsLoader:
"""
Find and create a loader for the group with the given quantization
scheme.
"""
# NOTE: we ignore group.output_activations because we don't support
# output quantization yet.
input_activations = group.input_activations
weights = group.weights
if (
format
in {
CompressionFormat.float_quantized.value,
CompressionFormat.naive_quantized.value,
}
and weights is not None
and weights.type == QuantizationType.FLOAT
and weights.num_bits == 8
):
# FP W8A8 or W8A16.
return W8ANFpLoader(input_activations=input_activations, weights=weights)
else:
raise ValueError(
f"Group '{group_name}' has unsupported compressed-tensors configurtion"
)
def _lookup_loader(self, prefix: str) -> WeightsLoader:
"""
Look up the loader to use for a given parameter name (prefix).
"""
if len(find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.ignore)) > 0:
return DefaultWeightsLoader(UnquantizedWeight)
# We currently only handle linear layers, so unconditionally pass
# a `Linear` instance.
targets = find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.loaders.keys())
if len(targets) == 0:
raise ValueError(
f"Cannot find compressed-tensors target for prefix: {prefix}"
)
return self.loaders[targets[0]]

View File

@ -0,0 +1,253 @@
from typing import List, Optional, Union
import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
from text_generation_server.layers.fp8 import (
Fp8Weight,
_load_scalar_or_matrix_scale,
requantize_with_max_scale,
)
from text_generation_server.utils.weights import Weights, WeightsLoader
class W8ANFpLoader(WeightsLoader):
"""
Loader for W8A8/W8A16 FP compressed-tensors parameters.
"""
def __init__(
self,
*,
input_activations: Optional[QuantizationArgs],
weights: QuantizationArgs,
):
assert weights.type == QuantizationType.FLOAT and weights.num_bits == 8
# We ignore the `strategy` option which sets the scales to be
# per-tensor, per-channel or per-token. What scales are supported
# is dependent on the kernels used (e.g. cutlass can do tokenwise,
# Torch cannot, and FP8-Marlin does not quantize inputs at all).
# So, instead we try to use the best-possible configuration.
self.load_weight_scale = not weights.dynamic
self.load_input_scale = (
input_activations is not None and not input_activations.dynamic
)
self.force_w8a16 = (
input_activations is not None and input_activations.num_bits == 16
)
def __str__(self) -> str:
def scale_to_str(scale):
return "static" if scale else "dynamic"
quantization_type = f"W8A{16 if self.force_w8a16 else 8}"
return f"{self.__class__.__name__} ({quantization_type}, weight: {scale_to_str(self.load_weight_scale)}, input: {scale_to_str(self.load_input_scale)})"
def get_weights(self, weights: "Weights", prefix: str):
w = weights.get_tensor(f"{prefix}.weight")
weight_scale = None
if self.load_weight_scale:
weight_scale = (
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(w.shape[0])
)
logical_widths = [w.shape[0]]
w, weight_scale = requantize_with_max_scale(
w,
weight_scale.unsqueeze(-1).to(weights.device),
logical_widths,
weights.dtype,
)
input_scale = None
if self.load_input_scale:
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
w = weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
)
weight_scale = None
if self.load_weight_scale:
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if weight_scale.numel() > 1:
weight_scale = weights.get_packed_sharded(
f"{prefix}.weight_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
logical_widths = [w.shape[0]]
w, weight_scale = requantize_with_max_scale(
w,
weight_scale.unsqueeze(-1).to(weights.device),
logical_widths,
weights.dtype,
)
input_scale = None
if self.load_input_scale:
input_scale = weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
if input_scale.numel() > 1:
input_scale = weights.get_packed_sharded(
f"{prefix}.input_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
input_scale = input_scale.reshape(-1).max()
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
w = [
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
]
shapes = [x.shape for x in w]
# Concat then send to the device
w = torch.cat(w, dim=dim).to(weights.device)
weight_scale = None
if self.load_weight_scale:
weight_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
for p, shape in zip(prefixes, shapes)
]
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)
logical_widths = [x[0] for x in shapes]
w, weight_scale = requantize_with_max_scale(
w,
weight_scale.unsqueeze(-1).to(weights.device),
logical_widths,
weights.dtype,
)
input_scale = None
if self.load_input_scale:
input_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
for p, shape in zip(prefixes, shapes)
if weights.has_tensor(f"{p}.input_scale")
]
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
input_scale = (
torch.cat(input_scale, dim=0).reshape(-1).max()
if len(input_scale) != 0
else None
)
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)
def get_multi_weights(self, weights: "Weights", prefixes: List[str], dim: int):
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
w = [weights.get_tensor(f"{p}.weight", to_device=False) for p in prefixes]
shapes = [x.shape for x in w]
# Concat then send to the device
w = torch.cat(w, dim=dim).to(weights.device)
weight_scale = None
if self.load_weight_scale:
weight_scale = [
weights.get_tensor(f"{p}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(shape[0])
for p, shape in zip(prefixes, shapes)
]
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)
logical_widths = [x[0] for x in shapes]
w, weight_scale = requantize_with_max_scale(
w,
weight_scale.unsqueeze(-1).to(weights.device),
logical_widths,
weights.dtype,
)
input_scale = None
if self.load_input_scale:
input_scale = [
weights.get_tensor(f"{p}.input_scale", to_dtype=False)
.reshape(-1)
.expand(shape[0])
for p, shape in zip(prefixes, shapes)
if weights.has_tensor(f"{p}.input_scale")
]
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
input_scale = (
torch.cat(input_scale, dim=0).reshape(-1).max()
if len(input_scale) != 0
else None
)
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)
def get_weights_row(self, weights: "Weights", prefix: str):
w = weights.get_sharded(f"{prefix}.weight", dim=1)
weight_scale = None
if self.load_weight_scale:
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
logical_widths = [w.shape[0]]
w, weight_scale = requantize_with_max_scale(
w,
weight_scale.unsqueeze(-1).to(weights.device),
logical_widths,
weights.dtype,
)
input_scale = None
if self.load_input_scale:
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)
return Fp8Weight(
weight=w,
weight_scale=weight_scale,
input_scale=input_scale,
dtype=weights.dtype,
force_w8a16=self.force_w8a16,
)

View File

@ -207,7 +207,7 @@ def requantize_with_max_scale(
for idx, logical_width in enumerate(logical_widths):
end = start + logical_width
weight_dq = per_tensor_dequantize(
weight[start:end, :], weight_scale[idx], dtype
weight[start:end, :], weight_scale[start:end, :], dtype
)
weight[start:end, :], max_w_scale_normalized = fp8_quantize(
weight_dq, max_w_scale
@ -270,6 +270,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
)
# FP8 branch
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
scale = scale.reshape(-1).expand(w.shape[0])
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
)
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
@ -278,10 +283,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
.reshape(-1)
.max()
)
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(0), logical_widths, weights.dtype
)
return Fp8Weight(
weight=w,
@ -316,6 +317,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
block_sizes=block_sizes,
to_dtype=False,
)
scale = scale.reshape(-1).expand(w.shape[0])
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
)
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
@ -330,10 +336,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
to_dtype=False,
)
input_scale = input_scale.reshape(-1).max()
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(0), logical_widths, weights.dtype
)
return Fp8Weight(
weight=w,
@ -380,6 +382,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
]
scale = torch.cat(scale, dim=0).reshape(-1)
logical_widths = [x[0] for x in shapes]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
)
input_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
for p, shape in zip(prefixes, shapes)
@ -392,11 +399,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
else None
)
logical_widths = [x[0] for x in shapes]
w, scale = requantize_with_max_scale(
w, scale.to(weights.device), logical_widths, weights.dtype
)
return Fp8Weight(
weight=w,
weight_scale=scale,
@ -435,11 +437,18 @@ class HybridFP8UnquantLoader(WeightsLoader):
)
scale = [
weights.get_tensor(f"{p}.weight_scale", to_dtype=False).reshape(-1)
for p in prefixes
weights.get_tensor(f"{p}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(shape[0])
for p, shape in zip(prefixes, shapes)
]
scale = torch.cat(scale, dim=0).reshape(-1)
logical_widths = [x[0] for x in shapes]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
)
input_scale = [
weights.get_tensor(f"{p}.input_scale", to_dtype=False).reshape(-1)
for p in prefixes
@ -452,11 +461,6 @@ class HybridFP8UnquantLoader(WeightsLoader):
else None
)
logical_widths = [x[0] for x in shapes]
w, scale = requantize_with_max_scale(
w, scale.to(weights.device), logical_widths, weights.dtype
)
return Fp8Weight(
weight=w,
weight_scale=scale,
@ -485,7 +489,15 @@ class HybridFP8UnquantLoader(WeightsLoader):
weight_block_size=self.weight_block_size,
)
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
scale = (
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(w.shape[0])
)
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype
)
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
@ -494,10 +506,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
.reshape(-1)
.max()
)
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(0), logical_widths, weights.dtype
)
return Fp8Weight(
weight=w,
weight_scale=scale,
@ -615,45 +624,32 @@ class Fp8Linear(torch.nn.Module):
weight_block_size=weight_block_size,
)
@classmethod
def get_shared_device_identity(cls, device):
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
if device not in cls._device_identity_cache:
cls._device_identity_cache[device] = torch.ones(1, device=device)
return cls._device_identity_cache[device]
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.weight_block_size is not None:
if self.weight_block_size is not None or self.input_scale is None:
return apply_block_fp8_linear_hpu_dynamic(
input, self.qweight, self.scale, self.input_scale, self.bias
)
qinput, scale = fp8_quantize(
input,
self.input_scale,
scale_upper_bound=self.scale_upper_bound,
scalar=True,
)
output = torch._scaled_mm(
qinput,
self.qweight.t(),
out_dtype=self.dtype,
scale_a=scale,
scale_b=self.scale,
x_fp8 = torch.ops.hpu.cast_to_fp8_v2(
input, 1.0 / self.input_scale, False, False, torch.float8_e4m3fn
)[0]
return torch.ops.hpu.fp8_gemm_v2(
A=x_fp8,
trans_A=False,
B=self.qweight,
trans_B=True,
D=None,
out_dtype=input.dtype,
A_scale_inv=self.input_scale,
B_scale_inv=self.scale,
bias=self.bias,
accumulate=False,
)
if isinstance(output, tuple) and len(output) == 2:
output = output[0]
return output
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
scale = weights.get_tensor(prefix, to_dtype=False)
if scale.numel() > 1:
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
return scale.reshape(-1)
return scale.reshape(-1).expand(shape[0])

View File

@ -276,6 +276,63 @@ class GPTQWeightsLoader(WeightsLoader):
use_exllama=use_exllama,
)
def get_multi_weights(self, weights: Weights, prefixes: List[str], dim: int):
if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):
return DefaultWeightsLoader.get_multi_weights(weights, prefixes, dim)
try:
qweight = torch.cat(
[weights.get_tensor(f"{p}.qweight") for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
)
scales = torch.cat([weights.get_tensor(f"{p}.scales") for p in prefixes], dim=1)
self._get_gptq_params(weights)
qzeros = torch.cat([weights.get_tensor(f"{p}.qzeros") for p in prefixes], dim=1)
use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act
if self.quantize == "gptq" and self.quant_method == "gptq":
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
elif self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
).to(dtype=torch.int32)
else:
g_idx = None
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_awq_kernel=self.quantize == "awq",
use_exllama=use_exllama,
)
def get_weights_row(self, weights: Weights, prefix: str):
self._get_gptq_params(weights)

View File

@ -1412,7 +1412,6 @@ class FlashCausalLM(Model):
aliases=aliases,
weights_loader=weights_loader,
)
print(f"weights: {weights}")
prefix = None
model = model_class(prefix, config, weights)

View File

@ -122,6 +122,13 @@ def _get_quantizer_config(model_id, revision):
def get_loader(
quantize: Optional[str], model_id: str, revision: Optional[str]
) -> WeightsLoader:
if quantize == "compressed-tensors":
config = _get_config_json(model_id, revision, "config.json")
from text_generation_server.layers.compressed_tensors import (
CompressedTensorsLoader,
)
return CompressedTensorsLoader(config)
quantizer_config = _get_quantizer_config(model_id, revision)
if quantize in {"awq", "gptq"}:
from text_generation_server.layers.gptq import GPTQWeightsLoader

View File

@ -162,6 +162,11 @@ impl Allocator for SimpleAllocator {
tokens: u32,
_prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation> {
let mut tokens = tokens;
if self.is_hpu_device {
// need 1 slot for ping-pong optimization
tokens += 1;
}
// Apply window size
let (required_blocks, repeats) = {
let (tokens, repeats) = match self.window_size {
@ -176,8 +181,7 @@ impl Allocator for SimpleAllocator {
let required_blocks = tokens.div_ceil(self.block_size);
(required_blocks, repeats)
};
let mut tokens = tokens as usize;
let tokens = tokens as usize;
if required_blocks > self.free_blocks.len() as u32 {
None
} else {
@ -189,8 +193,6 @@ impl Allocator for SimpleAllocator {
.split_off(self.free_blocks.len() - required_blocks as usize);
if self.is_hpu_device {
blocks.sort();
// need 1 slot for ping-pong optimization
tokens += 1;
}
let mut slots =
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);