Skip to content

Commit 404a910

Browse files
kim-emluisacicolini
authored andcommitted
feat: refactor of find functions on List/Array/Vector (leanprover#6833)
This PR makes the signatures of `find` functions across `List`/`Array`/`Vector` consistent. Verification lemmas will follow in subsequent PRs. We were previously quite inconsistent about the signature of `indexOf`/`findIdx` functions across `List` and `Array`. Moreover, there are still quite large gaps in the verification lemma coverage for these even at the `List` level. My intention is to make the signatures consistent by providing: `findIdx` / `findIdx?` / `findFinIdx?` (these all take a predicate, and return respectively a `Nat`, `Option Nat`, `Option (Fin l.length)`) and similarly `idxOf` / `idxOf?` / `finIdxOf?` (which look for an element) for each of List/Array/Vector. I've seen enough examples by now where each variant is genuinely the most convenient at the call-site, so I'm going to accept the cost of having many closely related functions. *Hopefully* for the verification lemmas we can simp all of these into "projections" of the `Option (Fin l.length)` versions, and then only have to specify that. However, I will not plan on immediately either filling in the missing verification lemmas (or even deciding what the simp normal forms relating these operations are), and just reach parity amongst List/Array/Vector for what is already there.
1 parent e9c203f commit 404a910

21 files changed

+186
-103
lines changed

src/Init/Data/Array/Basic.lean

+20-8
Original file line numberDiff line numberDiff line change
@@ -674,18 +674,30 @@ def findFinIdx? {α : Type u} (p : α → Bool) (as : Array α) : Option (Fin as
674674
decreasing_by simp_wf; decreasing_trivial_pre_omega
675675
loop 0
676676

677+
@[inline]
678+
def findIdx (p : α → Bool) (as : Array α) : Nat := (as.findIdx? p).getD as.size
679+
677680
@[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
678-
def indexOfAux [BEq α] (a : Array α) (v : α) (i : Nat) : Option (Fin a.size) :=
681+
def idxOfAux [BEq α] (a : Array α) (v : α) (i : Nat) : Option (Fin a.size) :=
679682
if h : i < a.size then
680683
if a[i] == v then some ⟨i, h⟩
681-
else indexOfAux a v (i+1)
684+
else idxOfAux a v (i+1)
682685
else none
683686
decreasing_by simp_wf; decreasing_trivial_pre_omega
684687

685-
def indexOf? [BEq α] (a : Array α) (v : α) : Option (Fin a.size) :=
686-
indexOfAux a v 0
688+
@[deprecated idxOfAux (since := "2025-01-29")]
689+
abbrev indexOfAux := @idxOfAux
690+
691+
def finIdxOf? [BEq α] (a : Array α) (v : α) : Option (Fin a.size) :=
692+
idxOfAux a v 0
687693

688-
@[deprecated indexOf? (since := "2024-11-20")]
694+
@[deprecated "`Array.indexOf?` has been deprecated, use `idxOf?` or `finIdxOf?` instead." (since := "2025-01-29")]
695+
abbrev indexOf? := @finIdxOf?
696+
697+
def idxOf? [BEq α] (a : Array α) (v : α) : Option Nat :=
698+
(a.finIdxOf? v).map (·.val)
699+
700+
@[deprecated idxOf? (since := "2024-11-20")]
689701
def getIdx? [BEq α] (a : Array α) (v : α) : Option Nat :=
690702
a.findIdx? fun a => a == v
691703

@@ -884,7 +896,7 @@ def eraseIdx! (a : Array α) (i : Nat) : Array α :=
884896
This function takes worst case O(n) time because
885897
it has to backshift all later elements. -/
886898
def erase [BEq α] (as : Array α) (a : α) : Array α :=
887-
match as.indexOf? a with
899+
match as.finIdxOf? a with
888900
| none => as
889901
| some i => as.eraseIdx i
890902

@@ -893,9 +905,9 @@ def erase [BEq α] (as : Array α) (a : α) : Array α :=
893905
This function takes worst case O(n) time because
894906
it has to backshift all later elements. -/
895907
def eraseP (as : Array α) (p : α → Bool) : Array α :=
896-
match as.findIdx? p with
908+
match as.findFinIdx? p with
897909
| none => as
898-
| some i => as.eraseIdxIfInBounds i
910+
| some i => as.eraseIdx i
899911

900912
/-- Insert element `a` at position `i`. -/
901913
@[inline] def insertIdx (as : Array α) (i : Nat) (a : α) (_ : i ≤ as.size := by get_elem_tactic) : Array α :=

src/Init/Data/List/Basic.lean

+43-6
Original file line numberDiff line numberDiff line change
@@ -1269,21 +1269,58 @@ theorem findSome?_cons {f : α → Option β} :
12691269
/-! ### indexOf -/
12701270

12711271
/-- Returns the index of the first element equal to `a`, or the length of the list otherwise. -/
1272-
def indexOf [BEq α] (a : α) : List α → Nat := findIdx (· == a)
1272+
def idxOf [BEq α] (a : α) : List α → Nat := findIdx (· == a)
12731273

1274-
@[simp] theorem indexOf_nil [BEq α] : ([] : List α).indexOf x = 0 := rfl
1274+
/-- Returns the index of the first element equal to `a`, or the length of the list otherwise. -/
1275+
@[deprecated idxOf (since := "2025-01-29")] abbrev indexOf := @idxOf
1276+
1277+
@[simp] theorem idxOf_nil [BEq α] : ([] : List α).idxOf x = 0 := rfl
1278+
1279+
@[deprecated idxOf_nil (since := "2025-01-29")]
1280+
theorem indexOf_nil [BEq α] : ([] : List α).idxOf x = 0 := rfl
12751281

12761282
/-! ### findIdx? -/
12771283

12781284
/-- Return the index of the first occurrence of an element satisfying `p`. -/
1279-
def findIdx? (p : α → Bool) : List α → (start : Nat := 0) → Option Nat
1280-
| [], _ => none
1281-
| a :: l, i => if p a then some i else findIdx? p l (i + 1)
1285+
def findIdx? (p : α → Bool) (l : List α) : Option Nat :=
1286+
go l 0
1287+
where
1288+
go : List α → Nat → Option Nat
1289+
| [], _ => none
1290+
| a :: l, i => if p a then some i else go l (i + 1)
12821291

12831292
/-! ### indexOf? -/
12841293

12851294
/-- Return the index of the first occurrence of `a` in the list. -/
1286-
@[inline] def indexOf? [BEq α] (a : α) : List α → Option Nat := findIdx? (· == a)
1295+
@[inline] def idxOf? [BEq α] (a : α) : List α → Option Nat := findIdx? (· == a)
1296+
1297+
/-- Return the index of the first occurrence of `a` in the list. -/
1298+
@[deprecated idxOf? (since := "2025-01-29")]
1299+
abbrev indexOf? := @idxOf?
1300+
1301+
/-! ### findFinIdx? -/
1302+
1303+
/-- Return the index of the first occurrence of an element satisfying `p`, as a `Fin l.length`,
1304+
or `none` if no such element is found. -/
1305+
@[inline] def findFinIdx? (p : α → Bool) (l : List α) : Option (Fin l.length) :=
1306+
go l 0 (by simp)
1307+
where
1308+
go : (l' : List α) → (i : Nat) → (h : l'.length + i = l.length) → Option (Fin l.length)
1309+
| [], _, _ => none
1310+
| a :: l, i, h =>
1311+
if p a then
1312+
some ⟨i, by
1313+
simp only [Nat.add_comm _ i, ← Nat.add_assoc] at h
1314+
exact Nat.lt_of_add_right_lt (Nat.lt_of_succ_le (Nat.le_of_eq h))⟩
1315+
else
1316+
go l (i + 1) (by simp at h; simpa [← Nat.add_assoc, Nat.add_right_comm] using h)
1317+
1318+
/-! ### finIdxOf? -/
1319+
1320+
/-- Return the index of the first occurrence of `a`, as a `Fin l.length`,
1321+
or `none` if no such element is found. -/
1322+
@[inline] def finIdxOf? [BEq α] (a : α) : (l : List α) → Option (Fin l.length) :=
1323+
findFinIdx? (· == a)
12871324

12881325
/-! ### countP -/
12891326

src/Init/Data/List/Erase.lean

+9-6
Original file line numberDiff line numberDiff line change
@@ -472,13 +472,13 @@ theorem getLast_erase_mem (xs : List α) (a : α) (h) : (xs.erase a).getLast h
472472
(erase_sublist a xs).getLast_mem h
473473

474474
theorem erase_eq_eraseIdx [LawfulBEq α] (l : List α) (a : α) :
475-
l.erase a = match l.indexOf? a with
475+
l.erase a = match l.idxOf? a with
476476
| none => l
477477
| some i => l.eraseIdx i := by
478478
induction l with
479479
| nil => simp
480480
| cons x xs ih =>
481-
rw [erase_cons, indexOf?_cons]
481+
rw [erase_cons, idxOf?_cons]
482482
split
483483
· simp
484484
· simp [ih]
@@ -600,20 +600,23 @@ protected theorem IsPrefix.eraseIdx {l l' : List α} (h : l <+: l') (k : Nat) :
600600
-- See also `mem_eraseIdx_iff_getElem` and `mem_eraseIdx_iff_getElem?` in
601601
-- `Init/Data/List/Nat/Basic.lean`.
602602

603-
theorem erase_eq_eraseIdx_of_indexOf [BEq α] [LawfulBEq α]
604-
(l : List α) (a : α) (i : Nat) (w : l.indexOf a = i) :
603+
theorem erase_eq_eraseIdx_of_idxOf [BEq α] [LawfulBEq α]
604+
(l : List α) (a : α) (i : Nat) (w : l.idxOf a = i) :
605605
l.erase a = l.eraseIdx i := by
606606
subst w
607607
rw [erase_eq_iff]
608608
by_cases h : a ∈ l
609609
· right
610610
obtain ⟨as, bs, rfl, h'⟩ := eq_append_cons_of_mem h
611611
refine ⟨as, bs, h', by simp, ?_⟩
612-
rw [indexOf_append, if_neg h', indexOf_cons_self, eraseIdx_append_of_length_le] <;>
612+
rw [idxOf_append, if_neg h', idxOf_cons_self, eraseIdx_append_of_length_le] <;>
613613
simp
614614
· left
615615
refine ⟨h, ?_⟩
616616
rw [eq_comm, eraseIdx_eq_self]
617-
exact Nat.le_of_eq (indexOf_eq_length h).symm
617+
exact Nat.le_of_eq (idxOf_eq_length h).symm
618+
619+
@[deprecated erase_eq_eraseIdx_of_idxOf (since := "2025-01-29")]
620+
abbrev erase_eq_eraseIdx_of_indexOf := @erase_eq_eraseIdx_of_idxOf
618621

619622
end List

0 commit comments

Comments
 (0)