Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add code for unrolling affine.for under threshold #441

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

snarang181
Copy link
Collaborator

Add corresponding test

@snarang181
Copy link
Collaborator Author

When I run gdb --args bazel-bin/enzymexlamlir-opt --raise-affine-to-stablehlo Enzyme-JAX/test/lit_tests/raising/raiseaffinefor_unroll.mlir, I see

/home/snarang181/Enzyme-JAX/test/lit_tests/raising/raiseaffinefor_unroll.mlir:10:13: error: 'stablehlo.dynamic_update_slice' op operation destroyed but still has uses
            affine.store %4, %arg0[0, %arg2 + 136, %arg1 + 7] : memref<1x187x194xf64, 1>
            ^
/home/snarang181/Enzyme-JAX/test/lit_tests/raising/raiseaffinefor_unroll.mlir:10:13: note: see current operation: %0 = "stablehlo.dynamic_update_slice"(<<NULL VALUE>>, <<NULL VALUE>>, <<NULL VALUE>>, <<NULL VALUE>>, <<NULL VALUE>>) : (<<NULL TYPE>>, <<NULL TYPE>>, <<NULL TYPE>>, <<NULL TYPE>>, <<NULL TYPE>>) -> tensor<1x187x194xf64>
/home/snarang181/Enzyme-JAX/test/lit_tests/raising/raiseaffinefor_unroll.mlir:19:9: note: - use: %39 = "stablehlo.dynamic_update_slice"(<<UNKNOWN SSA VALUE>>, %38, %34, %35, %36) : (tensor<1x187x194xf64>, tensor<1x1x180xf64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x187x194xf64>

        affine.store %3, %arg0[0, 135, %arg1 + 7] : memref<1x187x194xf64, 1>
        ^
/home/snarang181/Enzyme-JAX/test/lit_tests/raising/raiseaffinefor_unroll.mlir:13:14: note: - use: %17 = "stablehlo.slice"(<<UNKNOWN SSA VALUE>>) <{limit_indices = array<i64: 1, 136, 187>, start_indices = array<i64: 0, 135, 7>, strides = array<i64: 1, 1, 1>}> : (tensor<1x187x194xf64>) -> tensor<1x1x180xf64>

        %2 = affine.load %arg0[0, 135, %arg1 + 7] : memref<1x187x194xf64, 1>
             ^
/home/snarang181/Enzyme-JAX/test/lit_tests/raising/raiseaffinefor_unroll.mlir:12:14: note: - use: %14 = "stablehlo.slice"(<<UNKNOWN SSA VALUE>>) <{limit_indices = array<i64: 1, 136, 187>, start_indices = array<i64: 0, 135, 7>, strides = array<i64: 1, 1, 1>}> : (tensor<1x187x194xf64>) -> tensor<1x1x180xf64>

        %1 = affine.load %arg0[0, 135, -%arg1 + 186] : memref<1x187x194xf64, 1>
             ^
LLVM ERROR: operation destroyed but still has uses

This is presumably something to do with the use of innerOp in the for-loop. Do I need to use something like replaceAllUsesWith?

@wsmoses
Copy link
Member

wsmoses commented Mar 8, 2025

When I run gdb --args bazel-bin/enzymexlamlir-opt --raise-affine-to-stablehlo Enzyme-JAX/test/lit_tests/raising/raiseaffinefor_unroll.mlir, I see

/home/snarang181/Enzyme-JAX/test/lit_tests/raising/raiseaffinefor_unroll.mlir:10:13: error: 'stablehlo.dynamic_update_slice' op operation destroyed but still has uses
            affine.store %4, %arg0[0, %arg2 + 136, %arg1 + 7] : memref<1x187x194xf64, 1>
            ^
/home/snarang181/Enzyme-JAX/test/lit_tests/raising/raiseaffinefor_unroll.mlir:10:13: note: see current operation: %0 = "stablehlo.dynamic_update_slice"(<<NULL VALUE>>, <<NULL VALUE>>, <<NULL VALUE>>, <<NULL VALUE>>, <<NULL VALUE>>) : (<<NULL TYPE>>, <<NULL TYPE>>, <<NULL TYPE>>, <<NULL TYPE>>, <<NULL TYPE>>) -> tensor<1x187x194xf64>
/home/snarang181/Enzyme-JAX/test/lit_tests/raising/raiseaffinefor_unroll.mlir:19:9: note: - use: %39 = "stablehlo.dynamic_update_slice"(<<UNKNOWN SSA VALUE>>, %38, %34, %35, %36) : (tensor<1x187x194xf64>, tensor<1x1x180xf64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x187x194xf64>

        affine.store %3, %arg0[0, 135, %arg1 + 7] : memref<1x187x194xf64, 1>
        ^
/home/snarang181/Enzyme-JAX/test/lit_tests/raising/raiseaffinefor_unroll.mlir:13:14: note: - use: %17 = "stablehlo.slice"(<<UNKNOWN SSA VALUE>>) <{limit_indices = array<i64: 1, 136, 187>, start_indices = array<i64: 0, 135, 7>, strides = array<i64: 1, 1, 1>}> : (tensor<1x187x194xf64>) -> tensor<1x1x180xf64>

        %2 = affine.load %arg0[0, 135, %arg1 + 7] : memref<1x187x194xf64, 1>
             ^
/home/snarang181/Enzyme-JAX/test/lit_tests/raising/raiseaffinefor_unroll.mlir:12:14: note: - use: %14 = "stablehlo.slice"(<<UNKNOWN SSA VALUE>>) <{limit_indices = array<i64: 1, 136, 187>, start_indices = array<i64: 0, 135, 7>, strides = array<i64: 1, 1, 1>}> : (tensor<1x187x194xf64>) -> tensor<1x1x180xf64>

        %1 = affine.load %arg0[0, 135, -%arg1 + 186] : memref<1x187x194xf64, 1>
             ^
LLVM ERROR: operation destroyed but still has uses

This is presumably something to do with the use of innerOp in the for-loop. Do I need to use something like replaceAllUsesWith?

@Pangoraw since you said you had something like this earlier, can you take a quick look?

also cc @chelini who might have cycles to help co debug

affine-to-stable-hlo-raising pass
@Pangoraw
Copy link
Collaborator

Pangoraw commented Mar 8, 2025

I ran the upstream affine LICM and affine unrolling passes before invoking raise-affine-to-stablehlo

@ivanradanov
Copy link
Collaborator

@snarang181
I think the problem was that you were generating the stablehlo ops at the for op (which is in the parallel which we delete later), see the last commit. Does that look right?

@snarang181
Copy link
Collaborator Author

@snarang181 I think the problem was that you were generating the stablehlo ops at the for op (which is in the parallel which we delete later), see the last commit. Does that look right?

Thanks for looking into this @ivanradanov; and yes, that makes sense.

@snarang181 snarang181 marked this pull request as ready for review March 8, 2025 14:13
@snarang181 snarang181 self-assigned this Mar 8, 2025
@snarang181 snarang181 marked this pull request as draft March 8, 2025 14:43
@snarang181
Copy link
Collaborator Author

This seems to work fine on the test case but is giving a core dumped error on the full unoptimized mlir (running fullpipe.sh). Here is the error snippet

enzymexlamlir-opt: src/enzyme_ad/jax/Passes/AffineToStableHLORaising.cpp:282: mlir::affine::AffineValueMap alignMemoryAccess(mlir::Value&, mlir::affine::AffineValueMap, mlir::Value*, llvm::ArrayRef<mlir::affine::AffineValueMap>, mlir::OpBuilder&): Assertion `shapeA.size() == cast<RankedTensorType>(a.getType()).getShape().size()' failed

I am happy to debug further, some co-debugging help to get me on the right track would be appreciated though.

@wsmoses
Copy link
Member

wsmoses commented Mar 8, 2025

Probably I’d start by adding a print before each function being raised to make a minimal test case

@snarang181
Copy link
Collaborator Author

Probably I’d start by adding a print before each function being raised to make a minimal test case

Do I need to build with some args to get the LLVM_DEBUG to show on stdout?

@wsmoses
Copy link
Member

wsmoses commented Mar 8, 2025

You would add -debug for that.

however here I’m suggesting adding llvm::errs() << fn << “\n”; at an appropriate point inside the pass, recompiling, and running (which requires no extra flag)

@ivanradanov
Copy link
Collaborator

ivanradanov commented Mar 9, 2025

I get this:

enzymexlamlir-opt: src/enzyme_ad/jax/Passes/AffineToStableHLORaising.cpp:295: mlir::affine::AffineValueMap alignMemoryAccess(mlir::Value&, mlir::affine::AffineValueMap, mlir::Value*, llvm::ArrayRef<mlir::affine::AffineValueMap>, mlir::OpBuilder&): Assertion `shapeBs[i].size() == cast<RankedTensorType>(bs[i].getType()).getShape().size()' failed.

(not 100% sure) but it seems to be because here we emit a null AffineMap for a constant. I am not entirely sure what the correct one would be though

frame #7: 0x00000000073b4e46 enzymexlamlir-opt`tryRaisingOpToStableHLO(op=0x00000000123d31b0, mapping=0x00007fffffffb210, builder=0x00007fffffffb1d0, maps=0x00007fffffffb1b0) at AffineToStableHLORaising.cpp:784:65                                                             
   781            b = mapping.lookup(op->getOperand(1));                                                                                                                                                                                                                          
   782                                                                                                                                                                                                                                                                            
   783      auto mapA = maps[a], mapB = maps[b];                                                                                                                                                                                                                                  
-> 784      auto outputMap = alignMemoryAccess(a, mapA, b, mapB, builder);                                                                                                                                                                                                        
   785      assert(a.getType() == b.getType());                                                                                                                                                                                                                                   
   786                                                                                                                                                                                                                                                                            
   787      auto IT = a.getType().cast<RankedTensorType>();                                                                                                                                                                                                                       
(lldb) p a.dump()                                                                                                                                                                                                                                                                 
%770 = "stablehlo.constant"() <{value = dense<0> : tensor<1xi64>}> : () -> tensor<1xi64>                                                                                                                                                                                          

I am off for today but will continue investigating further tomorrow.

@snarang181
Copy link
Collaborator Author

I get this:


enzymexlamlir-opt: src/enzyme_ad/jax/Passes/AffineToStableHLORaising.cpp:295: mlir::affine::AffineValueMap alignMemoryAccess(mlir::Value&, mlir::affine::AffineValueMap, mlir::Value*, llvm::ArrayRef<mlir::affine::AffineValueMap>, mlir::OpBuilder&): Assertion `shapeBs[i].size() == cast<RankedTensorType>(bs[i].getType()).getShape().size()' failed.

(not 100% sure) but it seems to be because here we emit a null AffineMap for a constant. I am not entirely sure what the correct one would be though


frame #7: 0x00000000073b4e46 enzymexlamlir-opt`tryRaisingOpToStableHLO(op=0x00000000123d31b0, mapping=0x00007fffffffb210, builder=0x00007fffffffb1d0, maps=0x00007fffffffb1b0) at AffineToStableHLORaising.cpp:784:65                                                             

   781            b = mapping.lookup(op->getOperand(1));                                                                                                                                                                                                                          

   782                                                                                                                                                                                                                                                                            

   783      auto mapA = maps[a], mapB = maps[b];                                                                                                                                                                                                                                  

-> 784      auto outputMap = alignMemoryAccess(a, mapA, b, mapB, builder);                                                                                                                                                                                                        

   785      assert(a.getType() == b.getType());                                                                                                                                                                                                                                   

   786                                                                                                                                                                                                                                                                            

   787      auto IT = a.getType().cast<RankedTensorType>();                                                                                                                                                                                                                       

(lldb) p a.dump()                                                                                                                                                                                                                                                                 

%770 = "stablehlo.constant"() <{value = dense<0> : tensor<1xi64>}> : () -> tensor<1xi64>                                                                                                                                                                                          

I am off for today but will continue investigating further tomorrow.

@ivanradanov, that's right. Feel free to ping me on slack when you get online your tomorrow and I'd like to co-debug if possible to get a feel for things. Thanks again for looking into it.

@wsmoses
Copy link
Member

wsmoses commented Mar 9, 2025

cc @Pangoraw

@wsmoses
Copy link
Member

wsmoses commented Mar 9, 2025

can you add the failing test here as a test file for ease?

@Pangoraw
Copy link
Collaborator

Pangoraw commented Mar 9, 2025

The current approach is missing handling the converted IV as a constant when raising the memref load/store. In the current situation I guess the load/store would result in slices over all possible values.

Solving this problem would probably be a first step into making the loops stablehlo.while since the iv would no longer be a constant but a block argument.

In the meantime, we can run the upstream affine unrolling pass?

@Pangoraw
Copy link
Collaborator

Pangoraw commented Mar 9, 2025

(not 100% sure) but it seems to be because here we emit a null AffineMap for a constant. I am not entirely sure what the correct one would be though

So these maps are used to identify for each dim in the raised tensors by which induction variable it is dependent. A scalar constant is raised to a 0-dim tensor and therefore as a map with zero dimensions.

For example the following value:

%0 = affine.load %a[%i, %j]

would have a map like (%i, %j) -> (%i, %j). The maps are then used to align the different axes via broadcast/transpose when two values need to interact together to produce a resulting value.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants