diff --git a/README.md b/README.md index 1c2b213..d6307ea 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,7 @@ We include and implement the following attacks, as described in our paper. - [Min-K% Prob](https://swj0419.github.io/detect-pretrain.github.io/) (`min_k`). Uses k% of tokens with minimum likelihood for score computation. - [Min-K%++](https://zjysteven.github.io/mink-plus-plus/) (`min_k++`). Uses k% of tokens with minimum *normalized* likelihood for score computation. - [Gradient Norm](https://arxiv.org/abs/2402.17012) (`gradnorm`). Uses gradient norm of the target datapoint as score. +- [ReCaLL](https://royxie.com/recall-project-page/)(`recall`). Operates by comparing the unconditional and conditional log-likelihoods. ## Adding your own dataset diff --git a/configs/recall.json b/configs/recall.json new file mode 100644 index 0000000..c6753d5 --- /dev/null +++ b/configs/recall.json @@ -0,0 +1,41 @@ +{ + "experiment_name": "recall", + "base_model": "EleutherAI/pythia-1.4b", + "dataset_member": "the_pile", + "dataset_nonmember": "the_pile", + "min_words": 100, + "max_words": 200, + "max_tokens": 512, + "max_data": 100000, + "output_name": "unified_mia", + "specific_source": "Github_ngram_13_<0.8_truncated", + "n_samples": 1000, + "blackbox_attacks": ["loss", "ref", "zlib", "min_k", "min_k++", "recall"], + "env_config": { + "results": "results_new", + "device": "cuda:0", + "device_aux": "cuda:0" + }, + "ref_config": { + "models": [ + "EleutherAI/pythia-160m" + ] + }, + "recall_config":{ + "num_shots": 1 + }, + "neighborhood_config": { + "model": "bert", + "n_perturbation_list": [ + 25 + ], + "pct_words_masked": 0.3, + "span_length": 2, + "dump_cache": false, + "load_from_cache": true, + "neighbor_strategy": "random" + }, + "dump_cache": false, + "load_from_cache": false, + "load_from_hf": true +} \ No newline at end of file diff --git a/mimir/attacks/all_attacks.py b/mimir/attacks/all_attacks.py index 21fc613..4692695 100644 --- a/mimir/attacks/all_attacks.py +++ b/mimir/attacks/all_attacks.py @@ -15,6 +15,7 @@ class AllAttacks(str, Enum): MIN_K_PLUS_PLUS = "min_k++" # Done NEIGHBOR = "ne" # Done GRADNORM = "gradnorm" # Done + RECALL = "recall" # QUANTILE = "quantile" # Uncomment when tested implementation is available diff --git a/mimir/attacks/recall.py b/mimir/attacks/recall.py new file mode 100644 index 0000000..537c32e --- /dev/null +++ b/mimir/attacks/recall.py @@ -0,0 +1,132 @@ +""" + ReCaLL Attack: https://github.com/ruoyuxie/recall/ +""" +import torch +import numpy as np +from mimir.attacks.all_attacks import Attack +from mimir.models import Model +from mimir.config import ExperimentConfig + +class ReCaLLAttack(Attack): + + #** Note: this is a suboptimal implementation of the ReCaLL attack due to necessary changes made to integrate it alongside the other attacks + #** for a better performing version, please refer to: https://github.com/ruoyuxie/recall + + def __init__(self, config: ExperimentConfig, target_model: Model): + super().__init__(config, target_model, ref_model = None) + self.prefix = None + + @torch.no_grad() + def _attack(self, document, probs, tokens = None, **kwargs): + recall_dict: dict = kwargs.get("recall_dict", None) + + nonmember_prefix = recall_dict.get("prefix") + num_shots = recall_dict.get("num_shots") + avg_length = recall_dict.get("avg_length") + + assert nonmember_prefix, "nonmember_prefix should not be None or empty" + assert num_shots, "num_shots should not be None or empty" + assert avg_length, "avg_length should not be None or empty" + + lls = self.target_model.get_ll(document, probs = probs, tokens = tokens) + ll_nonmember = self.get_conditional_ll(nonmember_prefix = nonmember_prefix, text = document, + num_shots = num_shots, avg_length = avg_length, + tokens = tokens) + recall = ll_nonmember / lls + + + assert not np.isnan(recall) + return recall + + def process_prefix(self, prefix, avg_length, total_shots): + model = self.target_model + tokenizer = self.target_model.tokenizer + + if self.prefix is not None: + # We only need to process the prefix once, after that we can just return + return self.prefix + + max_length = model.max_length + token_counts = [len(tokenizer.encode(shot)) for shot in prefix] + + target_token_count = avg_length + total_tokens = sum(token_counts) + target_token_count + if total_tokens<=max_length: + self.prefix = prefix + return self.prefix + # Determine the maximum number of shots that can fit within the max_length + max_shots = 0 + cumulative_tokens = target_token_count + for count in token_counts: + if cumulative_tokens + count <= max_length: + max_shots += 1 + cumulative_tokens += count + else: + break + # Truncate the prefix to include only the maximum number of shots + truncated_prefix = prefix[-max_shots:] + print(f"""\nToo many shots used. Initial ReCaLL number of shots was {total_shots}. Maximum number of shots is {max_shots}. Defaulting to maximum number of shots.""") + self.prefix = truncated_prefix + return self.prefix + + def get_conditional_ll(self, nonmember_prefix, text, num_shots, avg_length, tokens=None): + assert nonmember_prefix, "nonmember_prefix should not be None or empty" + + model = self.target_model + tokenizer = self.target_model.tokenizer + + if tokens is None: + target_encodings = tokenizer(text=text, return_tensors="pt") + else: + target_encodings = tokens + + processed_prefix = self.process_prefix(nonmember_prefix, avg_length, total_shots=num_shots) + input_encodings = tokenizer(text="".join(processed_prefix), return_tensors="pt") + + prefix_ids = input_encodings.input_ids.to(model.device) + text_ids = target_encodings.input_ids.to(model.device) + + max_length = model.max_length + + if prefix_ids.size(1) >= max_length: + raise ValueError("Prefix length exceeds or equals the model's maximum context window.") + + labels = torch.cat((prefix_ids, text_ids), dim=1) + total_length = labels.size(1) + + total_loss = 0 + total_tokens = 0 + with torch.no_grad(): + for i in range(0, total_length, max_length): + begin_loc = i + end_loc = min(i + max_length, total_length) + trg_len = end_loc - begin_loc + + input_ids = labels[:, begin_loc:end_loc].to(model.device) + target_ids = input_ids.clone() + + if begin_loc < prefix_ids.size(1): + prefix_overlap = min(prefix_ids.size(1) - begin_loc, max_length) + target_ids[:, :prefix_overlap] = -100 + + if end_loc > total_length - text_ids.size(1): + target_overlap = min(end_loc - (total_length - text_ids.size(1)), max_length) + target_ids[:, -target_overlap:] = input_ids[:, -target_overlap:] + + if torch.all(target_ids == -100): + continue + + outputs = model.model(input_ids, labels=target_ids) + loss = outputs.loss + if torch.isnan(loss): + print(f"NaN detected in loss at iteration {i}. Non masked target_ids size is {(target_ids != -100).sum().item()}") + continue + non_masked_tokens = (target_ids != -100).sum().item() + total_loss += loss.item() * non_masked_tokens + total_tokens += non_masked_tokens + + average_loss = total_loss / total_tokens if total_tokens > 0 else 0 + return -average_loss + + + diff --git a/mimir/attacks/utils.py b/mimir/attacks/utils.py index 05b66cc..766e22b 100644 --- a/mimir/attacks/utils.py +++ b/mimir/attacks/utils.py @@ -7,6 +7,7 @@ from mimir.attacks.min_k_plus_plus import MinKPlusPlusAttack from mimir.attacks.neighborhood import NeighborhoodAttack from mimir.attacks.gradnorm import GradNormAttack +from mimir.attacks.recall import ReCaLLAttack # TODO Use decorators to link attack implementations with enum above @@ -19,6 +20,7 @@ def get_attacker(attack: str): AllAttacks.MIN_K_PLUS_PLUS: MinKPlusPlusAttack, AllAttacks.NEIGHBOR: NeighborhoodAttack, AllAttacks.GRADNORM: GradNormAttack, + AllAttacks.RECALL: ReCaLLAttack } attack_cls = mapping.get(attack, None) if attack_cls is None: diff --git a/mimir/config.py b/mimir/config.py index 64c911e..2d9a05b 100644 --- a/mimir/config.py +++ b/mimir/config.py @@ -59,6 +59,13 @@ def __post_init__(self): if self.dump_cache and self.load_from_cache: raise ValueError("Cannot dump and load cache at the same time") +@dataclass +class ReCaLLConfig(Serializable): + """ + Config for ReCaLL attack + """ + num_shots: Optional[int] = 1 + """Number of shots for ReCaLL Attacks""" @dataclass class EnvironmentConfig(Serializable): @@ -194,6 +201,8 @@ class ExperimentConfig(Serializable): """Random seed""" ref_config: Optional[ReferenceConfig] = None """Reference model config""" + recall_config: Optional[ReCaLLConfig] = None + """ReCaLL attack config""" neighborhood_config: Optional[NeighborhoodConfig] = None """Neighborhood attack config""" env_config: Optional[EnvironmentConfig] = None diff --git a/run.py b/run.py index de8032a..b1f571e 100644 --- a/run.py +++ b/run.py @@ -19,7 +19,8 @@ EnvironmentConfig, NeighborhoodConfig, ReferenceConfig, - OpenAIConfig + OpenAIConfig, + ReCaLLConfig ) import mimir.data_utils as data_utils import mimir.plot_utils as plot_utils @@ -77,6 +78,7 @@ def get_mia_scores( is_train: bool, n_samples: int = None, batch_size: int = 50, + **kwargs ): # Fix randomness fix_seed(config.random_seed) @@ -100,6 +102,13 @@ def get_mia_scores( n_perturbation: [] for n_perturbation in n_perturbation_list } + recall_config = config.recall_config + if recall_config: + nonmember_prefix = kwargs.get("nonmember_prefix", None) + num_shots = recall_config.num_shots + avg_length = int(np.mean([len(target_model.tokenizer.encode(ex)) for ex in data["records"]])) + recall_dict = {"prefix":nonmember_prefix, "num_shots":num_shots, "avg_length":avg_length} + # For each batch of data # TODO: Batch-size isn't really "batching" data - change later for batch in tqdm(range(math.ceil(n_samples / batch_size)), desc=f"Computing criterion"): @@ -149,7 +158,23 @@ def get_mia_scores( if attack.startswith(AllAttacks.REFERENCE_BASED) or attack == AllAttacks.LOSS: continue - if attack != AllAttacks.NEIGHBOR: + if attack == AllAttacks.RECALL: + score = attacker.attack( + substr, + probs = s_tk_probs, + detokenized_sample=( + detokenized_sample[i] + if config.pretokenized + else None + ), + loss=loss, + all_probs=s_all_probs, + recall_dict = recall_dict + ) + sample_information[attack].append(score) + + + elif attack != AllAttacks.NEIGHBOR: score = attacker.attack( substr, probs=s_tk_probs, @@ -162,6 +187,7 @@ def get_mia_scores( all_probs=s_all_probs, ) sample_information[attack].append(score) + else: # For each 'number of neighbors' for n_perturbation in n_perturbation_list: @@ -416,6 +442,7 @@ def main(config: ExperimentConfig): neigh_config: NeighborhoodConfig = config.neighborhood_config ref_config: ReferenceConfig = config.ref_config openai_config: OpenAIConfig = config.openai_config + recall_config: ReCaLLConfig = config.recall_config if openai_config: openAI_model = OpenAI_APIModel(config) @@ -515,6 +542,15 @@ def main(config: ExperimentConfig): mask_model_tokenizer=mask_model.tokenizer if mask_model else None, ) + #* ReCaLL Specific + if AllAttacks.RECALL in config.blackbox_attacks: + assert recall_config, "Must provide a recall_config" + num_shots = recall_config.num_shots + nonmember_prefix = data_nonmember[:num_shots] + else: + nonmember_prefix = None + + other_objs, other_nonmembers = None, None if config.dataset_nonmember_other_sources is not None: other_objs, other_nonmembers = [], [] @@ -628,7 +664,8 @@ def main(config: ExperimentConfig): ref_models=ref_models, config=config, is_train=True, - n_samples=n_samples + n_samples=n_samples, + nonmember_prefix = nonmember_prefix ) # Collect scores for non-members nonmember_preds, nonmember_samples = get_mia_scores( @@ -640,6 +677,7 @@ def main(config: ExperimentConfig): config=config, is_train=False, n_samples=n_samples, + nonmember_prefix = nonmember_prefix ) blackbox_outputs = compute_metrics_from_scores( member_preds,