NLP —— 模型优化蒸馏案例
2026/6/5 14:18:25 网站建设 项目流程

目录

一、概念

二、主流四大类技术

1. 模型量化

2. 模型剪枝

3. 低秩因式分解

4. 模型蒸馏

三、代码案例

需求

代码思路

① Config文件

② 教师模型文件

③ 学生模型文件

<1> 定义参数

<2> 搭建网络层

<3> 前向传播

④ 数据预处理文件

<1> 读取文件数据处理

<2> 自定义数据集类

<3> 数据二次处理 -> 数据张量和掩码张量

<4> 构造数据加载器

⑤ 模型蒸馏训练

<1> 创建数据加载器对象

<2> 创建教师模型对象 + 加载已训练好的模型

<3> 创建学生模型对象

<4> 损失函数

<5> 优化器

<6> 变量(训练轮次、初始化f1_score,蒸馏温度T、α系数)

<7> 设置老师模型评估模式、学生模型训练模式

<8> 训练

⑥模型预测使用


一、概念

模型压缩:在尽量不损失精度前提下,减小模型参数量、显存占用、推理耗时,方便部署 CPU / 移动端。

目标: 参数变少、模型文件变小、推理更快、显存更低。 常见落地:大 BERT→小 BiLSTM

二、主流四大类技术

1. 模型量化

pytorch中默认 float32 int64. -> float16 int8 。

降低精度。从而缩减模型,并加速推断速度。。

pytorch 中 Quantization,官网API (静态、动态)API
Quantization — PyTorch 2.4 documentation

① 训练中量化 QAT 量化感知训练

② 训练后量化

<1> 动态量化 DQNLP领域

<2> 静态两会 QTQ CV领域

特性静态量化动态量化
APIpreparequantize_dynamic
适用模型CNN(ResNet, MobileNet)NLP模型(BERT, LSTM)等

PyTorch的动态量化只能在CPU上执行

核心代码

# 定义一个模型 class Model(torch.nn.Module): def __init__(self): super().__init__() self.embedded = nn.Embedding(4, 128) self.rnn = nn.GRU(128, 1024, batch_first=True) self.linear = nn.Linear(1024, 10) self.dropout = nn.Dropout(p=0.1) def forward(self, x): x, hn = self.rnn(self.embedded(x)) return self.dropout(self.linear(x))
# 创建量化模型实例 # model:原始模型 # qconfig_spec:待量化的层参数 # dtype:量化权重的目标类型 model2 = torch.quantization.quantize_dynamic(model=model1, qconfig_spec={torch.nn.Linear, nn.GRU}, dtype=torch.qint8)


2. 模型剪枝

NLP中不用,一般在CV中用。

Pytorch中对模型剪枝的支持在torch.nn.utils.prune模块中, 分以下几种剪枝方式:

  • 随机剪枝

  • L1结构化剪枝

  • L1非结构化剪枝

  • 全局非结构化剪枝

非结构化剪枝结构化剪枝
按单个权重裁剪按神经元、通道、整行/列裁剪
剪枝后是稀疏矩阵剪枝后是稠密矩阵
类似于裁掉部门中贡献度低的个人类似于裁掉整个部门

代码:

# 演示随机非结构化剪枝 def dm01(): linear = nn.Linear(2, 3) print("linear-->", linear.weight) model = prune.random_unstructured(linear, 'weight', amount=2) print("model-->", model.weight) # 演示全局非结构化剪枝 def dm02(): net = nn.Sequential(OrderedDict([ ('first', nn.Linear(3, 4)), ('second', nn.Linear(4, 2)), ])) print("net1-->", net) for model in net: print("model-->", model.weight) parameters_to_prune = ((net.first, 'weight'), (net.second, 'weight')) # parameters_to_prune:待剪枝的参数 # pruning_method:剪枝的方式,L1Unstructured表示非结构化剪枝(常用) # amount:如果是小数,则表示比例,如果是整数,则表示数量 prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount = 0.2) print("net2-->", net) for model in net: print("model-->", model.weight)

3. 低秩因式分解

比如21128词表 * 768维度 很大,进行分解。运用矩阵分解,减少网络参数量,提升效率。

4. 模型蒸馏

复杂模型(教师模型)-> 简单模型(学生模型)

教师模型

  • 定义:复杂的、高性能的模型,通常是大型深度神经网络。

  • 特点:参数量大,能够学习复杂的特征和关系。

  • 需要提前训练好。

学生模型

  • 定义:简化的、小型的模型,可以是教师模型的子集或者简单模型。

  • 特点:参数量较小,适用于资源受限的场景。

  • 不需要提前训练好。

知识的来源:

  • 硬标签蒸馏:学生模型直接学习教师模型的分类结果。

  • 软标签蒸馏:学生模型学习教师模型对每个类别的概率分布

  • 中间层蒸馏:学生模型学习教师模型的隐藏层、特征图等。

关键点:

  1. 高温T平滑输出概率,生成软标签
  2. 效果:BERT (110M 参数) → BiLSTM (几 M 参数),体积压缩十几倍
  3. 损失 = 真实标签 CE 损失 + KL 蒸馏损失

适用:NLP 分类、文本任务。

公式:

# 计算KL散度值 p = torch.log_softmax(teacher_pred/T, dim=-1) q = torch.log_softmax(student_pred/T, dim=-1) # KL散度值,也就是软标签的值 """ 参数解释: input:是【学生模型】输出的结果 target:预测结果参考值。也就是【教师模型】输出的结果 reduction:上面两个值的计算方式。 log_target:是否对计算结果求log对数 """ kl_value = torch.nn.functional.kl_div( input=q, target=p, reduction="batchmean", log_target=True ) # 硬标签损失值 # 注意:是学生模型的预测概率,与样本的目标值算损失 hard_label_loss = loss(student_pred,labels) # 蒸馏的总损失值 # l = (1-α) * 硬标签损失值 + α * T² * KL散度值 distll_loss = (1 - alpha) * hard_label_loss + alpha * (T**2) * kl_value

q: 学生模型预测结果计算得来

p: 教师模型预测结果计算得来

CE(y,p)也就是 学生模型自己的交叉熵损失

  • 参数α:系数,控制从学生模型和教师模型学习的比例,比如α=0.8。

  • 参数T:蒸馏温度,是一个平滑系数,控制softmax的输出,比如T=4。

蒸馏总损失值 L_{KD} = (1 - α)CE(y,p) + αKL(q,p)

KLDivLoss — PyTorch 2.4 documentation

三、代码案例

需求

以文本分类任务,基于Bert模型的 教师模型,学生模型内部使用BiLstm神经网络

数据文本 ( 内容, 类别索引 )

数据源:三个内容文件,一个类别文件。

代码思路

① Config文件

配置各个文件路径(数据源,模型,批次大小,句子最大长度)

class Config(object): def __init__(self): # 1 - 设备 # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.device = "cpu" # 2 原始文件 self.train_datapath = 'data/train.txt' self.test_datapath = 'data/test.txt' self.dev_datapath = 'data/dev.txt' self.class_datapath = 'data/class.txt' # 3 数据加载参数 self.batch_size = 64 self.max_seq_len = 32 # 4 Bert 预训练模型路径 self.bert_path = '../Base_Bert_TMF/bert_base_model/bert-base-chinese' # 5 - 目标值 文本解析 self.classname_list = [line.strip() for line in open(self.class_datapath,mode='r',encoding='utf-8')] self.classname_len = len(self.classname_list) # 6 - 训练好的【教师模型】路径 self.teacher_model_path = 'save_model/teacher_bert.pkl' # 7 - 学生模型路径 self.student_model_path = 'save_model/student_model.pkl'
② 教师模型文件

基于Bert模型,经过线性层处理,冻结反向传播。(已训练好的模型)

线性层(in_features = Bert模型的隐藏状态大小,out_features=数据源类的总共个数)

""" 教师模型,基于Bert模型 """ import torch import torch.nn as nn from transformers import BertModel from transformers import BertConfig from config import Config config = Config() class TeacherBertModel(nn.Module): def __init__(self): super().__init__() self.bert_model = BertModel.from_pretrained(config.bert_path) temp_config = BertConfig.from_pretrained(config.bert_path) in_features = temp_config.hidden_size self.linear = nn.Linear( in_features=in_features, out_features=config.classname_len ) def forward(self, input_ids, attention_mask=None): # 教师模型不需要训练 要冻结反向传播 with torch.no_grad(): bert_output = self.bert_model( input_ids=input_ids, attention_mask=attention_mask ) # 2- 教师模型的:池化层,实际就是nn.Linear+激活函数。不用额外定义 """ 1- last_hidden_state[:,0]和pooler_output,实际是类似的东西,都表示[CLS]的隐藏状态。 区别:需要对last_hidden_state[:,0]经过nn.Linear和激活函数处理后,才能得到pooler_output 对应源代码位置:BertModel文件的697行 2- 获得池化层后的结果有两种方式: 2.1- 方式一:推荐。通过实例属性获得 bert_output.pooler_output 2.2- 方式二:通过实例属性索引获得 bert_output[1]。1的原因是pooler_output是类中的第2个实例属性 对应源代码位置:BertModel文件的1017行 """ # 因为是句子 分类问题,所以取句子的向量。 pooled_output = bert_output.pooler_output return self.linear(pooled_output)
③ 学生模型文件

定义学生模型类

<1> 定义参数

词汇表大小,词向量维度,隐藏状态,隐藏层层数

<2> 搭建网络层

词向量层、双向LSTM、随机失活层、线性层(输入 2倍的隐藏大小,输出 句子最大长度)

<3> 前向传播

<<1>> 数据张量化

<<2>> 输入原始数据处理,

过滤【CLS、SEP】特殊标识,基于transformer系列都有这个标识。

结合输入掩码张量对原始数据矩阵点乘处理

得到最终有效的词张量数据

<<3>> 调用BiLstm循环神经网络 -> 得到输出数据【batch_size,seq_len,hidden_size】

<<4>> 因为是文本分类需要的是句子,对输出数据累加->降维->记得句向量数据

<<5>>调用(随机失活 + 线性层)-> 输出

""" 学生模型 用BILSTM 双向模型 """ from torch import Tensor from config import Config import torch import torch.nn as nn from transformers import BertConfig config = Config() bert_config = BertConfig.from_pretrained(config.bert_path) class BILSTMStudentModel(nn.Module): def __init__(self): super().__init__() """ 设置参数 基于Bert模型的中文词汇表大小 """ self.vocab_size = bert_config.vocab_size self.embedding_dim = 128 self.hidden_size = 256 self.num_layers = 3 """ 搭建网络层 embedding_dim:由我们自己设置,与教师模型没有任何关系 """ self.embedding = nn.Embedding(self.vocab_size, self.embedding_dim) self.lstm = nn.LSTM( input_size=self.embedding_dim, #输入的词向量维度,必须和embding_dim 相同 hidden_size=self.hidden_size, #隐藏层向量维度 自定义 batch_first=True, #是否batch_size开头的张量 【batch_size,seq_len,hidden_size】 num_layers=self.num_layers, #隐藏层层数 bidirectional=True #是否双向 ) self.dropout = nn.Dropout(p=0.2) """ 因为双向LSTM 所以 hidden_size*2 多分类任务,任务值是 取数据类别个数 作为输出 """ self.linear = nn.Linear(self.hidden_size*2, config.classname_len) def forward(self, input_ids, attention_mask): # 1 - 数据张量化 ebd = self.embedding(input_ids) """ 带 【CLS、SEP】特殊标识 Token:BERT 系 Transformer 编码器网络 所以数据要先把 【CLS】、【SEP】标识去除 """ # 2 - cls_token_index = 101 #句子开头 CLS固定索引值 sep_token_index = 102 #句子结尾 SEP固定索引值 # 2.1 # 对 input_ids 数据过滤 CLS 和 SEP ebd_mask = (input_ids != cls_token_index) & (input_ids != sep_token_index) # 2.2 # 过滤后的数据 与 掩码进行再次过滤 => 得到实际要用的掩码 ebd_mask:Tensor = ebd_mask & attention_mask # 2.3 # 对 edb_mask 升维 # 原始【batch_size,seq_len】 -> 【batch_size, seq_len, 1】 ebd_mask = ebd_mask.unsqueeze(-1) # 2.4 # 原始数据 与 实际掩码 进行点乘预算,得到实际有效的数据源 ebd = ebd * ebd_mask # 3 - 调用循环神经网络BiLSTM # 为什么调用lstm的时候,没有手动传递初始的细胞状态和隐藏状态:LSTM内部会自动的进行全0初始化。源代码在1056行 out_put, (hidden, c) = self.lstm(ebd) # 4 - 计算平均池化值 # 4.1 # 降维: 以为是对词向量进行 网络处理,需求做的是句子分类 # 【batch_size,seq_len,hidden_size】=> [batch_size, hidden_size] output_sum = out_put.sum(dim=1) # 4.2 # 获取所有有效词的个数 + 1e-6 为了防止个数为0 token_count = ebd_mask.sum(dim=1) + 1e-6 # 4.3 # 计算获取 最终的句子向量数据 new_output = output_sum / token_count # 5 # 调用线性层,得到预测结构,并返回 return self.linear(self.dropout(new_output))
④ 数据预处理文件
<1> 读取文件数据处理

表格数据读取 -> 得到数组 (每行的数据)

<2> 自定义数据集类

<<1>> __init__ 参数定义 self.data_list = <1>处理得到的

<<2>> __len__ 样本条数

<<3>> __getitem__ 函数,根据索引获得 对应的 文本和分类 值

<3> 数据二次处理 -> 数据张量和掩码张量

<<1>> 传入每批次数据
输入数据:[('近期新盘推荐 通州纯新别墅本周开盘', 1), ('陕西退休教师嫌弃精神病 女儿将其勒死被捕', 5)]
输出数据:[('近期新盘推荐 通州纯新别墅本周开盘', '陕西退休教师嫌弃精神病女儿 将其勒死被捕'), (1, 5)]

得到 文本内容元组 和 类别元组

<<2>> 通过 transformers 的 BertTokenizer, 把数据转换为词索引张量

<<3>> 返回 数据张量(intput_ids)、掩码张量(attention_mask)、真实类别张量(lables)

<4> 构造数据加载器

<<1>> 通过<1>、<2>、得到数据集

<<2>> 创建数据加载器对象 DataLoader

<<3>> 返回加载器对象

""" 数据处理 得到模型需要的 input_dis 和 attention_mask. 并传递 真实值 Labels # 1 读取文件获得数据 # 2 定义数据集 # 3 数据二次处理 (按batch,处理成input_dis,attention_mask 张量) # 4 构建数据加载器 """ import torch import torch.nn as nn from config import Config from torch.utils.data import Dataset,DataLoader from transformers import BertTokenizer config = Config() bert_tokenizer = BertTokenizer.from_pretrained(config.bert_path) # 1 - 数据获取,处理 def load_data(datapath): with open(datapath,mode="r",encoding="UTF-8") as f: lines = f.readlines() result_list = [] for line in lines: line = line.strip() if line=="": continue # 样本数据 # 两天价网站背后重重迷雾:做个网站究竟要多少钱 4 title, label = line.split('\t') # 【可选】健壮性代码 """ 只要是有数据类型转换的地方,基本都有健壮性代码 """ if not label.isdigit(): print(f"label的数据内容不合法,值是{label}") continue # 保存数据 result_list.append((title,int(label))) return result_list # 2 - 自定义数据集 class NewsDataset(Dataset): def __init__(self,data_list): super().__init__() self.data_list = data_list #读取数据 self.sample_len = len(self.data_list) #样本条数 def __len__(self): return self.sample_len def __getitem__(self, idx): # 防止数组越界 index = min(max(idx, 0),self.sample_len-1) title,label = self.data_list[index] return title,label # 3 - 数据二次处理,按每批次数据处理 def collate_fn(batch_data): """ zip(*)处理过程如下: 输入数据:[('近期新盘推荐 通州纯新别墅本周开盘', 1), ('陕西退休教师嫌弃精神病女儿将其勒死被捕', 5)] 输出数据:[('近期新盘推荐 通州纯新别墅本周开盘', '陕西退休教师嫌弃精神病女儿将其勒死被捕'), (1, 5)] """ titles,labels = zip(*batch_data) # 根据词索引 数据张量化 -> 获取词索引张量 title_tensor = bert_tokenizer( titles, padding="max_length", truncation=True, max_length=config.max_seq_len, return_tensors="pt" ) return ( title_tensor.input_ids, title_tensor.attention_mask, torch.tensor(labels,dtype=torch.long) ) # 4 - 构建数据加载器 def build_dataloader(datapath, shuffle=True): data = load_data(datapath) dataset = NewsDataset(data) data_loader = DataLoader( dataset=dataset, batch_size=config.batch_size, shuffle=shuffle, collate_fn=collate_fn ) return data_loader
⑤ 模型蒸馏训练

学生模型训练边训练边预测保存

<1> 创建数据加载器对象
<2> 创建教师模型对象 + 加载已训练好的模型
<3> 创建学生模型对象
<4> 损失函数
<5> 优化器
<6> 变量(训练轮次、初始化f1_score,蒸馏温度T、α系数)
<7> 设置老师模型评估模式、学生模型训练模式
<8> 训练

<8.1> 根据数据加载器分批次 获取输入张量、掩码张量、真实类别张量

<8.2> 模型前向传播,其中老师模型冻结,不需要更新

<8.3> 计算KL散度

<8.4> 计算学生模型交叉熵损失值

<8.5> 计算蒸馏总损失值

<8.6> 梯度清零、反向传播、梯度更新

<8.7> 每固定间隔 对学生模型进行评估

<<1>> 数据加载器(加载评估数据)

<<2>> 学生模型切换评估模式

<<3>> 数据加载器分批次进行模型评估

保存真实结果和评估结果

<<4>> 计算评估指标

f1_score、accuracy(准确率)、precision(精确率)、recall(召回率)

<8.8> f1_socre > 上一次的f1_socre 值,保存模型进行覆盖。

<8.9> 学生模型切换训练模型,继续训练直到所有训练数据结束

""" 模型蒸馏 """ import torch import torch.nn as nn from tqdm import tqdm from data_preprocessing import build_dataloader from student_bilstm_model import BILSTMStudentModel from teacher_bert_model import TeacherBertModel from config import Config from sklearn.metrics import accuracy_score,precision_score,recall_score,f1_score config = Config() def eval(student_model): # 1. 数据加载器 dataloader = build_dataloader(config.dev_datapath, shuffle=False) # 2. 切换模式 student_model.eval() all_pred_result = [] # 预测结果列表 all_true_result = [] # 真实结果列表 # 3. 预测 with torch.no_grad(): for batch_idx, batch_data in enumerate(tqdm(dataloader),start=1): input_dis, attention_mask, labels = batch_data input_dis = input_dis.to(config.device) attention_mask = attention_mask.to(config.device) labels = labels.to(config.device) # 预测结果 student_pred = student_model(input_dis, attention_mask) student_pred_index = torch.argmax(student_pred, dim=-1) # cpu():因为不涉及张量的计算,因此为了节约GPU资源,可以将数据转到CPU上再处理 # .tolist() tensor([0,2,1]) → [0,2,1] # .extend() # append([1,2,3]) → [[1,2,3]](嵌套列表) # extend([1,2,3]) → [1,2,3](把元素挨个拼进去) all_pred_result.extend(student_pred_index.cpu().tolist()) all_true_result.extend(labels.cpu().tolist()) # 4 - 计算评估指标 f1score = f1_score(all_true_result,all_pred_result,average="macro") # 准确率 accuracy = accuracy_score(all_true_result,all_pred_result) precision = precision_score(all_true_result,all_pred_result,average="macro") recall = recall_score(all_true_result,all_pred_result,average="macro") return f1score, accuracy, precision, recall def train_and_eval(): # 1. 通过加载器获取数据 data_loader = build_dataloader(config.train_datapath, shuffle=True) # 2 - 教师模型 teacher_model = TeacherBertModel().to(config.device) teacher_model.load_state_dict(torch.load(config.teacher_model_path)) # 3 - 学生模型 student_model = BILSTMStudentModel().to(config.device) # 4 - 损失函数 loss_fn = nn.CrossEntropyLoss() # 5 - 优化器 optimizer = torch.optim.Adam(student_model.parameters(), lr=5e-5) # 6 - 其他变量 epochs = 1 best_f1score = 0 T = 2 #蒸馏温度 alpha = 0.7 #计算蒸馏总损失 KL散度和学生 概率比例 # 7 - 训练模式 student_model.train() teacher_model.eval() # 8 训练 for epoch in range(epochs): for batch_idx, batch_data in enumerate(tqdm(data_loader),start=1): input_dis, attention_mask, labels = batch_data # 8.1 批次训练数据 # (输入张量、掩码张量、真实张量) input_dis = input_dis.to(config.device) attention_mask = attention_mask.to(config.device) labels = labels.to(config.device) # 8.2 模型前向传播 # 老师模型冻结,不需要更新 with torch.no_grad(): teacher_pred = teacher_model(input_dis, attention_mask) teacher_pred_labels = torch.argmax(teacher_pred, dim=-1) student_pred = student_model(input_dis, attention_mask) student_pred_labels = torch.argmax(student_pred, dim=-1) # 8.3 # 计算KL散度 p = torch.log_softmax(teacher_pred/T, dim=-1) q = torch.log_softmax(student_pred/T, dim=-1) # KL散度值,也就是软标签的值 """ 注意:kl_div的包不要导错了!!! 参数解释: input:是【学生模型】输出的结果 target:预测结果参考值。也就是【教师模型】输出的结果 reduction:上面两个值的计算方式。 log_target:是否对计算结果求log对数 """ kl_value = torch.nn.functional.kl_div( input=q, target=p, reduction='batchmean', log_target=True ) # 8.4 学生模型自己的损失值 loss_value = loss_fn(student_pred, labels) # 8.5 蒸馏总损失值 固定公式 distill_loss = (1-alpha) * loss_value + alpha * kl_value * (T**2) # 8.6 梯度清零,反向传播,梯度更新 optimizer.zero_grad() distill_loss.backward() optimizer.step() # 8.7 每间隔100个批次 或者 最后一个批次,对学生模型进行验证 if batch_idx%100==0 or batch_idx==len(data_loader): f1_score, accuracy, precision, recall = eval(student_model) print(f"第{batch_idx}批次,f1score={f1_score},accuracy={accuracy},precision={precision},recall={recall}") if f1_score > best_f1score: torch.save(student_model.state_dict(), config.student_model_path) best_f1score = f1_score # 切换回训练模式 student_model.train() if __name__ == '__main__': train_and_eval()
⑥模型预测使用
""" 预测函数 提供模型服务 """ import torch from config import Config from transformers import BertTokenizer from student_bilstm_model import BILSTMStudentModel config = Config() model = BILSTMStudentModel().to(config.device) model.load_state_dict(torch.load(config.student_model_path)) model.eval() tokenizer = BertTokenizer.from_pretrained(config.bert_path) def model_predict(json_data): # 1 - 外部数据 取得句子 title = json_data['title'] # 2 - 文本转张量 获得 input_ids, attention_mask title_tensor = tokenizer( [title], padding="max_length", truncation=True, max_length=config.max_seq_len, return_tensors="pt" ) input_ids = title_tensor.input_ids.to(config.device) attention_mask = title_tensor.attention_mask.to(config.device) with torch.no_grad(): output = model(input_ids, attention_mask) output_index = torch.argmax(output, dim=-1).item() #取概率最大的索引值 pred_class_name = config.classname_list[output_index] json_data["pred_class"] = pred_class_name return json_data if __name__ == '__main__': print(model_predict({'title': '体验2D巅峰 倚天屠龙记十大创新新概览'}))

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询