add option to set the watermark gamma & delta from the launcher

This commit is contained in:
OlivierDehaene 2023-03-01 17:14:44 +01:00
parent 299f7367a5
commit 0be0506a7e
2 changed files with 34 additions and 10 deletions

View File

@ -55,6 +55,10 @@ struct Args {
otlp_endpoint: Option<String>,
#[clap(long, env)]
cors_allow_origin: Vec<String>,
#[clap(long, env)]
watermark_gamma: Option<f32>,
#[clap(long, env)]
watermark_delta: Option<f32>,
}
fn main() -> ExitCode {
@ -88,6 +92,8 @@ fn main() -> ExitCode {
json_output,
otlp_endpoint,
cors_allow_origin,
watermark_gamma,
watermark_delta,
} = args;
// Signal handler
@ -243,6 +249,8 @@ fn main() -> ExitCode {
huggingface_hub_cache,
weights_cache_override,
disable_custom_kernels,
watermark_gamma,
watermark_delta,
otlp_endpoint,
status_sender,
shutdown,
@ -414,6 +422,8 @@ fn shard_manager(
huggingface_hub_cache: Option<String>,
weights_cache_override: Option<String>,
disable_custom_kernels: bool,
watermark_gamma: Option<f32>,
watermark_delta: Option<f32>,
otlp_endpoint: Option<String>,
status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<Mutex<bool>>,
@ -494,6 +504,16 @@ fn shard_manager(
env.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
}
// Watermark Gamma
if let Some(watermark_gamma) = watermark_gamma {
env.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into()))
}
// Watermark Delta
if let Some(watermark_delta) = watermark_delta {
env.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
}
// Start process
tracing::info!("Starting shard {rank}");
let mut p = match Popen::create(

View File

@ -13,19 +13,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from transformers import LogitsProcessor
GAMMA = os.getenv("WATERMARK_GAMMA", 0.5)
DELTA = os.getenv("WATERMARK_DELTA", 2.0)
class WatermarkLogitsProcessor(LogitsProcessor):
def __init__(
self,
vocab_size: int,
gamma: float = 0.5,
delta: float = 2.0,
hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
device: str = "cpu",
self,
vocab_size: int,
gamma: float = GAMMA,
delta: float = DELTA,
hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
device: str = "cpu",
):
# watermarking parameters
self.vocab_size = vocab_size
@ -36,7 +40,7 @@ class WatermarkLogitsProcessor(LogitsProcessor):
def _seed_rng(self, input_ids: torch.LongTensor) -> None:
assert (
input_ids.shape[-1] >= 1
input_ids.shape[-1] >= 1
), "requires at least a 1 token prefix sequence to seed rng"
prev_token = input_ids[-1].item()
self.rng.manual_seed(self.hash_key * prev_token)
@ -54,7 +58,7 @@ class WatermarkLogitsProcessor(LogitsProcessor):
@staticmethod
def _calc_greenlist_mask(
scores: torch.FloatTensor, greenlist_token_ids
scores: torch.FloatTensor, greenlist_token_ids
) -> torch.BoolTensor:
green_tokens_mask = torch.zeros_like(scores)
green_tokens_mask[-1, greenlist_token_ids] = 1
@ -63,13 +67,13 @@ class WatermarkLogitsProcessor(LogitsProcessor):
@staticmethod
def _bias_greenlist_logits(
scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float
scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float
) -> torch.Tensor:
scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
return scores
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
assert len(input_ids) == 1
greenlist_ids = self._get_greenlist_ids(input_ids[0])