Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tqdm causes fork() error on macOS after loading checkpoint with Orbax #1658

Closed
plainerman opened this issue Mar 5, 2025 · 1 comment
Closed
Labels
checkpoint type:bug Something isn't working

Comments

@plainerman
Copy link

When using Orbax (v0.11.6) on macOS, loading a checkpoint prevents tqdm from working correctly. After the checkpoint is loaded, wrapping an iterator with tqdm causes the program to abort with the error:

aborting: fork() is not allowed since tensorstore uses internal threading

This issue does not occur on Ubuntu.

Minimal Reproducible Example

import orbax.checkpoint as ocp
import jax.numpy as jnp
import os
from tqdm import trange


if __name__ == "__main__":
    losses = [10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
    checkpoint_directory = "~/Downloads/checkpoints"
    with ocp.CheckpointManager(os.path.abspath(os.path.expanduser(checkpoint_directory))) as checkpoint_manager:
        checkpoint_manager.save(
            1,
            args=ocp.args.Composite(
                losses=ocp.args.ArraySave(jnp.array(losses)),
            ),
        )

    # we append to losses to simulate a new epoch
    losses = losses.append(0)

    with ocp.CheckpointManager(os.path.abspath(os.path.expanduser(checkpoint_directory))) as checkpoint_manager:
        assert checkpoint_manager.latest_step() is not None
        restored = checkpoint_manager.restore(
            checkpoint_manager.latest_step(),
            args=ocp.args.Composite(
                losses=ocp.args.ArrayRestore(),
            ),
        )
        losses = restored.losses

    print(losses)

    for i in trange(10):
        print(i)

Output:

[10  9  8  7  6  5  4  3  2  1]
aborting: fork() is not allowed since tensorstore uses internal threading

After printing the restored checkpoint, the tqdm loop does not execute and the process aborts.

I am not sure whether this issue originates from Orbax, Tensorstore, or tqdm, but it appears to be related to Orbax's interaction with Tensorstore’s internal threading on macOS.

Is there any way to prevent orbax from loading in a thread?

@cpgaffney1
Copy link
Collaborator

Looks like a TensorStore restriction, I'd recommend you open an issue there. Orbax has no such restriction. Threading is an integral part of saving and loading arrays / chunks of arrays since it enables greater parallelization.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
checkpoint type:bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants