Skip to content
This repository was archived by the owner on Jul 2, 2021. It is now read-only.

add FCIS ResNet101 #568

Merged
merged 33 commits into from
Apr 25, 2018
Merged

add FCIS ResNet101 #568

merged 33 commits into from
Apr 25, 2018

Conversation

knorth55
Copy link
Contributor

@knorth55 knorth55 commented Apr 13, 2018

Merge after #560

  • add mask_voting
  • add FCIS
  • add FCISResNet101
  • add tests
  • add docs
cd chainercv/examples/fcis
wget https://raw.githubusercontent.com/knorth55/chainer-fcis/master/examples/voc/images/SBD_test2008_000090.jpg
python demo.py SBD_test2008_000090.jpg --gpu 0

fcis_voc_vis

@knorth55 knorth55 self-assigned this Apr 13, 2018
@knorth55 knorth55 force-pushed the experimental-fcis branch from 0232a41 to 0bb0cef Compare April 13, 2018 17:23
@knorth55 knorth55 changed the title [WIP] add FCIS ResNet101 add FCIS ResNet101 Apr 13, 2018
@knorth55 knorth55 force-pushed the experimental-fcis branch from 0bb0cef to bb7e2d4 Compare April 13, 2018 17:47
@knorth55 knorth55 changed the title add FCIS ResNet101 [WIP] add FCIS ResNet101 Apr 14, 2018

class ResNet101Extractor(chainer.Chain):

def __init__(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you change the feature extractor to take initialW as an argument, and pass the zero-initializer when a pretrained model is specified?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK


def __call__(self, x):
with chainer.using_config('train', False):
with chainer.function.no_backprop_mode():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer setting related backpropagation in the training code.
Can you remove with chainer.function.no_backprop_mode from here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, in that case, I cannot reuse this model in training code...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use disable_update to not update weights.

Copy link
Contributor Author

@knorth55 knorth55 Apr 15, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, you prefer setting disable_update() by hand in training code?
OK, i will do that.
In fact, no_backprop_mode and disable_update are confusing me many times.

self.psroi_conv2 = L.Convolution2D(
1024, group_size * group_size * n_class * 2,
1, 1, 0, initialW=initialW)
self.psroi_conv3 = L.Convolution2D(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about changing the name of the links to conv1, cls_seg and ag_loc?
This naming convention looks as if three convolutions are applied sequentially to the input.

Alternatively, psroi_conv1, psroi_cls_seg, psroi_ag_loc is OK (I prefer the first choice for simplicity).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

def __call__(self, x, rois, roi_indices, img_size):
h = F.relu(self.psroi_conv1(x))
h_cls_seg = self.psroi_conv2(h)
h_ag_locs = self.psroi_conv3(h)
Copy link
Member

@yuyu2172 yuyu2172 Apr 15, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about h_ag_loc instead of h_ag_locs.
We usually do not put s with variables starting with h_ (e.g.., not hs).
Also, this looks more consistent with h_cls_seg, which you did not put s.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

initialW = chainer.initializers.Normal(0.01)
with self.init_scope():
self.psroi_conv1 = L.Convolution2D(
2048, 1024, 1, 1, 0, initialW=initialW)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using Conv2DActiv?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm, i'm don't see any merits of using Conv2DActiv...
It is easy to use default one for me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I think this is just a matter of preference. Please leave this as is.

roi_ag_locs = roi_ag_locs.reshape((n_roi, 2, 4))

# Mask Regression
# shape: (n_rois, n_class, 2, roi_size, roi_size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

# shape: (n_rois, n_class, 2, roi_size, roi_size)
# Group Pick by Score
max_cls_indices = roi_cls_scores.array.argmax(axis=1)
# shape: (n_rois, 2, roi_size, roi_size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto



def _global_average_pooling_2d(x):
n_rois, n_channel, H, W = x.array.shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

# Group Max
# shape: (n_rois, n_class, roi_size, roi_size)
h_cls = pool_cls_seg.transpose((0, 1, 3, 4, 2))
h_cls = F.max(h_cls, axis=4)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remvoe transpose by using the right axis when taking maximum.

max_cls_indices = roi_cls_scores.array.argmax(axis=1)
# shape: (n_rois, 2, roi_size, roi_size)
roi_seg_scores = pool_cls_seg[
np.arange(len(max_cls_indices)), max_cls_indices]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use self.xp.arange.
This is faster because np.arange need to transfer data from CPU to GPU.

@knorth55 knorth55 force-pushed the experimental-fcis branch from 017b6b5 to a860287 Compare April 15, 2018 08:04
self, h_cls_seg, h_ag_locs, rois, roi_indices):
# PSROI Pooling
# shape: (n_rois, n_class*2, roi_size, roi_size)
pool_cls_seg = psroi_pooling_2d(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I suggest to use roi_seg_cls_scores instead of pool_cls_seg?

roi_seg_scores is (n_roi, 2, R, R) and roi_cls_scores is(n_roi, n_class).
For a variable with shape (n_roi, n_class, 2, R, R), I think roi_seg_cls_scores is the appropriate name.
Basically, _cls_ indicates that we have extra dimension that corresponds to categories.

pool_cls_seg = pool_cls_seg.reshape(
(-1, self.n_class, 2, self.roi_size, self.roi_size))
# shape: (n_rois, 2*4, roi_size, roi_size)
pool_ag_locs = psroi_pooling_2d(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I suggest to name (n_roi, 2 * 4, R, R) as roi_seg_ag_locs instead of pool_ag_locs?
This seems like an consistent usage of _seg_, which is used for all variables with (n_roi, *, R, R).


_models = {
'sbd': {
'n_fg_class': 20,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is setting the name of the pretrained weight voc inappropriate?

Copy link
Contributor Author

@knorth55 knorth55 Apr 15, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VOC images with SBD annotations, so I prefer SBD.
Also, the model is trained with SBD training image list, which differs from VOC train image list.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

'sbd': {
'n_fg_class': 20,
'url': 'https://github.com/yuyu2172/share-weights/releases/'
'download/0.0.6/fcis_resnet101_2018_04_14.npz'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you name the weight in the following protocol?
fcis_resnet101_DATASETNAME_trained_DATE.npz?
(e.g., faster_rcnn_vgg16_voc07_trained_2017_08_06.npz).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, i will update

@knorth55 knorth55 force-pushed the experimental-fcis branch 3 times, most recently from e8b2522 to 5899fdd Compare April 15, 2018 13:54
@knorth55 knorth55 changed the title [WIP] add FCIS ResNet101 add FCIS ResNet101 Apr 15, 2018
@yuyu2172
Copy link
Member

You haven't followed the comments?
May I ask you to tell me why?

#568 (comment)
I asked roi_seg_ag_locs (in your code roi_ag_loc_scores)

#568 (comment)
I asked roi_seg_cls_scores (in your code roi_cls_seg_scores)



def mask_voting(
mask_score, bbox, score, size,
Copy link
Member

@yuyu2172 yuyu2172 Apr 22, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have two comments on the names of variables.

First, if the range of the value is [0, 1], use _prob instead of _score.

Second, the name mask_score can be misunderstood as an array with shape (R, H, W). How about naming a variable (R, RH, RW) as roi_mask_prob. (depending on the range of its values). This name is consistent with names in fcis.py.
Also, how about renaming bbox to roi so that it is clear that roi_mask_prob is defined in the region of roi.
When iterating over roi, either for bb in roi: or for r in roi is OK.
In addition to that, score can be renamed to roi_prob to make its shape more precise (this modification is only optional).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You actually do use something similar to what I suggested in test_mask_voting.py.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First, if the range of the value is [0, 1], use _prob instead of _score.

I look in some chainercv codes, and score and prob are mixed.
For example, object_detection returns scores, which is probability.
In order to clean up, we first set the rule first.

Second, the name mask_score can be misunderstood as an array with shape (R, H, W). How about naming a variable (R, RH, RW) as roi_mask_prob. (depending on the range of its values). This name is consistent with names in fcis.py.
Also, how about renaming bbox to roi so that it is clear that roi_mask_prob is defined in the region of roi.

OK, i will fix it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I look in some chainercv codes, and score and prob are mixed.
For example, object_detection returns scores, which is probability.
In order to clean up, we first set the rule first.

I missed this comment. sorry.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We used scores because we wanted to relax the interface of detection models.
Hmm... I totally forgot this...

You can leave this part in the way you like for now. We can later discuss the naming conventions. I personally think prob is a lot more readable (that is probably why Faster R-CNN sometimes uses *_prob).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I totally agree that prob is easier to understand.
But, for API, score is consistent to object detection API, so I will use scores :(.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can consistently use prob up until the end of prediction.
In the end, you do scores.append(prob).

Copy link
Contributor Author

@knorth55 knorth55 Apr 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, that sounds good

cls_score.append(score_l)

sorted_score = np.sort(np.concatenate(cls_score))[::-1]
keep_n = min(len(sorted_score), limit)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n_keep


roi_seg_score = roi_seg_scores.array
roi_score = roi_scores.array
bbox = rois / scale
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about not introducing bbox.
I mean use roi = rois / scale.

The rationale to prefer roi is explained in the second suggestion of #568 (comment).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

roi and bbox is complicated, too.
In this case, rois = bboxes so it doesn't matter.
What's wrong with it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By definition, rois and bboxes are different (bboxes is a list of bbox).
But, as you said, in the case when batchsize=1, rois = bboxes.

Having said that, I think that continuing to use rois in predict is better because it explicitly indicates that roi_mask_prob is defined relative to rois.
This is the same reason why we use rois elsewhere, such as __call__.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the difference of bboxes and rois.
In my understanding, bboxes and rois have same shape (R, 4).
In FCIS, we don't use roi_ag_locs (because it is already applied in __call__ with iter2=True) so that roi is same as bbox.
For the return output, the predict should return bbox as output not roi, so that I rename here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. bboxes is a list of (R, 4), whereas rois is (R', 4).

You mean bbox is after non maximum suppression and other processing?
In that case, this rois are already bboxes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

roi_mask are complete different from other roi_*.
What you want is removing other roi_* variables. Maybe i understand.
In that case, the rule below sounds good to me

roi_mask -> roi_mask (no change)
rois -> bboxes
roi_seg_scores -> seg_scores
roi_scores -> scores

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I was saying is this.

roi_mask -> roimask
rois -> bboxes
roi_seg_scores -> seg_scores
roi_scores -> scores

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean bbox is after non maximum suppression and other processing?
In that case, this rois are already bboxes.

No no. The types are different.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No no. The types are different.

Types are different? Can you tell me more detailed information?

roi_mask -> roimask

And I prefer roi_mask to roimask because it is easy to read.

bbox[:, 0::2] = self.xp.clip(bbox[:, 0::2], 0, size[0])
bbox[:, 1::2] = self.xp.clip(bbox[:, 1::2], 0, size[1])

roi_seg_prob = F.softmax(roi_seg_score).array
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about roi_seg_prob = F.softmax(roi_seg_score).array[:, 1, :, :] so that the naming convention is consistent with mask_voting.py.

roi_prob = chainer.cuda.to_cpu(roi_prob)
bbox = chainer.cuda.to_cpu(bbox)

roi_mask_score, bbox, label, score = mask_voting(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the returned mask a probability?
In that case, can you rename roi_mask_score to roi_mask_prob.

# concat 1st and 2nd iteration results
rois = self.xp.concatenate((rois, rois2))
roi_indices = self.xp.concatenate((roi_indices, roi_indices))
roi_seg_scores = F.concat(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have any particular reason to prefer using *_seg_* for names?
I think it is the same as *_mask_*, so how about using only one convention.
Since mask is used for the output, how about choosing *_mask_*.

roi_ag_locs = roi_ag_locs.array
mean = self.xp.array(self.loc_normalize_mean)
std = self.xp.array(self.loc_normalize_std)
roi_loc = roi_ag_locs[:, 1, :]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

roi_loc --> roi_locs

scale = img_var.shape[3] / size[1]
roi_seg_scores, _, roi_scores, rois, _ = \
self.__call__(img_var, scale)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yuyu2172
yuyu2172 previously approved these changes Apr 22, 2018
@yuyu2172 yuyu2172 dismissed their stale review April 22, 2018 10:00

mistake

Copy link
Member

@yuyu2172 yuyu2172 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comments on the documentation.

2. **Region Proposal Networks**: Given the feature maps calculated in \
the previous stage, produce set of RoIs around objects.
3. **Localization, Segmentation and Classification Heads**: Using feature \
maps that belong to the proposed RoIs, segment region of the objects, \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

segment region --> segment regions

:class:`chainer.Chain` objects :obj:`feature`, :obj:`rpn` and :obj:`head`.
There are two functions :meth:`predict` and :meth:`__call__` to conduct
instance segmentation.
:meth:`predict` takes images and returns masks, object label and its score.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

object labels and their scores

Please refer to the documentation found there.
head (callable Chain): A callable that takes a BCHW array,
RoIs and batch indices for RoIs.
This returns class dependent segmentation scores, class-agnostic
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class agnostic segmentation scores ?

indices for RoIs.
mean (numpy.ndarray): A value to be subtracted from an image
in :meth:`prepare`.
min_size (int): A preprocessing paramter for :meth:`prepare`. Please
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parameter

in :meth:`prepare`.
min_size (int): A preprocessing paramter for :meth:`prepare`. Please
refer to a docstring found for :meth:`prepare`.
max_size (int): A preprocessing paramter for :meth:`prepare`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto


Args:
mask_scores (array): A mask score array whose shape is
:math:`(R, RH, RW)`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forgetting the second argument

* :math:`RW` is the height of pooled image.

Args:
mask_scores (array): A mask score array whose shape is
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mask_scores --> mask_score

Args:
mask_scores (array): A mask score array whose shape is
:math:`(R, RH, RW)`.
scores (array): A class score array whose shape is
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scores --> score

for mask merging.
binary_thresh (float): A threshold value of mask score
for mask merging.
limit (int): A limit number of outputs.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A limit number of outputs --> The maximum number of outputs.

binary_thresh (float): A threshold value of mask score
for mask merging.
limit (int): A limit number of outputs.
bg_label (int): A background label.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The id of the background label.



class TestFCISResNet101Pretrained(unittest.TestCase):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add @attr.disk like other dataset tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is because they don't use pretrained models.

Copy link
Contributor Author

@knorth55 knorth55 Apr 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SSD uses pretrained model!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. That is a mistake.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, i will update

def test_pretrained(self):
link = FCISResNet101(pretrained_model='sbd')
self.assertIsInstance(link, FCIS)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@knorth55
Copy link
Contributor Author

knorth55 commented Apr 22, 2018

You haven't followed the comments?
May I ask you to tell me why?

Because I found better solution.

#568 (comment)
I asked roi_seg_ag_locs (in your code roi_ag_loc_scores)

it is nonsense to add _seg_ in not segmentation variables.
cls_seg and loc is different.

#568 (comment)
I asked roi_seg_cls_scores (in your code roi_cls_seg_scores)

h_cls_seg -> roi_cls_seg_scores -> roi_cls_scores, roi_seg_scores
It is easier to understand.

@yuyu2172
Copy link
Member

yuyu2172 commented Apr 22, 2018

it is nonsense to add seg in not segmentation variables.

This variable has shape (n_roi, 2 *4, roi_size, roi_size), so the shape is mask. That is why adding _seg_ or _mask_ is needed. See #568 (comment)
roi_ag_mask_locs (R, 2*4,RH,RW) --(pooling)--> roi_ag_locs (R, 2*4) looks natural to me.

Also, this is not scores, but locs right?
Or, am I missing something?
https://github.com/knorth55/chainercv/blob/5899fddc6cef8e28aacfd08b8294ccf52f4cf801/chainercv/experimental/links/model/fcis/fcis_resnet101.py#L333

h_cls_seg -> roi_cls_seg_scores -> roi_cls_scores, roi_seg_scores

OK. I am fine with how you order cls and seg (or mask), but be consistent.

@knorth55
Copy link
Contributor Author

knorth55 commented Apr 23, 2018

This variable has shape (n_roi, 2 *4, roi_size, roi_size), so the shape is mask. That is why adding seg or mask is needed. See #568 (comment)

i still cannot understand why we need mask or seg when shape is (R, L, H, W).
seg and maskis not shape symbol but semantic symbol and have different meanings, so i cannot use in this case.
in my opinion, shape does not matter because it will be pooled to (R, 2, 4) roi_ag_locs.

Also, this is not scores, but locs right?
Or, am I missing something?

This variable is a kind of scores.
Also, roi_ag_locs is calculated from the variable, so roi_ag_loc_scores sounds nice to me.

OK. I am fine with how you order cls and seg (or mask), but be consistent.

Maybe you misunderstand something, but cls, seg and locis what the model predict.
cls_seg is the combined output for prediction of cls and seg.
In addition, I define seg is multi-class segmentation result and mask is class-agnostic one.

@yuyu2172
Copy link
Member

yuyu2172 commented Apr 23, 2018

This variable is a kind of scores.

OK. Maybe I am misunderstanding something.
The naming convention inside _pool is fine right now.
We can merge this for now because it is only an implementation detail.
I would appreciate if you explain this to me when we meet face to face.

Having said that, what do you think about stop using _seg_ and always use _mask_? I don't see any difference between the two, and using both of them is confusing.

@knorth55
Copy link
Contributor Author

knorth55 commented Apr 23, 2018

Having said that, what do you think about stop using seg and always use mask? I don't see any difference between the two, and using both of them is confusing.

_seg_ is multi-class including background and _mask_ is class-agnostic, background and foreground only.
In FCIS, _seg_ is quite important and unique point, and it propose to use not (R, L, H, W) but (R, L, 2, H, W) as segmentation.
In detailed, the paper proposes to predict foreground and background for all semantic labels.
That is why I use both (R, 2, H, W) for mask and (R, L, 2, H, W) for seg.
Maybe, in some part, my code is not consistent so that you got confused.

@knorth55
Copy link
Contributor Author

knorth55 commented Apr 23, 2018

OK. Maybe I am misunderstanding something.

At this point, I want to clarify the difference of score and prob.
I use score as network raw output prob as [0,1] stochastic values.
Is is correct?
I got confused, and for me, it looks scores in predict and score in __call__ are different in others such as object detection.

@knorth55
Copy link
Contributor Author

And also, another question.
In recent PR, chainer_experimental is added in chainercv.
FCIS should be in chainer_experimental, chainercv_experimental or experimental dir?
Which one do you prefer?

@yuyu2172
Copy link
Member

FCIS should be in chainer_experimental, chainercv_experimental or experimental dir?

experimental

@knorth55 knorth55 force-pushed the experimental-fcis branch 2 times, most recently from 11af374 to 0db468d Compare April 23, 2018 10:22
@knorth55
Copy link
Contributor Author

I updated Variable names as below.

roi_mask -> roi_mask (no change)
roi_seg_scores -> seg_scores
roi_seg_prob -> seg_prob
roi_mask_scores -> mask_scores
roi_cls_scores -> cls_scores
roi_cls_prob -> cls_prob

@knorth55
Copy link
Contributor Author

@yuyu2172
By the way, do we really need @attr.disk for pretrained model?
I read the PR #493 , and it is only added in Dataset test not pretrained model.

@knorth55 knorth55 force-pushed the experimental-fcis branch from 0db468d to d849730 Compare April 23, 2018 11:39
@Hakuyume
Copy link
Member

@knorth55 I think we don't need attr.disk. attr.disk is for tests that require a huge disk space.

@yuyu2172
Copy link
Member

I consider pretrained weight to require a huge disk space (but the notion of huge disk is subjective).
I missed SSD test in #493.

@yuyu2172
Copy link
Member

#568 (comment)

Types are different? Can you tell me more detailed information?

Assume that we have two images, and two sets of bounding boxes are assigned to them.
For the first image, we have bbox of shape (R1, 4), and for the second image, we have bbox of shape (R2, 4).
bboxes would be [(R1, 4), (R2, 4)], but rois would be (R1 + R2, 4).

@yuyu2172
Copy link
Member

roi_mask -> roi_mask (no change)
roi_seg_scores -> seg_scores
roi_seg_prob -> seg_prob
roi_mask_scores -> mask_scores
roi_cls_scores -> cls_scores
roi_cls_prob -> cls_prob

roi_* contained quite important information in the case when the batch size is more than one.
For instance, with two images, roi_cls_probs is (R1 + R2, L, 4), but cls_probs would be (2, L, 4).
Since __call__ can take batchsize > 1, can you please put back roi_* back to all variables at least in __call__?

For predict, we assume the batchsize to be 1, so it is really a matter of taste how we name variables.
Thus, it is ok to leave them as is.

Inside predict, I would prefer roimask to distinguish from roi_*. #568 (comment)

@Hakuyume
Copy link
Member

I consider pretrained weight to require a huge disk space (but the notion of huge disk is subjective).

Yes. The definition of huge is unclear. I think we should decide proper level considering the capacity of the CI system.


"""
h = F.relu(self.conv1(x))
h_cls_seg = self.cls_seg(h)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cls_seg --> seg_score and h_cls_seg --> h_seg_score

return roi_mask_scores, ag_locs, cls_scores, rois, roi_indices

def _pool(
self, h_cls_seg, h_ag_loc, rois, roi_indices):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@knorth55
Copy link
Contributor Author

Assume that we have two images, and two sets of bounding boxes are assigned to them.
For the first image, we have bbox of shape (R1, 4), and for the second image, we have bbox of shape (R2, 4).
bboxes would be [(R1, 4), (R2, 4)], but rois would be (R1 + R2, 4).

I see.
predict method is for batch_size=1, so rois = bboxes, which does not matter.

@knorth55
Copy link
Contributor Author

roi_* contained quite important information in the case when the batch size is more than one.
For instance, with two images, roi_cls_probs is (R1 + R2, L, 4), but cls_probs would be (2, L, 4).
Since call can take batchsize > 1, can you please put back roi_* back to all variables at least in call?
For predict, we assume the batchsize to be 1, so it is really a matter of taste how we name variables.
Thus, it is ok to leave them as is.
Inside predict, I would prefer roimask to distinguish from roi_*. #568 (comment)

Thinking of training code, we should define the variable name now.
mask: whole masks (R, H, W)
<=>
c_mask: clipped mask (R, RH, RW)
p_mask, pooled_mask, clipped_mask, roi_mask, ins_mask, obj_mask?
I prefer c_mask.

@yuyu2172
Copy link
Member

yuyu2172 commented Apr 24, 2018

predict method is for batch_size=1, so rois = bboxes, which does not matter.

More precisely, rois = bbox. (bboxes = [(R,4)] or (B, R, 4), rois=(R,4)).
That is why bbox = rois is correct (only in the case when batchsize=1), but bboxes = rois is not.

@yuyu2172
Copy link
Member

yuyu2172 commented Apr 24, 2018

I see. How about crop_mask or cropped_mask? The name improves readability.

crop_mask or cropped_mask:  (R, BH, BW) 
roimask:   (R, RH, RW)

Edit:
c_mask is ok.

if len(y_indices) == 0 or len(x_indices) == 0:
return None, None
else:
y_max = y_indices.max() + 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 is not necessary

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, wait. It was ok

assert bbox.shape[0] == len(mask_prob)
assert bbox.shape[0] == mask_weight.shape[0]

mask = np.zeros(size, dtype=np.float32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this msk?

x_max = x_indices.max() + 1
x_min = x_indices.min()

c_bbox = np.array(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*_bb?

mask_weight = mask_weight / mask_weight.sum()
mask_prob_i = roi_mask_prob[keep_indices]
bbox_i = bbox[keep_indices]
c_mask, c_bbox = _mask_aggregation(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@yuyu2172
Copy link
Member

This PR is pretty big, and implementation is good except for some naming conventions.
Since naming convention of ChainerCV has not been included in the documentation yet, I would like to postpone a discussion regarding names in another PR.
For now, I will merge this PR.

Related:
#159

@yuyu2172 yuyu2172 merged commit ecb1953 into chainer:master Apr 25, 2018
@yuyu2172 yuyu2172 mentioned this pull request Apr 25, 2018
21 tasks
@yuyu2172 yuyu2172 added this to the 0.10 milestone Apr 25, 2018
@knorth55 knorth55 deleted the experimental-fcis branch April 25, 2018 04:58
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants