2022. 3. 23. 14:50ㆍAI 기술
참고 논문 : https://arxiv.org/pdf/1909.11875.pdf
1. Introduction
우리가 알고 있는 전통적인 cloud-centric 방식은 엣지 클라우드(클라이언트, 데이터 오너 등등)로 부터 데이터를 수집하여 중앙 서버가 데이터를 처리해 주는 방식을 이용해 왔다.
그런데 현대 사회에서는 다음과 같은 문제 때문에 decentric한 방법이 대두되고 있다.
- 데이터 소유자의 프라이버시 문제
- Center로 데이터를 전송하게 되면 데이터의 양을 감당하기가 힘듬
분산학습 vs 연합학습
둘의 학습되는 과정은 다를게 없지만, 가정이 다르다.
분산학습의 경우 하나의 모델을 병렬적으로 학습하고 각 데이터가 독립적이며 동일한 분포를 갖는다고 가정한다 (IID)
연합학습은 Non-IID를 가정하기에 중앙서버에서 Aggregation을 함으로써 데이터의 불균형을 맞춘다고 볼 수 있다.
그래서 등장한 개념이 Federated Learning이다. 이 Decentralized ML 방식은 오로지 학습된 model의 weight만을 서버로 전송하게 된다.
- 네트워크와 메모리의 효율적인 사용 : 데이터를 전송시켜 학습하는 것이 아닌 모델의 파라미터만 전송하기에 네트워킹 및 메모리 비용이 크게 감소한다
- 보안 : 사용자의 개인적인 정보가 넘어가는 것이 아니므로 사용자의 개인정보를 보호한다
하지만 Real World Federated Learning을 성공시킨 사례가 매우 드물고, 왜 이제서야 이러한 개념이 화제가 되고 있는 것일까?
- 최소 2명의 client, 많게는 수백, 수천의 client가 참가한다고 가정하면, 각 사용자간의 다양성(데이터의 질과 양, 컴퓨터의 성능, 등)의 관리가 필요하다.
2. FL Step
- Step 1 (Task Initialization) : 학습에 참여할 client 정의 / 학습에 사용할 model 배포
- Step 2 (Client model training) : 각 client는 자신의 데이터로 베포된 model을 학습시킨다. 각 사용자들의 목표는 한 round동안 train loss가 최소가 되는 weight를 서버에게 전달하는 것
- Step 3 (Server model update) : FL의 핵심이 되는 과정, 각 client에게 받은 weight들을 알고리즘을 통해 aggregation하고, update된 weight를 다시 클라이언트에게 전송
3. Challenge of FL
원래 분산학습이 된 ML 모델은 center server가 모든 데이터에 접근가능했다. 즉, 데이터셋을 서로 비슷한 분포를 가지게 나눌수 있었다. 하지만, FL 방식은 center server는 어떠한 데이터도 열람할 수 없으므로 문제를 야기한다.
실제로 FedAvg논문 저자는 Non-IID 한 데이터에서 학습된 CNN 모델은 중앙집중학습된 모델보다 51%나 더 낮은 정확도를 보였다고 밝혔다.
이렇게 각 client마다 서로 다른 데이터 분포를 가지고 있는 경우에, multi task learning의 개념을 수용하려고 많은 노력이 이루어졌다. 일반적인 loss를 사용하지 않고, task간의 관계를 고려하여 loss function을 수정하는 것이다.
* mulit task learning : 의료 분야 같이 한정적이고 부족한 데이터셋을 보완하기 위해 -> 연관성 있는 여러가지 task를 동시에 학습시켜 활용하는 방법
FedPer Algorithm
통계적 이질성을 처리하기 위해 multi task learning 개념을 이용, 각 client들은 FedAvg 알고리즘을 이용하여 base layer을 학습한다. 그 다음, 각자의 데이터를 가진 사용자는 자신만의 데이터를 가지고 personalization layer를 학습하게 된다.
FedProx
training loss가 증가할 때, 모델 업데이트가 현재 client의 파라미터에 영향을 적게 받도록 adaptive하게 조정한다
LoAdaBoostFedAvg
client들로 부터 받은 가중치를 aggregation하기 전에 전 round의 loss와 현 round loss와 비교한다. 만약 지금의 손실값이 더 크다면 aggregation을 바로 하는 것이 아니라, 재학습을 시킨다.
4. Communication Cost
사실 분산학습에서 연합학습으로 넘어 갈때 가장 난항을 겪는 문제 중 하나이다. 이 부분은 현재까지 나온 코드나 논문을 봐도 제대로 해결 된 것 같아 보이진 않는다.
(실제로 제가 연합학습 네트워크를 구축하는 과정에서 지금까지도 어려움을 겪고 있습니다. Server와 Client의 연결이 수시로 끊긴 다거나, Real-World로 연결 되는 라이브러리조차도 거의 없습니다.)
이러한 통신문제가 발생하는 이유는,
- 각 Client의 불안정한 네트워크 상태(WIFI가 중간에 끊긴다던가..)
- client의 가중치 업로드 속도가 다운로드 속도보다 빠른 경우
Cost를 어떻게 낮출 수 있을지?
- Edge and End Computation: FL 셋업에서 종종 통신 바용이 계산 비용을 뛰어 넘을 때가 있다. 왜냐하면 디바이스 내의 데이터셋은 상대적으로 작고 점점 참가자들의 모바일 디바이스의 프로세서는 빨라지고 있기 때문이다. 반면에 참가자들은 모델 학습을 오직 Wi-Fi에 연결되어 있을 때만 수행하길 원한다. 따라서 모델 학습에 필요한 통신 라운드 수를 줄이기 위해 각 aggregation 전에 edge node 또는 최종 디바이스에서 더 많은 계산을 수행할 수 있다. 또한 알고리즘의 빠른 수렴이 보장되면 edge server 및 최종 장치에서 더 많은 계산을 수행하는 대신 관련된 라운드 수를 줄일 수 있다.
- Model Compression: 분산 학습에 공통적으로 쓰이는 기법이다. 모델 또는 gradient 압축은 업데이트 관련 통신을 간결하게 할 수 있다. 완전한 업데이트 통신 보다는 quantization, subsampling 등으로 압축하여 업데이트한다. 하지만 이런 압축으로 인해 노이즈가 발생할 수 있으므로 각 라운드 동안 전송되는 업데이트의 사이즈를 줄이면서도 학습 모델의 품질을 유지하는 것이 목표이다.
- Importance-based Updating: 각 라운드에서 오직 중요하거나 관련 있는 업데이트만을 선택하여 통신하는 전략이다. 실제로 통신 비용을 절약하는 것 이외에도 참가자 일부 업데이트를 생략하면 global 모델의 성능을 향상시킬 수도 있다.
5. FLOWER - Federated GAN code
from collections import OrderedDict
import warnings
import flwr as fl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import numpy as np
from torch.nn import GroupNorm
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
from imutils import build_montages
from model import *
warnings.filterwarnings("ignore", category=UserWarning)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
def train(netG, netD, trainloader, epochs):
"""Train the network on the training set."""
netG, netD = netG.to(DEVICE), netD.to(DEVICE)
netD.apply(weights_init)
netG.apply(weights_init)
criterion = torch.nn.BCELoss()
optimizerD = torch.optim.Adam(netD.parameters(), lr = 0.001, betas=(0.5, 0.999))
optimizerG = torch.optim.Adam(netG.parameters(), lr=0.001, betas=(0.5, 0.999))
real_label = 1
fake_label = 0
netG.train()
netD.train()
for _ in range(epochs):
lossG_item = 0
lossD_item = 0
print(len(trainloader))
for images, labels in trainloader:
images, labels = images.to(DEVICE), labels.to(DEVICE)
netD.zero_grad()
output = netD(images)
b_size = images.size(0)
label = torch.full((b_size,), real_label, dtype=torch.float, device=DEVICE)
lossD_real = criterion(output, label)
lossD_real.backward()
noise = torch.randn(b_size,100,1,1, device=DEVICE)
fake = netG(noise)
output = netD(fake.detach())
#label = torch.zeros_like(output)
label.fill_(fake_label)
lossD_fake = criterion(output, label)
lossD_fake.backward()
lossD = lossD_fake + lossD_real
optimizerD.step()
netG.zero_grad()
label.fill_(real_label)
output = netD(fake)
lossG = criterion(output,label)
lossG.backward()
optimizerG.step()
print(lossG.item())
lossG_item += lossG.item()
if(epochs == 1):
return lossG_item
def test(netG, netD, testloader):
"""Validate the network on the entire test set."""
criterion = torch.nn.BCELoss()
optimizerD = torch.optim.Adam(netD.parameters(), lr = 0.001, betas=(0.5, 0.999))
optimizerG = torch.optim.Adam(netG.parameters(), lr=0.001, betas=(0.5, 0.999))
real_label = 1
fake_label = 0
netD.eval()
netG.eval()
with torch.no_grad():
for images, labels in testloader:
images, labels = images.to(DEVICE), labels.to(DEVICE)
outputs = netG(images)
loss += criterion(outputs, labels).item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
loss /= len(testloader.dataset)
accuracy = correct / total
return loss, accuracy
def load_data():
"""Load CIFAR-10 (training and test set)."""
transform = transforms.Compose(
[transforms.Scale(64),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
trainset = CIFAR10("./dataset", train=True, download=True, transform=transform)
testset = CIFAR10("./dataset", train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True)
testloader = DataLoader(testset, batch_size=32)
num_examples = {"trainset": len(trainset), "testset": len(testset)}
return trainloader, testloader, num_examples
def main():
"""Create model, load data, define Flower client, start Flower client."""
netG = Generator()
netD = Discriminator()
# Load data (CIFAR-10)
trainloader, testloader, num_examples = load_data()
# Flower client
class CifarClient(fl.client.NumPyClient):
def get_parameters(self):
return [val.cpu().numpy() for _, val in netG.state_dict().items()]
def set_parameters(self, parameters):
params_dict = zip(netG.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
netG.load_state_dict(state_dict, strict=True)
def fit(self, parameters, config):
self.set_parameters(parameters)
_ = train(netG, netD, trainloader, epochs=1)
return self.get_parameters(), num_examples["trainset"], {}
def evaluate(self, parameters, config):
self.set_parameters(parameters)
loss = train(netG, netD, trainloader, epochs=1)
return float(loss), num_examples["testset"]
# Start client
fl.client.start_numpy_client("[::]:8080", client=CifarClient())
if __name__ == "__main__":
main()
import flwr as fl
from typing import List,Tuple
import numpy as np
class AggregateCustomMetricStrategy(fl.server.strategy.FedAvg):
def aggregate_evaluate(
self,
rnd: int,
results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
failures: List[BaseException],
):
"""Aggregate evaluation losses using weighted average."""
if not results:
return None
accuracies = [r.metrics["accuracy"] * r.num_examples for _, r in results]
examples = [r.num_examples for _, r in results]
accuracy_aggregated = sum(accuracies) / sum(examples)
print(f"Round {rnd} accuracy aggregated from client results: {accuracy_aggregated}")
return super().aggregate_evaluate(rnd, results, failures)
strategy = AggregateCustomMetricStrategy(
fraction_fit=0.1, # Sample 10% of available clients for the next round
min_fit_clients=2, # Minimum number of clients to be sampled for the next round
min_available_clients=2, # Minimum number of clients that need to be connected to the server before a training round can start
)
fl.server.start_server(server_address="[::]:8080", config={"num_rounds": 3}, strategy=strategy)
6. Block chain - Based Federated Learning
영상은 연구보다는 좀더 상용성에 중점을 맞추어 설명되어 있습니다.
https://www.youtube.com/watch?v=HMEUb78E3CQ
'AI 기술' 카테고리의 다른 글
mmsegmentation 공략하기(설치부터 custom dataset학습까지) (0) | 2023.06.07 |
---|---|
내가 필요해서 정리하는 JAX(2)(JAX의 설치 / 직접 설치해보고 적는 방법) (2) | 2023.04.24 |
내가 필요해서 정리하는 JAX(1)(JAX의 기초부터 XLA까지) (0) | 2023.04.18 |
[Linux] 특정 디렉토리의 Tree 구조만 복사하기 (0) | 2022.11.17 |
Federated Learning에 다양한 GAN 적용시키기 (1) | 2022.05.13 |