add option to set the watermark gamma & delta from the launcher

This commit is contained in:
OlivierDehaene 2023-03-01 17:14:44 +01:00
parent 299f7367a5
commit 0be0506a7e
2 changed files with 34 additions and 10 deletions

View File

@ -55,6 +55,10 @@ struct Args {
otlp_endpoint: Option<String>, otlp_endpoint: Option<String>,
#[clap(long, env)] #[clap(long, env)]
cors_allow_origin: Vec<String>, cors_allow_origin: Vec<String>,
#[clap(long, env)]
watermark_gamma: Option<f32>,
#[clap(long, env)]
watermark_delta: Option<f32>,
} }
fn main() -> ExitCode { fn main() -> ExitCode {
@ -88,6 +92,8 @@ fn main() -> ExitCode {
json_output, json_output,
otlp_endpoint, otlp_endpoint,
cors_allow_origin, cors_allow_origin,
watermark_gamma,
watermark_delta,
} = args; } = args;
// Signal handler // Signal handler
@ -243,6 +249,8 @@ fn main() -> ExitCode {
huggingface_hub_cache, huggingface_hub_cache,
weights_cache_override, weights_cache_override,
disable_custom_kernels, disable_custom_kernels,
watermark_gamma,
watermark_delta,
otlp_endpoint, otlp_endpoint,
status_sender, status_sender,
shutdown, shutdown,
@ -414,6 +422,8 @@ fn shard_manager(
huggingface_hub_cache: Option<String>, huggingface_hub_cache: Option<String>,
weights_cache_override: Option<String>, weights_cache_override: Option<String>,
disable_custom_kernels: bool, disable_custom_kernels: bool,
watermark_gamma: Option<f32>,
watermark_delta: Option<f32>,
otlp_endpoint: Option<String>, otlp_endpoint: Option<String>,
status_sender: mpsc::Sender<ShardStatus>, status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<Mutex<bool>>, shutdown: Arc<Mutex<bool>>,
@ -494,6 +504,16 @@ fn shard_manager(
env.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into())) env.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
} }
// Watermark Gamma
if let Some(watermark_gamma) = watermark_gamma {
env.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into()))
}
// Watermark Delta
if let Some(watermark_delta) = watermark_delta {
env.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
}
// Start process // Start process
tracing::info!("Starting shard {rank}"); tracing::info!("Starting shard {rank}");
let mut p = match Popen::create( let mut p = match Popen::create(

View File

@ -13,19 +13,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import torch import torch
from transformers import LogitsProcessor from transformers import LogitsProcessor
GAMMA = os.getenv("WATERMARK_GAMMA", 0.5)
DELTA = os.getenv("WATERMARK_DELTA", 2.0)
class WatermarkLogitsProcessor(LogitsProcessor): class WatermarkLogitsProcessor(LogitsProcessor):
def __init__( def __init__(
self, self,
vocab_size: int, vocab_size: int,
gamma: float = 0.5, gamma: float = GAMMA,
delta: float = 2.0, delta: float = DELTA,
hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
device: str = "cpu", device: str = "cpu",
): ):
# watermarking parameters # watermarking parameters
self.vocab_size = vocab_size self.vocab_size = vocab_size
@ -36,7 +40,7 @@ class WatermarkLogitsProcessor(LogitsProcessor):
def _seed_rng(self, input_ids: torch.LongTensor) -> None: def _seed_rng(self, input_ids: torch.LongTensor) -> None:
assert ( assert (
input_ids.shape[-1] >= 1 input_ids.shape[-1] >= 1
), "requires at least a 1 token prefix sequence to seed rng" ), "requires at least a 1 token prefix sequence to seed rng"
prev_token = input_ids[-1].item() prev_token = input_ids[-1].item()
self.rng.manual_seed(self.hash_key * prev_token) self.rng.manual_seed(self.hash_key * prev_token)
@ -54,7 +58,7 @@ class WatermarkLogitsProcessor(LogitsProcessor):
@staticmethod @staticmethod
def _calc_greenlist_mask( def _calc_greenlist_mask(
scores: torch.FloatTensor, greenlist_token_ids scores: torch.FloatTensor, greenlist_token_ids
) -> torch.BoolTensor: ) -> torch.BoolTensor:
green_tokens_mask = torch.zeros_like(scores) green_tokens_mask = torch.zeros_like(scores)
green_tokens_mask[-1, greenlist_token_ids] = 1 green_tokens_mask[-1, greenlist_token_ids] = 1
@ -63,13 +67,13 @@ class WatermarkLogitsProcessor(LogitsProcessor):
@staticmethod @staticmethod
def _bias_greenlist_logits( def _bias_greenlist_logits(
scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float
) -> torch.Tensor: ) -> torch.Tensor:
scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
return scores return scores
def __call__( def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor: ) -> torch.FloatTensor:
assert len(input_ids) == 1 assert len(input_ids) == 1
greenlist_ids = self._get_greenlist_ids(input_ids[0]) greenlist_ids = self._get_greenlist_ids(input_ids[0])