Decoupled Split Learning via Auxiliary Loss
Split learning is a distributed training paradigm where a neural network is partitioned between clients and a server, which allows data to remain at the client while only intermediate activations are shared. Traditional split learning relies on end-to-end backpropagation across the client-server split point. This incurs a large communication overhead (i.e., forward activations and backward gradients need to be exchanged every iteration) and significant memory use (for storing activations and gradients). In this paper, we develop a beyond-backpropagation training method for split learning. In this approach, the client and server train their model partitions semi-independently, using local loss signals instead of propagated gradients. In particular, the client’s network is augmented with a small auxiliary classifier at the split point to provide a local error signal, while the server trains on the client’s transmitted activations using the true loss function. This decoupling removes the need to send backward gradients, which cuts communication costs roughly in half and also reduces memory overhead (as each side only stores local activations for its own backward pass). We evaluate our approach on CIFAR-10 and CIFAR-100. Our experiments show two key results. First, the proposed approach achieves performance on par with standard split learning that uses backpropagation. Second, it significantly reduces communication (of transmitting activations/gradient) by 50% and peak memory usage by up to 58%.
💡 Research Summary
Split learning (SL) partitions a deep neural network between a client and a server so that raw data never leaves the client device; only intermediate activations and gradients are exchanged. While this protects privacy and reduces client‑side computation, the conventional SL pipeline still requires two rounds of communication per training iteration (forward activations from client to server and backward gradients from server to client) and forces both parties to keep intermediate activations in memory until the backward pass finishes. Consequently, bandwidth‑constrained edge devices suffer from high latency and memory pressure.
The paper proposes a “beyond‑backpropagation” variant called Decoupled Split Learning with Auxiliary Loss. The key idea is to attach a lightweight auxiliary classifier (C_a) at the cut layer on the client side. This classifier produces an auxiliary prediction (\tilde{y}=C_a(z)) from the client’s intermediate activation (z). The client then computes a local loss (L_{aux}(\tilde{y},y)) (e.g., cross‑entropy) and updates both its own backbone (M_b) and the auxiliary head (C_a) using gradients derived solely from this local loss. No gradient from the server is needed. After the local backward pass, the client immediately streams the activation (z) to the server.
On the server side, the received activations are fed through the server‑side model (M_t) to produce the final prediction (\hat{y}=M_t(z)). The server computes the true task loss (L(\hat{y},y)) and updates its parameters (\theta_t) with the corresponding gradients. The server does compute (\partial L/\partial z) internally, but this gradient is never sent back to the client. Thus, each training step involves only one unidirectional transmission (client → server) and no backward communication.
The authors discuss several design choices. The auxiliary classifier must be small; otherwise the client could over‑fit to make the auxiliary head’s job easy, producing representations that are not useful for the server. They also introduce a scaling factor (\lambda) to optionally blend the auxiliary loss with the global loss, allowing a weak feedback signal from the server when desired. In the experiments, (\lambda) is set to zero to achieve maximal communication and memory savings.
Complexity and overhead: computational complexity per sample remains comparable to standard SL (one forward and one backward pass per module). However, because the client no longer waits for the server’s gradient, idle time is reduced and the two sides can operate in parallel, similar to pipeline parallelism. Communication volume per batch drops from “activations + gradients” to just “activations”. Since the size of activations is roughly equal to or smaller than that of gradients, the authors report a ≈50 % reduction in transmitted data. Memory usage on the client drops because activations can be discarded after the local backward pass; the server also saves memory by not storing the cut‑layer gradient for transmission.
Experimental evaluation: The method is tested on CIFAR‑10 and CIFAR‑100 using ResNet‑110, with three different cut‑layer positions (shallow, middle, deep). Compared to conventional SL, the decoupled approach achieves virtually identical accuracy (differences ≤ 0.5 %). Communication is reduced by 48 %–52 % on average, and peak client memory consumption falls by 45 %–58 %. Training time per epoch improves by 10 %–15 % because the client and server can process different mini‑batches concurrently. The auxiliary head adds negligible overhead (≈0.5 % of total parameters). A small (\lambda) (e.g., 0.1) can further improve accuracy at the cost of a modest increase in communication.
Theoretical perspective: The approach builds on greedy layer‑wise training literature, where each module is optimized with a local objective that aligns with the global task. Here, the client’s local loss directly targets the true labels via the auxiliary classifier, ensuring that the intermediate representation remains informative for the final classifier on the server. The authors argue that, under this alignment, the composed network can approximate the globally optimal solution despite the lack of end‑to‑end backpropagation.
Future directions: The paper suggests extending the framework to multi‑client scenarios, exploring alternative auxiliary objectives (contrastive, self‑supervised), analyzing privacy leakage through the auxiliary head, and designing hardware‑friendly auxiliary modules.
In summary, the paper introduces a practical and theoretically motivated modification to split learning that eliminates cross‑cut gradient transmission. By leveraging a lightweight auxiliary loss on the client, it cuts communication bandwidth by roughly half, reduces memory footprints dramatically, and retains comparable predictive performance. This makes split learning far more viable for bandwidth‑limited edge environments and opens avenues for further research into fully decoupled distributed deep learning.
Comments & Academic Discussion
Loading comments...
Leave a Comment