Skip to content

Commit

Permalink
Fix Discretization serialization when num_bins is used. (#20971)
Browse files Browse the repository at this point in the history
Previously, serialization / deserialization would fail if:
- the layer was saved / restored before `adapt` was called
- the layer was saved / restored after `adapt` was called, but the dataset was such that the number of bins learned was fewer than `num_bins`

The fix consists in adding a `from_config` to handle `bin_boundaries` separately. This is because at initial creation, `bin_boundaries` and `num_bins` cannot be both set, but when restoring the layer after `adapt`, they are both set.

Tightened the error checking:
- never allow `num_bins` and `bin_boundaries` to be specified at the same time, even if they match (same as `tf_keras`)
- don't allow `num_bins` and `bin_boundaries` to be `None` at the same time
- verify that `adapt` has been called in `call`

Also removed `init_bin_boundaries` as the value was never used and its presence can be inferred.
  • Loading branch information
hertschuh authored Mar 4, 2025
1 parent 19b1418 commit eb1f844
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 15 deletions.
54 changes: 39 additions & 15 deletions keras/src/layers/preprocessing/discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class Discretization(TFDataLayer):
and `[2., +inf)`.
If this option is set, `adapt()` should not be called.
num_bins: The integer number of bins to compute.
If this option is set,
If this option is set, `bin_boundaries` should not be set and
`adapt()` should be called to learn the bin boundaries.
epsilon: Error tolerance, typically a small fraction
close to zero (e.g. 0.01). Higher values of epsilon increase
Expand Down Expand Up @@ -130,17 +130,17 @@ def __init__(
f"Received: `num_bins={num_bins}`"
)
if num_bins is not None and bin_boundaries is not None:
if len(bin_boundaries) != num_bins - 1:
raise ValueError(
"Both `num_bins` and `bin_boundaries` should not be "
f"set. Received: `num_bins={num_bins}` and "
f"`bin_boundaries={bin_boundaries}`"
)

self.input_bin_boundaries = bin_boundaries
self.bin_boundaries = (
bin_boundaries if bin_boundaries is not None else []
)
raise ValueError(
"Both `num_bins` and `bin_boundaries` should not be set. "
f"Received: `num_bins={num_bins}` and "
f"`bin_boundaries={bin_boundaries}`"
)
if num_bins is None and bin_boundaries is None:
raise ValueError(
"You need to set either `num_bins` or `bin_boundaries`."
)

self.bin_boundaries = bin_boundaries
self.num_bins = num_bins
self.epsilon = epsilon
self.output_mode = output_mode
Expand Down Expand Up @@ -183,7 +183,7 @@ def adapt(self, data, steps=None):
repeating dataset, you must specify the `steps` argument. This
argument is not supported with array inputs or list inputs.
"""
if self.input_bin_boundaries is not None:
if self.num_bins is None:
raise ValueError(
"Cannot adapt a Discretization layer that has been initialized "
"with `bin_boundaries`, use `num_bins` instead."
Expand All @@ -204,14 +204,14 @@ def update_state(self, data):
self.summary = merge_summaries(summary, self.summary, self.epsilon)

def finalize_state(self):
if self.input_bin_boundaries is not None:
if self.num_bins is None:
return
self.bin_boundaries = get_bin_boundaries(
self.summary, self.num_bins
).tolist()

def reset_state(self):
if self.input_bin_boundaries is not None:
if self.num_bins is None:
return
self.summary = np.array([[], []], dtype="float32")

Expand All @@ -225,6 +225,13 @@ def load_own_variables(self, store):
return

def call(self, inputs):
if self.bin_boundaries is None:
raise ValueError(
"You need to either pass the `bin_boundaries` argument at "
"construction time or call `adapt(dataset)` before you can "
"start using the `Discretization` layer."
)

indices = self.backend.numpy.digitize(inputs, self.bin_boundaries)
return numerical_utils.encode_categorical_inputs(
indices,
Expand All @@ -246,6 +253,23 @@ def get_config(self):
"dtype": self.dtype,
}

@classmethod
def from_config(cls, config, custom_objects=None):
if (
config.get("bin_boundaries", None) is not None
and config.get("num_bins", None) is not None
):
# After `adapt` was called, both `bin_boundaries` and `num_bins` are
# populated, but `__init__` won't let us create a new layer with
# both `bin_boundaries` and `num_bins`. We therefore apply
# `bin_boundaries` after creation.
config = config.copy()
bin_boundaries = config.pop("bin_boundaries")
discretization = cls(**config)
discretization.bin_boundaries = bin_boundaries
return discretization
return cls(**config)


def summarize(values, epsilon):
"""Reduce a 1D sequence of values to a summary.
Expand Down
42 changes: 42 additions & 0 deletions keras/src/layers/preprocessing/discretization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,32 @@ def test_tf_data_compatibility(self):
for output in ds.take(1):
output.numpy()

def test_serialization(self):
layer = layers.Discretization(num_bins=5)

# Serialization before `adapt` is called.
config = layer.get_config()
revived_layer = layers.Discretization.from_config(config)
self.assertEqual(config, revived_layer.get_config())

# Serialization after `adapt` is called but `num_bins` was not reached.
layer.adapt(np.array([0.0, 1.0, 5.0]))
config = layer.get_config()
revived_layer = layers.Discretization.from_config(config)
self.assertEqual(config, revived_layer.get_config())

# Serialization after `adapt` is called and `num_bins` is reached.
layer.adapt(np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]))
config = layer.get_config()
revived_layer = layers.Discretization.from_config(config)
self.assertEqual(config, revived_layer.get_config())

# Serialization with `bin_boundaries`.
layer = layers.Discretization(bin_boundaries=[0.0, 0.35, 0.5, 1.0])
config = layer.get_config()
revived_layer = layers.Discretization.from_config(config)
self.assertEqual(config, revived_layer.get_config())

def test_saving(self):
# With fixed bins
layer = layers.Discretization(bin_boundaries=[0.0, 0.35, 0.5, 1.0])
Expand Down Expand Up @@ -163,3 +189,19 @@ def test_saving(self):
model.save(fpath)
model = saving_api.load_model(fpath)
self.assertAllClose(layer(ref_input), ref_output)

def test_init_num_bins_and_bin_boundaries_raises(self):
with self.assertRaisesRegex(
ValueError, "Both `num_bins` and `bin_boundaries`"
):
layers.Discretization(num_bins=3, bin_boundaries=[0.0, 1.0])

with self.assertRaisesRegex(
ValueError, "either `num_bins` or `bin_boundaries`"
):
layers.Discretization()

def test_call_before_adapt_raises(self):
layer = layers.Discretization(num_bins=3)
with self.assertRaisesRegex(ValueError, "You need .* call .*adapt"):
layer([[0.1, 0.8, 0.9]])

0 comments on commit eb1f844

Please sign in to comment.