Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: UTF-8 string validation #3958

Merged
merged 3 commits into from
Apr 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions src/Init/Data/String/Extra.lean
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,25 @@ def toNat! (s : String) : Nat :=
else
panic! "Nat expected"

/--
Convert a [UTF-8](https://en.wikipedia.org/wiki/UTF-8) encoded `ByteArray` string to `String`.
The result is unspecified if `a` is not properly UTF-8 encoded.
-/
@[extern "lean_string_from_utf8_unchecked"]
opaque fromUTF8Unchecked (a : @& ByteArray) : String
/-- Returns true if the given byte array consists of valid UTF-8. -/
@[extern "lean_string_validate_utf8"]
opaque validateUTF8 (a : @& ByteArray) : Bool

/-- Converts a [UTF-8](https://en.wikipedia.org/wiki/UTF-8) encoded `ByteArray` string to `String`. -/
@[extern "lean_string_from_utf8"]
opaque fromUTF8 (a : @& ByteArray) (h : validateUTF8 a) : String

/-- Converts a [UTF-8](https://en.wikipedia.org/wiki/UTF-8) encoded `ByteArray` string to `String`,
or returns `none` if `a` is not properly UTF-8 encoded. -/
@[inline] def fromUTF8? (a : ByteArray) : Option String :=
if h : validateUTF8 a then fromUTF8 a h else none

/-- Converts a [UTF-8](https://en.wikipedia.org/wiki/UTF-8) encoded `ByteArray` string to `String`,
or panics if `a` is not properly UTF-8 encoded. -/
@[inline] def fromUTF8! (a : ByteArray) : String :=
if h : validateUTF8 a then fromUTF8 a h else panic! "invalid UTF-8 string"

/-- Convert the given `String` to a [UTF-8](https://en.wikipedia.org/wiki/UTF-8) encoded byte array. -/
/-- Converts the given `String` to a [UTF-8](https://en.wikipedia.org/wiki/UTF-8) encoded byte array. -/
@[extern "lean_string_to_utf8"]
opaque toUTF8 (a : @& String) : ByteArray

Expand Down
18 changes: 11 additions & 7 deletions src/Init/System/IO.lean
Original file line number Diff line number Diff line change
Expand Up @@ -768,12 +768,16 @@ def ofBuffer (r : Ref Buffer) : Stream where
write := fun data => r.modify fun b =>
-- set `exact` to `false` so that repeatedly writing to the stream does not impose quadratic run time
{ b with data := data.copySlice 0 b.data b.pos data.size false, pos := b.pos + data.size }
getLine := r.modifyGet fun b =>
let pos := match b.data.findIdx? (start := b.pos) fun u => u == 0 || u = '\n'.toNat.toUInt8 with
-- include '\n', but not '\0'
| some pos => if b.data.get! pos == 0 then pos else pos + 1
| none => b.data.size
(String.fromUTF8Unchecked <| b.data.extract b.pos pos, { b with pos := pos })
getLine := do
let buf ← r.modifyGet fun b =>
let pos := match b.data.findIdx? (start := b.pos) fun u => u == 0 || u = '\n'.toNat.toUInt8 with
-- include '\n', but not '\0'
| some pos => if b.data.get! pos == 0 then pos else pos + 1
| none => b.data.size
(b.data.extract b.pos pos, { b with pos := pos })
match String.fromUTF8? buf with
| some str => pure str
| none => throw (.userError "invalid UTF-8")
putStr := fun s => r.modify fun b =>
let data := s.toUTF8
{ b with data := data.copySlice 0 b.data b.pos data.size false, pos := b.pos + data.size }
Expand All @@ -791,7 +795,7 @@ def withIsolatedStreams [Monad m] [MonadFinally m] [MonadLiftT BaseIO m] (x : m
(if isolateStderr then withStderr (Stream.ofBuffer bOut) else id) <|
x
let bOut ← liftM (m := BaseIO) bOut.get
let out := String.fromUTF8Unchecked bOut.data
let out := String.fromUTF8! bOut.data
pure (out, r)

end FS
Expand Down
2 changes: 1 addition & 1 deletion src/Init/System/Uri.lean
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def decodeUri (uri : String) : String := Id.run do
((decoded.push c).push h1, i + 2)
else
(decoded.push c, i + 1)
return String.fromUTF8Unchecked decoded
return String.fromUTF8! decoded
where hexDigitToUInt8? (c : UInt8) : Option UInt8 :=
if zero ≤ c ∧ c ≤ nine then some (c - zero)
else if lettera ≤ c ∧ c ≤ letterf then some (c - lettera + 10)
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Data/Json/Stream.lean
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ open IO
/-- Consumes `nBytes` bytes from the stream, interprets the bytes as a utf-8 string and the string as a valid JSON object. -/
def readJson (h : FS.Stream) (nBytes : Nat) : IO Json := do
let bytes ← h.read (USize.ofNat nBytes)
let s := String.fromUTF8Unchecked bytes
let some s := String.fromUTF8? bytes | throw (IO.userError "invalid UTF-8")
ofExcept (Json.parse s)

def writeJson (h : FS.Stream) (j : Json) : IO Unit := do
Expand Down
32 changes: 18 additions & 14 deletions src/runtime/object.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1614,10 +1614,14 @@ extern "C" LEAN_EXPORT object * lean_mk_string(char const * s) {
return lean_mk_string_from_bytes(s, strlen(s));
}

extern "C" LEAN_EXPORT obj_res lean_string_from_utf8_unchecked(b_obj_arg a) {
extern "C" LEAN_EXPORT obj_res lean_string_from_utf8(b_obj_arg a) {
return lean_mk_string_from_bytes(reinterpret_cast<char *>(lean_sarray_cptr(a)), lean_sarray_size(a));
}

extern "C" LEAN_EXPORT uint8 lean_string_validate_utf8(b_obj_arg a) {
return validate_utf8(lean_sarray_cptr(a), lean_sarray_size(a));
}

extern "C" LEAN_EXPORT obj_res lean_string_to_utf8(b_obj_arg s) {
size_t sz = lean_string_size(s) - 1;
obj_res r = lean_alloc_sarray(1, sz, sz);
Expand Down Expand Up @@ -1741,38 +1745,38 @@ extern "C" LEAN_EXPORT obj_res lean_string_data(obj_arg s) {

static bool lean_string_utf8_get_core(char const * str, usize size, usize i, uint32 & result) {
unsigned c = static_cast<unsigned char>(str[i]);
/* zero continuation (0 to 127) */
/* zero continuation (0 to 0x7F) */
if ((c & 0x80) == 0) {
result = c;
return true;
}

/* one continuation (128 to 2047) */
/* one continuation (0x80 to 0x7FF) */
if ((c & 0xe0) == 0xc0 && i + 1 < size) {
unsigned c1 = static_cast<unsigned char>(str[i+1]);
result = ((c & 0x1f) << 6) | (c1 & 0x3f);
if (result >= 128) {
if (result >= 0x80) {
return true;
}
}

/* two continuations (2048 to 55295 and 57344 to 65535) */
/* two continuations (0x800 to 0xD7FF and 0xE000 to 0xFFFF) */
if ((c & 0xf0) == 0xe0 && i + 2 < size) {
unsigned c1 = static_cast<unsigned char>(str[i+1]);
unsigned c2 = static_cast<unsigned char>(str[i+2]);
result = ((c & 0x0f) << 12) | ((c1 & 0x3f) << 6) | (c2 & 0x3f);
if (result >= 2048 && (result < 55296 || result > 57343)) {
if (result >= 0x800 && (result < 0xD800 || result > 0xDFFF)) {
return true;
}
}

/* three continuations (65536 to 1114111) */
/* three continuations (0x10000 to 0x10FFFF) */
if ((c & 0xf8) == 0xf0 && i + 3 < size) {
unsigned c1 = static_cast<unsigned char>(str[i+1]);
unsigned c2 = static_cast<unsigned char>(str[i+2]);
unsigned c3 = static_cast<unsigned char>(str[i+3]);
result = ((c & 0x07) << 18) | ((c1 & 0x3f) << 12) | ((c2 & 0x3f) << 6) | (c3 & 0x3f);
if (result >= 65536 && result <= 1114111) {
if (result >= 0x10000 && result <= 0x10FFFF) {
return true;
}
}
Expand Down Expand Up @@ -1810,32 +1814,32 @@ extern "C" LEAN_EXPORT uint32 lean_string_utf8_get(b_obj_arg s, b_obj_arg i0) {
}

extern "C" LEAN_EXPORT uint32_t lean_string_utf8_get_fast_cold(char const * str, size_t i, size_t size, unsigned char c) {
/* one continuation (128 to 2047) */
/* one continuation (0x80 to 0x7FF) */
if ((c & 0xe0) == 0xc0 && i + 1 < size) {
unsigned c1 = static_cast<unsigned char>(str[i+1]);
uint32_t result = ((c & 0x1f) << 6) | (c1 & 0x3f);
if (result >= 128) {
if (result >= 0x80) {
return result;
}
}

/* two continuations (2048 to 55295 and 57344 to 65535) */
/* two continuations (0x800 to 0xD7FF and 0xE000 to 0xFFFF) */
if ((c & 0xf0) == 0xe0 && i + 2 < size) {
unsigned c1 = static_cast<unsigned char>(str[i+1]);
unsigned c2 = static_cast<unsigned char>(str[i+2]);
uint32_t result = ((c & 0x0f) << 12) | ((c1 & 0x3f) << 6) | (c2 & 0x3f);
if (result >= 2048 && (result < 55296 || result > 57343)) {
if (result >= 0x800 && (result < 0xD800 || result > 0xDFFF)) {
return result;
}
}

/* three continuations (65536 to 1114111) */
/* three continuations (0x10000 to 0x10FFFF) */
if ((c & 0xf8) == 0xf0 && i + 3 < size) {
unsigned c1 = static_cast<unsigned char>(str[i+1]);
unsigned c2 = static_cast<unsigned char>(str[i+2]);
unsigned c3 = static_cast<unsigned char>(str[i+3]);
uint32_t result = ((c & 0x07) << 18) | ((c1 & 0x3f) << 12) | ((c2 & 0x3f) << 6) | (c3 & 0x3f);
if (result >= 65536 && result <= 1114111) {
if (result >= 0x10000 && result <= 0x10FFFF) {
return result;
}
}
Expand Down
66 changes: 58 additions & 8 deletions src/runtime/utf8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ unsigned utf8_to_unicode(uchar const * begin, uchar const * end) {
auto it = begin;
unsigned c = *it;
++it;
if (c < 128)
if (c < 0x80)
return c;
unsigned mask = (1u << 6) -1;
unsigned hmask = mask;
Expand Down Expand Up @@ -164,40 +164,40 @@ optional<unsigned> get_utf8_first_byte_opt(unsigned char c) {

unsigned next_utf8(char const * str, size_t size, size_t & i) {
unsigned c = static_cast<unsigned char>(str[i]);
/* zero continuation (0 to 127) */
/* zero continuation (0 to 0x7F) */
if ((c & 0x80) == 0) {
i++;
return c;
}

/* one continuation (128 to 2047) */
/* one continuation (0x80 to 0x7FF) */
if ((c & 0xe0) == 0xc0 && i + 1 < size) {
unsigned c1 = static_cast<unsigned char>(str[i+1]);
unsigned r = ((c & 0x1f) << 6) | (c1 & 0x3f);
if (r >= 128) {
if (r >= 0x80) {
i += 2;
return r;
}
}

/* two continuations (2048 to 55295 and 57344 to 65535) */
/* two continuations (0x800 to 0xD7FF and 0xE000 to 0xFFFF) */
if ((c & 0xf0) == 0xe0 && i + 2 < size) {
unsigned c1 = static_cast<unsigned char>(str[i+1]);
unsigned c2 = static_cast<unsigned char>(str[i+2]);
unsigned r = ((c & 0x0f) << 12) | ((c1 & 0x3f) << 6) | (c2 & 0x3f);
if (r >= 2048 && (r < 55296 || r > 57343)) {
if (r >= 0x800 && (r < 0xD800 || r > 0xDFFF)) {
i += 3;
return r;
}
}

/* three continuations (65536 to 1114111) */
/* three continuations (0x10000 to 0x10FFFF) */
if ((c & 0xf8) == 0xf0 && i + 3 < size) {
unsigned c1 = static_cast<unsigned char>(str[i+1]);
unsigned c2 = static_cast<unsigned char>(str[i+2]);
unsigned c3 = static_cast<unsigned char>(str[i+3]);
unsigned r = ((c & 0x07) << 18) | ((c1 & 0x3f) << 12) | ((c2 & 0x3f) << 6) | (c3 & 0x3f);
if (r >= 65536 && r <= 1114111) {
if (r >= 0x10000 && r <= 0x10FFFF) {
i += 4;
return r;
}
Expand All @@ -220,6 +220,56 @@ void utf8_decode(std::string const & str, std::vector<unsigned> & out) {
}
}

bool validate_utf8(uint8_t const * str, size_t size) {
size_t i = 0;
while (i < size) {
unsigned c = str[i];
if ((c & 0x80) == 0) {
/* zero continuation (0 to 0x7F) */
i++;
} else if ((c & 0xe0) == 0xc0) {
/* one continuation (0x80 to 0x7FF) */
if (i + 1 >= size) return false;

unsigned c1 = str[i+1];
if ((c1 & 0xc0) != 0x80) return false;

unsigned r = ((c & 0x1f) << 6) | (c1 & 0x3f);
if (r < 0x80) return false;

i += 2;
} else if ((c & 0xf0) == 0xe0) {
/* two continuations (0x800 to 0xD7FF and 0xE000 to 0xFFFF) */
if (i + 2 >= size) return false;

unsigned c1 = str[i+1];
unsigned c2 = str[i+2];
if ((c1 & 0xc0) != 0x80 || (c2 & 0xc0) != 0x80) return false;

unsigned r = ((c & 0x0f) << 12) | ((c1 & 0x3f) << 6) | (c2 & 0x3f);
if (r < 0x800 || (r >= 0xD800 && r < 0xDFFF)) return false;

i += 3;
} else if ((c & 0xf8) == 0xf0) {
/* three continuations (0x10000 to 0x10FFFF) */
if (i + 3 >= size) return false;

unsigned c1 = str[i+1];
unsigned c2 = str[i+2];
unsigned c3 = str[i+3];
if ((c1 & 0xc0) != 0x80 || (c2 & 0xc0) != 0x80 || (c3 & 0xc0) != 0x80) return false;

unsigned r = ((c & 0x07) << 18) | ((c1 & 0x3f) << 12) | ((c2 & 0x3f) << 6) | (c3 & 0x3f);
if (r < 0x10000 || r > 0x10FFFF) return false;

i += 4;
} else {
return false;
}
}
return true;
}

#define TAG_CONT static_cast<unsigned char>(0b10000000)
#define TAG_TWO_B static_cast<unsigned char>(0b11000000)
#define TAG_THREE_B static_cast<unsigned char>(0b11100000)
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/utf8.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ LEAN_EXPORT unsigned next_utf8(char const * str, size_t size, size_t & i);
/* Decode a UTF-8 encoded string `str` into unicode scalar values */
LEAN_EXPORT void utf8_decode(std::string const & str, std::vector<unsigned> & out);

/* Returns true if the provided string is valid UTF-8 */
LEAN_EXPORT bool validate_utf8(uint8_t const * str, size_t size);

/* Push a unicode scalar value into a utf-8 encoded string */
LEAN_EXPORT void push_unicode_scalar(std::string & s, unsigned code);

Expand Down
Loading