You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When using Orbax (v0.11.6) on macOS, loading a checkpoint prevents tqdm from working correctly. After the checkpoint is loaded, wrapping an iterator with tqdm causes the program to abort with the error:
aborting: fork() is not allowed since tensorstore uses internal threading
This issue does not occur on Ubuntu.
Minimal Reproducible Example
importorbax.checkpointasocpimportjax.numpyasjnpimportosfromtqdmimporttrangeif__name__=="__main__":
losses= [10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
checkpoint_directory="~/Downloads/checkpoints"withocp.CheckpointManager(os.path.abspath(os.path.expanduser(checkpoint_directory))) ascheckpoint_manager:
checkpoint_manager.save(
1,
args=ocp.args.Composite(
losses=ocp.args.ArraySave(jnp.array(losses)),
),
)
# we append to losses to simulate a new epochlosses=losses.append(0)
withocp.CheckpointManager(os.path.abspath(os.path.expanduser(checkpoint_directory))) ascheckpoint_manager:
assertcheckpoint_manager.latest_step() isnotNonerestored=checkpoint_manager.restore(
checkpoint_manager.latest_step(),
args=ocp.args.Composite(
losses=ocp.args.ArrayRestore(),
),
)
losses=restored.lossesprint(losses)
foriintrange(10):
print(i)
Output:
[10 9 8 7 6 5 4 3 2 1]
aborting: fork() is not allowed since tensorstore uses internal threading
After printing the restored checkpoint, the tqdm loop does not execute and the process aborts.
I am not sure whether this issue originates from Orbax, Tensorstore, or tqdm, but it appears to be related to Orbax's interaction with Tensorstore’s internal threading on macOS.
Is there any way to prevent orbax from loading in a thread?
The text was updated successfully, but these errors were encountered:
Looks like a TensorStore restriction, I'd recommend you open an issue there. Orbax has no such restriction. Threading is an integral part of saving and loading arrays / chunks of arrays since it enables greater parallelization.
When using Orbax (v0.11.6) on macOS, loading a checkpoint prevents tqdm from working correctly. After the checkpoint is loaded, wrapping an iterator with tqdm causes the program to abort with the error:
This issue does not occur on Ubuntu.
Minimal Reproducible Example
Output:
After printing the restored checkpoint, the tqdm loop does not execute and the process aborts.
I am not sure whether this issue originates from Orbax, Tensorstore, or tqdm, but it appears to be related to Orbax's interaction with Tensorstore’s internal threading on macOS.
Is there any way to prevent orbax from loading in a thread?
The text was updated successfully, but these errors were encountered: