-
Notifications
You must be signed in to change notification settings - Fork 131
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GatherLayer on batch axis #1089
base: master
Are you sure you want to change the base?
Conversation
Hi @albertz, what do you think about the way the size placeholder and dim tag are modified in general? Right now there is a failing test case, where we first do |
returnn/tf/layers/basic.py
Outdated
kind=Dim.Types.Spatial, description="%s_gather_axis" % self.name, | ||
dyn_size=new_size, batch=self.output.batch, | ||
src_data=self.output, src_axis=axis, auto_generated=True) | ||
self.output.size_placeholder[axis] = new_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't use the Dim
object you created?
Instead of assigning size_placeholder
, I think it would be better to set the newly created dim tag.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the usual way to set dim tags? I can't just reassign self.output.dim_tags
. declare_same_as
is used elsewhere, but not sure if it applies here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See most other layers. Usually you set dim_tags
in get_out_data_from_opts
. You should not assign a new dim tag in __init__
. In __init__
, you just might to assign the dyn_size_ext
or dyn_size_ext.placeholder
of a dim tag which was previously newly created in get_out_data_from_opts
.
# gather targets and encoder outputs | ||
"tgt": {"class": "gather", "from": "data", "axis": "B", "position": "idx"}, # B', T (sparse) | ||
"enc_raw": {"class": "gather", "from": "base:encoder", "axis": "B", "position": "idx"}, # B', T, F | ||
"enc": {"class": "reinterpret_data", "size_base": "tgt", "from": "enc_raw"}, # B', T, F |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed?
returnn/tf/layers/basic.py
Outdated
from ..util.data import Dim | ||
Dim( | ||
kind=Dim.Types.Spatial, description="%s_gather_axis" % self.name, | ||
dyn_size=new_size, batch=self.output.batch, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You miss dyn_size_ext
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should not set dyn_size
in case it is non-standard. Set dyn_size_ext
instead.
When you modify the batch dim, you should create a new |
returnn/tf/layers/basic.py
Outdated
@@ -1341,6 +1341,17 @@ def __init__(self, position, axis, **kwargs): | |||
# (BatchAxes.., InputAxesBeforeGatherAxis, PositionAxes.., InputAxesAfterGatherAxis..) | |||
self.output.placeholder = tf.gather(params=params, indices=indices, axis=gather_axis, batch_dims=batch_dims) | |||
|
|||
if input_data.dim_tags[old_gather_axis].is_batch_dim(): | |||
for axis in self.output.size_placeholder: | |||
new_size = tf.gather(params=self.output.size_placeholder[axis], indices=position_data.placeholder) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You assume that position_data
is of shape [new-batch]?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, in the case I have in mind yes. But for the failing test case, this is different and we need to take this into account.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is it in that case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There it's of shape [B,T,F], however, in the input B and T are packed
>>> input_data
Data{'flat_output', [B&Packed{'time'},F|F'feature'(5)]}
>>> self.output
Data{'output_output', [B,T|'time'[B],'other-spatial'(7),F|F'feature'(5)]}
>>> position_data
Data{'indices_flat_output', [B,T|'time'[B],F|'other-spatial'(7)], dtype='int32'}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which test case is that? The one you added? test_rand_indices
?
Why is position_data
of this shape? As described, it should have some new-batch dim in it, right? Or basically just the shape [new-batch]? When you gather into the batch dim. It definitely should not have the old batch dim in its shape.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that I need to assign it for output
. But it should come from position_data
, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is it None
for position_data
? I don't mean in the test case, I mean in the real case which motivated this test case. In the real case, you would not have such InternalLayer
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should never be done if the data has a batch dim, unless sth is wrong. In case of the test case, then the test case is buggy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, this is about the test case. However, in the case that I'm interested in, still input_data.batch == position_data.batch
is True
. This is probably because I'm using an EvalLayer
to get the batch indices from a 0/1 vector with shape (B,) and that EvalLayer
does not set the output correctly. Then we would need a layer which does that correctly, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An EvalLayer
should never change the shape. If it does, and you are not very careful in setting the output data, then yes, this is a bug in your config.
As I said, the fix is similar to what is done in the |
|
Not many layers do that. I just recall |
kind=Dim.Types.Spatial, description="%s_gather_axis" % self.name, | ||
dyn_size=new_size, batch=self.output.batch, | ||
src_data=self.output, src_axis=axis, auto_generated=True) | ||
self.output.size_placeholder[axis] = new_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should not assign size_placeholder
but rather the dim tags.
from ..util.data import Dim | ||
Dim( | ||
kind=Dim.Types.Spatial, description="%s_gather_axis" % self.name, | ||
dyn_size=new_size, batch=self.output.batch, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should not assign dyn_size
but rather dyn_size_ext
.
for dim_tag in self.output.dim_tags: | ||
if dim_tag.is_spatial_dim(): | ||
axis = self.output.get_batch_axis_excluding_batch(self.output.get_axis_by_tag_name(dim_tag.description)) | ||
new_size = tf.gather(params=self.output.size_placeholder[axis], indices=position_data.placeholder) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should not access size_placeholder
but rather dim_tag.dyn_size_ext
.
if input_data.dim_tags[old_gather_axis].is_batch_dim(): | ||
for dim_tag in self.output.dim_tags: | ||
if dim_tag.is_spatial_dim(): | ||
axis = self.output.get_batch_axis_excluding_batch(self.output.get_axis_by_tag_name(dim_tag.description)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is:
- way too complicated: you can simply do
for axis, dim_tag in enumerate(self.output.dim_tags)
- wrong: do not rely on
get_axis_by_tag_name
anddim_tag.description
- not necessary: just use
dim_tag.dyn_size_ext
position = InternalLayer( | ||
name="position", network=net, | ||
output=Data( | ||
name="position", | ||
placeholder=tf.constant(position_np, dtype=tf.int64), | ||
batch_dim_axis=0, shape=[], dtype="int64", | ||
)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@albertz do I need to change the creation of position
in order to make it have a different batch axis dim tag here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. It's actually not so simple because of the special treatment of the batch dim tag. I'm not sure it's really possible currently.
In practice, in your real code, how would you end up with position
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In practice, in your real code, how would you end up with position?
Do you mean what dim tag I get there?
>>> position.output.dim_tags[0].description
'batch:position'
So it's not actually the global batch dim. I was just confused because I got
>>> position.output.dim_tags[0] == values.output.dim_tags[0]
True
but this is because the check does not cover this case, see comment here: #1089 (comment)
As discussed offline, it is possible to get the desired results in my use case using the
Since that does exactly what I need, I'll close this PR and the corresponding issue. |
Well, GatherLayer on batch axis is still maybe sometimes a valid thing someone wants to do. I would leave this PR open. |
This PR fixes #1087. As I face the issue in the context of supervised multilingual training, I added a more general test case also for that which does not necessarily need to go into the main branch. The fix is similar to how the
size_placeholder
is modified in theShiftAxisLayer
.