Java学习者论坛

 找回密码
 立即注册

QQ登录

只需一步,快速开始

手机号码,快捷登录

恭喜Java学习者论坛(https://www.javaxxz.com)已经为数万Java学习者服务超过8年了!积累会员资料超过10000G+
成为本站VIP会员,下载本站10000G+会员资源,购买链接:点击进入购买VIP会员
JAVA高级面试进阶视频教程Java架构师系统进阶VIP课程

分布式高可用全栈开发微服务教程

Go语言视频零基础入门到精通

Java架构师3期(课件+源码)

Java开发全终端实战租房项目视频教程

SpringBoot2.X入门到高级使用教程

大数据培训第六期全套视频教程

深度学习(CNN RNN GAN)算法原理

Java亿级流量电商系统视频教程

互联网架构师视频教程

年薪50万Spark2.0从入门到精通

年薪50万!人工智能学习路线教程

年薪50万!大数据从入门到精通学习路线年薪50万!机器学习入门到精通视频教程
仿小米商城类app和小程序视频教程深度学习数据分析基础到实战最新黑马javaEE2.1就业课程从 0到JVM实战高手教程 MySQL入门到精通教程
查看: 1100|回复: 0

[默认分类] Faster RCNN代码详解(一):算法整体结构

[复制链接]
  • TA的每日心情
    开心
    2021-12-13 21:45
  • 签到天数: 15 天

    [LV.4]偶尔看看III

    发表于 2018-5-17 11:04:50 | 显示全部楼层 |阅读模式

    本系列博客介绍Faster RCNN算法的细节,以MXNet框架的代码为例。希望可以通过该系列博客让更多同学了解Faster RCNN算法中关于RPN网络的构建、损失函数的定义、正负样本的定义等细节,这样对于理解Faster RCNN后续的延伸版本(比如R-FCN、Mask RCNN)以及其他object detection算法也有一定的帮助。接下来的讲解基本上按照训练代码的顺序进行。
    项目地址:https://github.com/precedenceguo/mx-rcnn
    该系列博客以end to end的训练方式为例来介绍Faster RCNN算法,训练代码所在脚本:~/mx-rcnn/train_end2end.py,主要包含网络结构的构建(以特征提取主网络采用resnet为例,通过~mx-rcnn/rcnn/symbol/symbol_resnet.py脚本的get_resnet_train函数构建)和数据读取(通过~mx-rcnn/rcnn/core/loader.py脚本的AnchorLoader类和~/mx-rcnn/rcnn/io/rpn.py脚本的assign_anchor函数进行读取,前者会调用后者)两部分。
    先来看看训练入口:~/mx-rcnn/train_end2end.py脚本的代码细节:
    1. [code]import argparse
    2. import pprint
    3. import mxnet as mx
    4. import numpy as np
    5. from rcnn.logger import logger
    6. from rcnn.config import config, default, generate_config
    7. from rcnn.symbol import *
    8. from rcnn.core import callback, metric
    9. from rcnn.core.loader import AnchorLoader
    10. from rcnn.core.module import MutableModule
    11. from rcnn.utils.load_data import load_gt_roidb, merge_roidb, filter_roidb
    12. from rcnn.utils.load_model import load_param
    13. def train_net(args, ctx, pretrained, epoch, prefix, begin_epoch, end_epoch, lr=0.001, lr_step="5"):
    14.     # setup config
    15.     config.TRAIN.BATCH_IMAGES = 1
    16.     config.TRAIN.BATCH_ROIS = 128
    17.     config.TRAIN.END2END = True
    18.     config.TRAIN.BBOX_NORMALIZATION_PRECOMPUTED = True
    19.     # load symbol
    20. # eval语句是执行字符串命令,以args.network为resnet为例,就是调用~mx-rcnn/rcnn/symbol/symbol_resnet.py
    21. # 脚本中的get_resnet_train函数来得到Faster RCNN的网络结构。
    22. # 在该函数中涉及具体的RPN网络、RPN网络的损失函数、检测网络、检测网络的损失函数细节。
    23.     sym = eval("get_" + args.network + "_train")(num_classes=config.NUM_CLASSES, num_anchors=config.NUM_ANCHORS)
    24.     feat_sym = sym.get_internals()["rpn_cls_score_output"]
    25.     # setup multi-gpu
    26.     batch_size = len(ctx)
    27.     input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size
    28.     # print config
    29.     logger.info(pprint.pformat(config))
    30.     # load dataset and prepare imdb for training
    31. # 这部分是从yml文件读取标注信息,主要调用的接口是load_gt_roidb函数
    32.     image_sets = [iset for iset in args.image_set.split("+")]
    33.     roidbs = [load_gt_roidb(args.dataset, image_set, args.root_path, args.dataset_path, flip=not args.no_flip)
    34.               for image_set in image_sets]
    35.     roidb = merge_roidb(roidbs)
    36.     roidb = filter_roidb(roidb)
    37.     # load training data
    38. # 调用~mx-rcnn/rcnn/core/loader.py脚本中的AnchorLoader类读取数据,
    39. # 这里面包含了anchor的初始化,正负样本的确定,回归和分类的目标等。
    40.     train_data = AnchorLoader(feat_sym, roidb, batch_size=input_batch_size, shuffle=not args.no_shuffle,  ctx=ctx, work_load_list=args.work_load_list,                          feat_stride=config.RPN_FEAT_STRIDE, anchor_scales=config.ANCHOR_SCALES,
    41.                               anchor_ratios=config.ANCHOR_RATIOS, aspect_grouping=config.TRAIN.ASPECT_GROUPING)
    42.     # infer max shape
    43.     max_data_shape = [("data", (input_batch_size, 3, max([v[0] for v in config.SCALES]), max([v[1] for v in config.SCALES])))]
    44.     max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape)
    45.     max_data_shape.append(("gt_boxes", (input_batch_size, 100, 5)))
    46.     logger.info("providing maximum shape %s %s" % (max_data_shape, max_label_shape))
    47.     # infer shape
    48.     data_shape_dict = dict(train_data.provide_data + train_data.provide_label)
    49.     arg_shape, out_shape, aux_shape = sym.infer_shape(**data_shape_dict)
    50.     arg_shape_dict = dict(zip(sym.list_arguments(), arg_shape))
    51.     out_shape_dict = dict(zip(sym.list_outputs(), out_shape))
    52.     aux_shape_dict = dict(zip(sym.list_auxiliary_states(), aux_shape))
    53.     logger.info("output shape %s" % pprint.pformat(out_shape_dict))
    54.     # load and initialize params
    55. # args.resume是True则表示断点训练,那样的话导入的模型是训到某个epoch的检测模型。
    56. # 如果args.resume是False则不采用断点训练,这样的话就只导入分类模型进行参数初始化,
    57. # 且需要对RPN网络中的层和整个网络最后的分类和回归支路的全连接层进行随机或0值初始化。
    58.     if args.resume:
    59.         arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    60.     else:
    61.         arg_params, aux_params = load_param(pretrained, epoch, convert=True)
    62.         arg_params["rpn_conv_3x3_weight"] = mx.random.normal(0, 0.01, shape=arg_shape_dict["rpn_conv_3x3_weight"])
    63.         arg_params["rpn_conv_3x3_bias"] = mx.nd.zeros(shape=arg_shape_dict["rpn_conv_3x3_bias"])
    64.         arg_params["rpn_cls_score_weight"] = mx.random.normal(0, 0.01, shape=arg_shape_dict["rpn_cls_score_weight"])
    65.         arg_params["rpn_cls_score_bias"] = mx.nd.zeros(shape=arg_shape_dict["rpn_cls_score_bias"])
    66.         arg_params["rpn_bbox_pred_weight"] = mx.random.normal(0, 0.01, shape=arg_shape_dict["rpn_bbox_pred_weight"])
    67.         arg_params["rpn_bbox_pred_bias"] = mx.nd.zeros(shape=arg_shape_dict["rpn_bbox_pred_bias"])
    68.         arg_params["cls_score_weight"] = mx.random.normal(0, 0.01, shape=arg_shape_dict["cls_score_weight"])
    69.         arg_params["cls_score_bias"] = mx.nd.zeros(shape=arg_shape_dict["cls_score_bias"])
    70.         arg_params["bbox_pred_weight"] = mx.random.normal(0, 0.001, shape=arg_shape_dict["bbox_pred_weight"])
    71.         arg_params["bbox_pred_bias"] = mx.nd.zeros(shape=arg_shape_dict["bbox_pred_bias"])
    72.     # check parameter shapes
    73.     for k in sym.list_arguments():
    74.         if k in data_shape_dict:
    75.             continue
    76.         assert k in arg_params, k + " not initialized"
    77.         assert arg_params[k].shape == arg_shape_dict[k], \
    78.             "shape inconsistent for " + k + " inferred " + str(arg_shape_dict[k]) + " provided " + str(arg_params[k].shape)
    79.     for k in sym.list_auxiliary_states():
    80.         assert k in aux_params, k + " not initialized"
    81.         assert aux_params[k].shape == aux_shape_dict[k], \
    82.             "shape inconsistent for " + k + " inferred " + str(aux_shape_dict[k]) + " provided " + str(aux_params[k].shape)
    83.     # create solver
    84. # model的初始化操作通过自定义的MutableModule类实现,该类也是继承mxnet.module.base_modle.BaseModule
    85. # 这个基类进行重写,使得该类可以处理size不同的输入数据。
    86.     fixed_param_prefix = config.FIXED_PARAMS
    87.     data_names = [k[0] for k in train_data.provide_data]
    88.     label_names = [k[0] for k in train_data.provide_label]
    89.     mod = MutableModule(sym, data_names=data_names, label_names=label_names,
    90.                         logger=logger, context=ctx, work_load_list=args.work_load_list,
    91.                         max_data_shapes=max_data_shape, max_label_shapes=max_label_shape,
    92.                         fixed_param_prefix=fixed_param_prefix)
    93.     # decide training params
    94.     # metric
    95. # 定义评价标准,具体评价函数都是在~/mx-rcnn/rcnn/core/metric.py脚本中实现。
    96. # 多个评价函数通过mx.metric.CompositeEvalMetric()类来管理,可以通过该类实例化得到的对象的add方法添加评价函数类。
    97.     rpn_eval_metric = metric.RPNAccMetric()
    98.     rpn_cls_metric = metric.RPNLogLossMetric()
    99.     rpn_bbox_metric = metric.RPNL1LossMetric()
    100.     eval_metric = metric.RCNNAccMetric()
    101.     cls_metric = metric.RCNNLogLossMetric()
    102.     bbox_metric = metric.RCNNL1LossMetric()
    103.     eval_metrics = mx.metric.CompositeEvalMetric()
    104.     for child_metric in [rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric]:
    105.         eval_metrics.add(child_metric)
    106.     # callback
    107.     batch_end_callback = callback.Speedometer(train_data.batch_size, frequent=args.frequent)
    108.     means = np.tile(np.array(config.TRAIN.BBOX_MEANS), config.NUM_CLASSES)
    109.     stds = np.tile(np.array(config.TRAIN.BBOX_STDS), config.NUM_CLASSES)
    110.     epoch_end_callback = callback.do_checkpoint(prefix, means, stds)
    111.     # decide learning rate
    112. # 接下来这一块都是对学习率的设置,最核心的是通过mx.lr_scheduler.MultiFactorScheduler(lr_iters, lr_factor)
    113. # 接口来构造学习率对象,输入中lr_iters表示每隔多少个batch修改学习率,lr_factor是修改学习率时候变化的比例。
    114.     base_lr = lr
    115.     lr_factor = 0.1
    116.     lr_epoch = [int(epoch) for epoch in lr_step.split(",")]
    117.     lr_epoch_diff = [epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch]
    118.     lr = base_lr * (lr_factor ** (len(lr_epoch) - len(lr_epoch_diff)))
    119.     lr_iters = [int(epoch * len(roidb) / batch_size) for epoch in lr_epoch_diff]
    120.     logger.info("lr %f lr_epoch_diff %s lr_iters %s" % (lr, lr_epoch_diff, lr_iters))
    121.     lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(lr_iters, lr_factor)
    122.     # optimizer
    123. # 优化器的初始化
    124.     optimizer_params = {"momentum": 0.9,
    125.                         "wd": 0.0005,
    126.                         "learning_rate": lr,
    127.                         "lr_scheduler": lr_scheduler,
    128.                         "rescale_grad": (1.0 / batch_size),
    129.                         "clip_gradient": 5}
    130.     # train
    131. # 调用fit方法进行训练
    132.     mod.fit(train_data, eval_metric=eval_metrics, epoch_end_callback=epoch_end_callback,
    133.             batch_end_callback=batch_end_callback, kvstore=args.kvstore,
    134.             optimizer="sgd", optimizer_params=optimizer_params,
    135.             arg_params=arg_params, aux_params=aux_params, begin_epoch=begin_epoch, num_epoch=end_epoch)
    136. def parse_args():
    137.     parser = argparse.ArgumentParser(description="Train Faster R-CNN network")
    138.     # general
    139.     parser.add_argument("--network", help="network name", default=default.network, type=str)
    140.     parser.add_argument("--dataset", help="dataset name", default=default.dataset, type=str)
    141.     args, rest = parser.parse_known_args()
    142.     generate_config(args.network, args.dataset)
    143.     parser.add_argument("--image_set", help="image_set name", default=default.image_set, type=str)
    144.     parser.add_argument("--root_path", help="output data folder", default=default.root_path, type=str)
    145.     parser.add_argument("--dataset_path", help="dataset path", default=default.dataset_path, type=str)
    146.     # training
    147.     parser.add_argument("--frequent", help="frequency of logging", default=default.frequent, type=int)
    148.     parser.add_argument("--kvstore", help="the kv-store type", default=default.kvstore, type=str)
    149.     parser.add_argument("--work_load_list", help="work load for different devices", default=None, type=list)
    150.     parser.add_argument("--no_flip", help="disable flip images", action="store_true")
    151.     parser.add_argument("--no_shuffle", help="disable random shuffle", action="store_true")
    152.     parser.add_argument("--resume", help="continue training", action="store_true")
    153.     # e2e
    154.     parser.add_argument("--gpus", help="GPU device to train with", default="0", type=str)
    155.     parser.add_argument("--pretrained", help="pretrained model prefix", default=default.pretrained, type=str)
    156.     parser.add_argument("--pretrained_epoch", help="pretrained model epoch", default=default.pretrained_epoch, type=int)
    157.     parser.add_argument("--prefix", help="new model prefix", default=default.e2e_prefix, type=str)
    158.     parser.add_argument("--begin_epoch", help="begin epoch of training, use with resume", default=0, type=int)
    159.     parser.add_argument("--end_epoch", help="end epoch of training", default=default.e2e_epoch, type=int)
    160.     parser.add_argument("--lr", help="base learning rate", default=default.e2e_lr, type=float)
    161.     parser.add_argument("--lr_step", help="learning rate steps (in epoch)", default=default.e2e_lr_step, type=str)
    162.     args = parser.parse_args()
    163.     return args
    164. def main():
    165.     args = parse_args()
    166.     logger.info("Called with argument: %s" % args)
    167.     ctx = [mx.gpu(int(i)) for i in args.gpus.split(",")]
    168.     train_net(args, ctx, args.pretrained, args.pretrained_epoch, args.prefix, args.begin_epoch, args.end_epoch,
    169.               lr=args.lr, lr_step=args.lr_step)
    170. if __name__ == "__main__":
    171.     main()
    复制代码
    [/code]
    从宏观上了解了训练代码的架构后,接下来就要详细了解每一个模块的具体实现了,接下来一篇博客就来看看网络结构是怎么构造的吧:Faster RCNN代码详解(二):网络结构构建

    回复

    使用道具 举报

    您需要登录后才可以回帖 登录 | 立即注册

    本版积分规则

    QQ|手机版|Java学习者论坛 ( 声明:本站资料整理自互联网,用于Java学习者交流学习使用,对资料版权不负任何法律责任,若有侵权请及时联系客服屏蔽删除 )

    GMT+8, 2024-3-29 23:19 , Processed in 0.351775 second(s), 37 queries .

    Powered by Discuz! X3.4

    © 2001-2017 Comsenz Inc.

    快速回复 返回顶部 返回列表