Apps from Netflix, Amazon, Google, fraud detection and healthcare algorithms are using federated learning. This way, edge devices like mobile phones can help update ML models while keeping all the data locally — no need for a central server in the loop. Federated learning brings improved apps’ performance and more robust privacy for a user. Read on to find out how.
Federated learning (FL) is a form of collaborative machine learning without centralizing training data that comes in handy for specific industries:
But privacy seems to be the critical reason. Facebook recently announced changes in its advertising policies, applying FL due to privacy considerations. It is a consequence of Apple’s new strategy toward user privacy and building its own advertising business. So, let’s stock up popcorn and enjoy, but delve into the nuts and bolts of FL first.
Usually, an ML-based software algorithm brings data from the edge devices (phones, laptops, tablets, or any IoT devices) to crunch insights in the centralized server. However, advancements in edge AI make it possible to update the copies of ML models locally. Thus, only the updates are sent to the central server, and improvements are sent back to the local devices.
There are two types of FL. Single-party systems imply the one single entity that manages the overall organization. Multi-party systems envision two or more entities managing the overall network of devices.
By the way, we have covered edge analytics before. For details, check out Can Edge Analytics Become a Game Changer?
Simply put, federated learning brings the models to the data sources, which is vice versa to centralized, traditional machine learning. Moreover, the data sets in the data centers are balanced in size and stored on cloud-based platforms.
However, in the federated settings, each device has only the user’s data (e.g., some users may use some features or applications more often than others, so the dataset’s size might differ). Also, data in federated computation is very self-correlated — it is not a sample representing all the users’ behavior. Each device has only one user’s data on it.
First, you need to create and train a general machine learning model on a central server. Thus, the ML model has the logic and functionality to work with the insights gathered from the edge devices. When we have a generic model, there are five steps, or rounds, since it is a sequential process:
Let’s imagine that engineer wonders what users have seen daily high over 70 degrees Fahrenheit on their devices.
Here is the algorithm:
In this case, it is a federated mean, which encodes a protocol for computing the average value over the participating devices.
A baseline model is trained on a central server and sent to the sample of devices.
(adapted from TensorFlow’s talk Federated Learning: Machine Learning on Decentralized Data)
Thus, it comes to the critical stage of federated learning — training ML model on the decentralized data. The main training loop is performed on the decentralized data, considering user privacy.
Thus, engineers may only access combined device reports without giving access to an individual report itself. Moreover, per-device messages are never kept after the updates have been collected — the ephemeral report’s principle as coined by TensorFlow. Such rounds are repeated several times to provide the best answer to the engineer.
The main training loop is performed on the decentralized data
The following principle, focused collection, means that only needed information for computation is collected from devices. Thus, the server computes only a sum of vectors from encrypted device reports thanks to the secure aggregation technique. The server is not able to decrypt the individual messages.
Focused collection — only needed information for computation is collected from devices
Thanks to the secure aggregation technique, the server computes only a sum of vectors from encrypted device reports, and it’s not able to decrypt the individual messages
Then generic ML model is updating, considering the insights gathered from the sample of the edge devices. Some devices might drop out during the decentralized training, not affecting the result. For instance, you can collect data from 500 devices but focus on the sample of 100 ones available at the moment of federated computation.
The server will aggregate user updates into a new model by averaging the model updates, optionally using secure aggregation. The engineer will monitor the updates of federated training through metrics that are aggregated along with the model. Thus, training continues until the model performance is OK. Then, a different subset of devices is chosen by the server and given the new model parameters. It is an iterative process that will continue through many training rounds.
Federated model averaging enters the picture during this round. Federated averaging works by computing a data-weighted average of the model updates from many gradient descent steps on the device.
Federated averaging works by computing a data-weighted average of the model updates from many steps of the gradient descent on the device
Moreover, differentially private model averaging ensures the user’s privacy since the server learns the common patterns in the dataset without memorizing individual examples. Thus, devices “clip” the datasets provided when too large, and the server adds noise when combining updates.
Meanwhile, federated learning does not apply to all the tasks that machine learning can deal with:
Traditional ML practices include data cleaning as a prerequisite to efficient model development. Since the engineer has no access to the user’s data to ensure its relevance, this method is limited to the data that need no preprocessing.
The majority of ML models are supervised, i.e., they require data labeling by humans. But not all the user interactions could provide data labels like next word prediction when typing. Thus, federated learning is suitable for unsupervised learning applications like language modeling.
Data provided by edge devices is intermittently available since they generally participate only if idle, charging, and on an unmetered network.
Finally, devices participate only after the user’s permission, depending on the app’s policies.
Unlike centralized machine learning, federated learning implies bringing the models to the data sources that nowadays have enough computing power. Thus, the ML model is performed locally, and only updates, not private data, are sent to the central server.
In general, federated learning applies to unsupervised learning applications due to its data labeling limitations. But it comes in handy when using on-device data is more relevant and efficient for application. At the same time, federated learning has techniques that contribute to solid user privacy. They include on-device datasets, federated and secure aggregation, federated model averaging, and differentially private model averaging.