Skip to content

zhenhao-huang/paddlehub_ernie_emotion_analysis

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

paddlehub_ernie_emotion_analysis

百度PaddleHub-ERNIE微调中文情感分析(文本二分类)。详细流程参考博文

PaddlePaddle-PaddleHub

飞桨(PaddlePaddle)以百度多年的深度学习技术研究和业务应用为基础,是中国首个自主研发、功能完备、 开源开放的产业级深度学习平台,集深度学习核心训练和推理框架、基础模型库、端到端开发套件和丰富的工具组件于一体。PaddleHub旨在为开发者提供丰富的、高质量的、直接可用的预训练模型。

ERNIE

ERNIE(Enhanced Representation through kNowledge IntEgration)是百度提出的知识增强的语义表示模型,通过对词、实体等语义单元的掩码,使得模型学习完整概念的语义表示。在语言推断语义相似度命名实体识别情感分析问答匹配等自然语言处理(NLP)各类中文任务上的验证显示,模型效果全面超越BERTERNIE ERNIE 更多详情请参考ERNIE论文

一、环境安装

# CPU
pip install paddlepaddle
# GPU
pip install paddlepaddle-gpu
pip install paddlehub

有gpu的,建议安装paddlepaddle-gpu版(训练速度会提升好几倍)。paddlepaddle-gpu默认安装的是cuda10.2,如果需要安装其他cuda版本,到官方网站查找命令。(注意,从1.8.0开始采用动态图,所以paddlepaddlepaddlehub版本最好从1.8.0开始使用。)

二、数据预处理

这里使用的数据是二分类数据集weibo_senti_100k.csv,即情感倾向只有正向负向,下载地址:https://github.com/SophonPlus/ChineseNlpCorpus,已存放至data/weibo_senti_100k该目录下。由于PaddleHub用的是tsv格式的数据集,所以需要运行to_tsv.py该脚本将csv格式转成tsv格式。

三、微调

运行finetune_ernie.py。使用自己的数据集,需要修改base_path数据存放目录,label_list修改为实际的数据集标签。

选择模型

model = hub.Module(name='ernie_tiny', task='seq-cls', num_classes=len(MyDataset.label_list))
  • name:模型名称,可以选择ernie,ernie_tiny,bert-base-cased, bert-base-chinese, roberta-wwm-ext,roberta-wwm-ext-large等。
  • version:module版本号
  • task:fine-tune任务。seq-cls(文本分类任务)或token-cls(序列标注任务)。
  • num_classes:表示当前文本分类任务的类别数,根据具体使用的数据集确定,默认为2。 PaddleHub还提供BERT等模型可供选择, 当前支持文本分类任务的模型对应的加载示例如下: 
模型名 PaddleHub Module
ERNIE, Chinese hub.Module(name='ernie')
ERNIE tiny, Chinese hub.Module(name='ernie_tiny')
ERNIE 2.0 Base, English hub.Module(name='ernie_v2_eng_base')
ERNIE 2.0 Large, English hub.Module(name='ernie_v2_eng_large')
BERT-Base, English Cased hub.Module(name='bert-base-cased')
BERT-Base, English Uncased hub.Module(name='bert-base-uncased')
BERT-Large, English Cased hub.Module(name='bert-large-cased')
BERT-Large, English Uncased hub.Module(name='bert-large-uncased')
BERT-Base, Multilingual Cased hub.Module(nane='bert-base-multilingual-cased')
BERT-Base, Multilingual Uncased hub.Module(nane='bert-base-multilingual-uncased')
BERT-Base, Chinese hub.Module(name='bert-base-chinese')
BERT-wwm, Chinese hub.Module(name='chinese-bert-wwm')
BERT-wwm-ext, Chinese hub.Module(name='chinese-bert-wwm-ext')
RoBERTa-wwm-ext, Chinese hub.Module(name='roberta-wwm-ext')
RoBERTa-wwm-ext-large, Chinese hub.Module(name='roberta-wwm-ext-large')
RBT3, Chinese hub.Module(name='rbt3')
RBTL3, Chinese hub.Module(name='rbtl3')
ELECTRA-Small, English hub.Module(name='electra-small')
ELECTRA-Base, English hub.Module(name='electra-base')
ELECTRA-Large, English hub.Module(name='electra-large')
ELECTRA-Base, Chinese hub.Module(name='chinese-electra-base')
ELECTRA-Small, Chinese hub.Module(name='chinese-electra-small')

选择优化策略和运行配置

optimizer = paddle.optimizer.Adam(learning_rate=args.learning_rate, parameters=model.parameters())
trainer = hub.Trainer(model, optimizer, checkpoint_dir=args.checkpoint_dir, use_gpu=args.use_gpu)
trainer.train(train_dataset, epochs=args.num_epoch, batch_size=args.batch_size, eval_dataset=dev_dataset,
              save_interval=args.save_interval)
# 在测试集上评估当前训练模型
trainer.evaluate(test_dataset, batch_size=args.batch_size)

优化策略

Paddle提供了多种优化器选择,如SGDAdamAdamax等,详细参见策略。其中Adam:

  • learning_rate:全局学习率。默认为1e-3;
  • parameters:待优化模型参数。

运行配置

hub.Trainer主要控制Fine-tune的训练,包含以下可控制的参数: 

  • model:被优化模型;
  • optimizer:优化器选择;
  • checkpoint_dir:保存模型参数的地址;
  • use_gpu:是否使用gpu,默认为False。对于GPU用户,建议开启use_gpu。

trainer.train主要控制具体的训练过程,包含以下可控制的参数:

  • train_dataset:训练时所用的数据集;
  • epochs:训练轮数;
  • batch_size:训练的批大小,如果使用GPU,请根据实际情况调整batch_size;
  • num_workers:works的数量,默认为0;
  • eval_dataset:验证集;
  • log_interval:打印日志的间隔, 单位为执行批训练的次数。
  • save_interval:保存模型的间隔频次,单位为执行训练的轮数。

四、模型预测

完成Fine-tune后,Fine-tune过程在验证集上表现最优的模型会被保存在${CHECKPOINT_DIR}/best_model目录下,其中${CHECKPOINT_DIR}目录为Fine-tune时所选择的保存checkpoint的目录。运行脚本predict.py

五、结果

训练集: 测试集: 二分类数据集weibo_senti_100k.csv上,训练集准确率可以达到98%测试集准确率同样可以达到98%

About

百度PaddleHub-ERNIE微调中文情感分析(文本分类)

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages