Skip to content

Commit f37d4a5

Browse files
authored
[SymmetrySectors] Non-abelian fusion (#1363)
1 parent 8a2abce commit f37d4a5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+2158
-782
lines changed

NDTensors/src/imports.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ for lib in [
3333
:MetalExtensions,
3434
:BroadcastMapConversion,
3535
:RankFactorization,
36-
:Sectors,
3736
:LabelledNumbers,
3837
:GradedAxes,
38+
:SymmetrySectors,
3939
:TensorAlgebra,
4040
:SparseArrayInterface,
4141
:SparseArrayDOKs,

NDTensors/src/lib/GradedAxes/ext/GradedAxesSectorsExt/Project.toml

-2
This file was deleted.

NDTensors/src/lib/GradedAxes/ext/GradedAxesSectorsExt/src/GradedAxesSectorsExt.jl

-9
This file was deleted.

NDTensors/src/lib/GradedAxes/ext/GradedAxesSectorsExt/test/Project.toml

-3
This file was deleted.

NDTensors/src/lib/GradedAxes/ext/GradedAxesSectorsExt/test/runtests.jl

-15
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
module GradedAxes
22
include("blockedunitrange.jl")
33
include("gradedunitrange.jl")
4-
include("fusion.jl")
54
include("dual.jl")
65
include("unitrangedual.jl")
7-
include("../ext/GradedAxesSectorsExt/src/GradedAxesSectorsExt.jl")
6+
include("fusion.jl")
87
end

NDTensors/src/lib/GradedAxes/src/dual.jl

+2
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@ using NDTensors.LabelledNumbers:
55
label_dual(x) = label_dual(LabelledStyle(x), x)
66
label_dual(::NotLabelled, x) = x
77
label_dual(::IsLabelled, x) = labelled(unlabel(x), dual(label(x)))
8+
9+
flip(g::AbstractGradedUnitRange) = dual(gradedrange(label_dual.(blocklengths(g))))
+44-31
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
using BlockArrays: AbstractBlockedUnitRange
1+
using BlockArrays: AbstractBlockedUnitRange, blocklengths
22

33
# Represents the range `1:1` or `Base.OneTo(1)`.
44
struct OneToOne{T} <: AbstractUnitRange{T} end
55
OneToOne() = OneToOne{Bool}()
66
Base.first(a::OneToOne) = one(eltype(a))
77
Base.last(a::OneToOne) = one(eltype(a))
8+
BlockArrays.blockaxes(g::OneToOne) = (Block.(g),) # BlockArrays default crashes for OneToOne{Bool}
89

910
# https://github.com/ITensor/ITensors.jl/blob/v0.3.57/NDTensors/src/lib/GradedAxes/src/tensor_product.jl
1011
# https://en.wikipedia.org/wiki/Tensor_product
@@ -18,23 +19,25 @@ function tensor_product(
1819
return foldl(tensor_product, (a1, a2, a3, a_rest...))
1920
end
2021

22+
flip_dual(r::AbstractUnitRange) = r
23+
flip_dual(r::UnitRangeDual) = flip(r)
2124
function tensor_product(a1::AbstractUnitRange, a2::AbstractUnitRange)
22-
return error("Not implemented yet.")
25+
return tensor_product(flip_dual(a1), flip_dual(a2))
2326
end
2427

2528
function tensor_product(a1::Base.OneTo, a2::Base.OneTo)
2629
return Base.OneTo(length(a1) * length(a2))
2730
end
2831

29-
function tensor_product(a1::OneToOne, a2::AbstractUnitRange)
32+
function tensor_product(::OneToOne, a2::AbstractUnitRange)
3033
return a2
3134
end
3235

33-
function tensor_product(a1::AbstractUnitRange, a2::OneToOne)
36+
function tensor_product(a1::AbstractUnitRange, ::OneToOne)
3437
return a1
3538
end
3639

37-
function tensor_product(a1::OneToOne, a2::OneToOne)
40+
function tensor_product(::OneToOne, ::OneToOne)
3841
return OneToOne()
3942
end
4043

@@ -45,27 +48,28 @@ function fuse_labels(x, y)
4548
end
4649

4750
function fuse_blocklengths(x::Integer, y::Integer)
48-
return x * y
51+
# return blocked unit range to keep non-abelian interface
52+
return blockedrange([x * y])
4953
end
5054

5155
using ..LabelledNumbers: LabelledInteger, label, labelled, unlabel
5256
function fuse_blocklengths(x::LabelledInteger, y::LabelledInteger)
53-
return labelled(unlabel(x) * unlabel(y), fuse_labels(label(x), label(y)))
57+
# return blocked unit range to keep non-abelian interface
58+
return blockedrange([labelled(x * y, fuse_labels(label(x), label(y)))])
5459
end
5560

5661
using BlockArrays: blockedrange, blocks
5762
function tensor_product(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange)
58-
blocklengths = map(vec(collect(Iterators.product(blocks(a1), blocks(a2))))) do x
59-
return mapreduce(length, fuse_blocklengths, x)
63+
nested = map(Iterators.flatten((Iterators.product(blocks(a1), blocks(a2)),))) do it
64+
return mapreduce(length, fuse_blocklengths, it)
6065
end
61-
return blockedrange(blocklengths)
66+
new_blocklengths = mapreduce(blocklengths, vcat, nested)
67+
return blockedrange(new_blocklengths)
6268
end
6369

64-
function blocksortperm(a::AbstractBlockedUnitRange)
65-
# TODO: Figure out how to deal with dual sectors.
66-
# TODO: `rev=isdual(a)` may not be correct for symmetries beyond `U(1)`.
67-
## return Block.(sortperm(nondual_sectors(a); rev=isdual(a)))
68-
return Block.(sortperm(blocklabels(a)))
70+
# convention: sort UnitRangeDual according to nondual blocks
71+
function blocksortperm(a::AbstractUnitRange)
72+
return Block.(sortperm(blocklabels(nondual(a))))
6973
end
7074

7175
using BlockArrays: Block, BlockVector
@@ -82,25 +86,34 @@ end
8286
# Used by `TensorAlgebra.splitdims` in `BlockSparseArraysGradedAxesExt`.
8387
# Get the permutation for sorting, then group by common elements.
8488
# groupsortperm([2, 1, 2, 3]) == [[2], [1, 3], [4]]
85-
function blockmergesortperm(a::AbstractBlockedUnitRange)
86-
# If it is dual, reverse the sorting so the sectors
87-
# end up sorted in the same way whether or not the space
88-
# is dual.
89-
# TODO: Figure out how to deal with dual sectors.
90-
# TODO: `rev=isdual(a)` may not be correct for symmetries beyond `U(1)`.
91-
## return Block.(groupsortperm(nondual_sectors(a); rev=isdual(a)))
92-
return Block.(groupsortperm(blocklabels(a)))
89+
function blockmergesortperm(a::AbstractUnitRange)
90+
return Block.(groupsortperm(blocklabels(nondual(a))))
9391
end
9492

9593
# Used by `TensorAlgebra.splitdims` in `BlockSparseArraysGradedAxesExt`.
9694
invblockperm(a::Vector{<:Block{1}}) = Block.(invperm(Int.(a)))
9795

98-
# Used by `TensorAlgebra.fusedims` in `BlockSparseArraysGradedAxesExt`.
99-
function blockmergesortperm(a::GradedUnitRange)
100-
# If it is dual, reverse the sorting so the sectors
101-
# end up sorted in the same way whether or not the space
102-
# is dual.
103-
# TODO: Figure out how to deal with dual sectors.
104-
# TODO: `rev=isdual(a)` may not be correct for symmetries beyond `U(1)`.
105-
return Block.(groupsortperm(blocklabels(a)))
96+
function blockmergesort(g::AbstractGradedUnitRange)
97+
glabels = blocklabels(g)
98+
gblocklengths = blocklengths(g)
99+
new_blocklengths = map(sort(unique(glabels))) do la
100+
return labelled(sum(gblocklengths[findall(==(la), glabels)]; init=0), la)
101+
end
102+
return gradedrange(new_blocklengths)
103+
end
104+
105+
blockmergesort(g::UnitRangeDual) = flip(blockmergesort(flip(g)))
106+
blockmergesort(g::AbstractUnitRange) = g
107+
108+
# fusion_product produces a sorted, non-dual GradedUnitRange
109+
function fusion_product(g1, g2)
110+
return blockmergesort(tensor_product(g1, g2))
111+
end
112+
113+
fusion_product(g::AbstractUnitRange) = blockmergesort(g)
114+
fusion_product(g::UnitRangeDual) = fusion_product(flip(g))
115+
116+
# recursive fusion_product. Simpler than reduce + fix type stability issues with reduce
117+
function fusion_product(g1, g2, g3...)
118+
return fusion_product(fusion_product(g1, g2), g3...)
106119
end

NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl

+14-1
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@ using BlockArrays:
1212
blockedrange,
1313
BlockIndexRange,
1414
blockfirsts,
15-
blocklasts,
15+
blockisequal,
1616
blocklength,
1717
blocklengths,
1818
findblock,
1919
findblockindex,
2020
mortar
2121
using Compat: allequal
22+
using FillArrays: Fill
2223
using ..LabelledNumbers:
2324
LabelledNumbers, LabelledInteger, LabelledUnitRange, label, labelled, unlabel
2425

@@ -37,6 +38,18 @@ function Base.OrdinalRange{T,T}(a::GradedOneTo{<:LabelledInteger{T}}) where {T}
3738
return unlabel_blocks(a)
3839
end
3940

41+
# == is just a range comparison that ignores labels. Need dedicated function to check equality.
42+
struct NoLabel end
43+
blocklabels(r::AbstractUnitRange) = Fill(NoLabel(), blocklength(r))
44+
45+
function labelled_isequal(a1::AbstractUnitRange, a2::AbstractUnitRange)
46+
return blockisequal(a1, a2) && (blocklabels(a1) == blocklabels(a2))
47+
end
48+
49+
function space_isequal(a1::AbstractUnitRange, a2::AbstractUnitRange)
50+
return (isdual(a1) == isdual(a2)) && labelled_isequal(a1, a2)
51+
end
52+
4053
# This is only needed in certain Julia versions below 1.10
4154
# (for example Julia 1.6).
4255
# TODO: Delete this once we drop Julia 1.6 support.

NDTensors/src/lib/GradedAxes/src/unitrangedual.jl

+16
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ UnitRangeDual(a::AbstractUnitRange) = UnitRangeDual{eltype(a),typeof(a)}(a)
66
dual(a::AbstractUnitRange) = UnitRangeDual(a)
77
nondual(a::UnitRangeDual) = a.nondual_unitrange
88
dual(a::UnitRangeDual) = nondual(a)
9+
flip(a::UnitRangeDual) = dual(flip(nondual(a)))
910
nondual(a::AbstractUnitRange) = a
11+
isdual(::AbstractUnitRange) = false
12+
isdual(::UnitRangeDual) = true
1013
## TODO: Define this to instantiate a dual unit range.
1114
## materialize_dual(a::UnitRangeDual) = materialize_dual(nondual(a))
1215

@@ -16,6 +19,16 @@ Base.step(a::UnitRangeDual) = label_dual(step(nondual(a)))
1619

1720
Base.view(a::UnitRangeDual, index::Block{1}) = a[index]
1821

22+
function Base.show(io::IO, a::UnitRangeDual)
23+
return print(io, UnitRangeDual, "(", blocklasts(a), ")")
24+
end
25+
26+
function Base.show(io::IO, mimetype::MIME"text/plain", a::UnitRangeDual)
27+
return Base.invoke(
28+
show, Tuple{typeof(io),MIME"text/plain",AbstractArray}, io, mimetype, a
29+
)
30+
end
31+
1932
function Base.getindex(a::UnitRangeDual, indices::AbstractUnitRange{<:Integer})
2033
return dual(getindex(nondual(a), indices))
2134
end
@@ -92,6 +105,9 @@ BlockArrays.blockaxes(a::UnitRangeDual) = blockaxes(nondual(a))
92105
BlockArrays.blockfirsts(a::UnitRangeDual) = label_dual.(blockfirsts(nondual(a)))
93106
BlockArrays.blocklasts(a::UnitRangeDual) = label_dual.(blocklasts(nondual(a)))
94107
BlockArrays.findblock(a::UnitRangeDual, index::Integer) = findblock(nondual(a), index)
108+
109+
blocklabels(a::UnitRangeDual) = dual.(blocklabels(nondual(a)))
110+
95111
function BlockArrays.combine_blockaxes(a1::UnitRangeDual, a2::UnitRangeDual)
96112
return dual(combine_blockaxes(dual(a1), dual(a2)))
97113
end

NDTensors/src/lib/GradedAxes/test/runtests.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
using Test: @testset
33
@testset "GradedAxes" begin
44
include("test_basics.jl")
5-
include("test_tensor_product.jl")
65
include("test_dual.jl")
6+
include("test_tensor_product.jl")
77
end
88
end

NDTensors/src/lib/GradedAxes/test/test_basics.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ using BlockArrays:
99
blocklength,
1010
blocklengths,
1111
blocks
12-
using NDTensors.GradedAxes: GradedOneTo, GradedUnitRange, blocklabels, gradedrange
12+
using NDTensors.GradedAxes:
13+
GradedOneTo, GradedUnitRange, blocklabels, labelled_isequal, gradedrange
1314
using NDTensors.LabelledNumbers: LabelledUnitRange, islabelled, label, labelled, unlabel
1415
using Test: @test, @test_broken, @testset
1516
@testset "GradedAxes basics" begin
@@ -40,6 +41,7 @@ using Test: @test, @test_broken, @testset
4041
@test label(x) == "y"
4142
end
4243
@test isnothing(iterate(a, labelled(5, "y")))
44+
@test labelled_isequal(a, a)
4345
@test length(a) == 5
4446
@test step(a) == 1
4547
@test !islabelled(step(a))

NDTensors/src/lib/GradedAxes/test/test_dual.jl

+57-5
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,40 @@
11
@eval module $(gensym())
2-
using BlockArrays: Block, blockaxes, blockfirsts, blocklasts, blocks, findblock
3-
using NDTensors.GradedAxes: GradedAxes, UnitRangeDual, dual, gradedrange, nondual
2+
using BlockArrays:
3+
Block, blockaxes, blockfirsts, blocklasts, blocklength, blocklengths, blocks, findblock
4+
using NDTensors.GradedAxes:
5+
GradedAxes,
6+
UnitRangeDual,
7+
blocklabels,
8+
blockmergesortperm,
9+
blocksortperm,
10+
dual,
11+
flip,
12+
space_isequal,
13+
gradedrange,
14+
isdual,
15+
nondual
416
using NDTensors.LabelledNumbers: LabelledInteger, label, labelled
517
using Test: @test, @test_broken, @testset
618
struct U1
719
n::Int
820
end
921
GradedAxes.dual(c::U1) = U1(-c.n)
22+
Base.isless(c1::U1, c2::U1) = c1.n < c2.n
1023
@testset "dual" begin
1124
a = gradedrange([U1(0) => 2, U1(1) => 3])
1225
ad = dual(a)
1326
@test eltype(ad) == LabelledInteger{Int,U1}
14-
@test dual(ad) == a
15-
@test nondual(ad) == a
16-
@test nondual(a) == a
27+
28+
@test space_isequal(dual(ad), a)
29+
@test space_isequal(nondual(ad), a)
30+
@test space_isequal(nondual(a), a)
31+
@test space_isequal(ad, ad)
32+
@test !space_isequal(a, ad)
33+
@test !space_isequal(ad, a)
34+
35+
@test isdual(ad)
36+
@test !isdual(a)
37+
1738
@test blockfirsts(ad) == [labelled(1, U1(0)), labelled(3, U1(-1))]
1839
@test blocklasts(ad) == [labelled(2, U1(0)), labelled(5, U1(-1))]
1940
@test findblock(ad, 4) == Block(2)
@@ -34,5 +55,36 @@ GradedAxes.dual(c::U1) = U1(-c.n)
3455
@test label(ad[[Block(2), Block(1)]][Block(1)]) == U1(-1)
3556
@test ad[[Block(2)[1:2], Block(1)[1:2]]][Block(1)] == 3:4
3657
@test label(ad[[Block(2)[1:2], Block(1)[1:2]]][Block(1)]) == U1(-1)
58+
@test blocksortperm(a) == [Block(1), Block(2)]
59+
@test blocksortperm(ad) == [Block(1), Block(2)]
60+
@test blocklength(blockmergesortperm(a)) == 2
61+
@test blocklength(blockmergesortperm(ad)) == 2
62+
@test blockmergesortperm(a) == [Block(1), Block(2)]
63+
@test blockmergesortperm(ad) == [Block(1), Block(2)]
64+
end
65+
66+
@testset "flip" begin
67+
a = gradedrange([U1(0) => 2, U1(1) => 3])
68+
ad = dual(a)
69+
@test space_isequal(flip(a), dual(gradedrange([U1(0) => 2, U1(-1) => 3])))
70+
@test space_isequal(flip(ad), gradedrange([U1(0) => 2, U1(-1) => 3]))
71+
72+
@test blocklabels(a) == [U1(0), U1(1)]
73+
@test blocklabels(dual(a)) == [U1(0), U1(-1)]
74+
@test blocklabels(flip(a)) == [U1(0), U1(1)]
75+
@test blocklabels(flip(dual(a))) == [U1(0), U1(-1)]
76+
@test blocklabels(dual(flip(a))) == [U1(0), U1(-1)]
77+
78+
@test blocklengths(a) == [2, 3]
79+
@test blocklengths(dual(a)) == [2, 3]
80+
@test blocklengths(flip(a)) == [2, 3]
81+
@test blocklengths(flip(dual(a))) == [2, 3]
82+
@test blocklengths(dual(flip(a))) == [2, 3]
83+
84+
@test !isdual(a)
85+
@test isdual(dual(a))
86+
@test isdual(flip(a))
87+
@test !isdual(flip(dual(a)))
88+
@test !isdual(dual(flip(a)))
3789
end
3890
end

0 commit comments

Comments
 (0)