Skip to content

Commit c99b722

Browse files
add a text classification example
1 parent 1518e45 commit c99b722

File tree

1 file changed

+194
-0
lines changed

1 file changed

+194
-0
lines changed

examples/distilbert_demo.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
import logging
2+
import multiprocessing
3+
from typing import Any
4+
5+
import torch
6+
from datasets import load_dataset
7+
from torch import Tensor, nn
8+
from torch.nn import Module
9+
from torch.nn.modules.loss import _Loss
10+
from torch.optim import Optimizer
11+
from torch.utils.data import DataLoader
12+
from transformers import DistilBertConfig
13+
from transformers.models.distilbert.modeling_distilbert import (
14+
DistilBertForSequenceClassification,
15+
)
16+
from transformers.models.distilbert.tokenization_distilbert_fast import (
17+
DistilBertTokenizerFast,
18+
)
19+
20+
from fedmind.algs.fedavg import FedAvg
21+
from fedmind.config import get_config
22+
from fedmind.data import ClientDataset
23+
from fedmind.utils import EasyDict, StateDict
24+
25+
26+
class FedAvgTextClassification(FedAvg):
27+
"""Federated Averaging Algorithm for Text Classification."""
28+
29+
def __init__(self, model, fed_loader, test_loader, criterion, config):
30+
super().__init__(model, fed_loader, test_loader, criterion, config)
31+
self.logger.info(f"Start {self.__class__.__name__}.")
32+
33+
@staticmethod
34+
def _train_client(
35+
model: Module,
36+
gm_params: StateDict,
37+
train_loader: DataLoader,
38+
optimizer: Optimizer,
39+
criterion: _Loss,
40+
epochs: int,
41+
logger: logging.Logger,
42+
config: EasyDict,
43+
*args,
44+
**kwargs,
45+
) -> dict[str, Any]:
46+
"""Train the model with given environment.
47+
48+
Args:
49+
model: The model to train.
50+
gm_params: The global model parameters.
51+
train_loader: The DataLoader object that contains the training data.
52+
optimizer: The optimizer to use.
53+
criterion: The loss function to use.
54+
epochs: The number of epochs to train the model.
55+
logger: The logger object to log the training process.
56+
config: The configuration dictionary.
57+
58+
Returns:
59+
A dictionary containing the trained model parameters, training loss and more.
60+
"""
61+
# Train the model
62+
model.load_state_dict(gm_params)
63+
cost = 0.0
64+
model.train()
65+
for epoch in range(epochs):
66+
logger.debug(f"Epoch {epoch + 1}/{epochs}")
67+
for batch in train_loader:
68+
labels = batch.pop("label").to(config.DEVICE)
69+
inputs = {k: v.to(config.DEVICE) for k, v in batch.items()}
70+
optimizer.zero_grad()
71+
outputs = model(**inputs)
72+
loss: Tensor = criterion(outputs.logits, labels)
73+
loss.backward()
74+
optimizer.step()
75+
if loss.isnan():
76+
logger.warning("Loss is NaN.")
77+
cost += loss.item()
78+
79+
return {
80+
"model_update": model.state_dict(destination=StateDict()) - gm_params,
81+
"train_loss": cost / len(train_loader) / epochs,
82+
}
83+
84+
@staticmethod
85+
def _test_server(
86+
model: Module,
87+
gm_params: StateDict,
88+
test_loader: DataLoader,
89+
criterion: _Loss,
90+
logger: logging.Logger,
91+
config: EasyDict,
92+
*args,
93+
**kwargs,
94+
) -> dict:
95+
"""Test the model.
96+
97+
Args:
98+
model: The model to test.
99+
gm_params: The global model parameters.
100+
test_loader: The DataLoader object that contains the test data.
101+
criterion: The loss function to use.
102+
logger: The logger object to log the testing process.
103+
config: The configuration dictionary.
104+
105+
Returns:
106+
The evaluation metrics dict.
107+
"""
108+
109+
total_loss = 0
110+
correct = 0
111+
total = 0
112+
model.load_state_dict(gm_params)
113+
model.eval()
114+
with torch.no_grad():
115+
for batch in test_loader:
116+
labels = batch.pop("label").to(config.DEVICE)
117+
inputs = {k: v.to(config.DEVICE) for k, v in batch.items()}
118+
outputs = model(**inputs)
119+
loss: Tensor = criterion(outputs.logits, labels)
120+
total_loss += loss.item()
121+
predicted = torch.argmax(outputs.logits, 1)
122+
total += labels.size(0)
123+
correct += (predicted == labels).sum().item()
124+
125+
accuracy = correct / total
126+
logger.info(f"Test Loss: {total_loss:.4f}, Accuracy: {accuracy:.4f}")
127+
128+
return {"test_loss": total_loss, "test_accuracy": accuracy}
129+
130+
131+
def test_fedavg():
132+
# 0. Prepare necessary arguments
133+
config = get_config("config.yaml")
134+
if config.SEED >= 0:
135+
torch.manual_seed(config.SEED)
136+
137+
multiprocessing.set_start_method("spawn") # avoid deadlock of tokenizer with mp
138+
139+
# 1. Prepare Federated Learning DataSets
140+
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
141+
142+
org_ds = load_dataset("IMDB", split="train", cache_dir="dataset").map(
143+
lambda x: tokenizer(
144+
x["text"], truncation=True, padding="max_length", max_length=512
145+
),
146+
batched=True,
147+
)
148+
org_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"]) # type: ignore
149+
150+
test_ds = load_dataset("IMDB", split="test", cache_dir="dataset").map(
151+
lambda x: tokenizer(
152+
x["text"], truncation=True, padding="max_length", max_length=512
153+
),
154+
batched=True,
155+
)
156+
test_ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"]) # type: ignore
157+
158+
effective_size = len(org_ds) - len(org_ds) % config.NUM_CLIENT # type: ignore
159+
idx_groups = torch.randperm(effective_size).reshape(config.NUM_CLIENT, -1)
160+
fed_dss = [ClientDataset(org_ds, idx) for idx in idx_groups.tolist()] # type: ignore
161+
162+
genetors = [
163+
torch.Generator().manual_seed(config.SEED + i) if config.SEED >= 0 else None
164+
for i in range(config.NUM_CLIENT)
165+
]
166+
fed_loader = [
167+
DataLoader(ds, config.BATCH_SIZE, shuffle=True, generator=gtr)
168+
for ds, gtr in zip(fed_dss, genetors)
169+
]
170+
test_loader = DataLoader(test_ds, config.BATCH_SIZE * 4) # type: ignore
171+
# for batch in test_loader:
172+
# for k, v in batch.items():
173+
# print(f"{k} type: {type(v)}")
174+
175+
# 2. Prepare Model and Criterion
176+
classes = 2
177+
model = model = DistilBertForSequenceClassification(
178+
DistilBertConfig(num_labels=classes)
179+
)
180+
181+
criterion = nn.CrossEntropyLoss()
182+
183+
# 3. Run Federated Learning Simulation
184+
FedAvgTextClassification(
185+
model=model,
186+
fed_loader=fed_loader,
187+
test_loader=test_loader,
188+
criterion=criterion,
189+
config=config,
190+
).fit(config.NUM_CLIENT, config.ACTIVE_CLIENT, config.SERVER_EPOCHS)
191+
192+
193+
if __name__ == "__main__":
194+
test_fedavg()

0 commit comments

Comments
 (0)