mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
add option to set the watermark gamma & delta from the launcher
This commit is contained in:
parent
299f7367a5
commit
0be0506a7e
@ -55,6 +55,10 @@ struct Args {
|
|||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
cors_allow_origin: Vec<String>,
|
cors_allow_origin: Vec<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
watermark_gamma: Option<f32>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
watermark_delta: Option<f32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> ExitCode {
|
fn main() -> ExitCode {
|
||||||
@ -88,6 +92,8 @@ fn main() -> ExitCode {
|
|||||||
json_output,
|
json_output,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
cors_allow_origin,
|
cors_allow_origin,
|
||||||
|
watermark_gamma,
|
||||||
|
watermark_delta,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
// Signal handler
|
// Signal handler
|
||||||
@ -243,6 +249,8 @@ fn main() -> ExitCode {
|
|||||||
huggingface_hub_cache,
|
huggingface_hub_cache,
|
||||||
weights_cache_override,
|
weights_cache_override,
|
||||||
disable_custom_kernels,
|
disable_custom_kernels,
|
||||||
|
watermark_gamma,
|
||||||
|
watermark_delta,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
status_sender,
|
status_sender,
|
||||||
shutdown,
|
shutdown,
|
||||||
@ -414,6 +422,8 @@ fn shard_manager(
|
|||||||
huggingface_hub_cache: Option<String>,
|
huggingface_hub_cache: Option<String>,
|
||||||
weights_cache_override: Option<String>,
|
weights_cache_override: Option<String>,
|
||||||
disable_custom_kernels: bool,
|
disable_custom_kernels: bool,
|
||||||
|
watermark_gamma: Option<f32>,
|
||||||
|
watermark_delta: Option<f32>,
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
status_sender: mpsc::Sender<ShardStatus>,
|
status_sender: mpsc::Sender<ShardStatus>,
|
||||||
shutdown: Arc<Mutex<bool>>,
|
shutdown: Arc<Mutex<bool>>,
|
||||||
@ -494,6 +504,16 @@ fn shard_manager(
|
|||||||
env.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
|
env.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Watermark Gamma
|
||||||
|
if let Some(watermark_gamma) = watermark_gamma {
|
||||||
|
env.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Watermark Delta
|
||||||
|
if let Some(watermark_delta) = watermark_delta {
|
||||||
|
env.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
|
||||||
|
}
|
||||||
|
|
||||||
// Start process
|
// Start process
|
||||||
tracing::info!("Starting shard {rank}");
|
tracing::info!("Starting shard {rank}");
|
||||||
let mut p = match Popen::create(
|
let mut p = match Popen::create(
|
||||||
|
@ -13,19 +13,23 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import LogitsProcessor
|
from transformers import LogitsProcessor
|
||||||
|
|
||||||
|
GAMMA = os.getenv("WATERMARK_GAMMA", 0.5)
|
||||||
|
DELTA = os.getenv("WATERMARK_DELTA", 2.0)
|
||||||
|
|
||||||
|
|
||||||
class WatermarkLogitsProcessor(LogitsProcessor):
|
class WatermarkLogitsProcessor(LogitsProcessor):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
gamma: float = 0.5,
|
gamma: float = GAMMA,
|
||||||
delta: float = 2.0,
|
delta: float = DELTA,
|
||||||
hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
|
hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
|
||||||
device: str = "cpu",
|
device: str = "cpu",
|
||||||
):
|
):
|
||||||
# watermarking parameters
|
# watermarking parameters
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
@ -36,7 +40,7 @@ class WatermarkLogitsProcessor(LogitsProcessor):
|
|||||||
|
|
||||||
def _seed_rng(self, input_ids: torch.LongTensor) -> None:
|
def _seed_rng(self, input_ids: torch.LongTensor) -> None:
|
||||||
assert (
|
assert (
|
||||||
input_ids.shape[-1] >= 1
|
input_ids.shape[-1] >= 1
|
||||||
), "requires at least a 1 token prefix sequence to seed rng"
|
), "requires at least a 1 token prefix sequence to seed rng"
|
||||||
prev_token = input_ids[-1].item()
|
prev_token = input_ids[-1].item()
|
||||||
self.rng.manual_seed(self.hash_key * prev_token)
|
self.rng.manual_seed(self.hash_key * prev_token)
|
||||||
@ -54,7 +58,7 @@ class WatermarkLogitsProcessor(LogitsProcessor):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _calc_greenlist_mask(
|
def _calc_greenlist_mask(
|
||||||
scores: torch.FloatTensor, greenlist_token_ids
|
scores: torch.FloatTensor, greenlist_token_ids
|
||||||
) -> torch.BoolTensor:
|
) -> torch.BoolTensor:
|
||||||
green_tokens_mask = torch.zeros_like(scores)
|
green_tokens_mask = torch.zeros_like(scores)
|
||||||
green_tokens_mask[-1, greenlist_token_ids] = 1
|
green_tokens_mask[-1, greenlist_token_ids] = 1
|
||||||
@ -63,13 +67,13 @@ class WatermarkLogitsProcessor(LogitsProcessor):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _bias_greenlist_logits(
|
def _bias_greenlist_logits(
|
||||||
scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float
|
scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
|
scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
assert len(input_ids) == 1
|
assert len(input_ids) == 1
|
||||||
greenlist_ids = self._get_greenlist_ids(input_ids[0])
|
greenlist_ids = self._get_greenlist_ids(input_ids[0])
|
||||||
|
Loading…
Reference in New Issue
Block a user