Skip to content

Commit 5012fce

Browse files
authored
Merge pull request #191 from XinghaoWu/master
Add FedCAC algorithm
2 parents 097b6ed + 8725a06 commit 5012fce

4 files changed

Lines changed: 251 additions & 2 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ Thanks to [@Stonesjtu](https://github.com/Stonesjtu/pytorch_memlab/blob/d590c489
9898
- **GPFL**[GPFL: Simultaneously Learning Generic and Personalized Feature Information for Personalized Federated Learning](https://arxiv.org/pdf/2308.10279v3.pdf) *ICCV 2023*
9999
- **FedGH**[FedGH: Heterogeneous Federated Learning with Generalized Global Header](https://dl.acm.org/doi/10.1145/3581783.3611781) *ACM MM 2023*
100100
- **DBE**[Eliminating Domain Bias for Federated Learning in Representation Space](https://openreview.net/forum?id=nO5i1XdUS0) *NeurIPS 2023*
101+
- **FedCAC**[Bold but Cautious: Unlocking the Potential of Personalized Federated Learning through Cautiously Aggressive Collaboration](https://arxiv.org/abs/2309.11103) *ICCV 2023*
101102

102103
***Knowledge-distillation-based pFL***
103104

system/flcore/clients/clientcac.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import numpy as np
2+
import time
3+
import torch
4+
import torch.nn as nn
5+
import copy
6+
from flcore.clients.clientbase import Client
7+
8+
class clientCAC(Client):
9+
def __init__(self, args, id, train_samples, test_samples, **kwargs):
10+
super().__init__(args, id, train_samples, test_samples, **kwargs)
11+
self.args = args
12+
self.critical_parameter = None # record the critical parameter positions in FedCAC
13+
self.customized_model = copy.deepcopy(self.model) # customized global model
14+
self.critical_parameter, self.global_mask, self.local_mask = None, None, None
15+
16+
def train(self):
17+
trainloader = self.load_train_data()
18+
19+
start_time = time.time()
20+
21+
# record the model before local updating, used for critical parameter selection
22+
initial_model = copy.deepcopy(self.model)
23+
24+
# self.model.to(self.device)
25+
self.model.train()
26+
27+
max_local_epochs = self.local_epochs
28+
if self.train_slow:
29+
max_local_epochs = np.random.randint(1, max_local_epochs // 2)
30+
31+
for epoch in range(max_local_epochs):
32+
for i, (x, y) in enumerate(trainloader):
33+
if type(x) == type([]):
34+
x[0] = x[0].to(self.device)
35+
else:
36+
x = x.to(self.device)
37+
y = y.to(self.device)
38+
if self.train_slow:
39+
time.sleep(0.1 * np.abs(np.random.rand()))
40+
output = self.model(x)
41+
loss = self.loss(output, y)
42+
self.optimizer.zero_grad()
43+
loss.backward()
44+
self.optimizer.step()
45+
46+
if self.learning_rate_decay:
47+
self.learning_rate_scheduler.step()
48+
49+
# self.model.to('cpu')
50+
51+
# select the critical parameters
52+
self.critical_parameter, self.global_mask, self.local_mask = self.evaluate_critical_parameter(
53+
prevModel=initial_model, model=self.model, tau=self.args.tau
54+
)
55+
56+
self.train_time_cost['num_rounds'] += 1
57+
self.train_time_cost['total_cost'] += time.time() - start_time
58+
59+
def evaluate_critical_parameter(self, prevModel: nn.Module, model: nn.Module, tau: float):
60+
r"""
61+
Overview:
62+
Implement critical parameter selection.
63+
"""
64+
global_mask = [] # mark non-critical parameter
65+
local_mask = [] # mark critical parameter
66+
critical_parameter = []
67+
68+
# self.model.to(self.device)
69+
# prevModel.to(self.device)
70+
71+
# select critical parameters in each layer
72+
for (name1, prevparam), (name2, param) in zip(prevModel.named_parameters(), model.named_parameters()):
73+
g = (param.data - prevparam.data)
74+
v = param.data
75+
c = torch.abs(g * v)
76+
77+
metric = c.view(-1)
78+
num_params = metric.size(0)
79+
nz = int(tau * num_params)
80+
top_values, _ = torch.topk(metric, nz)
81+
thresh = top_values[-1] if len(top_values) > 0 else np.inf
82+
# if threshold equals 0, select minimal nonzero element as threshold
83+
if thresh <= 1e-10:
84+
new_metric = metric[metric > 1e-20]
85+
if len(new_metric) == 0: # this means all items in metric are zero
86+
print(f'Abnormal!!! metric:{metric}')
87+
else:
88+
thresh = new_metric.sort()[0][0]
89+
90+
# Get the local mask and global mask
91+
mask = (c >= thresh).int().to('cpu')
92+
global_mask.append((c < thresh).int().to('cpu'))
93+
local_mask.append(mask)
94+
critical_parameter.append(mask.view(-1))
95+
model.zero_grad()
96+
critical_parameter = torch.cat(critical_parameter)
97+
98+
# self.model.to('cpu')
99+
# prevModel.to('cpu')
100+
101+
return critical_parameter, global_mask, local_mask
102+
103+
def set_parameters(self, model):
104+
if self.local_mask != None:
105+
# self.model.to(self.device)
106+
# model.to(self.device)
107+
# self.customized_model.to(self.device)
108+
109+
index = 0
110+
for (name1, param1), (name2, param2), (name3, param3) in zip(
111+
self.model.named_parameters(), model.named_parameters(),
112+
self.customized_model.named_parameters()):
113+
param1.data = self.local_mask[index].to(self.device).float() * param3.data + \
114+
self.global_mask[index].to(self.args.device).float() * param2.data
115+
index += 1
116+
117+
# self.model.to('cpu')
118+
# model.to('cpu')
119+
# self.customized_model.to('cpu')
120+
else:
121+
super().set_parameters(model)

system/flcore/servers/servercac.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import time
2+
import torch
3+
import copy
4+
5+
from flcore.clients.clientcac import clientCAC
6+
from flcore.servers.serverbase import Server
7+
from utils.data_utils import read_client_data
8+
9+
class FedCAC(Server):
10+
def __init__(self, args, times):
11+
super().__init__(args, times)
12+
args.beta = int(args.beta)
13+
# select slow clients
14+
self.set_slow_clients()
15+
self.set_clients(clientCAC)
16+
17+
print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}")
18+
print("Finished creating server and clients.")
19+
20+
# self.load_model()
21+
self.Budget = []
22+
23+
# To be consistent with the existing pipeline interface. Maintaining an epoch counter.
24+
self.epoch = -1
25+
26+
def train(self):
27+
for i in range(self.global_rounds+1):
28+
self.epoch = i
29+
s_t = time.time()
30+
self.selected_clients = self.select_clients()
31+
self.send_models()
32+
33+
if i%self.eval_gap == 0:
34+
print(f"\n-------------Round number: {i}-------------")
35+
print("\nEvaluate personalized models")
36+
self.evaluate()
37+
38+
for client in self.selected_clients:
39+
client.train()
40+
41+
# threads = [Thread(target=client.train)
42+
# for client in self.selected_clients]
43+
# [t.start() for t in threads]
44+
# [t.join() for t in threads]
45+
46+
self.receive_models()
47+
self.aggregate_parameters()
48+
49+
self.Budget.append(time.time() - s_t)
50+
print('-'*25, 'time cost', '-'*25, self.Budget[-1])
51+
52+
if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt):
53+
break
54+
55+
print("\nBest accuracy.")
56+
# self.print_(max(self.rs_test_acc), max(
57+
# self.rs_train_acc), min(self.rs_train_loss))
58+
print(max(self.rs_test_acc))
59+
print("\nAverage time cost per round.")
60+
print(sum(self.Budget[1:])/len(self.Budget[1:]))
61+
62+
self.save_results()
63+
64+
if self.num_new_clients > 0:
65+
self.eval_new_clients = True
66+
self.set_new_clients(clientCAC)
67+
print(f"\n-------------Fine tuning round-------------")
68+
print("\nEvaluate new clients")
69+
self.evaluate()
70+
71+
def get_customized_global_models(self):
72+
r"""
73+
Overview:
74+
Aggregating customized global models for clients to collaborate critical parameters.
75+
"""
76+
assert type(self.args.beta) == int and self.args.beta >= 1
77+
overlap_buffer = [[] for i in range(self.args.num_clients)]
78+
79+
# calculate overlap rate between client i and client j
80+
for i in range(self.args.num_clients):
81+
for j in range(self.args.num_clients):
82+
if i == j:
83+
continue
84+
overlap_rate = 1 - torch.sum(
85+
torch.abs(self.clients[i].critical_parameter.to(self.device) - self.clients[j].critical_parameter.to(self.args.device))
86+
) / float(torch.sum(self.clients[i].critical_parameter.to(self.args.device)).cpu() * 2)
87+
overlap_buffer[i].append(overlap_rate)
88+
89+
# calculate the global threshold
90+
overlap_buffer_tensor = torch.tensor(overlap_buffer)
91+
overlap_sum = overlap_buffer_tensor.sum()
92+
overlap_avg = overlap_sum / ((self.args.num_clients - 1) * self.args.num_clients)
93+
overlap_max = overlap_buffer_tensor.max()
94+
threshold = overlap_avg + (self.epoch + 1) / self.args.beta * (overlap_max - overlap_avg)
95+
96+
# calculate the customized global model for each client
97+
for i in range(self.args.num_clients):
98+
w_customized_global = copy.deepcopy(self.clients[i].model.state_dict())
99+
collaboration_clients = [i]
100+
# find clients whose critical parameter locations are similar to client i
101+
index = 0
102+
for j in range(self.args.num_clients):
103+
if i == j:
104+
continue
105+
if overlap_buffer[i][index] >= threshold:
106+
collaboration_clients.append(j)
107+
index += 1
108+
109+
for key in w_customized_global.keys():
110+
for client in collaboration_clients:
111+
if client == i:
112+
continue
113+
w_customized_global[key] += self.clients[client].model.state_dict()[key]
114+
w_customized_global[key] = torch.div(w_customized_global[key], float(len(collaboration_clients)))
115+
# send the customized global model to client i
116+
self.clients[i].customized_model.load_state_dict(w_customized_global)
117+
118+
def send_models(self):
119+
if self.epoch != 0:
120+
self.get_customized_global_models()
121+
122+
super().send_models()

system/main.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from flcore.servers.serverntd import FedNTD
6262
from flcore.servers.servergh import FedGH
6363
from flcore.servers.serveravgDBE import FedAvgDBE
64+
from flcore.servers.servercac import FedCAC
6465

6566
from flcore.trainmodel.models import *
6667

@@ -355,6 +356,10 @@ def run(args):
355356
args.model.fc = nn.Identity()
356357
args.model = BaseHeadSplit(args.model, args.head)
357358
server = FedAvgDBE(args, i)
359+
360+
elif args.algorithm == 'FedCAC':
361+
server = FedCAC(args, i)
362+
358363

359364
else:
360365
raise NotImplementedError
@@ -429,7 +434,7 @@ def run(args):
429434
help="Whether to group and select clients at each round according to time cost")
430435
parser.add_argument('-tth', "--time_threthold", type=float, default=10000,
431436
help="The threthold for droping slow clients")
432-
# pFedMe / PerAvg / FedProx / FedAMP / FedPHP / GPFL
437+
# pFedMe / PerAvg / FedProx / FedAMP / FedPHP / GPFL / FedCAC
433438
parser.add_argument('-bt', "--beta", type=float, default=0.0)
434439
parser.add_argument('-lam', "--lamda", type=float, default=1.0,
435440
help="Regularization weight")
@@ -452,7 +457,7 @@ def run(args):
452457
parser.add_argument('-al', "--alpha", type=float, default=1.0)
453458
# Ditto / FedRep
454459
parser.add_argument('-pls', "--plocal_epochs", type=int, default=1)
455-
# MOON
460+
# MOON / FedCAC
456461
parser.add_argument('-tau', "--tau", type=float, default=1.0)
457462
# FedBABU
458463
parser.add_argument('-fte', "--fine_tuning_epochs", type=int, default=10)

0 commit comments

Comments
 (0)