Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: switching List lookup normal forms to L[n] and L[n]? #4400

Merged
merged 21 commits into from
Jun 15, 2024
55 changes: 30 additions & 25 deletions src/Init/Data/Array/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import Init.TacticsExtra
/-!
## Bootstrapping theorems about arrays

This file contains some theorems about `Array` and `List` needed for `Std.List.Basic`.
This file contains some theorems about `Array` and `List` needed for `Init.Data.List.Impl`.
-/

namespace Array
Expand All @@ -34,9 +34,13 @@ attribute [simp] data_toArray uset

@[simp] theorem size_mk (as : List α) : (Array.mk as).size = as.length := by simp [size]

theorem getElem_eq_data_get (a : Array α) (h : i < a.size) : a[i] = a.data.get ⟨i, h⟩ := by
theorem getElem_eq_data_getElem (a : Array α) (h : i < a.size) : a[i] = a.data[i] := by
by_cases i < a.size <;> (try simp [*]) <;> rfl

@[deprecated getElem_eq_data_getElem (since := "2024-06-12")]
theorem getElem_eq_data_get (a : Array α) (h : i < a.size) : a[i] = a.data.get ⟨i, h⟩ := by
simp [getElem_eq_data_getElem]

theorem foldlM_eq_foldlM_data.aux [Monad m]
(f : β → α → m β) (arr : Array α) (i j) (H : arr.size ≤ i + j) (b) :
foldlM.loop f arr arr.size (Nat.le_refl _) i j b = (arr.data.drop j).foldlM f b := by
Expand Down Expand Up @@ -114,11 +118,11 @@ theorem foldr_push (f : α → β → β) (init : β) (arr : Array α) (a : α)
theorem get_push_lt (a : Array α) (x : α) (i : Nat) (h : i < a.size) :
have : i < (a.push x).size := by simp [*, Nat.lt_succ_of_le, Nat.le_of_lt]
(a.push x)[i] = a[i] := by
simp only [push, getElem_eq_data_get, List.concat_eq_append, List.get_append_left, h]
simp only [push, getElem_eq_data_getElem, List.concat_eq_append, List.getElem_append_left, h]

@[simp] theorem get_push_eq (a : Array α) (x : α) : (a.push x)[a.size] = x := by
simp only [push, getElem_eq_data_get, List.concat_eq_append]
rw [List.get_append_right] <;> simp [getElem_eq_data_get, Nat.zero_lt_one]
simp only [push, getElem_eq_data_getElem, List.concat_eq_append]
rw [List.getElem_append_right] <;> simp [getElem_eq_data_getElem, Nat.zero_lt_one]

theorem get_push (a : Array α) (x : α) (i : Nat) (h : i < (a.push x).size) :
(a.push x)[i] = if h : i < a.size then a[i] else x := by
Expand Down Expand Up @@ -233,11 +237,11 @@ theorem get!_eq_getD [Inhabited α] (a : Array α) : a.get! n = a.getD n default
@[simp] theorem getElem_set_eq (a : Array α) (i : Fin a.size) (v : α) {j : Nat}
(eq : i.val = j) (p : j < (a.set i v).size) :
(a.set i v)[j]'p = v := by
simp [set, getElem_eq_data_get, ←eq]
simp [set, getElem_eq_data_getElem, ←eq]

@[simp] theorem getElem_set_ne (a : Array α) (i : Fin a.size) (v : α) {j : Nat} (pj : j < (a.set i v).size)
(h : i.val ≠ j) : (a.set i v)[j]'pj = a[j]'(size_set a i v ▸ pj) := by
simp only [set, getElem_eq_data_get, List.get_set_ne _ h]
simp only [set, getElem_eq_data_getElem, List.getElem_set_ne _ h]

theorem getElem_set (a : Array α) (i : Fin a.size) (v : α) (j : Nat)
(h : j < (a.set i v).size) :
Expand Down Expand Up @@ -321,7 +325,7 @@ termination_by n - i
@[simp] theorem mkArray_data (n : Nat) (v : α) : (mkArray n v).data = List.replicate n v := rfl

@[simp] theorem getElem_mkArray (n : Nat) (v : α) (h : i < (mkArray n v).size) :
(mkArray n v)[i] = v := by simp [Array.getElem_eq_data_get]
(mkArray n v)[i] = v := by simp [Array.getElem_eq_data_getElem]

/-- # mem -/

Expand All @@ -332,7 +336,7 @@ theorem not_mem_nil (a : α) : ¬ a ∈ #[] := nofun
/-- # get lemmas -/

theorem getElem?_mem {l : Array α} {i : Fin l.size} : l[i] ∈ l := by
erw [Array.mem_def, getElem_eq_data_get]
erw [Array.mem_def, getElem_eq_data_getElem]
apply List.get_mem

theorem getElem_fin_eq_data_get (a : Array α) (i : Fin _) : a[i] = a.data.get i := rfl
Expand All @@ -347,7 +351,7 @@ theorem get?_len_le (a : Array α) (i : Nat) (h : a.size ≤ i) : a[i]? = none :
simp [getElem?_neg, h]

theorem getElem_mem_data (a : Array α) (h : i < a.size) : a[i] ∈ a.data := by
simp only [getElem_eq_data_get, List.get_mem]
simp only [getElem_eq_data_getElem, List.getElem_mem]

theorem getElem?_eq_data_get? (a : Array α) (i : Nat) : a[i]? = a.data.get? i := by
by_cases i < a.size <;> simp_all [getElem?_pos, getElem?_neg, List.get?_eq_get, eq_comm]; rfl
Expand Down Expand Up @@ -395,7 +399,7 @@ theorem get?_push {a : Array α} : (a.push x)[i]? = if i = a.size then some x el

theorem get_set_eq (a : Array α) (i : Fin a.size) (v : α) :
(a.set i v)[i.1] = v := by
simp only [set, getElem_eq_data_get, List.get_set_eq]
simp only [set, getElem_eq_data_getElem, List.getElem_set_eq]

theorem get?_set_eq (a : Array α) (i : Fin a.size) (v : α) :
(a.set i v)[i.1]? = v := by simp [getElem?_pos, i.2]
Expand All @@ -414,7 +418,7 @@ theorem get_set (a : Array α) (i : Fin a.size) (j : Nat) (hj : j < a.size) (v :

@[simp] theorem get_set_ne (a : Array α) (i : Fin a.size) {j : Nat} (v : α) (hj : j < a.size)
(h : i.1 ≠ j) : (a.set i v)[j]'(by simp [*]) = a[j] := by
simp only [set, getElem_eq_data_get, List.get_set_ne _ h]
simp only [set, getElem_eq_data_getElem, List.getElem_set_ne _ h]

theorem getElem_setD (a : Array α) (i : Nat) (v : α) (h : i < (setD a i v).size) :
(setD a i v)[i] = v := by
Expand Down Expand Up @@ -452,7 +456,7 @@ theorem swapAt!_def (a : Array α) (i : Nat) (v : α) (h : i < a.size) :

@[simp] theorem getElem_pop (a : Array α) (i : Nat) (hi : i < a.pop.size) :
a.pop[i] = a[i]'(Nat.lt_of_lt_of_le (a.size_pop ▸ hi) (Nat.sub_le _ _)) :=
List.get_dropLast ..
List.getElem_dropLast ..

theorem eq_empty_of_size_eq_zero {as : Array α} (h : as.size = 0) : as = #[] := by
apply ext
Expand Down Expand Up @@ -500,6 +504,7 @@ theorem size_eq_length_data (as : Array α) : as.size = as.data.length := rfl
simp only [mkEmpty_eq, size_push] at *
omega

set_option linter.deprecated false in
@[simp] theorem reverse_data (a : Array α) : a.reverse.data = a.data.reverse := by
let rec go (as : Array α) (i j hj)
(h : i + j + 1 = a.size) (h₂ : as.size = a.size)
Expand All @@ -517,10 +522,10 @@ theorem size_eq_length_data (as : Array α) : as.size = as.data.length := rfl
simp only [H, getElem_eq_data_get, ← List.get?_eq_get, Nat.le_of_lt h₁, getElem?_eq_data_get?]
split <;> rename_i h₂
· simp only [← h₂, Nat.not_le.2 (Nat.lt_succ_self _), Nat.le_refl, and_false]
exact (List.get?_reverse' _ _ (Eq.trans (by simp_arith) h)).symm
exact (List.get?_reverse' (j+1) i (Eq.trans (by simp_arith) h)).symm
split <;> rename_i h₃
· simp only [← h₃, Nat.not_le.2 (Nat.lt_succ_self _), Nat.le_refl, false_and]
exact (List.get?_reverse' _ _ (Eq.trans (by simp_arith) h)).symm
exact (List.get?_reverse' i (j+1) (Eq.trans (by simp_arith) h)).symm
simp only [Nat.succ_le, Nat.lt_iff_le_and_ne.trans (and_iff_left h₃),
Nat.lt_succ.symm.trans (Nat.lt_iff_le_and_ne.trans (and_iff_left (Ne.symm h₂)))]
· rw [H]; split <;> rename_i h₂
Expand All @@ -533,7 +538,7 @@ theorem size_eq_length_data (as : Array α) : as.size = as.data.length := rfl
split
· match a with | ⟨[]⟩ | ⟨[_]⟩ => rfl
· have := Nat.sub_add_cancel (Nat.le_of_not_le ‹_›)
refine List.ext <| go _ _ _ _ (by simp [this]) rfl fun k => ?_
refine List.ext_get? <| go _ _ _ _ (by simp [this]) rfl fun k => ?_
split
· rfl
· rename_i h
Expand Down Expand Up @@ -769,17 +774,17 @@ theorem size_append (as bs : Array α) : (as ++ bs).size = as.size + bs.size :=

theorem get_append_left {as bs : Array α} {h : i < (as ++ bs).size} (hlt : i < as.size) :
(as ++ bs)[i] = as[i] := by
simp only [getElem_eq_data_get]
simp only [getElem_eq_data_getElem]
have h' : i < (as.data ++ bs.data).length := by rwa [← data_length, append_data] at h
conv => rhs; rw [← List.get_append_left (bs:=bs.data) (h':=h')]
conv => rhs; rw [← List.getElem_append_left (bs := bs.data) (h' := h')]
apply List.get_of_eq; rw [append_data]

theorem get_append_right {as bs : Array α} {h : i < (as ++ bs).size} (hle : as.size ≤ i)
(hlt : i - as.size < bs.size := Nat.sub_lt_left_of_lt_add hle (size_append .. ▸ h)) :
(as ++ bs)[i] = bs[i - as.size] := by
simp only [getElem_eq_data_get]
simp only [getElem_eq_data_getElem]
have h' : i < (as.data ++ bs.data).length := by rwa [← data_length, append_data] at h
conv => rhs; rw [← List.get_append_right (h':=h') (h:=Nat.not_lt_of_ge hle)]
conv => rhs; rw [← List.getElem_append_right (h' := h') (h := Nat.not_lt_of_ge hle)]
apply List.get_of_eq; rw [append_data]

@[simp] theorem append_nil (as : Array α) : as ++ #[] = as := by
Expand Down Expand Up @@ -987,13 +992,13 @@ theorem all_eq_true (p : α → Bool) (as : Array α) : all as p ↔ ∀ i : Fin
simp [all_iff_forall, Fin.isLt]

theorem all_def {p : α → Bool} (as : Array α) : as.all p = as.data.all p := by
rw [Bool.eq_iff_iff, all_eq_true, List.all_eq_true]; simp only [List.mem_iff_get]
rw [Bool.eq_iff_iff, all_eq_true, List.all_eq_true]; simp only [List.mem_iff_getElem]
constructor
· rintro w x ⟨r, rfl⟩
rw [← getElem_eq_data_get]
apply w
· rintro w x ⟨r, h, rfl⟩
rw [← getElem_eq_data_getElem]
exact w ⟨r, h⟩
· intro w i
exact w as[i] ⟨i, (getElem_eq_data_get as i.2).symm⟩
exact w as[i] ⟨i, i.2, (getElem_eq_data_getElem as i.2).symm⟩

theorem all_eq_true_iff_forall_mem {l : Array α} : l.all p ↔ ∀ x, x ∈ l → p x := by
simp only [all_def, List.all_eq_true, mem_def]
Expand Down
10 changes: 6 additions & 4 deletions src/Init/Data/List/BasicAux.lean
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,15 @@ See also `get?` and `get!`.
def getD (as : List α) (i : Nat) (fallback : α) : α :=
(as.get? i).getD fallback

@[ext] theorem ext : ∀ {l₁ l₂ : List α}, (∀ n, l₁.get? n = l₂.get? n) → l₁ = l₂
theorem ext_get? : ∀ {l₁ l₂ : List α}, (∀ n, l₁.get? n = l₂.get? n) → l₁ = l₂
| [], [], _ => rfl
| a :: l₁, [], h => nomatch h 0
| [], a' :: l₂, h => nomatch h 0
| a :: l₁, a' :: l₂, h => by
have h0 : some a = some a' := h 0
injection h0 with aa; simp only [aa, ext fun n => h (n+1)]
injection h0 with aa; simp only [aa, ext_get? fun n => h (n+1)]

@[deprecated (since := "2024-06-07")] abbrev ext := @ext_get?

/--
Returns the first element in the list.
Expand Down Expand Up @@ -191,15 +193,15 @@ def rotateRight (xs : List α) (n : Nat := 1) : List α :=
let e := xs.drop n
e ++ b

theorem get_append_left (as bs : List α) (h : i < as.length) {h'} : (as ++ bs).get ⟨i, h'⟩ = as.get ⟨i, h⟩ := by
theorem getElem_append_left (as bs : List α) (h : i < as.length) {h'} : (as ++ bs)[i] = as[i] := by
induction as generalizing i with
| nil => trivial
| cons a as ih =>
cases i with
| zero => rfl
| succ i => apply ih

theorem get_append_right (as bs : List α) (h : ¬ i < as.length) {h' h''} : (as ++ bs).get ⟨i, h'⟩ = bs.get ⟨i - as.length, h'' := by
theorem getElem_append_right (as bs : List α) (h : ¬ i < as.length) {h' h''} : (as ++ bs)[i]'h' = bs[i - as.length]'h'' := by
induction as generalizing i with
| nil => trivial
| cons a as ih =>
Expand Down
Loading
Loading