-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathserver.py
77 lines (61 loc) · 2.63 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#import flwr as fl
# Start Flower server
#fl.server.start_server(
# server_address="0.0.0.0:8080",
# config=fl.server.ServerConfig(num_rounds=3),
#)
from typing import Dict, List, Tuple
import numpy as np
import flwr as fl
from flwr.common import Metrics
NUM_CLIENTS = 3
numRounds = 15
#def get_evaluate_fn(testset: testSet):
# """Return an evaluation function for server-side (i.e. centralised) evaluation."""
# The `evaluate` function will be called after every round by the strategy
# def evaluate(
# server_round: int,
# parameters: fl.common.NDArrays,
# config: Dict[str, fl.common.Scalar],
# ):
# convLayersList = []
# convLayersList.append([16, 3, 'same', 'relu', 1, True, 'max', None, False])
# convLayersList.append([32, 3, 'same', 'relu', 1, True, 'max', None, False])
# convLayersList.append([64, 3, 'same', 'relu', 1, True, 'max', None, False])
#
# convLayers, denseLayers = createCNNLayers(convLayersList, getDataset.getTrainLoaders()[2])
# learningRateList = [1e-8, 0.0001, 0.001, 0.01, 0.1]
#
# initialModel = Sequential(layers.Rescaling(scale = 1./255, input_shape=getDataset.getTrainLoaders()[1]))
# model = makeModel(initialModel, convLayers, denseLayers, learningRateList[2])
# model.compile(optimizer='adam', #Adam(learning_rate=learningRate),
# loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
# metrics=['accuracy'])
# #model = get_model() # Construct the model
# model.set_weights(parameters) # Update model with the latest parameters
# loss, accuracy = model.evaluate(testset, verbose=VERBOSE)
# return loss, {"accuracy": accuracy}
#
# return evaluate
# Define metric aggregation function
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
# Multiply accuracy of each client by number of examples used
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
examples = [num_examples for num_examples, _ in metrics]
# Aggregate and return custom metric (weighted average)
return {"rmse": sum(rmse) / sum(examples)}
# Define strategy
strategy = fl.server.strategy.FedAvg(
evaluate_metrics_aggregation_fn=weighted_average,
#evaluate_fn=get_evaluate_fn(testSet), # global evaluation function
)
# Start Flower server
history = fl.server.start_server(
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=numRounds),
strategy=fl.server.strategy.FedAvg(),
)
#print(getDataset.datasetDir)
#print(history)
print(f"{history.metrics_centralized = }")
plotHistory.plot(history)