-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathclient.py
More file actions
103 lines (77 loc) · 3.11 KB
/
client.py
File metadata and controls
103 lines (77 loc) · 3.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import flwr as fl
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from sklearn.preprocessing import StandardScaler
from torch.utils.data import TensorDataset, DataLoader
from collections import OrderedDict
# Define FCN model
class Net(nn.Module):
def __init__(self, input_dim):
super(Net, self).__init__()
self.fc1 = nn.Linear(input_dim, 64)
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(32, 2)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
# Convert weights to/from Flower format
def get_weights(model):
return [val.cpu().numpy() for _, val in model.state_dict().items()]
def set_weights(model, parameters):
keys = list(model.state_dict().keys())
state_dict = OrderedDict({k: torch.tensor(v) for k, v in zip(keys, parameters)})
model.load_state_dict(state_dict, strict=True)
# Load client-specific data
def load_data(client_id):
file_path = f"heart_{client_id + 1}.csv"
df = pd.read_csv(file_path)
x = df.drop("target", axis=1).values.astype("float32")
y = df["target"].values.astype("int64")
x = StandardScaler().fit_transform(x)
x = torch.tensor(x)
y = torch.tensor(y)
split = int(0.8 * len(x))
x_train, x_test = x[:split], x[split:]
y_train, y_test = y[:split], y[split:]
return (x_train, y_train), (x_test, y_test), x.shape[1]
# Flower client implementation
class HeartClient(fl.client.NumPyClient):
def __init__(self, client_id):
(self.x_train, self.y_train), (self.x_test, self.y_test), input_dim = load_data(client_id)
self.model = Net(input_dim)
self.loss_fn = nn.CrossEntropyLoss()
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
def get_parameters(self, config):
return get_weights(self.model)
def fit(self, parameters, config):
set_weights(self.model, parameters)
self.model.train()
dataset = TensorDataset(self.x_train, self.y_train)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
for _ in range(1): # Single epoch
for x_batch, y_batch in dataloader:
self.optimizer.zero_grad()
outputs = self.model(x_batch)
loss = self.loss_fn(outputs, y_batch)
loss.backward()
self.optimizer.step()
return get_weights(self.model), len(self.x_train), {}
def evaluate(self, parameters, config):
set_weights(self.model, parameters)
self.model.eval()
with torch.no_grad():
outputs = self.model(self.x_test)
loss = self.loss_fn(outputs, self.y_test).item()
preds = outputs.argmax(1)
accuracy = (preds == self.y_test).float().mean().item()
return loss, len(self.x_test), {"accuracy": accuracy}
# Start the client
def start_client(client_id):
fl.client.start_numpy_client(server_address="localhost:8080", client=HeartClient(client_id))
if __name__ == "__main__":
import sys
client_id = int(sys.argv[1])
start_client(client_id)