第三章:代码实现解析
3.1 项目结构与依赖
3.1.1 项目目录结构
Boomda 的代码仓库组织如下:
Boomda/
├── train_boomda.py # 训练入口脚本
├── run_boomda.sh # 运行示例
├── config_dir.json # 数据路径配置
├── models/ # 模型定义
│ ├── Boomda_model.py # 核心模型(BoomdaModel)
│ ├── base_model.py # 基类模型
│ ├── min_norm_solvers.py # MGDA 求解器
│ ├── coralutils/ # CORAL 实现
│ │ └── coral.py
│ ├── networks/ # 网络组件
│ │ ├── classifier.py # 分类器
│ │ ├── discriminator.py # 域判别器
│ │ └── msa.py # 模态特定网络(Bert, APViT, Wavlm)
│ ├── func.py # 辅助函数
│ └── opt.py # 优化相关函数
├── data/ # 数据加载与预处理
├── opts/ # 命令行参数解析
├── utils/ # 工具函数(日志、评估等)
└── scripts/ # 训练脚本
└── Boomda.sh
3.1.2 主要依赖
- PyTorch(深度学习框架)
- Transformers(预训练模型:WavLM、BERT、APViT)
- NumPy、SciPy
- scikit-learn(评估指标)
3.2 核心模型 BoomdaModel
3.2.1 模型初始化
BoomdaModel 继承自 BaseModel,在 __init__ 中完成网络组件的构建:
class BoomdaModel(BaseModel):
def __init__(self, opt):
super().__init__(opt)
self.opt = opt
self.modality = opt.modality # 例如 'AVL'、'VL' 等
# ...
# 多模态融合分类器
cls_input_size = opt.embd_size_a * int("A" in self.modality) + \
opt.embd_size_v * int("V" in self.modality) + \
opt.embd_size_l * int("L" in self.modality)
self.netC = SimpleFC2(input_dim=cls_input_size, hidden=[128],
output_dim=opt.output_dim, d_f=self.d_f)
# 各模态编码器与分类器
if 'A' in self.modality:
self.netA = Wavlm(self.device, self.lora_r, opt.embd_size_a)
self.netCA = SimpleFC2(input_dim=opt.embd_size_a, hidden=[64],
output_dim=opt.output_dim, d_f=self.d_f)
if 'V' in self.modality:
self.netV = APViT_video(self.device, self.lora_r, opt.embd_size_v)
self.netCV = SimpleFC2(input_dim=opt.embd_size_v, hidden=[64],
output_dim=opt.output_dim, d_f=self.d_f)
if 'L' in self.modality:
self.netL = Bert(self.device, self.lora_r, opt.embd_size_l)
self.netCL = SimpleFC2(input_dim=opt.embd_size_l, hidden=[128],
output_dim=opt.output_dim, d_f=self.d_f)
关键设计:
- 每个模态拥有独立的编码器(netA、netV、netL)和分类器(netCA、netCV、netCL)
- 融合分类器 netC 接收所有模态表示的拼接
- 预训练模型参数分层设置学习率:底层冻结,顶层微调
3.2.2 前向传播
forward() 方法实现了源域和目标域的特征提取与预测:
def forward(self):
# 源域特征提取
if 'A' in self.modality:
self.feat_A = self.netA(**self.acoustic)
self.final_embd.append(self.feat_A)
# ... 视觉和文本类似
# 目标域特征提取
if 'A' in self.modality:
self.feat_A_t = self.netA(**self.acoustic_t)
self.final_embd_target.append(self.feat_A_t)
# 拼接多模态表示
self.feat = torch.cat(self.final_embd, dim=-1)
self.feat_t = torch.cat(self.final_embd_target, dim=-1)
self.feat_all = torch.cat([self.feat, self.feat_t], dim=0)
self.feat_all.register_hook(self.save_gradients)
# 分类预测
self.logits, self.fall = self.netC(self.feat_r)
self.pred = F.softmax(self.logits, dim=-1)
关键设计:
- feat_all 注册了梯度钩子 save_gradients,用于后续 MGDA 权重计算
- 源域和目标域的特征通过同一编码器提取,确保表示空间一致
3.2.3 伪标签投票
# 各模态目标域预测转为 one-hot 并累加
self.pred_tf = 1.5 * self.pred_toh # 融合模态权重为 1.5
if 'A' in self.modality:
self.pred_tf += self.pred_A_toh
if 'V' in self.modality:
self.pred_tf += self.pred_V_toh
if 'L' in self.modality:
self.pred_tf += self.pred_L_toh
# 筛选高置信度样本
self.index_pl = index_rely(self.pred_tf) # 票数 >= Mv 的样本索引
self.pl = self.pred_tf.argmax(dim=1)[self.index_pl]
函数 index_rely 实现阈值筛选:
def index_rely(matrix, threshold=3):
large_elements = matrix >= threshold
rows_with_large_elements = torch.any(large_elements, dim=1)
row_indices = torch.nonzero(rows_with_large_elements).squeeze(1)
return row_indices
3.2.4 权重计算(weight_cal)
这是 Boomda 的核心方法,实现多目标优化权重的计算:
def weight_cal(self):
# 1. 计算各模态的 CORAL 损失
if 'A' in self.modality:
self.loss_coral_a = mul * coral(self.fa, self.fa_t)
if 'V' in self.modality:
self.loss_coral_v = mul * coral(self.fv, self.fv_t)
if 'L' in self.modality:
self.loss_coral_l = mul * coral(self.fl, self.fl_t)
# 2. 第一次反向传播:各模态独立对齐损失
self.scaler.scale(self.loss_coral_all).backward(retain_graph=True)
grads_all['all'].append(copy.deepcopy(self.feat_all_gradients))
# 3. 第二次反向传播:融合模态对齐损失
self.loss_coral_com = mul * coral(self.fall, self.fall_t)
self.scaler.scale(self.loss_coral_com).backward(retain_graph=True)
grads_com['all'].append(copy.deepcopy(self.feat_all_gradients))
# 4. 分离各模态梯度
for t in self.modality:
if t == 'A':
grads['A'] = grads_all['all'][0: self.opt.embd_size_a]
grads_m['A'] = grads_com['all'][0: self.opt.embd_size_a]
# ... V, L 类似
# 5. 梯度归一化(可选)
gn = gradient_normalizers(grads, loss_data, 'loss')
# 6. 计算闭式解(use_full=1 时使用对角近似)
if self.use_full == 1:
pd_1 = torch.tensor([torch.dot(grads[t], grads[t]) for t in modality])
pd_2 = torch.tensor([torch.dot(grads_m[t], grads_m[t]) for t in modality])
pd_x = torch.tensor([torch.dot(grads[t], grads_m[t]) for t in modality])
diag = [i for i in pd_1]
diag.append(sum(pd_2))
Q_t = torch.tensor(diag)
Q_t_inv = 1 / Q_t
sol = Q_t_inv / sum(Q_t_inv)
return scale
关键流程:
1. 计算各模态独立 CORAL 损失和融合模态 CORAL 损失
2. 两次反向传播获取梯度(保留计算图)
3. 从 feat_all 的梯度中分离各模态的梯度片段
4. 梯度归一化(按损失值归一化)
5. 计算对角近似矩阵 $\tilde{\mathbf{Q}}$ 的闭式解
3.3 CORAL 对齐实现
3.3.1 核心 CORAL 函数
models/coralutils/coral.py 实现了 CORAL 损失:
import torch
import numpy as np
import torch.nn.functional as F
mse = torch.nn.MSELoss()
def coral(source, target):
d = source.size(1) # 特征维度
source_c, source_mu = compute_covariance(source)
target_c, target_mu = compute_covariance(target)
# 协方差差异(Frobenius 范数平方)
loss_c = torch.sum(torch.mul((source_c - target_c), (source_c - target_c)))
loss_c = loss_c / (4 * d * d)
# 均值差异(MSE)
loss_mu = mse(source_mu, target_mu)
return loss_c # 当前实现仅返回协方差差异
3.3.2 协方差计算
def compute_covariance(input_data):
n = input_data.size(0) # batch size
device = input_data.device
id_row = torch.ones(n).resize(1, n).to(device=device)
sum_column = torch.mm(id_row, input_data)
mean_column = torch.div(sum_column, n)
term_mul_2 = torch.mm(mean_column.t(), mean_column)
d_t_d = torch.mm(input_data.t(), input_data)
c = torch.add(d_t_d, (-1 * term_mul_2)) * 1 / (n - 1)
return c, mean_column
实现细节:
- 使用矩阵运算高效计算协方差矩阵
- 公式 $C = \frac{1}{n-1}(X^T X - \frac{1}{n}(\mathbf{1}^T X)^T (\mathbf{1}^T X))$
- 同时返回均值向量用于可选的均值对齐
3.3.3 信息瓶颈中的熵计算
代码中还实现了 log_var 函数用于估计表示的熵(信息瓶颈第一项):
def log_var(input_data):
n = input_data.size(0)
d = input_data.size(1)
input_data = F.normalize(input_data, p=2, dim=1)
covar = torch.cov(input_data.T)
det_covar = torch.linalg.det(covar)
log_det_covar = torch.log10(det_covar / d**2 + 1)
return log_det_covar
注意:代码实现与论文公式略有不同,使用了归一化后的协方差矩阵行列式。
3.4 多目标权重求解
3.4.1 MGDA 求解器
models/min_norm_solvers.py 提供了 MGDA 的完整实现,包括:
(1)最小范数元素求解(两任务情况)
def _min_norm_element_from2(v1v1, v1v2, v2v2):
if v1v2 >= v1v1:
gamma = 0.999
cost = v1v1
return gamma, cost
if v1v2 >= v2v2:
gamma = 0.001
cost = v2v2
return gamma, cost
gamma = -1.0 * ((v1v2 - v2v2) / (v1v1 + v2v2 - 2*v1v2))
cost = v2v2 + gamma * (v1v2 - v2v2)
return gamma, cost
(2)Frank-Wolfe 迭代
def find_min_norm_element_FW(vecs):
# 初始化解为最优两任务组合
init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps)
# Frank-Wolfe 迭代直到收敛
while iter_count < MinNormSolver.MAX_ITER:
t_iter = np.argmin(np.dot(grad_mat, sol_vec))
v1v1 = np.dot(sol_vec, np.dot(grad_mat, sol_vec))
v1v2 = np.dot(sol_vec, grad_mat[:, t_iter])
v2v2 = grad_mat[t_iter, t_iter]
nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2)
new_sol_vec = nc * sol_vec
new_sol_vec[t_iter] += 1 - nc
if np.sum(np.abs(new_sol_vec - sol_vec)) < MinNormSolver.STOP_CRIT:
return sol_vec, nd
sol_vec = new_sol_vec
iter_count += 1
(3)梯度归一化
def gradient_normalizers(grads, losses, normalization_type):
gn = {}
if normalization_type == 'loss':
for t in grads:
gn[t] = losses[t]
elif normalization_type == 'l2':
for t in grads:
gn[t] = np.sqrt(np.sum([gr.pow(2).sum().data.cpu() for gr in grads[t]]))
return gn
3.4.2 实际使用的闭式解
在 BoomdaModel.weight_cal() 中,实际使用的是论文提出的闭式解(use_full=1),而非迭代求解 MGDA。这保证了训练的高效性:
pd_1 = torch.tensor([torch.dot(grads[t], grads[t]) for t in modality])
pd_2 = torch.tensor([torch.dot(grads_m[t], grads_m[t]) for t in modality])
pd_x = torch.tensor([torch.dot(grads[t], grads_m[t]) for t in modality])
diag = [i for i in pd_1]
diag.append(sum(pd_2))
Q_t = torch.tensor(diag)
Q_t_inv = 1 / Q_t
sol = Q_t_inv / sum(Q_t_inv)
3.5 训练流程与入口
3.5.1 训练循环
train_boomda.py 实现了完整的训练流程:
for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
for i, (data, data_t) in enumerate(zip(dataset, dataset_t)):
# 计算退火系数 alpha
p = float(i + start_steps) / total_steps
alpha = p ** 2
# 设置输入并优化
model.set_input(data, data_t, alpha)
model.optimize_parameters(epoch)
# 打印损失
if total_iters % opt.print_freq == 0:
losses = model.get_current_losses()
logger.info(...)
# 更新学习率
model.update_learning_rate(logger)
3.5.2 optimize_parameters
def optimize_parameters(self, epoch):
self.optimizer.zero_grad()
with autocast(): # 自动混合精度
self.forward()
scale = self.weight_cal() # 计算 MGDA 权重
self.loss_cal(scale) # 计算整体损失
self.optimizer.zero_grad()
self.scaler.scale(self.loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
关键设计:
- 使用 PyTorch 的 autocast 进行自动混合精度训练
- weight_cal() 内部包含两次反向传播(保留计算图)
- 整体损失计算后进行第三次反向传播更新参数
3.5.3 损失计算
def loss_cal(self, scale):
# 指数移动平均平滑权重
beta = 0.2
self.scale = (1 - beta) * self.scale + beta * scale_current
# 各损失项
self.loss_CE = self.criterion_ce(self.logits, self.label)
self.loss_CE_t = self.criterion_ce(self.logits_t, self.label_t)
if self.index_pl.shape[0] >= 1:
self.loss_CE_tpl = self.criterion_ce(self.logits_t[self.index_pl], self.pl)
# 信息瓶颈熵正则
self.loss_ee = mul2 * entropy_re(self.fall)
# 各模态 CORAL 损失
self.loss_coral = mul * coral(self.fall, self.fall_t)
if 'A' in self.modality:
self.loss_coral_a = mul * coral(self.fa, self.fa_t)
# ... V, L 类似
# 加权求和
self.loss_coralall = scalea['J'] * self.loss_coral + \
scalea['A'] * self.loss_coral_a + \
scalea['V'] * self.loss_coral_v + \
scalea['L'] * self.loss_coral_l
# 总损失
self.loss = self.loss_CE + 0.5 * self.alpha * self.loss_CE_tpl + \
0.1 * (self.loss_CE_A + self.loss_CE_V + self.loss_CE_L) + \
0.1 * self.loss_coralall + 0.00001 * (self.loss_ee + ...)
3.6 运行示例与配置
3.6.1 运行脚本
scripts/Boomda.sh 提供了运行示例:
# IEMOCAP -> IEMOCAP (VL 模态)
bash scripts/Boomda.sh VL 1 0 IEMOCAP IEMOCAP logs/0813VL_ie2ie label.csv 2e-3 "0.1" "'-1 1 0 0'" "'1 0.1 0.1 0.4'"
# MSP -> MSP (AVL 模态)
bash scripts/Boomda.sh AVL 1 0 MSP MSP logs/1501cAVL_ms2ms label.csv 2e-3 "0.1" "'-1 1 0 0'" "'1 0.1 0.18 0.4'"
参数说明:
- $1 modality:使用模态(A/V/L/AV/VL/AVL)
- $2 run_idx:运行索引
- $3 gpu:GPU 编号
- $4 source:源域数据集
- $5 target:目标域数据集
- $6 log_dir:日志目录
- $7 csv_name:标签 CSV 文件
- $8 lr:学习率
- $9 weights:权重列表
- ${10} change_weight1:源域数据变换权重
- ${11} change_weight2:目标域数据变换权重
3.6.2 关键超参数
| 参数 | 值 | 说明 |
|---|---|---|
| embd_size_a/v/l | 768 | 表示维度(实际投影到 256) |
| output_dim | 4 | 类别数(neutral, happy, sad, angry) |
| cls_layers | 128,128 | 分类器隐藏层 |
| dropout_rate | 0.3 | Dropout 比率 |
| niter | 4 | 初始 epoch 数 |
| niter_decay | 4 | 衰减 epoch 数 |
| batch_size | 48 | 批次大小 |
| lr | 1e-3 | 学习率 |
| beta1 | 0.9 | Adam 动量系数 |
本节小结
- Boomda 代码采用模块化设计,模型、数据、损失、求解器分离
BoomdaModel是核心类,实现了前向传播、伪标签投票、权重计算和损失计算- CORAL 损失通过矩阵运算高效实现协方差对齐
- MGDA 求解器提供了完整的迭代求解方法,但实际使用闭式解提升效率
- 训练使用自动混合精度(AMP),三次反向传播(两次 MGDA + 一次整体更新)
思考题
- 代码中
weight_cal()进行了两次反向传播,为什么需要保留计算图(retain_graph=True)? entropy_re函数在代码中实现了信息瓶颈的哪一项?查看源码并分析。- 如果要在 Boomda 的基础上增加一个新的模态(如生理信号),需要修改哪些代码?
- 闭式解和迭代 MGDA 求解在实际训练中的速度差异可能有多大?如何量化?