Skip to content

Commit

Permalink
Add some docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
iamgroot42 committed Feb 15, 2024
1 parent 3b7c668 commit 9d20190
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 19 deletions.
12 changes: 12 additions & 0 deletions mimir/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@


class Data:
"""
Data class to load and cache datasets.
"""
def __init__(self, name,
config: ExperimentConfig,
presampled: str = None,
Expand All @@ -36,6 +39,9 @@ def load_neighbors(
model: str = "bert",
in_place_swap: bool = False,
):
"""
Load neighbors from cache (local or from HF)
"""
data_split = "train" if train else "test"
data_split += "_neighbors"
filename = self._get_name_to_save() + "_neighbors_{}_{}".format(
Expand Down Expand Up @@ -63,6 +69,9 @@ def dump_neighbors(
model: str = "bert",
in_place_swap: bool = False,
):
"""
Dump neighbors to cache local cache.
"""
data_split = "train" if train else "test"
data_split += "_neighbors"
filename = self._get_name_to_save() + "_neighbors_{}_{}".format(
Expand Down Expand Up @@ -309,6 +318,9 @@ def pile_selection_utility(data, key: str, wanted_source: str = None):


def sourcename_process(x: str):
"""
Helper function to process source name.
"""
return x.replace(" ", "_").replace("-", "_").lower()


Expand Down
27 changes: 9 additions & 18 deletions mimir/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,16 @@ def __init__(self, config: ExperimentConfig, name: str):
self.load_model_properties()

def load(self):
"""
Load reference model noto GPU(s)
"""
if "llama" not in self.name and "alpaca" not in self.name:
super().load()

def unload(self):
"""
Unload reference model from GPU(s)
"""
if "llama" not in self.name and "alpaca" not in self.name:
super().unload()

Expand Down Expand Up @@ -369,19 +375,6 @@ def get_lls(self, texts: List[str], batch_size: int = 6):
del label_batch
del attention_mask
return losses #np.mean(losses)

@torch.no_grad()
def get_min_k_prob(self, text: str, tokens=None, probs=None, k=.2, window=1, stride=1):
all_prob = probs if probs is not None else self.get_probabilities(text, tokens=tokens)
# iterate through probabilities by ngram defined by window size at given stride
ngram_probs = []
for i in range(0, len(all_prob) - window + 1, stride):
ngram_prob = all_prob[i:i+window]
ngram_probs.append(np.mean(ngram_prob))
min_k_probs = sorted(ngram_probs)[:int(len(ngram_probs) * k)]

return -np.mean(min_k_probs)


def sample_from_model(self, texts: List[str], **kwargs):
"""
Expand Down Expand Up @@ -435,11 +428,6 @@ def get_entropy(self, text: str):
logits = self.model(**tokenized).logits[:,:-1]
neg_entropy = F.softmax(logits, dim=-1) * F.log_softmax(logits, dim=-1)
return -neg_entropy.sum(-1).mean().item()

@torch.no_grad()
def get_zlib_entropy(self, text: str, tokens=None, probs=None):
zlib_entropy = len(zlib.compress(bytes(text, 'utf-8')))
return self.get_ll(text, tokens=tokens, probs=probs) / zlib_entropy

@torch.no_grad()
def get_max_norm(self, text: str, context_len=None, tk_freq_map=None):
Expand Down Expand Up @@ -496,6 +484,9 @@ def __init__(self, config: ExperimentConfig, **kwargs):

@property
def api_calls(self):
"""
Get the number of tokens used in API calls
"""
return self.API_TOKEN_COUNTER

@torch.no_grad()
Expand Down
3 changes: 3 additions & 0 deletions mimir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@


def fix_seed(seed: int = 0):
"""
Fix seed for reproducibility.
"""
ch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
Expand Down
2 changes: 1 addition & 1 deletion notebooks/new_mi_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def run_blackbox_attacks(
# TODO: Instead of doing this outside, set config default to always include LOSS
sample_information[BlackBoxAttacks.LOSS].append(loss)

# TODO: Shift functionality into each attack entirely, so that this is just a for loop
# TODO: Change flow in same way as run.py
# For each attack
for attack in attacks:
if attack == BlackBoxAttacks.ZLIB:
Expand Down

0 comments on commit 9d20190

Please sign in to comment.