-
Notifications
You must be signed in to change notification settings - Fork 10
/
server.py
63 lines (53 loc) · 1.92 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import copy
from collections import OrderedDict
import numpy as np
import torch
class Server:
def __init__(self, args, train_clients, test_clients, model, metrics):
self.args = args
self.train_clients = train_clients
self.test_clients = test_clients
self.model = model
self.metrics = metrics
self.model_params_dict = copy.deepcopy(self.model.state_dict())
def select_clients(self):
num_clients = min(self.args.clients_per_round, len(self.train_clients))
return np.random.choice(self.train_clients, num_clients, replace=False)
def train_round(self, clients):
"""
This method trains the model with the dataset of the clients. It handles the training at single round level
:param clients: list of all the clients to train
:return: model updates gathered from the clients, to be aggregated
"""
updates = []
for i, c in enumerate(clients):
# TODO: missing code here!
raise NotImplementedError
return updates
def aggregate(self, updates):
"""
This method handles the FedAvg aggregation
:param updates: updates received from the clients
:return: aggregated parameters
"""
# TODO: missing code here!
raise NotImplementedError
def train(self):
"""
This method orchestrates the training the evals and tests at rounds level
"""
for r in range(self.args.num_rounds):
# TODO: missing code here!
raise NotImplementedError
def eval_train(self):
"""
This method handles the evaluation on the train clients
"""
# TODO: missing code here!
raise NotImplementedError
def test(self):
"""
This method handles the test on the test clients
"""
# TODO: missing code here!
raise NotImplementedError