-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathserver.py
More file actions
83 lines (66 loc) · 2.48 KB
/
server.py
File metadata and controls
83 lines (66 loc) · 2.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import flwr as fl
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, TensorDataset
# Define FCN model for tabular data
class Net(nn.Module):
def __init__(self, input_dim):
super(Net, self).__init__()
self.fc1 = nn.Linear(input_dim, 64)
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(32, 2) # 2 classes for CrossEntropyLoss
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
# Load test data for server-side evaluation
def load_test_data():
df = pd.read_csv("data_test.csv")
x = df.drop("target", axis=1).values.astype("float32")
y = df["target"].values.astype("int64")
x = StandardScaler().fit_transform(x)
x = torch.tensor(x)
y = torch.tensor(y)
dataset = TensorDataset(x, y)
dataloader = DataLoader(dataset, batch_size=32)
return dataloader, x.shape[1] # Return input_dim for Net()
# Set weights in the model
def set_weights(model, parameters):
state_dict = dict(zip(model.state_dict().keys(), [torch.tensor(p) for p in parameters]))
model.load_state_dict(state_dict, strict=True)
# Evaluation function
def evaluate_fn(server_round, parameters, config):
testloader, input_dim = load_test_data()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net(input_dim).to(device)
set_weights(model, parameters)
model.eval()
criterion = nn.CrossEntropyLoss()
loss = 0.0
correct = 0
with torch.no_grad():
for x_batch, y_batch in testloader:
x_batch, y_batch = x_batch.to(device), y_batch.to(device)
output = model(x_batch)
loss += criterion(output, y_batch).item()
correct += (output.argmax(1) == y_batch).sum().item()
accuracy = correct / len(testloader.dataset)
print(f"[Server Evaluation] Round {server_round} - Loss: {loss:.4f}, Accuracy: {accuracy:.4f}")
return loss, {"accuracy": accuracy}
# Start Flower server
if __name__ == "__main__":
strategy = fl.server.strategy.FedAvg(
fraction_fit=1.0,
min_fit_clients=3,
min_available_clients=3,
evaluate_fn=evaluate_fn,
on_fit_config_fn=lambda rnd: {"round": rnd}
)
fl.server.start_server(
server_address="localhost:8080",
config=fl.server.ServerConfig(num_rounds=15),
strategy=strategy
)