Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

python3Packages.jax: fix libstdc++ mismatch when built with CUDA #225661

Merged
merged 2 commits into from
Apr 13, 2023

Conversation

SomeoneSerge
Copy link
Contributor

@SomeoneSerge SomeoneSerge commented Apr 11, 2023

Description of changes

Hopefully this will address #220341 for jax{,lib}

Things done
  • Built on platform(s)
    • x86_64-linux
    • aarch64-linux
    • x86_64-darwin
    • aarch64-darwin
  • For non-Linux: Is sandbox = true set in nix.conf? (See Nix manual)
  • Tested, as applicable:
  • Tested compilation of all packages that depend on this change using nix-shell -p nixpkgs-review --run "nixpkgs-review rev HEAD". Note: all changes have to be committed, also see nixpkgs-review usage
  • Tested basic functionality of all binary files (usually in ./result/bin/)
  • 23.05 Release Notes (or backporting 22.11 Release notes)
    • (Package updates) Added a release notes entry if the change is major or breaking
    • (Module updates) Added a release notes entry if the change is significant
    • (Module addition) Added a release notes entry if adding a new NixOS module
  • Fits CONTRIBUTING.md.

@SomeoneSerge SomeoneSerge added the 6.topic: cuda Parallel computing platform and API label Apr 11, 2023
@ofborg ofborg bot requested a review from ndl April 11, 2023 01:29
@SomeoneSerge
Copy link
Contributor Author

Result of nixpkgs-review pr 225661 --extra-nixpkgs-config '{ cudaCapabilities = [ "8.6" ]; }' run on x86_64-linux 1

4 packages built:
  • python310Packages.jaxlibWithCuda
  • python310Packages.jaxlibWithCuda.dist
  • python311Packages.jaxlibWithCuda
  • python311Packages.jaxlibWithCuda.dist

@SomeoneSerge
Copy link
Contributor Author

❯ nix-build with-my-cuda.nix -A python3Packages.jax
/nix/store/i3sg8xpx1fzva7yjp2wdxcprd63c9kg8-python3.10-jax-0.4.1

nixpkgs-review with cudaSupport on the way (the cvxpy tests throttling again, idk why they get stuck)

@SomeoneSerge SomeoneSerge marked this pull request as ready for review April 11, 2023 18:51
@SomeoneSerge
Copy link
Contributor Author

CC @NixOS/cuda-maintainers

@SomeoneSerge SomeoneSerge requested a review from samuela April 11, 2023 18:51
Copy link
Member

@samuela samuela left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

woohoo! can't wait to have JAX fixed!

@@ -256,6 +256,7 @@ let
sed -i 's@include/pybind11@pybind11@g' $src
done
'' + lib.optionalString cudaSupport ''
export NIX_LDFLAGS+=" -L${backendStdenv.nixpkgsCompatibleLibstdcxx}/lib"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the order matter here? I'm assuming stdenv's (undesirable) libstdc++ will also be in NIX_LDFLAGS somewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there's any extra -L in NIX_LDFLAGS, but maybe I should quickly check it by adding an echo somewhere

@@ -18,8 +18,8 @@ final: prev: let
# E.g. for cudaPackages_11_8 we use gcc11 with gcc12's libstdc++
# Cf. https://github.com/NixOS/nixpkgs/pull/218265 for context
backendStdenv = final.callPackage ./stdenv.nix {
nixpkgsStdenv = prev.pkgs.stdenv;
nvccCompatibleStdenv = prev.pkgs.buildPackages."${finalVersion.gcc}Stdenv";
nixpkgsCompatibleLibstdcxx = prev.pkgs.buildPackages.gcc.cc.lib; # Or is it targetPackages?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, what can I say: either at some point we get a comment from someone who understands cross-compilation, or eventually we start cross-compiling ourselves and find the right way empirically 😆 I tried asking on matrix, but I was clumsy and failed to attract attention of the right people

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah cross-compilation is a mess! I honestly have no idea either :P The current comment looks a bit like a TODO that slipped through code review. Perhaps we could leave a comment to the effect of yours above ^

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO the way that crosscompilation works in nix doesn't fit modern projects anymore. I think the autotools-based approach comes from a time where there wasn't really remote execution, platforms/config split, offloading etc. This is why e.g. Bazel and Buck use a different model that makes it easier to transition between toolchains:

Seems like my concerns raised in #225074 weren't too far off after all 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's interesting, I'll have a look at the issue! All I can say about Bazel's toolchains is that there was some "tutorial: make a toolchain for X" page in their documentation that many a time I started reading, and I have not managed reach the end even once 🤣

Copy link
Contributor Author

@SomeoneSerge SomeoneSerge Apr 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice, I don't know of any attempts of cross-compiling CUDA packages in nixpkgs. Personally, I haven't even a reason to try until we'd addressed #225915

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might also still not be understanding the nix toolchains well enough though. I'm trying to come up with something that makes it easy to swap out LLVM/GGC/libstdc++/libcxx etc without needing all those different crosscompilation targets but also haven't found a good solution yet.

@SomeoneSerge
Copy link
Contributor Author

Result of nixpkgs-review pr 225661 --extra-nixpkgs-config '{ cudaCapabilities = [ "8.6" ]; cudaSupport = true; }' run on x86_64-linux 1

25 packages failed to build:
  • python310Packages.arviz
  • python310Packages.arviz.dist
  • python310Packages.bambi
  • python310Packages.bambi.dist
  • python310Packages.blackjax
  • python310Packages.blackjax.dist
  • python310Packages.jaxopt
  • python310Packages.jaxopt.dist
  • python310Packages.numpyro
  • python310Packages.numpyro.dist
  • python310Packages.pymc
  • python310Packages.pymc.dist
  • python310Packages.pytensor
  • python310Packages.pytensor.dist
  • python311Packages.blackjax
  • python311Packages.blackjax.dist
  • python311Packages.dm-haiku
  • python311Packages.dm-haiku.dist
  • python311Packages.dm-haiku.testsout
  • python311Packages.jaxopt
  • python311Packages.jaxopt.dist
  • python311Packages.jmp
  • python311Packages.jmp.dist
  • python311Packages.numpyro
  • python311Packages.numpyro.dist
39 packages built:
  • python310Packages.aeppl
  • python310Packages.aeppl.dist
  • python310Packages.aesara
  • python310Packages.aesara.dist
  • python310Packages.augmax
  • python310Packages.augmax.dist
  • python310Packages.chex
  • python310Packages.chex.dist
  • python310Packages.dm-haiku
  • python310Packages.dm-haiku.dist
  • python310Packages.dm-haiku.testsout
  • python310Packages.jax
  • python310Packages.jax.dist
  • python310Packages.jaxlib (python310Packages.jaxlib-build ,python310Packages.jaxlibWithCuda)
  • python310Packages.jaxlib.dist (python310Packages.jaxlib-build.dist ,python310Packages.jaxlibWithCuda.dist)
  • python310Packages.jmp
  • python310Packages.jmp.dist
  • python310Packages.objax
  • python310Packages.objax.dist
  • python310Packages.optax
  • python310Packages.optax.dist
  • python310Packages.optax.testsout
  • python310Packages.tensorflow-bin
  • python310Packages.tensorflow-bin.dist
  • python310Packages.treeo
  • python310Packages.treeo.dist
  • python311Packages.augmax
  • python311Packages.augmax.dist
  • python311Packages.chex
  • python311Packages.chex.dist
  • python311Packages.jax
  • python311Packages.jax.dist
  • python311Packages.jaxlib (python311Packages.jaxlib-build ,python311Packages.jaxlibWithCuda)
  • python311Packages.jaxlib.dist (python311Packages.jaxlib-build.dist ,python311Packages.jaxlibWithCuda.dist)
  • python311Packages.optax
  • python311Packages.optax.dist
  • python311Packages.optax.testsout
  • python311Packages.treeo
  • python311Packages.treeo.dist

@SomeoneSerge
Copy link
Contributor Author

Failed derivations

Comment on lines 21 to 22
nixpkgsCompatibleLibstdcxx = prev.pkgs.buildPackages.gcc.cc.lib; # Or is it targetPackages?
nvccCompatibleCC = prev.pkgs.buildPackages."${finalVersion.gcc}".cc;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe buildPackages is correct here, in both cases.

For nvccCompatibleCC: we want a copy of the compiler that runs on our build platform. I.e. it should come from a stage where hostPlatform is equal to our package's buildPlatform and where targetPlatform is equal to our package's hostPlatform; aka buildPackages.

nixpkgsCompatibleLibstdcxx is a little weirder; normally we grab libraries that will run on a package's hostPlatform from the current stage (aka pkgsHostTarget) but because libstdc++ comes from gcc which is a compiler, the libstdc++ that's produced actually is built for use on the gcc package's targetPlatform. So buildPackages here is correct.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Such a relief to hear this! I was slowly arriving at this interpretation, but with every morning I feel confused about the build/host/target terminlogy again

# older libstdc++. This, in practice, means that we should use libstdc++ from
# the same stdenv that the rest of nixpkgs uses.
# We currently do not try to support anything other than gcc and linux.
libcxx = nixpkgsCompatibleLibstdcxx;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately I don't think libcxx here has the semantics you're looking for.

nixpkgsCompatibleLibstdcxx is added to propagatedDeps but I think that's it; If you look in cc-wrapper, libcxx is only consulted for libc++ include paths when libcxx.isLLVM is present and true:

+ optionalString (libcxx != null || (useGccForLibs && gccForLibs.langCC or false) || (isGNU && cc.langCC or false)) ''
touch "$out/nix-support/libcxx-cxxflags"
touch "$out/nix-support/libcxx-ldflags"
''
+ optionalString (libcxx == null && (useGccForLibs && gccForLibs.langCC or false)) ''
for dir in ${gccForLibs}${lib.optionalString (hostPlatform != targetPlatform) "/${targetPlatform.config}"}/include/c++/*; do
echo "-isystem $dir" >> $out/nix-support/libcxx-cxxflags
done
for dir in ${gccForLibs}${lib.optionalString (hostPlatform != targetPlatform) "/${targetPlatform.config}"}/include/c++/*/${targetPlatform.config}; do
echo "-isystem $dir" >> $out/nix-support/libcxx-cxxflags
done
''
+ optionalString (libcxx.isLLVM or false) ''
echo "-isystem ${lib.getDev libcxx}/include/c++/v1" >> $out/nix-support/libcxx-cxxflags
echo "-stdlib=libc++" >> $out/nix-support/libcxx-ldflags
echo "-l${libcxx.cxxabi.libName}" >> $out/nix-support/libcxx-ldflags
''

Even if nixpkgsCompatibleLibstdcxx.isLLVM were present and true this wouldn't work as intended; the logic for libcxx in cc-wrapper (i.e. -stdlib, cxxabi, include/c++/v1) really is specific to clang and libc++.

Further, setting libcxx does not inhibit cc-wrapper from tacking on the gcc library paths where the cc's libraries (i.e. libstdc++ and friends) to cflags/ldflags.


Unfortunately I don't think cc-wrapper as it exists today has the right knobs to support this use case (cc = gcc but libstdc++ from a different gcc) but I think we can get pretty close with useCcForLibs and gccForLibs = gccFromWhichYouWantToGrabLibstdcxx. Here's a quick PoC (imperfect -- it seems to lose libc headers somehow -- but hopefully gets the point across): dc6a8f9

Note that with the commit above, using "${np.cudaPackages.backendStdenv.cc}/bin/g++ to compile binaries results in them being linked against libstdc++ from gcc12 (observable by running ldd on the binaries). With the PR as it exists right now, such binaries are still linked against libstdc++ from backendStdenv.cc.cc.

Depending on what the requirement here is (just a newer libstdc++? compiler builtins like libgcc_s too? etc) getting cc-wrapper to support this use case might make more sense than adding link options for the desired libstdc++ to packages in an ad-hoc way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately I don't think libcxx here has the semantics you're looking for.

Yes, I was almost sure about that when I added the original line

nixpkgsCompatibleLibstdcxx is added to propagatedDeps but I think that's it

Indeed! And by sheer accident this has been sufficient for most of our broken cuda packages. This is why took a blind eye on the obvious erroneousness of the current "approach" and hastily merged #223664

I think it shouldn't be hard, though, to develop this into a correctly working stdenv while keeping all of the interfaces intact

Even if nixpkgsCompatibleLibstdcxx.isLLVM were present and true this wouldn't work as intended; this logic ... is specific to clang

I suspected so. Again, effectively we've only relied on the propagated build inputs so far, but we definitely should keep it this way for long

Unfortunately I don't think cc-wrapper as it exists today has the right knobs to support this use case

I was afraid of that. I think it would make sense to extend cc-wrapper for our use-case

Here's a quick PoC ... dc6a8f9

Oh wow, thank you so much! I think this is exactly the direction we should be heading in!

With the PR as it exists right now, such binaries are still linked against libstdc++ from backendStdenv.cc.cc

Thankfully, backendStdenv.cc.cc does not really link to its .lib and happily picks up any libstdc++ we put in downstream package's buildInputs... unless they are built with Bazel

But, as I admit above, this behaviour is accidental and we must not rely on it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depending on what the requirement here is (just a newer libstdc++? compiler builtins like libgcc_s too? etc) getting cc-wrapper to support this use case might make more sense than adding link options for the desired libstdc++ to packages in an ad-hoc way.

At this point, I think we only need to override what c++ stdlib gets written into runpaths of downstream derivations

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And by sheer accident this has been sufficient for most of our broken cuda packages. This is why took a blind eye on the obvious erroneousness of the current "approach" and hastily merged #223664

I think it shouldn't be hard, though, to develop this into a correctly working stdenv while keeping all of the interfaces intact

I think it would make sense to extend cc-wrapper for our use-case

But, as I admit above, this behaviour is accidental and we must not rely on it

Ah okay, thanks for clarifying; this is good to hear 🙂

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thankfully, backendStdenv.cc.cc does not really link to its .lib and happily picks up any libstdc++ we put in downstream package's buildInputs... unless they are built with Bazel

Oh, interesting...

In general I'd think that your propagated deps would make their way into the linkopts that Bazel uses (propagatedDeps should get picked up by the bintools stdenv hooks and added to NIX_LDFLAGS which buildBazelPackage should map to linkopts).

Maybe it's an ordering thing? I think normally things specified in the NIX_LDFLAGS env var would have precedence over flags in libc-ldflags; with Bazel's linkopts this wouldn't be the case.

Easiest way to test would be to add NIX_LDFLAGS as an action_env in the Bazel options I think.

@SomeoneSerge
Copy link
Contributor Author

SomeoneSerge commented Apr 12, 2023

@rrbutani thank you so much for your reviews, they were more useful that you might suspect! I think we're going to merge this PR just with the ad hoc hacks for jaxlib (modulo cudatoolkit/extension.nix confusion) because we want the CI to start building cache again. We want, however, to rework cudaPackages.backendStdenv entirely, and thanks to your suggestions I now have a much clearer picture of how we could see that through. I'll make sure to link the relevant PRs when we get to them!

@rrbutani
Copy link
Contributor

rrbutani commented Apr 12, 2023

no problem!

I think we're going to merge this PR just with the ad hoc hacks for jaxlib (modulo cudatoolkit/extension.nix confusion) because we want the CI to start building cache again.

sounds good!

I suspect tweaking cc-wrapper for this use case will quickly get entangled in "compiler builtins lib source vs C++ stdlib source" / other cc-wrapper clean up efforts; makes sense not to block any of this work on that. 🙂

@SomeoneSerge
Copy link
Contributor Author

@samuela I think we can merge this now then!

By the way, I don't know if you noticed (it's been mostly discussed on matrix), but we've been maintaining tickets at https://github.com/orgs/NixOS/projects/27/views/1 (thanks @ConnorBaker for the idea)

@samuela
Copy link
Member

samuela commented Apr 13, 2023

jax and jaxlib both build for me with cudaSupport = true!

@samuela samuela merged commit 929a328 into NixOS:master Apr 13, 2023
@samuela
Copy link
Member

samuela commented Apr 13, 2023

Thanks so much @SomeoneSerge ! I use JAX a lot so I will def be using this fix!

@SomeoneSerge
Copy link
Contributor Author

No time for a cc-wrapper PR yet, but opened an issue to track: #226165

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

4 participants