第三章:代码实现解析

3.1 项目结构与依赖

3.1.1 项目目录结构

Boomda 的代码仓库组织如下:

Boomda/
├── train_boomda.py          # 训练入口脚本
├── run_boomda.sh            # 运行示例
├── config_dir.json          # 数据路径配置
├── models/                  # 模型定义
│   ├── Boomda_model.py      # 核心模型(BoomdaModel)
│   ├── base_model.py        # 基类模型
│   ├── min_norm_solvers.py  # MGDA 求解器
│   ├── coralutils/          # CORAL 实现
│   │   └── coral.py
│   ├── networks/            # 网络组件
│   │   ├── classifier.py    # 分类器
│   │   ├── discriminator.py # 域判别器
│   │   └── msa.py           # 模态特定网络(Bert, APViT, Wavlm)
│   ├── func.py              # 辅助函数
│   └── opt.py               # 优化相关函数
├── data/                    # 数据加载与预处理
├── opts/                    # 命令行参数解析
├── utils/                   # 工具函数(日志、评估等)
└── scripts/                 # 训练脚本
    └── Boomda.sh

3.1.2 主要依赖


3.2 核心模型 BoomdaModel

3.2.1 模型初始化

BoomdaModel 继承自 BaseModel,在 __init__ 中完成网络组件的构建:

class BoomdaModel(BaseModel):
    def __init__(self, opt):
        super().__init__(opt)
        self.opt = opt
        self.modality = opt.modality  # 例如 'AVL'、'VL' 等
        # ...

        # 多模态融合分类器
        cls_input_size = opt.embd_size_a * int("A" in self.modality) + \
                         opt.embd_size_v * int("V" in self.modality) + \
                         opt.embd_size_l * int("L" in self.modality)
        self.netC = SimpleFC2(input_dim=cls_input_size, hidden=[128], 
                              output_dim=opt.output_dim, d_f=self.d_f)

        # 各模态编码器与分类器
        if 'A' in self.modality:
            self.netA = Wavlm(self.device, self.lora_r, opt.embd_size_a)
            self.netCA = SimpleFC2(input_dim=opt.embd_size_a, hidden=[64], 
                                   output_dim=opt.output_dim, d_f=self.d_f)
        if 'V' in self.modality:
            self.netV = APViT_video(self.device, self.lora_r, opt.embd_size_v)
            self.netCV = SimpleFC2(input_dim=opt.embd_size_v, hidden=[64], 
                                   output_dim=opt.output_dim, d_f=self.d_f)
        if 'L' in self.modality:
            self.netL = Bert(self.device, self.lora_r, opt.embd_size_l)
            self.netCL = SimpleFC2(input_dim=opt.embd_size_l, hidden=[128], 
                                   output_dim=opt.output_dim, d_f=self.d_f)

关键设计
- 每个模态拥有独立的编码器(netAnetVnetL)和分类器(netCAnetCVnetCL
- 融合分类器 netC 接收所有模态表示的拼接
- 预训练模型参数分层设置学习率:底层冻结,顶层微调

3.2.2 前向传播

forward() 方法实现了源域和目标域的特征提取与预测:

def forward(self):
    # 源域特征提取
    if 'A' in self.modality:
        self.feat_A = self.netA(**self.acoustic)
        self.final_embd.append(self.feat_A)
    # ... 视觉和文本类似

    # 目标域特征提取
    if 'A' in self.modality:
        self.feat_A_t = self.netA(**self.acoustic_t)
        self.final_embd_target.append(self.feat_A_t)

    # 拼接多模态表示
    self.feat = torch.cat(self.final_embd, dim=-1)
    self.feat_t = torch.cat(self.final_embd_target, dim=-1)
    self.feat_all = torch.cat([self.feat, self.feat_t], dim=0)
    self.feat_all.register_hook(self.save_gradients)

    # 分类预测
    self.logits, self.fall = self.netC(self.feat_r)
    self.pred = F.softmax(self.logits, dim=-1)

关键设计
- feat_all 注册了梯度钩子 save_gradients,用于后续 MGDA 权重计算
- 源域和目标域的特征通过同一编码器提取,确保表示空间一致

3.2.3 伪标签投票

# 各模态目标域预测转为 one-hot 并累加
self.pred_tf = 1.5 * self.pred_toh  # 融合模态权重为 1.5
if 'A' in self.modality:
    self.pred_tf += self.pred_A_toh
if 'V' in self.modality:
    self.pred_tf += self.pred_V_toh
if 'L' in self.modality:
    self.pred_tf += self.pred_L_toh

# 筛选高置信度样本
self.index_pl = index_rely(self.pred_tf)  # 票数 >= Mv 的样本索引
self.pl = self.pred_tf.argmax(dim=1)[self.index_pl]

函数 index_rely 实现阈值筛选:

def index_rely(matrix, threshold=3):
    large_elements = matrix >= threshold
    rows_with_large_elements = torch.any(large_elements, dim=1)
    row_indices = torch.nonzero(rows_with_large_elements).squeeze(1)
    return row_indices

3.2.4 权重计算(weight_cal)

这是 Boomda 的核心方法,实现多目标优化权重的计算:

def weight_cal(self):
    # 1. 计算各模态的 CORAL 损失
    if 'A' in self.modality:
        self.loss_coral_a = mul * coral(self.fa, self.fa_t)
    if 'V' in self.modality:
        self.loss_coral_v = mul * coral(self.fv, self.fv_t)
    if 'L' in self.modality:
        self.loss_coral_l = mul * coral(self.fl, self.fl_t)

    # 2. 第一次反向传播:各模态独立对齐损失
    self.scaler.scale(self.loss_coral_all).backward(retain_graph=True)
    grads_all['all'].append(copy.deepcopy(self.feat_all_gradients))

    # 3. 第二次反向传播:融合模态对齐损失
    self.loss_coral_com = mul * coral(self.fall, self.fall_t)
    self.scaler.scale(self.loss_coral_com).backward(retain_graph=True)
    grads_com['all'].append(copy.deepcopy(self.feat_all_gradients))

    # 4. 分离各模态梯度
    for t in self.modality:
        if t == 'A':
            grads['A'] = grads_all['all'][0: self.opt.embd_size_a]
            grads_m['A'] = grads_com['all'][0: self.opt.embd_size_a]
        # ... V, L 类似

    # 5. 梯度归一化(可选)
    gn = gradient_normalizers(grads, loss_data, 'loss')

    # 6. 计算闭式解(use_full=1 时使用对角近似)
    if self.use_full == 1:
        pd_1 = torch.tensor([torch.dot(grads[t], grads[t]) for t in modality])
        pd_2 = torch.tensor([torch.dot(grads_m[t], grads_m[t]) for t in modality])
        pd_x = torch.tensor([torch.dot(grads[t], grads_m[t]) for t in modality])
        diag = [i for i in pd_1]
        diag.append(sum(pd_2))
        Q_t = torch.tensor(diag)
        Q_t_inv = 1 / Q_t
        sol = Q_t_inv / sum(Q_t_inv)

    return scale

关键流程
1. 计算各模态独立 CORAL 损失和融合模态 CORAL 损失
2. 两次反向传播获取梯度(保留计算图)
3. 从 feat_all 的梯度中分离各模态的梯度片段
4. 梯度归一化(按损失值归一化)
5. 计算对角近似矩阵 $\tilde{\mathbf{Q}}$ 的闭式解


3.3 CORAL 对齐实现

3.3.1 核心 CORAL 函数

models/coralutils/coral.py 实现了 CORAL 损失:

import torch
import numpy as np
import torch.nn.functional as F

mse = torch.nn.MSELoss()

def coral(source, target):
    d = source.size(1)  # 特征维度

    source_c, source_mu = compute_covariance(source)
    target_c, target_mu = compute_covariance(target)

    # 协方差差异(Frobenius 范数平方)
    loss_c = torch.sum(torch.mul((source_c - target_c), (source_c - target_c)))
    loss_c = loss_c / (4 * d * d)

    # 均值差异(MSE)
    loss_mu = mse(source_mu, target_mu)

    return loss_c  # 当前实现仅返回协方差差异

3.3.2 协方差计算

def compute_covariance(input_data):
    n = input_data.size(0)  # batch size
    device = input_data.device

    id_row = torch.ones(n).resize(1, n).to(device=device)
    sum_column = torch.mm(id_row, input_data)
    mean_column = torch.div(sum_column, n)
    term_mul_2 = torch.mm(mean_column.t(), mean_column)
    d_t_d = torch.mm(input_data.t(), input_data)
    c = torch.add(d_t_d, (-1 * term_mul_2)) * 1 / (n - 1)

    return c, mean_column

实现细节
- 使用矩阵运算高效计算协方差矩阵
- 公式 $C = \frac{1}{n-1}(X^T X - \frac{1}{n}(\mathbf{1}^T X)^T (\mathbf{1}^T X))$
- 同时返回均值向量用于可选的均值对齐

3.3.3 信息瓶颈中的熵计算

代码中还实现了 log_var 函数用于估计表示的熵(信息瓶颈第一项):

def log_var(input_data):
    n = input_data.size(0)
    d = input_data.size(1)

    input_data = F.normalize(input_data, p=2, dim=1)
    covar = torch.cov(input_data.T)
    det_covar = torch.linalg.det(covar)
    log_det_covar = torch.log10(det_covar / d**2 + 1)
    return log_det_covar

注意:代码实现与论文公式略有不同,使用了归一化后的协方差矩阵行列式。


3.4 多目标权重求解

3.4.1 MGDA 求解器

models/min_norm_solvers.py 提供了 MGDA 的完整实现,包括:

(1)最小范数元素求解(两任务情况)

def _min_norm_element_from2(v1v1, v1v2, v2v2):
    if v1v2 >= v1v1:
        gamma = 0.999
        cost = v1v1
        return gamma, cost
    if v1v2 >= v2v2:
        gamma = 0.001
        cost = v2v2
        return gamma, cost
    gamma = -1.0 * ((v1v2 - v2v2) / (v1v1 + v2v2 - 2*v1v2))
    cost = v2v2 + gamma * (v1v2 - v2v2)
    return gamma, cost

(2)Frank-Wolfe 迭代

def find_min_norm_element_FW(vecs):
    # 初始化解为最优两任务组合
    init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps)

    # Frank-Wolfe 迭代直到收敛
    while iter_count < MinNormSolver.MAX_ITER:
        t_iter = np.argmin(np.dot(grad_mat, sol_vec))

        v1v1 = np.dot(sol_vec, np.dot(grad_mat, sol_vec))
        v1v2 = np.dot(sol_vec, grad_mat[:, t_iter])
        v2v2 = grad_mat[t_iter, t_iter]

        nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2)
        new_sol_vec = nc * sol_vec
        new_sol_vec[t_iter] += 1 - nc

        if np.sum(np.abs(new_sol_vec - sol_vec)) < MinNormSolver.STOP_CRIT:
            return sol_vec, nd
        sol_vec = new_sol_vec
        iter_count += 1

(3)梯度归一化

def gradient_normalizers(grads, losses, normalization_type):
    gn = {}
    if normalization_type == 'loss':
        for t in grads:
            gn[t] = losses[t]
    elif normalization_type == 'l2':
        for t in grads:
            gn[t] = np.sqrt(np.sum([gr.pow(2).sum().data.cpu() for gr in grads[t]]))
    return gn

3.4.2 实际使用的闭式解

BoomdaModel.weight_cal() 中,实际使用的是论文提出的闭式解(use_full=1),而非迭代求解 MGDA。这保证了训练的高效性:

pd_1 = torch.tensor([torch.dot(grads[t], grads[t]) for t in modality])
pd_2 = torch.tensor([torch.dot(grads_m[t], grads_m[t]) for t in modality])
pd_x = torch.tensor([torch.dot(grads[t], grads_m[t]) for t in modality])
diag = [i for i in pd_1]
diag.append(sum(pd_2))
Q_t = torch.tensor(diag)
Q_t_inv = 1 / Q_t
sol = Q_t_inv / sum(Q_t_inv)

3.5 训练流程与入口

3.5.1 训练循环

train_boomda.py 实现了完整的训练流程:

for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
    for i, (data, data_t) in enumerate(zip(dataset, dataset_t)):
        # 计算退火系数 alpha
        p = float(i + start_steps) / total_steps
        alpha = p ** 2

        # 设置输入并优化
        model.set_input(data, data_t, alpha)
        model.optimize_parameters(epoch)

        # 打印损失
        if total_iters % opt.print_freq == 0:
            losses = model.get_current_losses()
            logger.info(...)

    # 更新学习率
    model.update_learning_rate(logger)

3.5.2 optimize_parameters

def optimize_parameters(self, epoch):
    self.optimizer.zero_grad()
    with autocast():  # 自动混合精度
        self.forward()
        scale = self.weight_cal()   # 计算 MGDA 权重
        self.loss_cal(scale)        # 计算整体损失

    self.optimizer.zero_grad()
    self.scaler.scale(self.loss).backward()
    self.scaler.step(self.optimizer)
    self.scaler.update()

关键设计
- 使用 PyTorch 的 autocast 进行自动混合精度训练
- weight_cal() 内部包含两次反向传播(保留计算图)
- 整体损失计算后进行第三次反向传播更新参数

3.5.3 损失计算

def loss_cal(self, scale):
    # 指数移动平均平滑权重
    beta = 0.2
    self.scale = (1 - beta) * self.scale + beta * scale_current

    # 各损失项
    self.loss_CE = self.criterion_ce(self.logits, self.label)
    self.loss_CE_t = self.criterion_ce(self.logits_t, self.label_t)
    if self.index_pl.shape[0] >= 1:
        self.loss_CE_tpl = self.criterion_ce(self.logits_t[self.index_pl], self.pl)

    # 信息瓶颈熵正则
    self.loss_ee = mul2 * entropy_re(self.fall)

    # 各模态 CORAL 损失
    self.loss_coral = mul * coral(self.fall, self.fall_t)
    if 'A' in self.modality:
        self.loss_coral_a = mul * coral(self.fa, self.fa_t)
    # ... V, L 类似

    # 加权求和
    self.loss_coralall = scalea['J'] * self.loss_coral + \
                        scalea['A'] * self.loss_coral_a + \
                        scalea['V'] * self.loss_coral_v + \
                        scalea['L'] * self.loss_coral_l

    # 总损失
    self.loss = self.loss_CE + 0.5 * self.alpha * self.loss_CE_tpl + \
                0.1 * (self.loss_CE_A + self.loss_CE_V + self.loss_CE_L) + \
                0.1 * self.loss_coralall + 0.00001 * (self.loss_ee + ...)

3.6 运行示例与配置

3.6.1 运行脚本

scripts/Boomda.sh 提供了运行示例:

# IEMOCAP -> IEMOCAP (VL 模态)
bash scripts/Boomda.sh VL 1 0 IEMOCAP IEMOCAP logs/0813VL_ie2ie label.csv 2e-3 "0.1" "'-1 1 0 0'" "'1 0.1 0.1 0.4'"

# MSP -> MSP (AVL 模态)
bash scripts/Boomda.sh AVL 1 0 MSP MSP logs/1501cAVL_ms2ms label.csv 2e-3 "0.1" "'-1 1 0 0'" "'1 0.1 0.18 0.4'"

参数说明:
- $1 modality:使用模态(A/V/L/AV/VL/AVL)
- $2 run_idx:运行索引
- $3 gpu:GPU 编号
- $4 source:源域数据集
- $5 target:目标域数据集
- $6 log_dir:日志目录
- $7 csv_name:标签 CSV 文件
- $8 lr:学习率
- $9 weights:权重列表
- ${10} change_weight1:源域数据变换权重
- ${11} change_weight2:目标域数据变换权重

3.6.2 关键超参数

参数 说明
embd_size_a/v/l 768 表示维度(实际投影到 256)
output_dim 4 类别数(neutral, happy, sad, angry)
cls_layers 128,128 分类器隐藏层
dropout_rate 0.3 Dropout 比率
niter 4 初始 epoch 数
niter_decay 4 衰减 epoch 数
batch_size 48 批次大小
lr 1e-3 学习率
beta1 0.9 Adam 动量系数

本节小结


思考题

  1. 代码中 weight_cal() 进行了两次反向传播,为什么需要保留计算图(retain_graph=True)?
  2. entropy_re 函数在代码中实现了信息瓶颈的哪一项?查看源码并分析。
  3. 如果要在 Boomda 的基础上增加一个新的模态(如生理信号),需要修改哪些代码?
  4. 闭式解和迭代 MGDA 求解在实际训练中的速度差异可能有多大?如何量化?