Skip to content

Commit

Permalink
Merge pull request #290 from backend-developers-ltd/no-sig-payload
Browse files Browse the repository at this point in the history
Removed duplicate signed job payload from V2JobRequest
  • Loading branch information
adal-chiriliuc-reef authored Oct 25, 2024
2 parents 67a0980 + 49c6f18 commit 4d5d35a
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 13 deletions.
25 changes: 17 additions & 8 deletions compute_horde/compute_horde/fv_protocol/facilitator_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ class Response(BaseModel, extra="forbid"):
errors: list[Error] = []


class SignedRequest(BaseModel, extra="forbid"):
signature_type: str
signatory: str
timestamp_ns: int
signature: str
signed_payload: JsonValue
class Signature(BaseModel, extra="forbid"):
# has defaults to allow easy instantiation
signature_type: str = ""
signatory: str = ""
timestamp_ns: int = 0
signature: str = ""


class V0JobRequest(BaseModel, extra="forbid"):
Expand Down Expand Up @@ -102,8 +102,10 @@ class V2JobRequest(BaseModel, extra="forbid"):
# this points to a `ValidatorConsumer.job_new` handler (fuck you django-channels!)
type: Literal["job.new"] = "job.new"
message_type: Literal["V2JobRequest"] = "V2JobRequest"
signature: Signature | None = None

# !!! all fields below are included in the signed json payload
uuid: str
miner_hotkey: str | None
executor_class: ExecutorClass
docker_image: str
raw_script: str
Expand All @@ -112,11 +114,18 @@ class V2JobRequest(BaseModel, extra="forbid"):
use_gpu: bool
volume: Volume | None = None
output_upload: OutputUpload | None = None
signed_request: SignedRequest
# !!! all fields above are included in the signed json payload

def get_args(self):
return self.args

def json_for_signing(self) -> JsonValue:
payload = self.model_dump(mode="json")
del payload["type"]
del payload["message_type"]
del payload["signature"]
return payload

@model_validator(mode="after")
def validate_at_least_docker_image_or_raw_script(self) -> Self:
if not (bool(self.docker_image) or bool(self.raw_script)):
Expand Down
4 changes: 2 additions & 2 deletions compute_horde/compute_horde/mv_protocol/validator_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ..base.volume import Volume, VolumeType
from ..base_requests import BaseRequest, JobMixin
from ..executor_class import ExecutorClass
from ..utils import MachineSpecs, _json_dumps_default
from ..utils import MachineSpecs, json_dumps_default

SAFE_DOMAIN_REGEX = re.compile(r".*")

Expand Down Expand Up @@ -99,7 +99,7 @@ class ReceiptPayload(pydantic.BaseModel):

def blob_for_signing(self):
# pydantic v2 does not support sort_keys anymore.
return json.dumps(self.model_dump(), sort_keys=True, default=_json_dumps_default)
return json.dumps(self.model_dump(), sort_keys=True, default=json_dumps_default)


class JobFinishedReceiptPayload(ReceiptPayload):
Expand Down
2 changes: 1 addition & 1 deletion compute_horde/compute_horde/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_validators(netuid=12, network="finney", block: int | None = None) -> lis
return neurons[:VALIDATORS_LIMIT]


def _json_dumps_default(obj):
def json_dumps_default(obj):
if isinstance(obj, datetime.datetime):
return obj.isoformat()

Expand Down
53 changes: 53 additions & 0 deletions compute_horde/tests/test_job_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import base64
import uuid

from compute_horde.base.volume import VolumeType, ZipUrlVolume
from compute_horde.executor_class import DEFAULT_EXECUTOR_CLASS
from compute_horde.fv_protocol.facilitator_requests import Signature, V2JobRequest
from compute_horde.signature import BittensorWalletSigner, BittensorWalletVerifier
from compute_horde.signature import Signature as RawSignature


def test_signed_job_roundtrip(signature_wallet):
volume = ZipUrlVolume(
volume_type=VolumeType.zip_url,
contents="https://example.com/input.zip",
relative_path="input",
)
job = V2JobRequest(
uuid=str(uuid.uuid4()),
executor_class=DEFAULT_EXECUTOR_CLASS,
docker_image="hello-world",
raw_script="bash",
args=["--verbose", "--dry-run"],
env={"CUDA": "1"},
use_gpu=False,
volume=volume,
output_upload=None,
)

signer = BittensorWalletSigner(signature_wallet)
payload = job.json_for_signing()
raw_signature = signer.sign(payload)

job.signature = Signature(
signature_type=raw_signature.signature_type,
signatory=raw_signature.signatory,
timestamp_ns=raw_signature.timestamp_ns,
signature=base64.b64encode(raw_signature.signature).decode("utf8"),
)

job_json = job.model_dump_json()
deserialized_job = V2JobRequest.model_validate_json(job_json)

assert deserialized_job.signature is not None
deserialized_raw_signature = RawSignature(
signature_type=deserialized_job.signature.signature_type,
signatory=deserialized_job.signature.signatory,
timestamp_ns=deserialized_job.signature.timestamp_ns,
signature=base64.b64decode(deserialized_job.signature.signature),
)

deserialized_payload = deserialized_job.json_for_signing()
verifier = BittensorWalletVerifier()
verifier.verify(deserialized_payload, deserialized_raw_signature)
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import tenacity
import websockets
from channels.layers import get_channel_layer
from compute_horde.fv_protocol.facilitator_requests import Error, JobRequest, Response
from compute_horde.fv_protocol.facilitator_requests import Error, JobRequest, Response, V2JobRequest
from compute_horde.fv_protocol.validator_requests import (
V0AuthenticationRequest,
V0Heartbeat,
Expand Down Expand Up @@ -243,7 +243,7 @@ async def get_miner_axon_info(self, hotkey: str) -> bittensor.AxonInfo:

async def miner_driver(self, job_request: JobRequest):
"""drive a miner client from job start to completion, then close miner connection"""
assert job_request.miner_hotkey is not None
assert not isinstance(job_request, V2JobRequest)
miner, _ = await Miner.objects.aget_or_create(hotkey=job_request.miner_hotkey)
miner_axon_info = await self.get_miner_axon_info(job_request.miner_hotkey)
job = await OrganicJob.objects.acreate(
Expand Down

0 comments on commit 4d5d35a

Please sign in to comment.