Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor strings #3

Merged
merged 3 commits into from
Apr 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

Tools to enable [StaticCompiler.jl](https://github.com/tshort/StaticCompiler.jl)-based static compilation of Julia code to standalone native binaries by eliding GC allocations and `llvmcall`-ing all the things.

This package currently requires Julia 1.8+
This package currently requires Julia 1.8+ (in particular, 1.8.0-beta3 is known to work).

You probably also want to check out the latest main branch of this package rather than the latest registered version (since the package requires 1.8, new releases need to be merged manually in the general registry until the registry switches over to 1.8, so I'm holding off to spare registry maintainers the workload)

Caution: this package should be considered experimental at present, and involves a lot of juggling of pointers

Expand Down Expand Up @@ -129,7 +131,7 @@ julia> function times_table(argc::Int, argv::Ptr{Ptr{UInt8}})
times_table (generic function with 1 method)

julia> filepath = compile_executable(times_table, (Int64, Ptr{Ptr{UInt8}}), "./")
"/Users/user/times_table"
"/Users/user/code/StaticTools.jl/times_table"

shell> ./times_table 3 3
1 2 3
Expand All @@ -145,3 +147,33 @@ The same array, reinterpreted as Int32:
3 6 9
0 0 0
```

We also have random number generators:
```julia
julia> function rand_matrix(argc::Int, argv::Ptr{Ptr{UInt8}})
argc == 3 || return printf(stderrp(), c"Incorrect number of command-line arguments\n")
rows = parse(Int64, argv, 2) # First command-line argument
cols = parse(Int64, argv, 3) # Second command-line argument

M = MallocArray{Float64}(undef, rows, cols)
rng = static_rng()
@inbounds for i=1:rows
for j=1:cols
M[i,j] = rand(rng)
end
end
printf(M)
free(M)
end
rand_matrix (generic function with 1 method)

julia> compile_executable(rand_matrix, (Int64, Ptr{Ptr{UInt8}}), "./")
"/Users/user/code/StaticTools.jl/rand_matrix"

shell> ./rand_matrix 5 5
7.890932e-01 7.532989e-01 8.593202e-01 4.790301e-01 6.464508e-01
5.619692e-01 9.800402e-02 8.545220e-02 5.545224e-02 2.966089e-01
7.021460e-01 4.587692e-01 9.316740e-01 8.736913e-01 8.271038e-01
8.098993e-01 5.368138e-01 3.055373e-02 3.972266e-01 8.146640e-01
8.241520e-01 7.532375e-01 2.969434e-01 9.436580e-01 2.819992e-01
```
24 changes: 13 additions & 11 deletions src/mallocstring.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,19 @@
# Fundamentals
@inline Base.unsafe_convert(::Type{Ptr{T}}, m::MallocString) where {T} = Ptr{T}(s.pointer)
@inline Base.pointer(s::MallocString) = s.pointer
@inline Base.length(s::MallocString) = s.length
@inline Base.sizeof(s::MallocString) = s.length
@inline Base.length(s::MallocString) = s.length-1 # For consistency with base
@inline Base.sizeof(s::MallocString) = s.length # Thou shalt not lie
@inline function Base.:(==)(a::MallocString, b::MallocString)
(N = length(a)) == length(b) || return false
pa, pb = pointer(a), pointer(b)
for n in 0:N-1
for n 0:N
unsafe_load(pa + n) == unsafe_load(pb + n) || return false
end
return true
end
const NullTerminatedString = Union{StaticString, MallocString}
@inline function Base.:(==)(a::NullTerminatedString, b::NullTerminatedString)
const AnyString = Union{NullTerminatedString, AbstractString}
@inline function Base.:(==)(a::AnyString, b::AnyString)
GC.@preserve a b begin
(N = length(a)) == length(b) || return false
pa, pb = pointer(a), pointer(b)
Expand Down Expand Up @@ -71,6 +72,7 @@
# Some of the AbstractArray interface:
@inline Base.firstindex(s::MallocString) = 1
@inline Base.lastindex(s::MallocString) = s.length
@inline Base.eachindex(s::MallocString) = 1:s.length
@inline Base.getindex(s::MallocString, i::Int) = unsafe_load(pointer(s)+(i-1))
@inline Base.setindex!(s::MallocString, x::UInt8, i::Integer) = unsafe_store!(pointer(s)+(i-1), x)
@inline Base.setindex!(s::MallocString, x, i::Integer) = unsafe_store!(pointer(s)+(i-1), convert(UInt8,x))
Expand All @@ -84,13 +86,13 @@
end
end
@inline function Base.setindex!(s::MallocString, x, ::Colon)
ix₀ = firstindex(x)-1
for i = 1:length(s)
ix₀ = firstindex(x)-firstindex(s)
for i ∈ eachindex(s)
s[i] = x[i+ix₀]
end
end
@inline function Base.copy(s::MallocString)
new_s = MallocString(undef, length(s))
new_s = MallocString(undef, ncodeunits(s))
new_s[:] = s
return new_s
end
Expand All @@ -101,15 +103,15 @@
@inline Base.codeunit(s::MallocString) = UInt8
@inline Base.codeunit(s::MallocString, i::Integer) = s[i]
@inline function Base.:*(a::MallocString, b::MallocString) # Concatenation
N = length(a) + length(b) - 1
N = length(a) + length(b) + 1
c = MallocString(undef, N)
c[1:length(a)-1] = a
c[length(a):end-1] = b
c[1:length(a)] = a
c[length(a)+1:length(a)+length(b)] = b
c[end] = 0x00 # Null-terminate
return c
end
@inline function Base.:^(s::MallocString, n::Integer) # Repetition
l = length(s)-1 # Excluding the null-termination
l = length(s) # Excluding the null-termination
N = n*l + 1
c = MallocString(undef, N)
for i=1:n
Expand Down
19 changes: 10 additions & 9 deletions src/staticstring.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
# Fundamentals
@inline Base.unsafe_convert(::Type{Ptr{T}}, m::StaticString) where {T} = Ptr{T}(pointer_from_objref(m))
@inline Base.pointer(m::StaticString{N}) where {N} = Ptr{UInt8}(pointer_from_objref(m))
@inline Base.length(s::StaticString{N}) where N = N
@inline Base.length(s::StaticString{N}) where N = N-1
@inline Base.:(==)(::StaticString, ::StaticString) = false
@inline function Base.:(==)(a::StaticString{N}, b::StaticString{N}) where N
GC.@preserve a b begin
pa, pb = pointer(a), pointer(b)
for n in 0:N-1
for n 0:N-1
unsafe_load(pa + n) == unsafe_load(pb + n) || return false
end
return true
Expand All @@ -50,7 +50,8 @@

# Implement some of the AbstractArray interface:
@inline Base.firstindex(s::StaticString) = 1
@inline Base.lastindex(s::StaticString{N}) where N = N
@inline Base.lastindex(s::StaticString{N}) where {N} = N
@inline Base.eachindex(s::StaticString{N}) where {N} = 1:N
@inline Base.getindex(s::StaticString, i::Int) = unsafe_load(pointer(s)+(i-1))
@inline Base.getindex(s::StaticString, r::AbstractArray{Int}) = StaticString(codeunits(s)[r]) # Should probably null-terminate
@inline Base.getindex(s::StaticString, ::Colon) = s
Expand All @@ -64,8 +65,8 @@
end
end
@inline function Base.setindex!(s::StaticString, x, ::Colon)
ix₀ = firstindex(x)-1
@inbounds for i = 1:length(s)
ix₀ = firstindex(x)-firstindex(s)
@inbounds for i ∈ eachindex(s)
setindex!(s, x[i+ix₀], i)
end
end
Expand All @@ -78,15 +79,15 @@
@inline Base.codeunit(s::StaticString) = UInt8
@inline Base.codeunit(s::StaticString, i::Integer) = s[i]
@inline function Base.:*(a::StaticString, b::StaticString) # Concatenation
N = length(a) + length(b) - 1
N = length(a) + length(b) + 1
c = StaticString{N}(undef)
c[1:length(a)-1] = a
c[length(a):end-1] = b
c[1:length(a)] = a
c[length(a)+1:length(a)+length(b)] = b
c[end] = 0x00 # Null-terminate
return c
end
@inline function Base.:^(s::StaticString, n::Integer) # Repetition
l = length(s)-1 # Excluding the null-termination
l = length(s) # Excluding the null-termination, remember
N = n*l + 1
c = StaticString{N}(undef)
for i=1:n
Expand Down
14 changes: 9 additions & 5 deletions test/testmallocstring.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
# Test MallocString constructors
str = m"Hello, world! 🌍"
@test isa(str, MallocString)
@test length(str) == 19
@test sizeof(str) == 19
@test StaticTools.strlen(str) == 18
@test StaticTools.strlen(str) == length(str) == 18

# Test basic string operations
@test str == m"Hello, world! 🌍"
Expand All @@ -15,7 +14,7 @@
@test free(p) == 0
@test codeunit(str) === UInt8
@test codeunit(str, 5) == UInt8('o')
@test ncodeunits(str) == length(str)
@test ncodeunits(str) == length(str)+1
@test codeunits(str) == codeunits(c"Hello, world! 🌍")
@test codeunits(c"Hello, world! 🌍") == codeunits(str)

Expand All @@ -37,14 +36,14 @@
# Test ascii escaping
many_escapes = m"\"\0\a\b\f\n\r\t\v\'\"\\"
@test isa(many_escapes, MallocString)
@test length(many_escapes) == 13
@test length(many_escapes) == 12
@test codeunits(many_escapes) == codeunits("\"\0\a\b\f\n\r\t\v'\"\\\0")

# Test unsafe_mallocstring
s = "Hello there!"
m = unsafe_mallocstring(pointer(s))
@test isa(m, MallocString)
@test length(m) == 13
@test length(m) == 12
@test codeunits(m) == codeunits(c"Hello there!")
@test free(m) == 0

Expand All @@ -54,3 +53,8 @@
argv = pointer(a)
@test MallocString(argv,1) == c"Hello"
@test MallocString(argv,2) == c"there"

# Test consistency with base strings
abc = m"abc"
@test abc == "abc"
free(abc)
13 changes: 8 additions & 5 deletions test/teststaticstring.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
str = c"Hello, world! 🌍"
@test isa(str, StaticString{19})
@test sizeof(str) == 19
@test StaticTools.strlen(str) == 18
@test StaticTools.strlen(str) == length(str) == 18

# Test basic string operations
@test str == c"Hello, world! 🌍"
@test str*str == str^2
@test codeunit(str) === UInt8
@test codeunit(str, 5) == UInt8('o')
@test ncodeunits(str) == length(str)
@test ncodeunits(str) == length(str)+1
@test codeunits(c"Hello") == codeunits(c"Hello")

# Test mutability
Expand All @@ -28,7 +28,10 @@
@test str == copy(str)

# Test ascii escaping
many_escapes = c"\0\a\b\f\n\r\t\v'\"\\"
@test isa(many_escapes, StaticString{12})
many_escapes = c"\"\0\a\b\f\n\r\t\v\'\"\\"
@test isa(many_escapes, StaticString{13})
@test length(many_escapes) == 12
@test all(codeunits(many_escapes) .== codeunits("\0\a\b\f\n\r\t\v'\"\\\0"))
@test all(codeunits(many_escapes) .== codeunits("\"\0\a\b\f\n\r\t\v'\"\\\0"))

# Test consistency with base strings
@test c"abc" == "abc"