Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ReCaLL attack #26

Merged
merged 12 commits into from
Sep 16, 2024
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ We include and implement the following attacks, as described in our paper.
- [Min-K% Prob](https://swj0419.github.io/detect-pretrain.github.io/) (`min_k`). Uses k% of tokens with minimum likelihood for score computation.
- [Min-K%++](https://zjysteven.github.io/mink-plus-plus/) (`min_k++`). Uses k% of tokens with minimum *normalized* likelihood for score computation.
- [Gradient Norm](https://arxiv.org/abs/2402.17012) (`gradnorm`). Uses gradient norm of the target datapoint as score.
- [ReCaLL](https://royxie.com/recall-project-page/)(`recall`). Operates by comparing the unconditional and conditional log-likelihoods.

## Adding your own dataset

Expand Down
39 changes: 39 additions & 0 deletions configs/recall.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
"experiment_name": "recall",
"base_model": "EleutherAI/pythia-1.4b",
"dataset_member": "the_pile",
"dataset_nonmember": "the_pile",
"min_words": 100,
"max_words": 200,
"max_tokens": 512,
"max_data": 100000,
"output_name": "unified_mia",
"specific_source": "Github_ngram_13_<0.8_truncated",
"n_samples": 1000,
"recall_num_shots": 1,
"blackbox_attacks": ["loss", "ref", "zlib", "min_k", "min_k++", "recall"],
"env_config": {
"results": "results_new",
"device": "cuda:0",
"device_aux": "cuda:0"
},
"ref_config": {
"models": [
"EleutherAI/pythia-160m"
]
},
"neighborhood_config": {
"model": "bert",
"n_perturbation_list": [
25
],
"pct_words_masked": 0.3,
"span_length": 2,
"dump_cache": false,
"load_from_cache": true,
"neighbor_strategy": "random"
},
"dump_cache": false,
"load_from_cache": false,
"load_from_hf": true
}
1 change: 1 addition & 0 deletions mimir/attacks/all_attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class AllAttacks(str, Enum):
MIN_K_PLUS_PLUS = "min_k++" # Done
NEIGHBOR = "ne" # Done
GRADNORM = "gradnorm" # Done
RECALL = "recall"
# QUANTILE = "quantile" # Uncomment when tested implementation is available


Expand Down
127 changes: 127 additions & 0 deletions mimir/attacks/recall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""
ReCaLL Attack: https://github.com/ruoyuxie/recall/
"""
import torch
import numpy as np
from mimir.attacks.all_attacks import Attack
from mimir.models import Model
from mimir.config import ExperimentConfig

class ReCaLLAttack(Attack):

def __init__(self, config: ExperimentConfig, target_model: Model):
super().__init__(config, target_model, ref_model = None)
self.prefix = None

@torch.no_grad()
def _attack(self, document, probs, tokens = None, **kwargs):
recall_dict: dict = kwargs.get("recall_dict", None)

nonmember_prefix = recall_dict.get("prefix")
num_shots = recall_dict.get("num_shots")
avg_length = recall_dict.get("avg_length")

assert nonmember_prefix, "nonmember_prefix should not be None or empty"

lls = self.target_model.get_ll(document, probs = probs, tokens = tokens)
ll_nonmember = self.get_conditional_ll(nonmember_prefix = nonmember_prefix, text = document,
num_shots = num_shots, avg_length = avg_length,
tokens = tokens)
recall = ll_nonmember / lls


assert not np.isnan(recall)
return recall

def process_prefix(self, prefix, avg_length, total_shots):
model = self.target_model
tokenizer = self.target_model.tokenizer

if self.prefix is not None:
# We only need to process the prefix once, after that we can just return
return self.prefix

max_length = model.max_length
token_counts = [len(tokenizer.encode(shot)) for shot in prefix]

target_token_count = avg_length
total_tokens = sum(token_counts) + target_token_count
if total_tokens<=max_length:
self.prefix = prefix
return self.prefix
# Determine the maximum number of shots that can fit within the max_length
max_shots = 0
cumulative_tokens = target_token_count
for count in token_counts:
if cumulative_tokens + count <= max_length:
max_shots += 1
cumulative_tokens += count
else:
break
# Truncate the prefix to include only the maximum number of shots
truncated_prefix = prefix[-max_shots:]
print(f"""\nToo many shots used. Initial ReCaLL number of shots was {total_shots}. Maximum number of shots is {max_shots}. Defaulting to maximum number of shots.""")
self.prefix = truncated_prefix
return self.prefix

def get_conditional_ll(self, nonmember_prefix, text, num_shots, avg_length, tokens=None):
assert nonmember_prefix, "nonmember_prefix should not be None or empty"

model = self.target_model
tokenizer = self.target_model.tokenizer

if tokens is None:
target_encodings = tokenizer(text=text, return_tensors="pt")
else:
target_encodings = tokens

processed_prefix = self.process_prefix(nonmember_prefix, avg_length, total_shots=num_shots)
input_encodings = tokenizer(text="".join(processed_prefix), return_tensors="pt")

prefix_ids = input_encodings.input_ids.to(model.device)
text_ids = target_encodings.input_ids.to(model.device)

max_length = model.max_length

if prefix_ids.size(1) >= max_length:
raise ValueError("Prefix length exceeds or equals the model's maximum context window.")

labels = torch.cat((prefix_ids, text_ids), dim=1)
total_length = labels.size(1)

total_loss = 0
total_tokens = 0
with torch.no_grad():
for i in range(0, total_length, max_length):
begin_loc = i
end_loc = min(i + max_length, total_length)
trg_len = end_loc - begin_loc

input_ids = labels[:, begin_loc:end_loc].to(model.device)
target_ids = input_ids.clone()

if begin_loc < prefix_ids.size(1):
prefix_overlap = min(prefix_ids.size(1) - begin_loc, max_length)
target_ids[:, :prefix_overlap] = -100

if end_loc > total_length - text_ids.size(1):
target_overlap = min(end_loc - (total_length - text_ids.size(1)), max_length)
target_ids[:, -target_overlap:] = input_ids[:, -target_overlap:]

if torch.all(target_ids == -100):
continue

outputs = model.model(input_ids, labels=target_ids)
loss = outputs.loss
if torch.isnan(loss):
print(f"NaN detected in loss at iteration {i}. Non masked target_ids size is {(target_ids != -100).sum().item()}")
continue
non_masked_tokens = (target_ids != -100).sum().item()
total_loss += loss.item() * non_masked_tokens
total_tokens += non_masked_tokens

average_loss = total_loss / total_tokens if total_tokens > 0 else 0
return -average_loss



2 changes: 2 additions & 0 deletions mimir/attacks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from mimir.attacks.min_k_plus_plus import MinKPlusPlusAttack
from mimir.attacks.neighborhood import NeighborhoodAttack
from mimir.attacks.gradnorm import GradNormAttack
from mimir.attacks.recall import ReCaLLAttack


# TODO Use decorators to link attack implementations with enum above
Expand All @@ -19,6 +20,7 @@ def get_attacker(attack: str):
AllAttacks.MIN_K_PLUS_PLUS: MinKPlusPlusAttack,
AllAttacks.NEIGHBOR: NeighborhoodAttack,
AllAttacks.GRADNORM: GradNormAttack,
AllAttacks.RECALL: ReCaLLAttack
}
attack_cls = mapping.get(attack, None)
if attack_cls is None:
Expand Down
2 changes: 2 additions & 0 deletions mimir/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ class ExperimentConfig(Serializable):
"""Chunk size"""
scoring_model_name: Optional[str] = None
"""Scoring model (if different from base model)"""
recall_num_shots: Optional[int] = 1
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you create a separate class for Configuration (just like we have a separate NeighborhoodConfig for neighborhood attack) for this instead of adding it directly to the ExperimentConfig?

"""Number of shots for ReCaLL Attacks"""
top_k: Optional[int] = 40
"""Consider only top-k tokens"""
do_top_k: Optional[bool] = False
Expand Down
30 changes: 29 additions & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def get_mia_scores(
is_train: bool,
n_samples: int = None,
batch_size: int = 50,
**kwargs
):
# Fix randomness
fix_seed(config.random_seed)
Expand All @@ -100,6 +101,14 @@ def get_mia_scores(
n_perturbation: [] for n_perturbation in n_perturbation_list
}

nonmember_prefix = kwargs.get("nonmember_prefix", None)
if AllAttacks.RECALL in attackers_dict.keys():
if nonmember_prefix is None:
raise ValueError("Must include a prefix for ReCaLL attack")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want this condition? nonmember_prefix only needs to be present for the recall attack, not all attacks? If someone runs a config with multiple attacks including recall, must all attack function-calls include the nonmember-prefix?

num_shots = config.recall_num_shots
avg_length = int(np.mean([len(target_model.tokenizer.encode(ex)) for ex in data["records"]]))
recall_dict = {"prefix":nonmember_prefix, "num_shots":num_shots, "avg_length":avg_length}

# For each batch of data
# TODO: Batch-size isn't really "batching" data - change later
for batch in tqdm(range(math.ceil(n_samples / batch_size)), desc=f"Computing criterion"):
Expand Down Expand Up @@ -160,8 +169,10 @@ def get_mia_scores(
),
loss=loss,
all_probs=s_all_probs,
recall_dict = recall_dict
)
sample_information[attack].append(score)

else:
# For each 'number of neighbors'
for n_perturbation in n_perturbation_list:
Expand Down Expand Up @@ -515,6 +526,21 @@ def main(config: ExperimentConfig):
mask_model_tokenizer=mask_model.tokenizer if mask_model else None,
)

#* ReCaLL Specific
if AllAttacks.RECALL in config.blackbox_attacks:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the config has multiple attacks, data_member and data_nonmember should not be modified like this for all attacks, but right now they will modify raw data before running any attack

num_shots = config.recall_num_shots
nonmember_prefix = data_nonmember[:num_shots]
nonmember_data = data_nonmember[num_shots:]

member_prefix = data_member[:num_shots]
member_data = data_member[num_shots:]

data_nonmember = nonmember_data
data_member = member_data
else:
nonmember_prefix = None


other_objs, other_nonmembers = None, None
if config.dataset_nonmember_other_sources is not None:
other_objs, other_nonmembers = [], []
Expand Down Expand Up @@ -628,7 +654,8 @@ def main(config: ExperimentConfig):
ref_models=ref_models,
config=config,
is_train=True,
n_samples=n_samples
n_samples=n_samples,
nonmember_prefix = nonmember_prefix
)
# Collect scores for non-members
nonmember_preds, nonmember_samples = get_mia_scores(
Expand All @@ -640,6 +667,7 @@ def main(config: ExperimentConfig):
config=config,
is_train=False,
n_samples=n_samples,
nonmember_prefix = nonmember_prefix
)
blackbox_outputs = compute_metrics_from_scores(
member_preds,
Expand Down