Skip to content

Commit

Permalink
Merge pull request #185 from stanfordnlp/zen/dependency_clean
Browse files Browse the repository at this point in the history
[Minor] Update dependency
  • Loading branch information
frankaging authored Aug 24, 2024
2 parents 4f70e10 + 4b14b6e commit d4ca094
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 10 deletions.
22 changes: 14 additions & 8 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json, logging, torch, types
import nnsight
import numpy as np
from collections import OrderedDict
from typing import List, Optional, Tuple, Union, Dict, Any
Expand Down Expand Up @@ -27,6 +26,12 @@
from transformers.utils import ModelOutput
from tqdm import tqdm, trange

try:
import nnsight
except:
print("nnsight is not detected. Please install via 'pip install nnsight' for nnsight backend.")


@dataclass
class IntervenableModelOutput(ModelOutput):
"""
Expand Down Expand Up @@ -226,7 +231,7 @@ def __init__(self, config, model, backend, **kwargs):
# cached swapped activations (hot)
self.hot_activations = {}

self.aux_loss = []
self.full_intervention_outputs = []

# temp fields should not be accessed outside
self._batched_setter_activation_select = {}
Expand Down Expand Up @@ -1558,16 +1563,17 @@ def hook_callback(model, args, kwargs, output=None):
else:
if not isinstance(self.interventions[key][0], types.FunctionType):
if intervention.is_source_constant:
intervened_representation = do_intervention(
raw_intervened_representation = do_intervention(
selected_output,
None,
intervention,
subspaces[key_i] if subspaces is not None else None,
)
if isinstance(intervened_representation, InterventionOutput):
if intervened_representation.loss is not None:
self.aux_loss.append(intervened_representation.loss)
intervened_representation = intervened_representation.output
if isinstance(raw_intervened_representation, InterventionOutput):
self.full_intervention_outputs.append(raw_intervened_representation)
intervened_representation = raw_intervened_representation.output
else:
intervened_representation = raw_intervened_representation
else:
intervened_representation = do_intervention(
selected_output,
Expand Down Expand Up @@ -1866,7 +1872,7 @@ def forward(
if sources is not None and not isinstance(sources, list):
sources = [sources]

self.aux_loss.clear()
self.full_intervention_outputs.clear()

self._cleanup_states()

Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,3 @@ numpy>=1.23.5
fsspec>=2023.6.0
accelerate>=0.29.1
sentencepiece>=0.1.96
nnsight>=0.1.0
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

setup(
name="pyvene",
version="0.1.4",
version="0.1.5",
description="Use Activation Intervention to Interpret Causal Mechanism of Model",
long_description=long_description,
long_description_content_type='text/markdown',
Expand Down

0 comments on commit d4ca094

Please sign in to comment.