Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: treat bit rate optimization as degree of freedom optimization problem #500

Merged
merged 2 commits into from
Feb 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 65 additions & 112 deletions includes/acl/compression/impl/quantize.transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,19 @@ namespace acl
// until our error is acceptable.
// We try permutations from the lowest memory footprint to the highest.

const uint8_t* const bit_rate_permutations_per_dofs[] =
{
&acl_impl::k_local_bit_rate_permutations_1_dof[0][0],
&acl_impl::k_local_bit_rate_permutations_2_dof[0][0],
&acl_impl::k_local_bit_rate_permutations_3_dof[0][0],
};
const size_t num_bit_rate_permutations_per_dofs[] =
{
get_array_size(acl_impl::k_local_bit_rate_permutations_1_dof),
get_array_size(acl_impl::k_local_bit_rate_permutations_2_dof),
get_array_size(acl_impl::k_local_bit_rate_permutations_3_dof),
};

const uint32_t num_bones = context.num_bones;
for (uint32_t bone_index = 0; bone_index < num_bones; ++bone_index)
{
Expand All @@ -897,134 +910,73 @@ namespace acl
uint32_t prev_transform_size = ~0U;
bool is_error_good_enough = false;

if (context.has_scale)
{
const size_t num_permutations = get_array_size(acl_impl::k_local_bit_rate_permutations);
for (size_t permutation_index = 0; permutation_index < num_permutations; ++permutation_index)
{
const uint8_t rotation_bit_rate = acl_impl::k_local_bit_rate_permutations[permutation_index][0];
if (bone_bit_rates.rotation == 1)
{
if (rotation_bit_rate == 0)
continue; // Skip permutations we aren't interested in
}
else if (bone_bit_rates.rotation == k_invalid_bit_rate)
{
if (rotation_bit_rate != 0)
continue; // Skip permutations we aren't interested in
}
// Determine how many degrees of freedom we have to optimize our bit rates
uint32_t num_dof = 0;
num_dof += bone_bit_rates.rotation != k_invalid_bit_rate ? 1 : 0;
num_dof += bone_bit_rates.translation != k_invalid_bit_rate ? 1 : 0;
num_dof += bone_bit_rates.scale != k_invalid_bit_rate ? 1 : 0;

const uint8_t translation_bit_rate = acl_impl::k_local_bit_rate_permutations[permutation_index][1];
if (bone_bit_rates.translation == 1)
{
if (translation_bit_rate == 0)
continue; // Skip permutations we aren't interested in
}
else if (bone_bit_rates.translation == k_invalid_bit_rate)
{
if (translation_bit_rate != 0)
continue; // Skip permutations we aren't interested in
}
const uint8_t* bit_rate_permutations_per_dof = bit_rate_permutations_per_dofs[num_dof - 1];
const size_t num_bit_rate_permutations = num_bit_rate_permutations_per_dofs[num_dof - 1];

const uint8_t scale_bit_rate = acl_impl::k_local_bit_rate_permutations[permutation_index][2];
if (bone_bit_rates.scale == 1)
{
if (scale_bit_rate == 0)
continue; // Skip permutations we aren't interested in
}
else if (bone_bit_rates.scale == k_invalid_bit_rate)
{
if (scale_bit_rate != 0)
continue; // Skip permutations we aren't interested in
}

const uint32_t rotation_size = get_num_bits_at_bit_rate(rotation_bit_rate);
const uint32_t translation_size = get_num_bits_at_bit_rate(translation_bit_rate);
const uint32_t scale_size = get_num_bits_at_bit_rate(scale_bit_rate);
const uint32_t transform_size = rotation_size + translation_size + scale_size;

if (transform_size != prev_transform_size && is_error_good_enough)
{
// We already found the lowest transform size and we tried every permutation with that same size
break;
}

prev_transform_size = transform_size;
// Our desired bit rates start with the initial value
transform_bit_rates desired_bit_rates = bone_bit_rates;

context.bit_rate_per_bone[bone_index].rotation = bone_bit_rates.rotation != k_invalid_bit_rate ? rotation_bit_rate : k_invalid_bit_rate;
context.bit_rate_per_bone[bone_index].translation = bone_bit_rates.translation != k_invalid_bit_rate ? translation_bit_rate : k_invalid_bit_rate;
context.bit_rate_per_bone[bone_index].scale = bone_bit_rates.scale != k_invalid_bit_rate ? scale_bit_rate : k_invalid_bit_rate;

const float error = calculate_max_error_at_bit_rate_local(context, bone_index, error_scan_stop_condition::until_error_too_high);
size_t permutation_offset = 0;
for (size_t permutation_index = 0; permutation_index < num_bit_rate_permutations; ++permutation_index)
{
// If a bit rate is variable, grab a permutation for it
// We'll only consume as many bit rates as we have degrees of freedom

#if ACL_IMPL_DEBUG_VARIABLE_QUANTIZATION >= ACL_IMPL_DEBUG_LEVEL_VERBOSE_INFO
printf("%u: %u | %u | %u (%u) = %f\n", bone_index, rotation_bit_rate, translation_bit_rate, scale_bit_rate, transform_size, error);
#endif
uint32_t transform_size = 0; // In bits

if (error < best_error)
{
best_error = error;
best_bit_rates = context.bit_rate_per_bone[bone_index];
is_error_good_enough = error < error_threshold;
}
}
}
else
{
const size_t num_permutations = get_array_size(acl_impl::k_local_bit_rate_permutations_no_scale);
for (size_t permutation_index = 0; permutation_index < num_permutations; ++permutation_index)
if (desired_bit_rates.rotation != k_invalid_bit_rate)
{
const uint8_t rotation_bit_rate = acl_impl::k_local_bit_rate_permutations_no_scale[permutation_index][0];
if (bone_bit_rates.rotation == 1)
{
if (rotation_bit_rate == 0)
continue; // Skip permutations we aren't interested in
}
else if (bone_bit_rates.rotation == k_invalid_bit_rate)
{
if (rotation_bit_rate != 0)
continue; // Skip permutations we aren't interested in
}
desired_bit_rates.rotation = bit_rate_permutations_per_dof[permutation_offset++];
transform_size += get_num_bits_at_bit_rate(desired_bit_rates.rotation);
}

const uint8_t translation_bit_rate = acl_impl::k_local_bit_rate_permutations_no_scale[permutation_index][1];
if (bone_bit_rates.translation == 1)
{
if (translation_bit_rate == 0)
continue; // Skip permutations we aren't interested in
}
else if (bone_bit_rates.translation == k_invalid_bit_rate)
{
if (translation_bit_rate != 0)
continue; // Skip permutations we aren't interested in
}
if (desired_bit_rates.translation != k_invalid_bit_rate)
{
desired_bit_rates.translation = bit_rate_permutations_per_dof[permutation_offset++];
transform_size += get_num_bits_at_bit_rate(desired_bit_rates.translation);
}

const uint32_t rotation_size = get_num_bits_at_bit_rate(rotation_bit_rate);
const uint32_t translation_size = get_num_bits_at_bit_rate(translation_bit_rate);
const uint32_t transform_size = rotation_size + translation_size;
if (desired_bit_rates.scale != k_invalid_bit_rate)
{
desired_bit_rates.scale = bit_rate_permutations_per_dof[permutation_offset++];
transform_size += get_num_bits_at_bit_rate(desired_bit_rates.scale);
}

if (transform_size != prev_transform_size && is_error_good_enough)
{
// We already found the lowest transform size and we tried every permutation with that same size
break;
}
// If our inputs aren't normalized per segment, we can't store them on 0 bits because we'll have no
// segment range information. This occurs when we have a single segment. Skip those permutations.
if (bone_bit_rates.rotation == k_lowest_bit_rate && desired_bit_rates.rotation == 0)
continue;
else if (bone_bit_rates.translation == k_lowest_bit_rate && desired_bit_rates.translation == 0)
continue;
else if (bone_bit_rates.scale == k_lowest_bit_rate && desired_bit_rates.scale == 0)
continue;

// If we already found a permutation that is good enough, we test all the others
// that have the same size. Once the size changes, we stop.
if (is_error_good_enough && transform_size != prev_transform_size)
break;

prev_transform_size = transform_size;
prev_transform_size = transform_size;

context.bit_rate_per_bone[bone_index].rotation = bone_bit_rates.rotation != k_invalid_bit_rate ? rotation_bit_rate : k_invalid_bit_rate;
context.bit_rate_per_bone[bone_index].translation = bone_bit_rates.translation != k_invalid_bit_rate ? translation_bit_rate : k_invalid_bit_rate;
context.bit_rate_per_bone[bone_index] = desired_bit_rates;

const float error = calculate_max_error_at_bit_rate_local(context, bone_index, error_scan_stop_condition::until_error_too_high);
const float error = calculate_max_error_at_bit_rate_local(context, bone_index, error_scan_stop_condition::until_error_too_high);

#if ACL_IMPL_DEBUG_VARIABLE_QUANTIZATION >= ACL_IMPL_DEBUG_LEVEL_VERBOSE_INFO
printf("%u: %u | %u | %u (%u) = %f\n", bone_index, rotation_bit_rate, translation_bit_rate, k_invalid_bit_rate, transform_size, error);
printf("%u: %u | %u | %u (%u) = %f\n", bone_index, desired_bit_rates.rotation, desired_bit_rates.translation, desired_bit_rates.scale, transform_size, error);
#endif

if (error < best_error)
{
best_error = error;
best_bit_rates = context.bit_rate_per_bone[bone_index];
is_error_good_enough = error < error_threshold;
}
if (error < best_error)
{
best_error = error;
best_bit_rates = desired_bit_rates;
is_error_good_enough = error < error_threshold;
}
}

Expand All @@ -1038,6 +990,7 @@ namespace acl

constexpr uint32_t increment_and_clamp_bit_rate(uint32_t bit_rate, uint32_t increment)
{
// If the bit rate is already above highest (e.g 255 if constant), leave it as is otherwise increment and clamp
return bit_rate >= k_highest_bit_rate ? bit_rate : std::min<uint32_t>(bit_rate + increment, k_highest_bit_rate);
}

Expand Down
36 changes: 34 additions & 2 deletions includes/acl/compression/impl/transform_bit_rate_permutations.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,38 @@ namespace acl

namespace acl_impl
{
constexpr uint8_t k_local_bit_rate_permutations_no_scale[625][2] =
// Buffer size in bytes: 25
constexpr uint8_t k_local_bit_rate_permutations_1_dof[25][1] =
{
{ 0 }, // 0 bits per transform
{ 1 }, // 3 bits per transform
{ 2 }, // 6 bits per transform
{ 3 }, // 9 bits per transform
{ 4 }, // 12 bits per transform
{ 5 }, // 15 bits per transform
{ 6 }, // 18 bits per transform
{ 7 }, // 21 bits per transform
{ 8 }, // 24 bits per transform
{ 9 }, // 27 bits per transform
{ 10 }, // 30 bits per transform
{ 11 }, // 33 bits per transform
{ 12 }, // 36 bits per transform
{ 13 }, // 39 bits per transform
{ 14 }, // 42 bits per transform
{ 15 }, // 45 bits per transform
{ 16 }, // 48 bits per transform
{ 17 }, // 51 bits per transform
{ 18 }, // 54 bits per transform
{ 19 }, // 57 bits per transform
{ 20 }, // 60 bits per transform
{ 21 }, // 63 bits per transform
{ 22 }, // 66 bits per transform
{ 23 }, // 69 bits per transform
{ 24 }, // 96 bits per transform
};

// Buffer size in bytes: 1250
constexpr uint8_t k_local_bit_rate_permutations_2_dof[625][2] =
{
{ 0, 0 }, // 0 bits per transform
{ 0, 1 }, // 3 bits per transform
Expand Down Expand Up @@ -668,7 +699,8 @@ namespace acl
{ 24, 24 }, // 192 bits per transform
};

constexpr uint8_t k_local_bit_rate_permutations[15625][3] =
// Buffer size in bytes: 46875
constexpr uint8_t k_local_bit_rate_permutations_3_dof[15625][3] =
{
{ 0, 0, 0 }, // 0 bits per transform
{ 0, 0, 1 }, // 3 bits per transform
Expand Down
49 changes: 32 additions & 17 deletions tools/calc_local_bit_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,45 @@
print('Python 3.4 or higher needed to run this script')
sys.exit(1)

permutation_tries = []
permutation_tries_no_scale = []
permutation_dof_1 = []
permutation_dof_2 = []
permutation_dof_3 = []

for rotation_bit_rate in range(k_num_bit_rates):
for translation_bit_rate in range(k_num_bit_rates):
transform_size = k_bit_rate_num_bits[rotation_bit_rate] * 3 + k_bit_rate_num_bits[translation_bit_rate] * 3
permutation_tries_no_scale.append((transform_size, rotation_bit_rate, translation_bit_rate))
for dof_1 in range(k_num_bit_rates):
dof_1_size = k_bit_rate_num_bits[dof_1] * 3;
permutation_dof_1.append((dof_1_size, dof_1))

for scale_bit_rate in range(k_num_bit_rates):
transform_size = k_bit_rate_num_bits[rotation_bit_rate] * 3 + k_bit_rate_num_bits[translation_bit_rate] * 3 + k_bit_rate_num_bits[scale_bit_rate] * 3
permutation_tries.append((transform_size, rotation_bit_rate, translation_bit_rate, scale_bit_rate))
for dof_2 in range(k_num_bit_rates):
dof_2_size = dof_1_size + k_bit_rate_num_bits[dof_2] * 3
permutation_dof_2.append((dof_2_size, dof_1, dof_2))

for dof_3 in range(k_num_bit_rates):
dof_3_size = dof_2_size + k_bit_rate_num_bits[dof_3] * 3
permutation_dof_3.append((dof_3_size, dof_1, dof_2, dof_3))

# Sort by transform size, then by each bit rate
permutation_tries.sort()
permutation_tries_no_scale.sort()
permutation_dof_1.sort()
permutation_dof_2.sort()
permutation_dof_3.sort()

print('constexpr uint8_t k_local_bit_rate_permutations_no_scale[{}][2] ='.format(len(permutation_tries_no_scale)))
print('// Buffer size in bytes: {}'.format(len(permutation_dof_1) * 1));
print('constexpr uint8_t k_local_bit_rate_permutations_1_dof[{}][1] ='.format(len(permutation_dof_1)))
print('{')
for transform_size, rotation_bit_rate, translation_bit_rate in permutation_tries_no_scale:
print('\t{{ {}, {} }},\t\t// {} bits per transform'.format(rotation_bit_rate, translation_bit_rate, transform_size))
for transform_size, dof_1 in permutation_dof_1:
print('\t{{ {} }},\t\t// {} bits per transform'.format(dof_1, transform_size))
print('};')
print()
print('constexpr uint8_t k_local_bit_rate_permutations[{}][3] ='.format(len(permutation_tries)))
print('// Buffer size in bytes: {}'.format(len(permutation_dof_2) * 2));
print('constexpr uint8_t k_local_bit_rate_permutations_2_dof[{}][2] ='.format(len(permutation_dof_2)))
print('{')
for transform_size, rotation_bit_rate, translation_bit_rate, scale_bit_rate in permutation_tries:
print('\t{{ {}, {}, {} }},\t\t// {} bits per transform'.format(rotation_bit_rate, translation_bit_rate, scale_bit_rate, transform_size))
for transform_size, dof_1, dof_2 in permutation_dof_2:
print('\t{{ {}, {} }},\t\t// {} bits per transform'.format(dof_1, dof_2, transform_size))
print('};')
print()
print('// Buffer size in bytes: {}'.format(len(permutation_dof_3) * 3));
print('constexpr uint8_t k_local_bit_rate_permutations_3_dof[{}][3] ='.format(len(permutation_dof_3)))
print('{')
for transform_size, dof_1, dof_2, dof_3 in permutation_dof_3:
print('\t{{ {}, {}, {} }},\t\t// {} bits per transform'.format(dof_1, dof_2, dof_3, transform_size))
print('};')
print()
Loading