Abstract

In parallel with the rapid adoption of deep learning to multimedia data analysis, there has been growing awareness and concerns about data security and privacy. The recent advancement of federated learning enables many network clients to collaboratively train a model under the orchestration of a central server while preserving clients’ privacy. However, the standard assumption of independent and identical distribution (IID) may be broken under the federated learning because data label preferences may vary across clients. Recent efforts address this issue either by adapting a strong global model for each local model, respectively, or by training individual local models for similar clients together. However, both strategies degrade in highly non-IID scenarios. This work introduces a novel method, deep cooperative learning (DCL), to address this problem. It leverages the reciprocal structure between deep learning tasks in different clients to obtain effective feedback signals to enhance the learning process of personalized local models. To the best of our knowledge, this is the first time the non-IID is addressed under the principle of task interactions. We demonstrated the effectiveness of DCL on the two tasks of medical multimedia data analysis. The results show that our method presents a significant performance improvement compared with the standard federated learning method. In conclusion, this work developed a method for addressing non-IID problems in deep-learning-based privacy preservation learning. It allows the highly non-IID data to be used to improve the local model performance.

1. Introduction

In recent times, with the widely available medical imaging and computing devices, convolutional neural networks (CNNs) [1] have proven to be powerful tools for medical image segmentation task [26] and registration task [7, 8]. Segmentation is considered the most essential medical image process as it divides an image into the regions of interest based on anatomical structures or pathological tumors. The registration is the process of identifying a spatial transformation that maps two imaging modalities, such as CT (computed tomography) and MRI (magnetic resonance imaging), to common coordination such that corresponding anatomical structures are optimally aligned. The resulting pixelwise correspondence is fundamental for multimodality image analysis applications. Typically, training a CNN model requires patient scans to be transferred to a centralized data server where comprehensive analyses could be performed by using the parallel computing ability of the center. Given the increasing volumes of imaging data, the massive data collection and processing may be infeasible in a realistic scene because of the high throughput demands and the growing data privacy concerns. Federated learning [9] trains a global model collaboratively among a set of hospitals under the orchestration of a central server, without sharing their private raw data, so that a global model such as CNN-based segmentation can achieve better training performance than individually working alone. Also, since the data never leave the owner, the concerns about disclosing sensitive patient privacy and legal regulations are mitigated.

While federated learning works well on independent and identically distributed (IID) data, it experiences performance degradation on non-IID data [10, 11]. That is, the data distribution of individual hospitals may be totally different from each other. The heterogeneous data distributions prevent the global model from convergence because of the conflicting updating directions that these distributions support. Unfortunately, non-IID happens often in real-world applications [12, 13]. For example, consider the cases of image segmentation, where there are two hospitals with different label preferences, as illustrated in Figure 1. The two hospitals annotate different labels for the temporal lobe due to individuals’ preferences, although the underlying CT scan is the same. Global model aggregation becomes extremely hard in this case since a correct prediction for hospital 1 is incorrect for hospital 2. Having a single global model is insufficient for this case. It is more appropriate to train a personalized model for each hospital.

Recent efforts to address the non-IID issue can be classified into two strategies. The first strategy attempts to personalize a trained global model for each hospital with different label preferences. Personalization techniques for this category are classified into data-based and model-based approaches. Data-based approaches seek to reduce the local distribution divergences by balancing the distributions with a small amount of public [14, 15] or synthetic [16] data. These methods generally need to modify the local data distributions, which will disturb the local label bias, and thereby are not suited for our case. Instead of changing data, the model-based approaches learn a general global model for future personalization in individual hospitals by domain adaption learning that reduces the domain discrepancy between the global and local models [1720] or meta-learning that enables the global model to adapt the private data quickly and effectively [2124]. However, these methods presume the accessibility of a public proxy dataset that a global model will train on, which is unavailable in our case. In contrast to the first strategy that trains a single global model, the second strategy trains personalized models individually. Personalization techniques are classified into architecture-based and similarity-based approaches. The former achieves personalization by decoupling the local private model parameters from the shared global parameters [25], while the latter improves personalized model performance by enforcing stronger pairwise collaboration among hospitals with similar data distributions [2628]. Both methods exploit pairwise data similarities between hospitals for improving local model performance, but other pairwise relations, such as the reciprocal structure between tasks, remain unexplored in current works.

We propose a new non-IID federated learning paradigm, deep cooperative learning (DCL), which leverages the reciprocal structure between federated learning tasks to obtain effective feedback signals to enhance the learning process of personalized models. We use the medical image segmentation and registration tasks with inherent complementary structures to build the cooperative learning loop. The principle of DCL is simple. Consider two hospitals tasked with the two tasks, respectively. If the two task models work well, the segmentation results, i.e., anatomical structures, could be combined with the input of the registration model to boost its performance since the extra anatomical information helps the registration model find the right alignment of the anatomies. Similarly, since some anatomies are only visible on MRI, the aligned MRI produced from the registration model could be combined with CT to provide extra modality for the segmentation model. More importantly, DCL shares the models rather than model outputs among hospitals during the cooperative training loop, thereby achieving personalization and preserving data privacy simultaneously.

2. Methods

The deep cooperative learning consists of two steps. First, a reward mechanism is designed to promote mutual benefits between two tasks from different hospitals. The gain produced by one task to another is regarded as reward and fed back to the task model for adjusting its subsequent behavior for better performance. The non-IID labels are shared in this way among hospitals to share task experience and improve the model generalization. Second, a cooperative training mechanism between task models is created which treats individual model as a parameterized agent to maximize its long-term reward. Below, we provide detailed explanations for the two steps.

2.1. Reward

We design the reward mechanism via the deep discriminator networks. Let and be unlabeled CT and MRI images, respectively, and and be trained registration and segmentation networks/models, respectively. The circulation between task models can be summarized as follows (see Figure 2). registers CT and MRI. segments the output of registration model. and segment CT and MRI, respectively, and then registers the results of segmentation models. Suppose and are discriminator networks after the adversarial training measuring the confidence of the outputs of the segmentation and registration networks. We define reward 1 aswhere the subtraction measures the difference between direct registration and the segmentation-then-registration, i.e., the promotion derived from the segmentation results. Similarly, we define reward 2 asto measure the difference between direct segmentation and the registration-then-segmentation, i.e., the promotion derived from the registration results.

2.2. Cooperative Training

With the defined rewards 1 and 2, we next design the cooperative training mechanism. The segmentation and registration networks are treated as parametric representations of the policy and we use a policy gradient algorithm [29] to update these parameters alternatively through federated learning to achieve cooperative learning. If a large or positive reward is observed after performing an action (a parameter update), its gradient is added to the parameters of the current policy function to increase the probability of performing this action at this state. On the contrary, if a small or negative reward is observed after performing an action, its gradient is subtracted from the parameters of the current policy function to decrease the probability of performing this kind of action under this state. Formally, letting the parameters of segmentation [2] and registration networks be and , respectively, and the number samples of a mini-batch be , then the stochastic gradient can be written aswhere the parameters and could be updated according to the policy gradient. The cooperative training algorithm could be summarized as Algorithm 1 in Figure 3.

2.3. Implementation Details

We use the U-Net structure [30] for the segmentation network as illustrated in Figure 4. U-Net is considered one of the standard CNN architectures for image segmentation. The unique skip-connection layers of U-Net can capture the image features at multiple scales while avoiding the loss of the high-frequency details. We further improve the performance of U-Net with two modifications. First, the squeeze-and-excitation block is introduced to adaptively extract image features after each convolution in the U-Net encoder. Second, to avoid the resolution degradation caused by pooling and downsampling, the last pooling layer and downsampling layer of the network are changed to Atrous Spatial Pyramid Pooling (ASPP) [31] block which uses different perceptual field sizes around a single pixel and fuses the convolution results to detect small targets at multiple resolutions.

We use a two-stream regression network for the registration network as illustrated in Figure 5. Each stream takes an imaging modality as input and outputs its feature map. The subsequent regression layers predict the shifts between two images based on their feature maps. We further use the attention mechanism [32, 33] that mimics human attention improving its performance. The attention mechanism enhances the unaligned parts of two images, thereby concentrating the limited computational resource on them.

In the original generative adversarial network [34], the generated images or the real images are alternately fed into the discriminator network, and the generated images are not preprocessed in any way. Considering the characteristics of the segmentation and alignment tasks, we preprocess the generated images based on the attention mechanism to strengthen the relationship between the generated images and the real tokens, emphasizing their higher-order semantic inconsistencies for the evaluation of the generated images by the discriminator.

To prepare the input images for segmentation and registration models, we first resize the CT-MRI slice pairs from 512 × 512 to 480 × 480 pixel size and then randomly crop the downsampled images to 384 × 384 pixel size. We use the random crop to increase the sample size. The 384 × 384 pixel size facilitates the downsampling operations in the networks because it could be divided by 2 many times with no remainder. Before feeding the image into the networks, images are normalized to zero mean with unit-variance intensities and augmented with a random horizontal flip.

2.4. Evaluation Metrics

We use Dice coefficient and Hausdorff distance (HD) to evaluate the quality of segmentation. The Dice coefficient is computed as the area of overlap between the prediction (pred) and the ground truth (GT) divided by the total number of pixels in prediction and ground truth:

The HD measures the boundary distance (D) between predictions and ground truth and is defined as

Since the HD metric is sensitive to outliers, we report 95th-percentile HD (HD95) instead.

We regard a registration prediction as successful if the shifts differences are <3 mm in both x and y directions. For sets of image pairs, we let and denote the shift in the x and y directions, respectively. First, we calculate the number of image pairs that can meet and and denote that number as . The registration accuracy is then defined as

3. Results

We collected 178 and 81 patients with head and neck cancer from two hospitals, respectively, for this study. The training dataset consists of 142 and 64 patients, respectively, for the two clients and the remaining clients were used to evaluate the DCL performance. We preprocess the images before feeding them into the models. The DCL training protocol was implemented with the TensorFlow federated learning framework [35] on NVIDIA TITAN XP GPU. All networks are initialized with the Xavier initializer and trained with the Adam optimizer, the learning rate of 1e- 4, the batch size of 4, and a total of 25 k updates [36]. An appropriate learning rate is critical in our experiments. We found that learning rate larger than 1e-4 will cause loss oscillation.

We updated the parameters of networks with a stochastic gradient descent method where the initial learning rate was set to 1e-4, and a total of ∼25 k updates are used to train the networks.

We first compared the results from the segmentation network for hospital 1 with and without DCL method qualitatively. Twenty-four anatomy structures were used in the study including brain, spinal cord, spinal cord cavity, pituitary, parotid glands, oral cavity, mandible, mandible joint, temporal lobes, and so on. Figure 6 shows the segmentation results of standard federated learning (left) and the results of DCL federal learning (right). Since hospital 2 prefers smaller temporal lobes while hospital 1 prefers larger ones, the label conflict causes the model of hospital 1 to produce an undesired small temporal lobe (arrow). However, with DCL, a consistent temporal lobe (arrow) is predicted by the segmentation model of hospital 1.

We also show the segmentation results from hospital 2 in Figure 7. We find that federated learning with DCL outperforms standard federated learning in small organs such as eyeballs (arrow). The reason could be attributed to the fact that federated learning could increase the relatively insufficient training samples for the small organs.

Table 1 provides the Dice values of segmentation results for hospital 2. It is observed that DCL improved the average Dice value by 5.49% over standard federated learning. We perform Student’s t-test on the paired groups of standard federated learning and DCL for all organs. The value of 0.02 (<0.05) leads to the conclusion that DCL outperforms standard federated learning significantly in terms of the Dice metric. Compared with standard federated learning, DCL helped the segmentation network recognize small organs such as pituitary and optic nerves.

Table 2 reports the comparison of HD95 values of the federated learning with and without the proposed DCL method. As shown in the table, the mean HD95 value is improved by 2.2 mm when federated learning is used with the DCL method. Student’s t-test shows that DCL outperforms standard federated learning significantly in terms of the HD95 metric (). It is also noted that the HD95 of small volume organs such as crystal, optic chiasma, optical nerves, and pituitary is much smaller than that of the large organs. This means that the inconsistency label issue is more significant in small organs and DCL could alleviate it substantially.

We illustrate an example of the registration result for an image pair in Figure 8. We find that the corresponding anatomical structures such as cranium and brain are aligned correctly.

We further provide the numerical comparison in successful registration rates in Table 3. We also performed a chi-square test on RegAcc between the standard federated learning and the federated learning with the DCL method. 53 registration result pairs are involved in the test. It is observed that the registration network with DCL outperforms the standard federated learning.

In Figure 9, we plot the loss values as the function of training steps. The top subfigure shows the segmentation cross-entropy loss and the bottom subfigure shows the registration mean squared error. The standard federated learning and the DCL are plotted with gray color and red color, respectively. As illustrated in the figure, we find that the extra supervised signal from DCL prevents the segmentation training from overfitting. The segmentation quality is steadily improved after 10 k training steps. In registration training, the overfitting phenomenon is not observed, but the training of standard federated learning is stuck at a high error level.

4. Discussion

In the previous section, we demonstrated the feasibility of exploiting the reciprocal structure between segmentation and registration task among different hospitals to improve the performance of the segmentation and registration model in individual hospitals. The experiment results suggest that the proposed DCL method is necessary and contributes significantly to performance improvement. The proposed DCL method outperforms the standard federated learning by 5.49%, 2.2 mm, and 1.8% in terms of Dice, 95th percentile HD, and registration accuracy. The superior performance of DCL could be attributed to the model cooperation among a set of hospitals under the orchestration of the DCL.

The proposed DCL method is different from the global model personalization methods that personalize a single global model for each client through data or model adaptations that involve additional training on each local dataset [14, 15, 37, 38]. While these methods aim to collaboratively train a shared model without sharing private data, DCL is designed to enhance the local model with the help of other hospitals but still preserve the personalization of the local model. In contrast to global model personalization methods, personalization or preference is never lost for each client.

The proposed DCL method could be classified to the catalog of learning personalized models that build personalized models by modifying the federated learning model aggregation process. Our method is most close to the similarity-based approaches in this catalog which leverage client pairwise similarities to improve personalized model performance where similar personalized models are built for related clients. FedAMP [28] excels in capturing pairwise client relationships to learn similar models for related clients. It may be sensitive to poor data quality, whereas DCL leverages tasks’ complementarity to provide extra supervision signals and thereby is not affected by the data quality. Model interpolation methods [39] learn personalized models using a mixture of global and local models. However, they are likely to experience a degradation in performance in highly non-IID scenarios as they use a single global model as a basis for personalization. In contrast to the model interpolation methods, DCL can work under any data distribution. In this study, the data could be totally different for the segmentation task and the registration task. Overall, the major novelty of DCL is the exploitation of the task reciprocal structure, whereas current non-IID approaches mainly leverage the data similarities. The cooperative relationship is exploited to provide mutual rewards or pseudo-labels for the tasks of different hospitals. Since the reward is extracted from the first hospital’s A task to the second hospital’s B task, the biased data distributions are mitigated during the learning loop. The primary benefit of task cooperation is the robustness of the data distributions.

While the task cooperation improves the local model performance under highly non-IID scenarios, it also prevents the tasks that do not contain reciprocal structures from the DCL protocol. Additionally, since the trained models are shared between hospitals, the label preferences of one hospital may be leaked to others and may increase the risk of privacy exposure. Furthermore, it is still not clear to what extent these methods harm data privacy, and there are no quantitative measures to identify the degree of privacy leakage. Finally, while DCL is an effective method for the non-IID problem, other issues remain as open questions for the future when using federated learning for healthcare including decentralized online optimization [40], unbalanced data [41], limited communication bandwidth [42], and unreliable and limited device availability [43].

5. Conclusions

We developed a method named deep cooperative learning (DCL) to address the non-IID problem in federated learning. Comprehensive experiments have been carried out on CT and MRI segmentation and registration tasks and datasets to demonstrate the effectiveness of DCL. The results obtained from head-neck cancer patients of two hospitals show that the method outperforms the standard federated learning in segmentation and registration tasks. The method is, therefore, a solution for leveraging biased labels across hospitals.

Data Availability

The experiment data used to support the findings of this study are available from the corresponding author upon request.

Conflicts of Interest

The authors declare that there are no conflicts of interest regarding the publication of this paper.

Acknowledgments

This study was supported by the Natural Science Foundation of Shanghai (20ZR1440300).