[ICLR 2025] BLEND: Behavior-Guided Neural Population Dynamics Modeling via Privileged Knowledge Distillation

A research collaboration between Hong Kong University of Science and Technology (HKUST) Smart Lab and Peking University has been accepted at ICLR 2025, a premier machine learning conference. The study presents BLEND (Behavior-guided Neural Population Dynamics Modeling via Privileged Knowledge Distillation), a novel framework that enhances neural population dynamics modeling by leveraging privileged knowledge distillation.

BLEND introduces a teacher-student architecture, where the teacher model integrates both neural activity and behavioral data during training, while the student model learns to generalize using only neural data. This enables high-performance inference without requiring behavioral data at test time. Experiments demonstrate significant improvements across multiple tasks, including neural activity prediction, behavior decoding, and neuron identity classification. Notably, behavior decoding performance improves by over 50%, while neuron identity prediction accuracy increases by more than 15%.

Background

Neural population dynamics modeling aims to uncover how the brain orchestrates perception, movement, and cognition through collective neuronal activity. With the advent of large-scale neural recording technologies, researchers can now capture vast amounts of neural activity data. However, relying solely on neural signals often fails to fully capture the complexity of brain function. Behavioral signals—such as animal movement trajectories or eye-tracking data—provide crucial contextual information, helping to reveal deep connections between neural activity and behavior.

Yet, in real-world scenarios, behavioral signals are not always available. For instance, during resting states, behavioral data may be missing or incomplete. Existing neural dynamics modeling approaches often require complex model designs or strong assumptions about neural-behavior relationships, limiting their flexibility in practical applications. BLEND addresses this challenge by enabling models to utilize behavioral signals during training while remaining robust in their absence during inference.

HKUST SmartLab

Existing approaches to neural population dynamics modeling fall into three main categories:

A. Neural-Only Modeling

Traditional methods rely solely on neural activity data, using latent variable models such as PCA, LFADS, and NDT. While these methods are relatively simple, they fail to incorporate behavioral context, limiting their ability to explain the relationship between neural activity and behavior.

B. Behavior-Prior Methods

Some approaches introduce behavioral data as explicit priors, such as pi-VAE and CEBRA, improving the interpretability of neural dynamics. However, these methods often require specialized modules or complex training strategies, making them less adaptable to real-world scenarios where behavioral data may be incomplete.

C. Privileged Knowledge Distillation (BLEND)

BLEND adopts a privileged knowledge distillation framework, where behavioral data serves as privileged information during training. The teacher model learns from both neural and behavioral data, while the student model distills this knowledge to make accurate predictions using only neural data. This approach is model-agnostic, avoids reliance on complex assumptions, and remains robust even when behavioral data is unavailable.

HKUST SmartLab

Methodology

BLEND employs a teacher-student framework inspired by privileged information learning, consisting of two key stages:

A. Teacher Model Training

The teacher model is trained using both neural activity and behavioral data. To enhance its ability to capture temporal neural dynamics and behavior relationships, we introduce a masked sequence reconstruction task. By randomly masking portions of the neural data and reconstructing them using both observed neural and behavioral inputs, the teacher model learns rich representations that generalize well even under missing data conditions.

\[\theta^*=\underset{\theta}{\arg \min } \mathbb{E}_{\mathbf{x} \sim p(\mathbf{x}), \mathbf{b} \sim p(\mathbf{b}), \mathbf{m} \sim p(\mathbf{m})}\left[\frac{1}{|\mathbf{m}|} \sum_{i, t} \mathbf{m}_i^t \cdot \mathcal{L}_{\mathrm{rec}}\left(f_\theta\left(\mathbf{x}_{\bar{m}}, \mathbf{b}\right)_i^t, \mathbf{x}_i^t\right)\right] \triangleq \mathcal{L}_{\mathrm{MTM}}\]

B. Student Model Knowledge Distillation

The student model is trained to replicate the teacher model’s predictions using only neural data. To achieve this, we design a hybrid loss function combining reconstruction loss and knowledge distillation loss, ensuring that the student model accurately reconstructs neural sequences while closely matching the teacher model’s outputs.

\[\mathbb{E}_{\mathbf{x}, \mathbf{b}, \mathbf{m}}[\alpha \cdot \underbrace{\frac{1}{|\mathbf{m}|} \sum_{i, t} \mathbf{m}_i^t \cdot \mathcal{L}_{\mathrm{rec}}\left(f_{\theta_S}\left(\mathbf{x}_{\bar{m}}\right)_i^t, \mathbf{x}_i^t\right)}_{\text {MTM Loss }}+(1-\alpha) \cdot \underbrace{\sum_{i, t} \mathcal{L}_{\text {distill }}\left(f_{\theta_S}\left(\mathbf{x}_{\bar{m}}\right)_i^t, f_{\theta_T}\left(\mathbf{x}_{\bar{m}}, \mathbf{b}\right)_i^t\right)}_{\text {Distillation Loss }}]\]

We explore multiple knowledge distillation strategies:

  • Hard Distillation: Minimizing the mean squared error (MSE) between student and teacher outputs.
\[\mathcal{L}_{\text {distill }}\left(f_{\theta_S}\left(\mathbf{x}_{\bar{m}}\right)_i^t, f_{\theta_T}\left(\mathbf{x}_{\bar{m}}, \mathbf{b}\right)_i^t\right)=\left\|f_{\theta_S}\left(\mathbf{x}_{\bar{m}}\right)_i^t-f_{\theta_T}\left(\mathbf{x}_{\bar{m}}, \mathbf{b}\right)_i^t\right\|_2^2\]
  • Soft Distillation: Using temperature-scaled soft targets and minimizing the KL divergence between teacher and student distributions.
\[\mathcal{L}_{\text {distill }}\left(f_{\theta_S}\left(\mathbf{x}_{\bar{m}}\right)_i^t, f_{\theta_T}\left(\mathbf{x}_{\bar{m}}, \mathbf{b}\right)_i^t\right)=\tau^2 \cdot \mathrm{KL}\left(\sigma\left(\frac{f_{\theta_T}\left(\mathbf{x}_{\bar{m}}, \mathbf{b}\right)_i^t}{\tau}\right) \| \sigma\left(\frac{f_{\theta_S}\left(\mathbf{x}_{\bar{m}}\right)_i^t}{\tau}\right)\right)\]
  • Feature Distillation: Aligning intermediate feature representations between the teacher and student models.
\[\mathcal{L}_{\text {distill }}\left(f_{\theta_S}\left(\mathbf{x}_{\bar{m}}\right)_i^t, f_{\theta_T}\left(\mathbf{x}_{\bar{m}}, \mathbf{b}\right)_i^t\right)=\sum_{l=1}^L\left\|f_{\theta_S}^l\left(\mathbf{x}_{\bar{m}}\right)_i^t-f_{\theta_T}^l\left(\mathbf{x}_{\bar{m}}, \mathbf{b}\right)_i^t\right\|_2^2\]
  • Correlation Distillation: Preserving the structural relationships in teacher outputs by optimizing the similarity of correlation matrices.
\[\mathcal{L}_{\text {distill }}\left(f_{\theta_S}\left(\mathbf{x}_{\bar{m}}\right)_i^t, f_{\theta_T}\left(\mathbf{x}_{\bar{m}}, \mathbf{b}\right)_i^t\right)=\frac{1}{B} \sum_{j=1}^B\left\|\operatorname{Corr}\left(f_{\theta_S}\left(\mathbf{x}_{\bar{m}}\right)_j\right)-\operatorname{Corr}\left(f_{\theta_T}\left(\mathbf{x}_{\bar{m}}, \mathbf{b}\right)_j\right)\right\|_2^2\]

By integrating these techniques, BLEND effectively transfers behavior-informed neural representations from the teacher to the student model, enabling highly accurate predictions without requiring behavioral data at inference time.

Experimental Results

BLEND was extensively evaluated on two benchmark datasets, demonstrating state-of-the-art performance:

1. Neural Latent Benchmark (NLB’21)

We tested BLEND on MC-Maze, MC-RTT, and Area2-Bump datasets, covering tasks such as neural activity prediction (Co-bps), behavior decoding (Vel-R2), and PSTH matching (PSTH-R2). Key results include:

HKUST SmartLab

  • MC-Maze: BLEND improves behavior decoding (Vel-R2) by 14.6%.
  • Area2-Bump: BLEND boosts behavior decoding by 51.8% and PSTH-R2 by 66.6%.

2. Multimodal Calcium Imaging Dataset

We evaluated BLEND on neuron identity classification using transcriptomic data:

HKUST SmartLab

  • Excitatory/Inhibitory (EI) classification accuracy improved by 5.5%.
  • Inhibitory neuron subtype classification improved by 15.4%.

These results highlight BLEND’s ability to enhance neural representation learning and improve predictive accuracy across diverse tasks.

Model Analysis

To further validate BLEND’s effectiveness, we conducted qualitative evaluations on the MC-Maze dataset, focusing on behavior decoding:

A. 2D Hand Trajectory Prediction

BLEND significantly improves trajectory prediction accuracy, capturing both global movement trends and fine-grained details. Unlike baseline models, which exhibit deviations at trajectory bends and endpoints, BLEND produces smoother and more precise predictions.

B. Velocity Prediction

BLEND achieves more stable and accurate velocity predictions, effectively capturing speed fluctuations and critical transition points. Baseline models struggle with peak detection and exhibit excessive noise, while BLEND maintains better temporal consistency.

HKUST SmartLab

These analyses confirm that privileged knowledge distillation enables BLEND to learn richer neural representations, leading to superior behavior decoding and neural activity prediction.

Conclusion

BLEND introduces a novel privileged knowledge distillation framework for neural population dynamics modeling, enabling models to leverage behavioral data during training while remaining robust in its absence during inference. Experimental results demonstrate substantial performance gains across multiple tasks, including neural activity prediction, behavior decoding, and neuron identity classification.

Compared to traditional methods, BLEND avoids complex model designs and strong assumptions, making it highly adaptable to real-world applications. Moreover, its model-agnostic nature allows seamless integration into existing neural dynamics modeling frameworks.

Future Directions

Moving forward, we aim to:

  • Expand BLEND to more complex multimodal datasets, including non-sequential and discrete behavioral data.
  • Explore advanced neural-behavior fusion strategies to further enhance representation learning.
  • Apply BLEND to broader domains, such as medical imaging, affective computing, and neuroscience-inspired AI.

By advancing the understanding of neural population dynamics, BLEND paves the way for more interpretable and generalizable neural models, contributing to both computational neuroscience and AI research.

For more details, check out the full paper:
Guo, Zhengrui, et al. “BLEND: Behavior-guided Neural Population Dynamics Modeling via Privileged Knowledge Distillation.” arXiv preprint arXiv:2410.13872 (2024).