Skip to content

Commit

Permalink
Merge pull request #1 from delta-mpc/0.3.0
Browse files Browse the repository at this point in the history
0.3.0
  • Loading branch information
mh739025250 authored Jan 5, 2022
2 parents eeda1b3 + 185d63c commit 57b8ca2
Show file tree
Hide file tree
Showing 17 changed files with 151 additions and 293 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,10 @@ MANIFEST

# vscode
.vscode

whls
mnist/
mnist.npz
.python-version
*.text
.pypirc
10 changes: 7 additions & 3 deletions delta/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from tempfile import TemporaryDirectory

from .delta_node import DeltaNode
from .node import DebugNode
from .task import Task

__all__ = ["DeltaNode", "debug"]
__all__ = ["DeltaNode", "debug", "Task"]


def debug(task: Task):
debug_node = DebugNode(1)
task.run(debug_node)
with TemporaryDirectory() as tmp_dir:
for round in range(1, 3):
debug_node = DebugNode("1", round, tmp_dir)
task.run(debug_node)
16 changes: 15 additions & 1 deletion delta/algorithm/horizontal/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
from abc import ABC, abstractmethod
from typing import List, Optional, IO
from typing import List, Literal, Optional, IO

import numpy as np
import torch


CURVE_TYPE = Literal[
"secp192r1",
"secp224r1",
"secp256k1",
"secp256r1",
"secp384r1",
"secp521r1"
]


class HorizontalAlgorithm(ABC):
def __init__(
self,
Expand All @@ -16,6 +26,8 @@ def __init__(
wait_timeout: Optional[float] = None,
connection_timeout: Optional[float] = None,
fault_tolerant: bool = False,
precision: int = 8,
curve: CURVE_TYPE = "secp256k1"
):
assert (
merge_interval_epoch * merge_interval_iter == 0
Expand All @@ -31,6 +43,8 @@ def __init__(
self.wait_timeout = wait_timeout
self.connnection_timeout = connection_timeout
self.fault_tolerant = fault_tolerant
self.precision = precision
self.curve = curve

def should_merge(self, epoch: int, iteration: int, epoch_finished: bool):
if epoch_finished:
Expand Down
8 changes: 6 additions & 2 deletions delta/algorithm/horizontal/fault_tolerant_fedavg.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import IO, List, Optional
from typing import List, Optional

import numpy as np
import torch

from .base import HorizontalAlgorithm
from .base import CURVE_TYPE, HorizontalAlgorithm


class FaultTolerantFedAvg(HorizontalAlgorithm):
Expand All @@ -15,6 +15,8 @@ def __init__(
max_clients: int = 2,
wait_timeout: Optional[float] = None,
connection_timeout: Optional[float] = None,
precision: int = 8,
curve: CURVE_TYPE = "secp256k1",
):
super().__init__(
"FaultTolerantFedAvg",
Expand All @@ -25,6 +27,8 @@ def __init__(
wait_timeout=wait_timeout,
connection_timeout=connection_timeout,
fault_tolerant=True,
precision=precision,
curve=curve,
)

def params_to_result(self, params: List[torch.Tensor]) -> np.ndarray:
Expand Down
8 changes: 6 additions & 2 deletions delta/algorithm/horizontal/fedavg.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import IO, List, Optional
from typing import List, Optional

import numpy as np
import torch

from .base import HorizontalAlgorithm
from .base import CURVE_TYPE, HorizontalAlgorithm


class FedAvg(HorizontalAlgorithm):
Expand All @@ -15,6 +15,8 @@ def __init__(
max_clients: int = 2,
wait_timeout: Optional[float] = None,
connection_timeout: Optional[float] = None,
precision: int = 8,
curve: CURVE_TYPE = "secp256k1",
):
super().__init__(
"FedAvg",
Expand All @@ -25,6 +27,8 @@ def __init__(
wait_timeout=wait_timeout,
connection_timeout=connection_timeout,
fault_tolerant=False,
precision=precision,
curve=curve,
)

def params_to_result(self, params: List[torch.Tensor]) -> np.ndarray:
Expand Down
11 changes: 5 additions & 6 deletions delta/delta_node.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import json
import requests
from tempfile import TemporaryFile
from zipfile import ZipFile
from delta import serialize

from .task import Task
import httpx

from delta import serialize
from delta.task import Task


class DeltaNode(object):
Expand All @@ -16,7 +15,7 @@ def create_task(self, task: Task) -> int:
with TemporaryFile(mode="w+b") as file:
serialize.dump_task(file, task)
file.seek(0)
resp = requests.post(url, files={"file": (f"{task.name}.zip", file)})
resp = httpx.post(url, files={"file": file})
resp.raise_for_status()
data = resp.json()
task_id = data["task_id"]
Expand Down
47 changes: 18 additions & 29 deletions delta/node/debug.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,33 @@
import logging
import os.path
import shutil
from tempfile import TemporaryDirectory
from typing import IO, Any, Callable, Dict, Iterable, Tuple

from ..data import new_dataloader
from .node import Node


class DebugNode(Node):
def __init__(self, task_id: int):
def __init__(self, task_id: str, round: int, dirname: str):
self._logger = logging.getLogger(__name__)

self._temp_dir = TemporaryDirectory()
self._dir = dirname

self._task_id = task_id
self._state_count = 0
self._weight_count = 0
self._metrics_count = 0
self._round = round

def state_file(self) -> str:
return os.path.join(
self._temp_dir.name, f"{self._task_id}.state.{self._state_count}"
)
@property
def round(self) -> int:
return self._round

def weight_file(self) -> str:
return os.path.join(
self._temp_dir.name, f"{self._task_id}.weight.{self._weight_count}"
)
def state_file(self, round: int) -> str:
return os.path.join(self._dir, f"{self._task_id}.{round}.state")

def metrics_file(self) -> str:
return os.path.join(
self._temp_dir.name, f"{self._task_id}.metrics.{self._metrics_count}"
)
def weight_file(self, round: int) -> str:
return os.path.join(self._dir, f"{self._task_id}.{round}.weight")

def metrics_file(self, round: int) -> str:
return os.path.join(self._dir, f"{self._task_id}.{round}.metrics")

def new_dataloader(
self,
Expand Down Expand Up @@ -65,7 +60,7 @@ def upload(self, type: str, src: IO[bytes]):
raise ValueError(f"unknown upload type {type}")

def download_state(self, dst: IO[bytes]) -> bool:
filename = self.state_file()
filename = self.state_file(self.round - 1)
if os.path.exists(filename):
self._logger.info(f"load state {filename} for task {self._task_id}")
with open(filename, mode="rb") as f:
Expand All @@ -76,28 +71,25 @@ def download_state(self, dst: IO[bytes]) -> bool:
return False

def upload_state(self, file: IO[bytes]):
self._state_count += 1
filename = self.state_file()
filename = self.state_file(self.round)
with open(filename, mode="wb") as f:
shutil.copyfileobj(file, f)
self._logger.info(f"dump state {filename} for task {self._task_id}")

def upload_result(self, file: IO[bytes]):
self._weight_count += 1
filename = self.weight_file()
filename = self.weight_file(self.round)
with open(filename, mode="wb") as f:
shutil.copyfileobj(file, f)
self._logger.info(f"upload weight {filename} for task {self._task_id}")

def upload_metrics(self, file: IO[bytes]):
self._metrics_count += 1
filename = self.metrics_file()
filename = self.metrics_file(self.round)
with open(filename, mode="wb") as f:
shutil.copyfileobj(file, f)
self._logger.info(f"upload metrics {filename} for task {self._task_id}")

def download_weight(self, dst: IO[bytes]) -> bool:
filename = self.weight_file()
filename = self.weight_file(self.round - 1)
if os.path.exists(filename):
self._logger.info(f"download weight {filename} for task {self._task_id}")
with open(filename, mode="rb") as f:
Expand All @@ -106,6 +98,3 @@ def download_weight(self, dst: IO[bytes]) -> bool:
else:
self._logger.info(f"initial weight for task {self._task_id}")
return False

def finish(self):
self._logger.info(f"task {self._task_id} finished")
17 changes: 12 additions & 5 deletions delta/node/node.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from abc import ABC, abstractmethod
from typing import IO, Any, Callable, Dict, Iterable, Optional, Tuple
from typing import IO, Any, Callable, Dict, Tuple

from torch.utils.data import DataLoader


class Node(ABC):
@abstractmethod
def new_dataloader(
self, dataset: str, validate_frac: float, cfg: Dict[str, Any], preprocess: Callable
) -> Tuple[Iterable, Iterable]:
self,
dataset: str,
validate_frac: float,
cfg: Dict[str, Any],
preprocess: Callable,
) -> Tuple[DataLoader, DataLoader]:
pass

@abstractmethod
Expand All @@ -17,6 +23,7 @@ def download(self, type: str, dst: IO[bytes]) -> bool:
def upload(self, type: str, src: IO[bytes]):
...

@property
@abstractmethod
def finish(self):
pass
def round(self) -> int:
...
4 changes: 1 addition & 3 deletions delta/serialize/task.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from os import PathLike
from typing import IO, Union
from zipfile import ZipFile

from delta import serialize
from delta.task import HorizontalTask, Task
from delta.task import Task


def dump_task(file: Union[str, PathLike, IO[bytes]], task: Task):
Expand Down
59 changes: 20 additions & 39 deletions delta/task/horizontal.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def __init__(
self._state = {
"epoch": 1,
"iteration": 1,
"round": 1,
}

@property
Expand All @@ -65,14 +64,6 @@ def iteration(self) -> int:
def iteration(self, iteration: int):
self._state["iteration"] = iteration

@property
def round(self) -> int:
return self._state["round"]

@round.setter
def round(self, round: int):
self._state["round"] = round

@abstractmethod
def train(self, dataloader: Iterable):
...
Expand Down Expand Up @@ -163,49 +154,39 @@ def run(self, node: Node):
)

def train_context():
uploaded = False
_logger.info(f"start round {node.round}")
finished = False

_logger.info(f"start round {self.round}")
while self.round <= self.max_rounds:
while not finished:
for batch in train_loader:
if uploaded:
self._load_weight(node)
if self.round % self.validate_interval == 0:
_logger.info(f"round {self.round}, start to validate")
metrics = self.validate(val_loader)
_logger.info(f"metrics: {metrics}")
self._upload_metrics(node, metrics)
uploaded = False
self.round += 1
if self.round > self.max_rounds:
finished = True
break
_logger.info(f"start round {self.round}")

_logger.info(f"epoch {self.epoch} iteration {self.iteration}")
if finished:
break

_logger.info(f"epoch {self.epoch} iteration {self.iteration}")

yield batch

if self._alg.should_merge(self.epoch, self.iteration, False):
_logger.info(f"iteration {self.iteration}, start to merge")
self._save_state(node)
self._upload_result(node)
uploaded = True
finished = True
self.iteration += 1

if finished:
break

if self._alg.should_merge(self.epoch, self.iteration, True):
_logger.info(f"epoch {self.epoch}, start to merge")
self._save_state(node)
self._upload_result(node)
uploaded = True

self.epoch += 1
node.finish()
_logger.info(f"training finished, total epochs {self.epoch - 1}")
finished = True
if (self.iteration - 1) % len(train_loader) == 0:
self.epoch += 1

if node.round % self.validate_interval == 0:
_logger.info(f"round {node.round} start to validate")
metrics = self.validate(val_loader)
_logger.info(f"metrics: {metrics}")
self._upload_metrics(node, metrics)

self._save_state(node)
_logger.info(f"training round {node.round} finished")

self.train(train_context())

Expand Down
Loading

0 comments on commit 57b8ca2

Please sign in to comment.