Skip to content
This repository was archived by the owner on Nov 21, 2023. It is now read-only.

Train Faster R-CNN with different backbone #184

Closed
marinaKollmitz opened this issue Feb 22, 2018 · 12 comments
Closed

Train Faster R-CNN with different backbone #184

marinaKollmitz opened this issue Feb 22, 2018 · 12 comments

Comments

@marinaKollmitz
Copy link

marinaKollmitz commented Feb 22, 2018

Thank you very much for this great project! Could you please help me training Faster R-CNN with another backbone architecture, for example GoggleNet. I used the pickle_caffe_blobs.py script to convert the pretrained bvlc_googlenet.caffemodel to a .pkl model as mentioned at issue 77. But how can I train with it? Do I need to setup a .py file for the model like the VGG16.py in lib/modeling? Or is there a way to generate this file as well? Thank you for your help!

@youngwanLEE
Copy link

+1

@filipetrocadoferreira
Copy link

+2

@marinaKollmitz
Copy link
Author

marinaKollmitz commented Apr 11, 2018

I now implemented the conv body for the network in python and used the converted .pkl file as pretrained weights. To convert the .caffemodel weights to the pickle file, I only kept the conv body part of the network in the deploy.prototxt file. Here is my python implementation, in case anyone is interested. It is a smaller version of the GoogLeNet, called GoogLeNet-xxs, described here. The original network has more inception parts:

def add_googlenet_xxs_conv5_body(model):

    model.Conv('data', 'conv1-7x7_s2', 3, 64, 7, pad=3, stride=2)
    model.Relu('conv1-7x7_s2', 'conv1-7x7_s2')
    model.MaxPool('conv1-7x7_s2', 'pool1-3x3_s2', kernel=3, pad=0, stride=2)     
    model.LRN('pool1-3x3_s2', 'pool1-norm1', size=5, alpha=0.0001, beta=0.75)

    #stop gradient?
    #model.StopGradient('pool1-norm1', 'pool1-norm1')
    model.Conv('pool1-norm1', 'conv2-3x3_reduce', 64, 64, 1, pad=0, stride=1) #input size=?
    model.Relu('conv2-3x3_reduce', 'conv2-3x3_reduce')
    model.Conv('conv2-3x3_reduce', 'conv2-3x3', 64, 192, 3, pad=1, stride=1)
    model.Relu('conv2-3x3', 'conv2-3x3')
    model.LRN('conv2-3x3', 'conv2-norm2', size=5, alpha=0.0001, beta=0.75)
    model.MaxPool('conv2-norm2', 'pool2-3x3_s2', kernel=3, stride=2, pad=0)

    #inception 3a
    model.Conv('pool2-3x3_s2', 'inception_3a-1x1', 192, 64, 1, pad=0, stride=1)
    model.Relu('inception_3a-1x1', 'inception_3a-1x1')
    model.Conv('pool2-3x3_s2', 'inception_3a-3x3_reduce', 192, 96, 1, pad=0, stride=1)
    model.Relu('inception_3a-3x3_reduce', 'inception_3a-3x3_reduce')
    model.Conv('inception_3a-3x3_reduce', 'inception_3a-3x3', 96, 128, 3, pad=1, stride=1)
    model.Relu('inception_3a-3x3', 'inception_3a-3x3')
    model.Conv('pool2-3x3_s2', 'inception_3a-5x5_reduce', 192, 16, 1, pad=0, stride=1)
    model.Relu('inception_3a-5x5_reduce', 'inception_3a-5x5_reduce')
    model.Conv('inception_3a-5x5_reduce', 'inception_3a-5x5', 16, 32, 5, pad=2, stride=1)
    model.Relu('inception_3a-5x5', 'inception_3a-5x5')
    model.MaxPool('pool2-3x3_s2', 'inception_3a-pool', kernel=3, stride=1, pad=1)
    model.Conv('inception_3a-pool', 'inception_3a-pool_proj', 192, 32, 1, stride=1, pad=0)
    model.Relu('inception_3a-pool_proj', 'inception_3a-pool_proj')
    model.Concat(['inception_3a-1x1','inception_3a-3x3', 'inception_3a-5x5', 'inception_3a-pool_proj'], 'inception_3a-output')

    #inception 3b
    model.Conv('inception_3a-output', 'inception_3b-1x1', 256, 128, 1, pad=0, stride=1)
    model.Relu('inception_3b-1x1', 'inception_3b-1x1')
    model.Conv('inception_3a-output', 'inception_3b-3x3_reduce', 256, 128, 1, pad=0, stride=1)
    model.Relu('inception_3b-3x3_reduce', 'inception_3b-3x3_reduce')
    model.Conv('inception_3b-3x3_reduce', 'inception_3b-3x3', 128, 192, 3, pad=1, stride=1)
    model.Relu('inception_3b-3x3', 'inception_3b-3x3')
    model.Conv('inception_3a-output', 'inception_3b-5x5_reduce', 256, 32, 1, pad=0, stride=1)
    model.Relu('inception_3b-5x5_reduce', 'inception_3b-5x5_reduce')
    model.Conv('inception_3b-5x5_reduce', 'inception_3b-5x5', 32, 96, 5, pad=2, stride=1)
    model.Relu('inception_3b-5x5', 'inception_3b-5x5')
    model.MaxPool('inception_3a-output', 'inception_3b-pool', kernel=3, stride=1, pad=1)
    model.Conv('inception_3b-pool', 'inception_3b-pool_proj', 256, 64, 1, pad=0, stride=1)
    model.Relu('inception_3b-pool_proj', 'inception_3b-pool_proj')
    model.Concat(['inception_3b-1x1', 'inception_3b-3x3', 'inception_3b-5x5', 'inception_3b-pool_proj'], 'inception_3b-output')

    model.MaxPool('inception_3b-output', 'pool3-3x3_s2', kernel=3, stride=2, pad=0)

    #inception 4a
    model.Conv('pool3-3x3_s2', 'inception_4a-1x1', 480, 192, 1, pad=0, stride=1)
    model.Relu('inception_4a-1x1', 'inception_4a-1x1')
    model.Conv('pool3-3x3_s2', 'inception_4a-3x3_reduce', 480, 96, 1, pad=0, stride=1)
    model.Relu('inception_4a-3x3_reduce', 'inception_4a-3x3_reduce')
    model.Conv('inception_4a-3x3_reduce', 'inception_4a-3x3', 96, 208, 3, pad=1, stride=1)
    model.Relu('inception_4a-3x3', 'inception_4a-3x3')
    model.Conv('pool3-3x3_s2', 'inception_4a-5x5_reduce', 480, 16, 1, pad=0, stride=1)
    model.Relu('inception_4a-5x5_reduce', 'inception_4a-5x5_reduce')
    model.Conv('inception_4a-5x5_reduce', 'inception_4a-5x5', 16, 48, 5, pad=2, stride=1)
    model.Relu('inception_4a-5x5', 'inception_4a-5x5')
    model.MaxPool('pool3-3x3_s2', 'inception_4a-pool', kernel=3, stride=1, pad=1)
    model.Conv('inception_4a-pool', 'inception_4a-pool_proj', 480, 64, 1, pool=0, stride=1)
    model.Relu('inception_4a-pool_proj', 'inception_4a-pool_proj')
    blob_out = model.Concat(['inception_4a-1x1', 'inception_4a-3x3', 'inception_4a-5x5', 'inception_4a-pool_proj'], 'inception_4a-output')

    return blob_out, 512, 1. / 16.

I am still curious if there is a way to load a network without implementing the network structure in python. Since issue #339 targets this question a little more in detail, I will close this issue here.

@filipetrocadoferreira
Copy link

where can we get the weights?

@marinaKollmitz
Copy link
Author

Short answer: You can download my generated pretrained weights for the GoogLeNet-xxs here

Longer answer, how I generated the file: In general, what I did is download the .caffemodel and deploy.prototxt from the caffe model zoo here. Then I deleted everything from the deploy.txt except the convolution part I want to train with. Next I fed it to the detectron/tools/pickle_caffe_blobs.py script.

Some details: I had to rename the layers in the implementation, because detectron strips off the layer names before the last / when saving the blobs to model_final.pkl. I replaced all "/"s with "-"s in the layer names. As a consequence, I also had to rename the blobs in the pretrained.pkl file. I changed one part in the pickle_caffe_blobs.py script to rename the blobs when saving to file:

def normalize_googlenet_name(name):
    name = name.replace('/', '-')
    return name

def pickle_weights(out_file_name, weights):
    blobs = {
        normalize_googlenet_name(blob.name): utils.Caffe2TensorToNumpyArray(blob)
        for blob in weights.protos
    }
    with open(out_file_name, 'w') as f:
        pickle.dump(blobs, f, protocol=pickle.HIGHEST_PROTOCOL)
    print('Wrote blobs:')
    print(sorted(blobs.keys()))

I will upload everything to my fork soon, in the meantime, I hope it helps.

@filipetrocadoferreira
Copy link

filipetrocadoferreira commented Apr 11, 2018

Thank you so much. Good Work. I'll give a try.

Edit: But it would be nice to use directly net definition from '.pb' files to retrain/finetune. This question is general to caffe2 (not detectron)

@rohitbhio
Copy link

Hi @marinaKollmitz ,

Awesome work. Thanks for defining CONV body method. Have you also defined a similar method for Fully connected layers?

Does the method defining fully connected layer scheme has similar structure to that of VGG?

Thanks again

@marinaKollmitz
Copy link
Author

I just copied the VGG fully connected layers.

@rohitbhio
Copy link

@marinaKollmitz : Thanks for confirmation. I did use VGG FC layers.

@srikanth-kilaru
Copy link

Problem solved by following recommendation at #32 to USE_NCCL True

@srikanth-kilaru
Copy link

After training the model, I get the following errors during test.
command:
python2 tools/test_net.py --cfg configs/getting_started/ml349_2gpu_e2e_faster_rcnn
Inception_ResNetv2.yaml --multi-gpu-testing TEST.WEIGHTS /tmp/detectron-output/train/coco_2014_train/generalized
rcnn/model_final.pkl NUM_GPUS 2

Errors

==================
INFO net.py: 88: conv1-7x7_s2_w not found
INFO net.py: 88: conv1-7x7_s2_b not found
INFO net.py: 88: conv2-3x3_reduce_w not found
INFO net.py: 88: conv2-3x3_reduce_b not found
INFO net.py: 88: conv2-3x3_w not found
INFO net.py: 88: conv2-3x3_b not found
INFO net.py: 88: inception_3a-1x1_w not found
INFO net.py: 88: inception_3a-1x1_b not found
INFO net.py: 88: inception_3a-3x3_reduce_w not found
INFO net.py: 88: inception_3a-3x3_reduce_b not found
INFO net.py: 88: inception_3a-3x3_w not found
INFO net.py: 88: inception_3a-3x3_b not found
INFO net.py: 88: inception_3a-5x5_reduce_w not found
INFO net.py: 88: inception_3a-5x5_reduce_b not found
INFO net.py: 88: inception_3a-5x5_w not found
INFO net.py: 88: inception_3a-5x5_b not found
INFO net.py: 88: inception_3a-pool_proj_w not found
INFO net.py: 88: inception_3a-pool_proj_b not found
INFO net.py: 88: inception_3b-1x1_w not found
INFO net.py: 88: inception_3b-1x1_b not found
INFO net.py: 88: inception_3b-3x3_reduce_w not found
INFO net.py: 88: inception_3b-3x3_reduce_b not found
INFO net.py: 88: inception_3b-3x3_w not found
INFO net.py: 88: inception_3b-3x3_b not found
INFO net.py: 88: inception_3b-5x5_reduce_w not found
INFO net.py: 88: inception_3b-5x5_reduce_b not found
INFO net.py: 88: inception_3b-5x5_w not found
INFO net.py: 88: inception_3b-5x5_b not found
INFO net.py: 88: inception_3b-pool_proj_w not found
INFO net.py: 88: inception_3b-pool_proj_b not found
INFO net.py: 88: inception_4a-1x1_w not found
INFO net.py: 88: inception_4a-1x1_b not found
INFO net.py: 88: inception_4a-3x3_reduce_w not found
INFO net.py: 88: inception_4a-3x3_reduce_b not found
INFO net.py: 88: inception_4a-3x3_w not found
INFO net.py: 88: inception_4a-3x3_b not found
INFO net.py: 88: inception_4a-5x5_reduce_w not found
INFO net.py: 88: inception_4a-5x5_reduce_b not found
INFO net.py: 88: inception_4a-5x5_w not found
INFO net.py: 88: inception_4a-5x5_b not found
INFO net.py: 88: inception_4a-pool_proj_w not found
INFO net.py: 88: inception_4a-pool_proj_b not found
INFO net.py: 88: conv_rpn_w not found
INFO net.py: 88: conv_rpn_b not found
INFO net.py: 88: rpn_cls_logits_w not found
INFO net.py: 88: rpn_cls_logits_b not found
INFO net.py: 88: rpn_bbox_pred_w not found
INFO net.py: 88: rpn_bbox_pred_b not found
INFO net.py: 88: head_conv1_w not found
INFO net.py: 88: head_conv1_gn_s not found
INFO net.py: 88: head_conv1_gn_b not found
INFO net.py: 88: head_conv2_w not found
INFO net.py: 88: head_conv2_gn_s not found
INFO net.py: 88: head_conv2_gn_b not found
INFO net.py: 88: head_conv3_w not found
INFO net.py: 88: head_conv3_gn_s not found
INFO net.py: 88: head_conv3_gn_b not found
INFO net.py: 88: head_conv4_w not found
INFO net.py: 88: head_conv4_gn_s not found
INFO net.py: 88: head_conv4_gn_b not found
INFO net.py: 88: _mask_fcn1_w not found
INFO net.py: 88: _mask_fcn1_gn_s not found
INFO net.py: 88: _mask_fcn1_gn_b not found
INFO net.py: 88: _mask_fcn2_w not found
INFO net.py: 88: _mask_fcn2_gn_s not found
INFO net.py: 88: _mask_fcn2_gn_b not found
INFO net.py: 88: _mask_fcn3_w not found
INFO net.py: 88: _mask_fcn3_gn_s not found
INFO net.py: 88: _mask_fcn3_gn_b not found
INFO net.py: 88: _mask_fcn4_w not found
INFO net.py: 88: _mask_fcn4_gn_s not found
INFO net.py: 88: _mask_fcn4_gn_b not found
INFO net.py: 88: conv5_mask_w not found
INFO net.py: 88: conv5_mask_b not found
INFO net.py: 88: mask_fcn_logits_w not found
INFO net.py: 88: mask_fcn_logits_b not found
I0529 11:23:29.991972 25591 net_dag_utils.cc:102] Operator graph pruning prior to chain compute took: 5.1028e-05 secs
I0529 11:23:29.992153 25591 net_dag.cc:46] Number of parallel execution chains 25 Number of operators = 78
I0529 11:23:29.995748 25591 net_dag_utils.cc:102] Operator graph pruning prior to chain compute took: 3.4177e-05 secs
I0529 11:23:29.995859 25591 net_dag.cc:46] Number of parallel execution chains 18 Number of operators = 53
I0529 11:23:29.997187 25591 net_dag_utils.cc:102] Operator graph pruning prior to chain compute took: 1.3739e-05 secs
I0529 11:23:29.997264 25591 net_dag.cc:46] Number of parallel execution chains 1 Number of operators = 17
E0529 11:23:54.638962 26563 net_dag.cc:195] Exception from operator chain starting at '' (type 'Conv'): caffe2::EnforceNotMet: [enforce fail at blob.h:84] IsType(). wrong type for the Blob instance. Blob contains nullptr (uninitialized) while caller expects caffe2::Tensorcaffe2::CUDAContext .
Offending Blob name: gpu_0/conv1-7x7_s2_w.
Error from operator:
input: "gpu_0/data" input: "gpu_0/conv1-7x7_s2_w" input: "gpu_0/conv1-7x7_s2_b" output: "gpu_0/conv1-7x7_s2" name: "" type: "Conv" arg { name: "kernel" i: 7 } arg { name: "exhaustive_search" i: 0 } arg { name: "pad" i: 3 } arg { name: "order" s: "NCHW" } arg { name: "stride" i: 2 } device_option { device_type: 1 cuda_gpu_id: 0 } engine: "CUDNN"
WARNING workspace.py: 185: Original python traceback for operator 0 in network generalized_rcnn in exception above (most recent call last):
WARNING workspace.py: 190: File "/home/srikilaru/detectron/tools/test_net.py", line 116, in
WARNING workspace.py: 190: File "/home/srikilaru/detectron/detectron/core/test_engine.py", line 128, in run_inference
WARNING workspace.py: 190: File "/home/srikilaru/detectron/detectron/core/test_engine.py", line 125, in result_getter
WARNING workspace.py: 190: File "/home/srikilaru/detectron/detectron/core/test_engine.py", line 235, in test_net
WARNING workspace.py: 190: File "/home/srikilaru/detectron/detectron/core/test_engine.py", line 328, in initialize_model_from_cfg
WARNING workspace.py: 190: File "/home/srikilaru/detectron/detectron/modeling/model_builder.py", line 124, in create
WARNING workspace.py: 190: File "/home/srikilaru/detectron/detectron/modeling/model_builder.py", line 89, in generalized_rcnn
WARNING workspace.py: 190: File "/home/srikilaru/detectron/detectron/modeling/model_builder.py", line 230, in build_generic_detection_model
WARNING workspace.py: 190: File "/home/srikilaru/detectron/detectron/modeling/optimizer.py", line 54, in build_data_parallel_model
WARNING workspace.py: 190: File "/home/srikilaru/detectron/detectron/modeling/model_builder.py", line 170, in _single_gpu_build_func
WARNING workspace.py: 190: File "/home/srikilaru/detectron/detectron/modeling/Inception_ResNetv2.py", line 27, in add_inception_resnetv2_xxs_conv5_body
WARNING workspace.py: 190: File "/home/srikilaru/pytorch/build/caffe2/python/cnn.py", line 97, in Conv
WARNING workspace.py: 190: File "/home/srikilaru/pytorch/build/caffe2/python/brew.py", line 107, in scope_wrapper
WARNING workspace.py: 190: File "/home/srikilaru/pytorch/build/caffe2/python/helpers/conv.py", line 186, in conv
WARNING workspace.py: 190: File "/home/srikilaru/pytorch/build/caffe2/python/helpers/conv.py", line 139, in _Conv
INFO net.py: 88: inception_3b-pool_proj_b not found
Base
Traceback (most recent call last):
File "/home/srikilaru/detectron/tools/test_net.py", line 116, in
check_expected_results=True,
File "/home/srikilaru/detectron/detectron/core/test_engine.py", line 128, in run_inference
all_results = result_getter()
File "/home/srikilaru/detectron/detectron/core/test_engine.py", line 125, in result_getter
gpu_id=gpu_id
File "/home/srikilaru/detectron/detectron/core/test_engine.py", line 258, in test_net
model, im, box_proposals, timers
File "/home/srikilaru/detectron/detectron/core/test.py", line 66, in im_detect_all
model, im, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE, boxes=box_proposals
File "/home/srikilaru/detectron/detectron/core/test.py", line 158, in im_detect_bbox
workspace.RunNet(model.net.Proto().name)
File "/home/srikilaru/pytorch/build/caffe2/python/workspace.py", line 217, in RunNet
StringifyNetName(name), num_iter, allow_fail,
File "/home/srikilaru/pytorch/build/caffe2/python/workspace.py", line 178, in CallWithExceptionIntercept
return func(*args, **kwargs)
RuntimeError: [enforce fail at blob.h:84] IsType(). wrong type for the Blob instance. Blob contains nullptr (uni
nitialized) while caller expects caffe2::Tensorcaffe2::CUDAContext .
Offending Blob name: gpu_0/conv1-7x7_s2_w.
Error from operator:
input: "gpu_0/data" input: "gpu_0/conv1-7x7_s2_w" input: "gpu_0/conv1-7x7_s2_b" output: "gpu_0/conv1-7x7_s2" name:
"" type: "Conv" arg { name: "kernel" i: 7 } arg { name: "exhaustive_search" i: 0 } arg { name: "pad" i: 3 } arg { n
ame: "order" s: "NCHW" } arg { name: "stride" i: 2 } device_option { device_type: 1 cuda_gpu_id: 0 } engine: "CUDNN
"
Traceback (most recent call last):
File "tools/test_net.py", line 116, in
check_expected_results=True,
File "/home/srikilaru/detectron/detectron/core/test_engine.py", line 128, in run_inference
all_results = result_getter()
File "/home/srikilaru/detectron/detectron/core/test_engine.py", line 108, in result_getter
multi_gpu=multi_gpu_testing
File "/home/srikilaru/detectron/detectron/core/test_engine.py", line 155, in test_net_on_dataset
weights_file, dataset_name, proposal_file, num_images, output_dir
File "/home/srikilaru/detectron/detectron/core/test_engine.py", line 188, in multi_gpu_test_net_on_dataset
'detection', num_images, binary, output_dir, opts
File "/home/srikilaru/detectron/detectron/utils/subprocess.py", line 95, in process_in_parallel
log_subprocess_output(i, p, output_dir, tag, start, end)
File "/home/srikilaru/detectron/detectron/utils/subprocess.py", line 133, in log_subprocess_output
assert ret == 0, 'Range subprocess failed (exit code: {})'.format(ret)
AssertionError: Range subprocess failed (exit code: 1)

@johannathiemich
Copy link

Could you please explain how I can make use of that .pkl file you provided?
Just loading the .pkl file without changing anything in the config file does not seem to work.

I tried adding the code from your post (add_googlenet_xxs_conv5_body(model)) to the FPN.py file and referenced the method in the corresponding config file when using your provided .pkl file.
But I got an error saying that the blob returned by the add_google_net method does not have attribute "len" therefore an assertion failed. It seems to expect a list of tensors but since the last layer is a Concat layer, only one tensor is returned I think.

So how did you make all of this work? Which changes do I have to make?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants
@marinaKollmitz @rohitbhio @youngwanLEE @filipetrocadoferreira @johannathiemich @srikanth-kilaru and others