Federated Learning (Aalto University CS-E4740 2024)
Links
Content
Machine Learning (ML) Basics
ML starts with data. They are simply objects, represented as a vector characterized by some features alongside some label. The goal is to predict some quantity of interest, often the label, by finding some hypothesis (function) that maps the feature vector to the label. The hypothesis comes from a hypothesis space which we define, e.g. the space of linear functions. The data and the hypothesis are related through some loss function which represents how far the prediction of a hypothesis for for some feature is from the true label. These (data, hypothesis, loss function) are the components of the Empirical Risk Minimization (ERM). E.g. in a weather measurement example, the features can be the timestamp, latitude, longitude and the label can be the temperature.
In linear models, the optimization problem is convex and hence can be solved in closed form. This convexity property sets apart linear and nonlinear regression. Note however that this closed form solution contains a matrix inverse. One work around is to use a pseudo-inverse, or using iterative methods like gradient descent which would converge regardless of matrix invertibility.
The 2 main questions of a Machine Learning scientist: (For mathematicians, its (a) does the problem have a solution and (b) is it unique). (Student suggestion: How to interpret deep learning mathematically).
- How can I solve this optimization problem, and what could be said about the computational complexity of it?
- How useful is the solution - what are the statistical properties?
In federated learning, data and models are represented as networks, also called empirical graphs.
The data is assumed to be samples from some probability distribution \(p(x,y)\) over the joint space of features \(\mathcal{X}\) and labels \(\mathcal{Y}\). We often make the assumption that the observed data are independent and identically distributed (iid). Now, given a probability distribution over the data, we can define the performance (measured by a loss function \(L\)) of a hypothesis (parametrized by \(\mathbf{w}\)) by \(\mathbb{E}_{x,y\sim p(x,y)}[L(x,y,\mathbf{w})]:=R(\mathbf{w})\) - we call this expectation the risk.
For example, assuming a linear model \(y = \mathbf{w^{ * T}x} + \epsilon\) (aleatoric noise \(\epsilon \sim \mathcal{N}(0,\sigma^2)\)) with a squared loss function, we get \(R(\mathbf{ w })=||\mathbf{ w }^*-\mathbf{ w }||^2 + \sigma^2\). See Proposition 2.1 in the lecture notes for bounds on the risk under these assumptions, more specifically \(R(\mathbf{ w })\) is bounded by \(\frac{4}{m^2}||(\epsilon_1,\cdots,\epsilon_m)\mathbf{ X }||^2/\lambda_1^2\) where \(\lambda_1\) is the largest eigenvalue of \(\mathbf{ Q }=\frac{1}{m}\mathbf{ X}^T\mathbf{X }\). This implies that we can reduce \(R(\mathbf{ w })\) by increasing \(\lambda_1\) - this motivates whitening methods, where feature vectors are transformed to look like realizations from white noise.
We would want to answer the following questions:
- Do the assumptions make sense? We can test this via tools like correlation tests.
- Did we overfit to the data? For this we can use validation sets.
- The training-validation split can be determined by the Law of Large Numbers (LLN).
- If we did overfit, we could:
- Make the model smaller.
- Do regularization.
Suppose we only have 1 sample \((x_1,y_1)\) (in a regression setting) - we could do data augmentation by sampling a lot of features centered around \(x_1\), with standard deviation \(\sqrt{\alpha}\). This introduces extra terms in the loss function, namely \(\sum_{i=1}^{l}(y_i-\mathbf{w}^Tx_i)^2\). Via the LLN (see lecture notes) this converges to \(\sigma^2 ||\mathbf{ w }||^2\), which is an extra regularization term which reframes the problem as ridge regression. This problem can also be framed as optimizing over a constrained set for \(\mathbf{ w }\) via Lagrangian duality. (There is also the probabilistic modelling perspective of adding a Gaussian prior over \(\mathbf{ w }\)).
A specific choice for the regularization term can be used for federated learning where we can couple the training of multiple datasets.
Federated Learning Design Principle
In a Federated Learning (FL) setting, we have separate data sources within "nodes" and train local models within each nodes - this allows for e.g. to learn a model without sharing the data between the nodes.
In practice, we can use Python's NetworkX package which allows us to store arbitrary Python objects like machine learning models per node.
The main mathematical object of concern is the empirical graph, where we assume:
- Nodes are labelled from \(1,\cdots,n\).
- Each node \(i\) is associated with a model parametrized by \(\mathbf{w}_i\).
- Edges are undirected.
- Each edge \((i,j)\) has a non-negative weight \(w_{ij}\) which represents the coupling between nodes \(i\) and \(j\).
As the weights represent coupling between data sources, there is a lot of focus on heuristic (e.g. distance based weights for weather stations), statistical (e.g. t-tests for bioinformatics) and graph methods (e.g. connectivity measures) which aim to compute these weights, and thus get the empirical graph.
The Graph Laplacian for a graph with \(n\) nodes is an \(n\times n\) matrix defined to be \(\mathbf{L}=\mathbf{D}-\mathbf{A}\) where \(\mathbf{ D }\) is the degree matrix - diagonal matrix where the \(k^{th}\) diagonal has the degree of node \(k\), and \(\mathbf{ A }\) is the graph adjacency matrix. As we are working with undirected graphs, \(\mathbf{ A }\) is symmetric and thus so is \(\mathbf{L}\). The Laplacian is very informative of the global connectivity of the graph. Some properties of the Laplacian include:
- Rows sums to 0.
- The vector \(\mathbf{ 1 }\) with an eigenvector of \(\mathbf{ L }\) with eigenvalue 0 - thus \(\mathbf{ L }\) is non-invertible.
The Laplacian allows us to measure the total variation (squared Euclidean norm) of the parameter vectors associated with nodes in the empirical graph - namely the total variation can be written as a quadratic form, which can be shown to be non-negative thus \(\mathbf{ L }\) is positive semi-definite (PSD).
We already know that 0 is an eigenvalue of \(\mathbf{ L }\), if the next smallest eigenvalue \(\lambda_2=0\), there are at least 2 disconnected components, or equally, if \(\lambda_2\neq 0\), the graph is connected (qualitative statement) i.e. each node is reachable from any other node. \(\lambda_2\) also measures the connectivity of the graph (quantitative statement) - namely we have the relation \(\sum_{(i,j)\in E}^{}A_{ij}(\mathbf{ w }^{(i)}-\mathbf{ w }^{(j)})^2\geq \lambda_2 \sum_{i}^{}(\mathbf{ w }^{(i)}-\bar{\mathbf{ w }}^{(i)})^2\) where \(\bar{\mathbf{ w }}^{(i)}\) is the mean of weight vectors connected to node \(i\). #TODO Understand this relation and derive it.
The FL principle: For each node, we can minimize the training error in parallel. We now want to couple nodes in the empirical graph, which is equivalent to penalize the variation over the edges. This is achieved by supplementing the loss function with the total variation by multiplied with a constant \(\alpha\) called the GVTMin parameter. The result is an instance of regularized ERM.
Large \(\alpha\) values result in all nodes learning the same model - this is called single-model FL, and is common in applications like healthcare. Moderate \(\alpha\) values result in similar models for nodes in the same clusters, where information is pooled between similar nodes. Small \(\alpha\) values results gives more personalized training for each local dataset. This can be an issue if some nodes have few data samples, and may result in problems such as overfitting.
Note that more pooling results in training models more for the average case, which results in less personalization. As a result, selecting \(\alpha\) is important, and this can be guided by structural properties of the empirical graph.
GTV Minimization is equivalent to minimizing a quadratic form, see lecture notes.
[#TODO Read lecture notes Section 1,2,3].