Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: pass contract_name to VyperContract #338

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions boa/contracts/vyper/vyper_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,11 @@ def __init__(
compiler_data: CompilerData,
env: Optional[Env] = None,
filename: Optional[str] = None,
contract_name: Optional[str] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should follow the same order as the super class tbh

):
contract_name = Path(compiler_data.contract_path).stem
contract_name = (
contract_name if contract_name else Path(compiler_data.contract_path).stem
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
contract_name = (
contract_name if contract_name else Path(compiler_data.contract_path).stem
)
if contract_name is None:
contract_name = Path(compiler_data.contract_path).stem)

super().__init__(contract_name, env, filename)
self.compiler_data = compiler_data

Expand Down Expand Up @@ -518,8 +521,9 @@ def __init__(
created_from: Address = None,
filename: str = None,
gas=None,
contract_name=None,
):
super().__init__(compiler_data, env, filename)
super().__init__(compiler_data, env, filename, contract_name)

self.created_from = created_from
self._computation = None
Expand All @@ -544,7 +548,11 @@ def __init__(
addr = Address(override_address)
else:
addr = self._run_init(
*args, value=value, override_address=override_address, gas=gas
*args,
value=value,
override_address=override_address,
gas=gas,
contract_name=contract_name,
)
self._address = addr

Expand All @@ -569,7 +577,9 @@ def __init__(

self.env.register_contract(self._address, self)

def _run_init(self, *args, value=0, override_address=None, gas=None):
def _run_init(
self, *args, value=0, override_address=None, gas=None, contract_name=None
):
encoded_args = b""
if self._ctor:
encoded_args = self._ctor.prepare_calldata(*args)
Expand All @@ -582,6 +592,7 @@ def _run_init(self, *args, value=0, override_address=None, gas=None):
override_address=override_address,
gas=gas,
contract=self,
contract_name=contract_name,
)

self._computation = computation
Expand Down
1 change: 1 addition & 0 deletions boa/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def deploy(
override_address: Optional[_AddressType] = None,
# the calling vyper contract
contract: Any = None,
contract_name: Optional[str] = None, # TODO: This isn't used
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably don't need this, we can call contract.contract_name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If someone passes contract_name = xxx to the deploy of a non NetworkEnv boa.env, this would revert. So I think we do need this to make it compatible with the NetworkEnv object

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it would be an error to pass contract_name = xxx to env.deploy().

):
sender = self._get_sender(sender)

Expand Down
15 changes: 13 additions & 2 deletions boa/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,14 @@ def execute_code(

# OVERRIDES
def deploy(
self, sender=None, gas=None, value=0, bytecode=b"", contract=None, **kwargs
self,
sender=None,
gas=None,
value=0,
bytecode=b"",
contract=None,
contract_name=None,
**kwargs,
):
# reset to latest block for simulation
self._reset_fork()
Expand Down Expand Up @@ -402,7 +409,11 @@ def deploy(
print(f"contract deployed at {create_address}")

if (deployments_db := get_deployments_db()) is not None:
contract_name = getattr(contract, "contract_name", None)
contract_name = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why would we want 2 names for the contract? This value contract_name should be equal to contract.contract_name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean? This would only set 1 name. It defaults to the filename unless one is passed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea i don't think contract_name should be passed to env.deploy(). it already exists on the contract object

contract_name
if contract_name
else getattr(contract, "contract_name", None)
)
try:
source_bundle = get_verification_bundle(contract)
except Exception as e:
Expand Down
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pytest
pytest-xdist
pytest-cov
sphinx-rtd-theme
requests-cache

# jupyter
jupyter_server
Expand Down
32 changes: 31 additions & 1 deletion tests/integration/network/anvil/test_network_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,38 @@ def test_failed_transaction():
# XXX: probably want to test deployment revert behavior


def test_deployment_db():
def test_deployment_db_overriden_contract_name():
with set_deployments_db(DeploymentsDB(":memory:")) as db:
arg = 5
contract_name = "test_deployment"

# contract is written to deployments db
contract = boa.loads(code, arg, contract_name=contract_name)

# test get_deployments()
deployment = next(db.get_deployments())

initcode = contract.compiler_data.bytecode + arg.to_bytes(32, "big")

# sanity check all the fields
assert deployment.contract_address == contract.address
charles-cooper marked this conversation as resolved.
Show resolved Hide resolved
assert deployment.contract_name == contract.contract_name
assert deployment.contract_name == contract_name
assert deployment.deployer == boa.env.eoa
assert deployment.rpc == boa.env._rpc.name
assert deployment.source_code == contract.deployer.solc_json
assert deployment.abi == contract.abi

# some sanity checks on tx_dict and rx_dict fields
assert to_bytes(deployment.tx_dict["data"]) == initcode
assert deployment.tx_dict["chainId"] == hex(boa.env.get_chain_id())
assert Address(deployment.receipt_dict["contractAddress"]) == contract.address


def test_deployment_db_no_overriden_name():
with set_deployments_db(DeploymentsDB(":memory:")) as db:
arg = 5
non_contract_name = "test_deployment"

# contract is written to deployments db
contract = boa.loads(code, arg)
Expand All @@ -88,6 +117,7 @@ def test_deployment_db():
# sanity check all the fields
assert deployment.contract_address == contract.address
assert deployment.contract_name == contract.contract_name
assert deployment.contract_name != non_contract_name
assert deployment.deployer == boa.env.eoa
assert deployment.rpc == boa.env._rpc.name
assert deployment.source_code == contract.deployer.solc_json
Expand Down
4 changes: 3 additions & 1 deletion tests/integration/network/sepolia/test_sepolia_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,10 @@ def test_raise_exception(simple_contract, amount):
def test_deployment_db():
with set_deployments_db(DeploymentsDB(":memory:")) as db:
arg = 5
contract_name = "test_deployment"

# contract is written to deployments db
contract = boa.loads(code, arg)
contract = boa.loads(code, arg, contract_name=contract_name)

# test get_deployments()
deployment = next(db.get_deployments())
Expand All @@ -87,6 +88,7 @@ def test_deployment_db():
# sanity check all the fields
assert deployment.contract_address == contract.address
assert deployment.contract_name == contract.contract_name
assert deployment.contract_name == contract_name
assert deployment.deployer == boa.env.eoa
assert deployment.rpc == boa.env._rpc.name
assert deployment.source_code == contract.deployer.solc_json
Expand Down
Loading