Skip to content

Commit a7cee58

Browse files
kim-emluisacicolini
authored andcommitted
feat: alignment of lemmas about monadic functions on List/Array/Vector (leanprover#6883)
This PR completes the alignment of lemmas about monadic functions on `List/Array/Vector`. Amongst other changes, we change the simp normal form from `List.forM` to `ForM.forM`, and correct the definition of `List.flatMapM`, which previously was returning results in the incorrect order. There remain many gaps in the verification lemmas for monadic functions; this PR only makes the lemmas uniform across `List/Array/Vector`.
1 parent 3800f1e commit a7cee58

15 files changed

+604
-60
lines changed

src/Init/Control/Lawful.lean

+1
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ Authors: Sebastian Ullrich, Leonardo de Moura, Mario Carneiro
66
prelude
77
import Init.Control.Lawful.Basic
88
import Init.Control.Lawful.Instances
9+
import Init.Control.Lawful.Lemmas

src/Init/Control/Lawful/Lemmas.lean

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/-
2+
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
3+
Released under Apache 2.0 license as described in the file LICENSE.
4+
Authors: Kim Morrison
5+
-/
6+
prelude
7+
import Init.Control.Lawful.Basic
8+
import Init.RCases
9+
import Init.ByCases
10+
11+
-- Mapping by a function with a left inverse is injective.
12+
theorem map_inj_of_left_inverse [Applicative m] [LawfulApplicative m] {f : α → β}
13+
(w : ∃ g : β → α, ∀ x, g (f x) = x) {x y : m α}
14+
(h : f <$> x = f <$> y) : x = y := by
15+
rcases w with ⟨g, w⟩
16+
replace h := congrArg (g <$> ·) h
17+
simpa [w] using h
18+
19+
-- Mapping by an injective function is injective, as long as the domain is nonempty.
20+
theorem map_inj_of_inj [Applicative m] [LawfulApplicative m] [Nonempty α] {f : α → β}
21+
(w : ∀ x y, f x = f y → x = y) {x y : m α}
22+
(h : f <$> x = f <$> y) : x = y := by
23+
apply map_inj_of_left_inverse ?_ h
24+
let ⟨a⟩ := ‹Nonempty α›
25+
refine ⟨?_, ?_⟩
26+
· intro b
27+
by_cases p : ∃ a, f a = b
28+
· exact Exists.choose p
29+
· exact a
30+
· intro b
31+
simp only [exists_apply_eq_apply, ↓reduceDIte]
32+
apply w
33+
apply Exists.choose_spec (p := fun a => f a = f b)

src/Init/Data/Array/Attach.lean

+21-7
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,20 @@ theorem foldr_pmap (l : Array α) {P : α → Prop} (f : (a : α) → P a → β
291291
(l.pmap f H).foldr g x = l.attach.foldr (fun a acc => g (f a.1 (H _ a.2)) acc) x := by
292292
rw [pmap_eq_map_attach, foldr_map]
293293

294+
@[simp] theorem foldl_attachWith
295+
(l : Array α) {q : α → Prop} (H : ∀ a, a ∈ l → q a) {f : β → { x // q x} → β} {b} (w : stop = l.size) :
296+
(l.attachWith q H).foldl f b 0 stop = l.attach.foldl (fun b ⟨a, h⟩ => f b ⟨a, H _ h⟩) b := by
297+
subst w
298+
rcases l with ⟨l⟩
299+
simp [List.foldl_attachWith, List.foldl_map]
300+
301+
@[simp] theorem foldr_attachWith
302+
(l : Array α) {q : α → Prop} (H : ∀ a, a ∈ l → q a) {f : { x // q x} → β → β} {b} (w : start = l.size) :
303+
(l.attachWith q H).foldr f b start 0 = l.attach.foldr (fun a acc => f ⟨a.1, H _ a.2⟩ acc) b := by
304+
subst w
305+
rcases l with ⟨l⟩
306+
simp [List.foldr_attachWith, List.foldr_map]
307+
294308
/--
295309
If we fold over `l.attach` with a function that ignores the membership predicate,
296310
we get the same results as folding over `l` directly.
@@ -571,7 +585,7 @@ and simplifies these to the function directly taking the value.
571585
-/
572586
theorem foldl_subtype {p : α → Prop} {l : Array { x // p x }}
573587
{f : β → { x // p x } → β} {g : β → α → β} {x : β}
574-
{hf : ∀ b x h, f b ⟨x, h⟩ = g b x} :
588+
(hf : ∀ b x h, f b ⟨x, h⟩ = g b x) :
575589
l.foldl f x = l.unattach.foldl g x := by
576590
cases l
577591
simp only [List.foldl_toArray', List.unattach_toArray]
@@ -581,7 +595,7 @@ theorem foldl_subtype {p : α → Prop} {l : Array { x // p x }}
581595
/-- Variant of `foldl_subtype` with side condition to check `stop = l.size`. -/
582596
@[simp] theorem foldl_subtype' {p : α → Prop} {l : Array { x // p x }}
583597
{f : β → { x // p x } → β} {g : β → α → β} {x : β}
584-
{hf : ∀ b x h, f b ⟨x, h⟩ = g b x} (h : stop = l.size) :
598+
(hf : ∀ b x h, f b ⟨x, h⟩ = g b x) (h : stop = l.size) :
585599
l.foldl f x 0 stop = l.unattach.foldl g x := by
586600
subst h
587601
rwa [foldl_subtype]
@@ -592,7 +606,7 @@ and simplifies these to the function directly taking the value.
592606
-/
593607
theorem foldr_subtype {p : α → Prop} {l : Array { x // p x }}
594608
{f : { x // p x } → β → β} {g : α → β → β} {x : β}
595-
{hf : ∀ x h b, f ⟨x, h⟩ b = g x b} :
609+
(hf : ∀ x h b, f ⟨x, h⟩ b = g x b) :
596610
l.foldr f x = l.unattach.foldr g x := by
597611
cases l
598612
simp only [List.foldr_toArray', List.unattach_toArray]
@@ -602,7 +616,7 @@ theorem foldr_subtype {p : α → Prop} {l : Array { x // p x }}
602616
/-- Variant of `foldr_subtype` with side condition to check `stop = l.size`. -/
603617
@[simp] theorem foldr_subtype' {p : α → Prop} {l : Array { x // p x }}
604618
{f : { x // p x } → β → β} {g : α → β → β} {x : β}
605-
{hf : ∀ x h b, f ⟨x, h⟩ b = g x b} (h : start = l.size) :
619+
(hf : ∀ x h b, f ⟨x, h⟩ b = g x b) (h : start = l.size) :
606620
l.foldr f x start 0 = l.unattach.foldr g x := by
607621
subst h
608622
rwa [foldr_subtype]
@@ -612,15 +626,15 @@ This lemma identifies maps over arrays of subtypes, where the function only depe
612626
and simplifies these to the function directly taking the value.
613627
-/
614628
@[simp] theorem map_subtype {p : α → Prop} {l : Array { x // p x }}
615-
{f : { x // p x } → β} {g : α → β} {hf : ∀ x h, f ⟨x, h⟩ = g x} :
629+
{f : { x // p x } → β} {g : α → β} (hf : ∀ x h, f ⟨x, h⟩ = g x) :
616630
l.map f = l.unattach.map g := by
617631
cases l
618632
simp only [List.map_toArray, List.unattach_toArray]
619633
rw [List.map_subtype]
620634
simp [hf]
621635

622636
@[simp] theorem filterMap_subtype {p : α → Prop} {l : Array { x // p x }}
623-
{f : { x // p x } → Option β} {g : α → Option β} {hf : ∀ x h, f ⟨x, h⟩ = g x} :
637+
{f : { x // p x } → Option β} {g : α → Option β} (hf : ∀ x h, f ⟨x, h⟩ = g x) :
624638
l.filterMap f = l.unattach.filterMap g := by
625639
cases l
626640
simp only [size_toArray, List.filterMap_toArray', List.unattach_toArray, List.length_unattach,
@@ -629,7 +643,7 @@ and simplifies these to the function directly taking the value.
629643
simp [hf]
630644

631645
@[simp] theorem unattach_filter {p : α → Prop} {l : Array { x // p x }}
632-
{f : { x // p x } → Bool} {g : α → Bool} {hf : ∀ x h, f ⟨x, h⟩ = g x} :
646+
{f : { x // p x } → Bool} {g : α → Bool} (hf : ∀ x h, f ⟨x, h⟩ = g x) :
633647
(l.filter f).unattach = l.unattach.filter g := by
634648
cases l
635649
simp [hf]

src/Init/Data/Array/Basic.lean

+9-2
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,9 @@ instance : ForIn' m (Array α) α inferInstance where
357357

358358
-- No separate `ForIn` instance is required because it can be derived from `ForIn'`.
359359

360+
-- We simplify `Array.forIn'` to `forIn'`.
361+
@[simp] theorem forIn'_eq_forIn' [Monad m] : @Array.forIn' α β m _ = forIn' := rfl
362+
360363
/-- See comment at `forIn'Unsafe` -/
361364
@[inline]
362365
unsafe def foldlMUnsafe {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (f : β → α → m β) (init : β) (as : Array α) (start := 0) (stop := as.size) : m β :=
@@ -585,11 +588,15 @@ def findRevM? {α : Type} {m : Type → Type w} [Monad m] (p : α → m Bool) (a
585588
as.findSomeRevM? fun a => return if (← p a) then some a else none
586589

587590
@[inline]
588-
def forM {α : Type u} {m : Type v → Type w} [Monad m] (f : α → m PUnit) (as : Array α) (start := 0) (stop := as.size) : m PUnit :=
591+
protected def forM {α : Type u} {m : Type v → Type w} [Monad m] (f : α → m PUnit) (as : Array α) (start := 0) (stop := as.size) : m PUnit :=
589592
as.foldlM (fun _ => f) ⟨⟩ start stop
590593

591594
instance : ForM m (Array α) α where
592-
forM xs f := forM f xs
595+
forM xs f := Array.forM f xs
596+
597+
-- We simplify `Array.forM` to `forM`.
598+
@[simp] theorem forM_eq_forM [Monad m] (f : α → m PUnit) :
599+
Array.forM f as 0 as.size = forM as f := rfl
593600

594601
@[inline]
595602
def forRevM {α : Type u} {m : Type v → Type w} [Monad m] (f : α → m PUnit) (as : Array α) (start := as.size) (stop := 0) : m PUnit :=

src/Init/Data/Array/Monadic.lean

+108-23
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ open Nat
2020

2121
/-! ### mapM -/
2222

23+
@[simp] theorem mapM_append [Monad m] [LawfulMonad m] (f : α → m β) {l₁ l₂ : Array α} :
24+
(l₁ ++ l₂).mapM f = (return (← l₁.mapM f) ++ (← l₂.mapM f)) := by
25+
rcases l₁ with ⟨l₁⟩
26+
rcases l₂ with ⟨l₂⟩
27+
simp
28+
2329
theorem mapM_eq_foldlM_push [Monad m] [LawfulMonad m] (f : α → m β) (l : Array α) :
2430
mapM f l = l.foldlM (fun acc a => return (acc.push (← f a))) #[] := by
2531
rcases l with ⟨l⟩
@@ -37,58 +43,85 @@ theorem mapM_eq_foldlM_push [Monad m] [LawfulMonad m] (f : α → m β) (l : Arr
3743

3844
/-! ### foldlM and foldrM -/
3945

40-
theorem foldlM_map [Monad m] (f : β₁ → β₂) (g : α → β₂ → m α) (l : Array β₁) (init : α) :
41-
(l.map f).foldlM g init = l.foldlM (fun x y => g x (f y)) init := by
46+
theorem foldlM_map [Monad m] (f : β₁ → β₂) (g : α → β₂ → m α) (l : Array β₁) (init : α) (w : stop = l.size) :
47+
(l.map f).foldlM g init 0 stop = l.foldlM (fun x y => g x (f y)) init 0 stop := by
48+
subst w
4249
cases l
43-
rw [List.map_toArray] -- Why doesn't this fire via `simp`?
4450
simp [List.foldlM_map]
4551

4652
theorem foldrM_map [Monad m] [LawfulMonad m] (f : β₁ → β₂) (g : β₂ → α → m α) (l : Array β₁)
47-
(init : α) : (l.map f).foldrM g init = l.foldrM (fun x y => g (f x) y) init := by
53+
(init : α) (w : start = l.size) :
54+
(l.map f).foldrM g init start 0 = l.foldrM (fun x y => g (f x) y) init start 0 := by
55+
subst w
4856
cases l
49-
rw [List.map_toArray] -- Why doesn't this fire via `simp`?
5057
simp [List.foldrM_map]
5158

52-
theorem foldlM_filterMap [Monad m] [LawfulMonad m] (f : α → Option β) (g : γ → β → m γ) (l : Array α) (init : γ) :
53-
(l.filterMap f).foldlM g init =
59+
theorem foldlM_filterMap [Monad m] [LawfulMonad m] (f : α → Option β) (g : γ → β → m γ)
60+
(l : Array α) (init : γ) (w : stop = (l.filterMap f).size) :
61+
(l.filterMap f).foldlM g init 0 stop =
5462
l.foldlM (fun x y => match f y with | some b => g x b | none => pure x) init := by
63+
subst w
5564
cases l
56-
rw [List.filterMap_toArray] -- Why doesn't this fire via `simp`?
5765
simp [List.foldlM_filterMap]
5866
rfl
5967

60-
theorem foldrM_filterMap [Monad m] [LawfulMonad m] (f : α → Option β) (g : β → γ → m γ) (l : Array α) (init : γ) :
61-
(l.filterMap f).foldrM g init =
68+
theorem foldrM_filterMap [Monad m] [LawfulMonad m] (f : α → Option β) (g : β → γ → m γ)
69+
(l : Array α) (init : γ) (w : start = (l.filterMap f).size) :
70+
(l.filterMap f).foldrM g init start 0 =
6271
l.foldrM (fun x y => match f x with | some b => g b y | none => pure y) init := by
72+
subst w
6373
cases l
64-
rw [List.filterMap_toArray] -- Why doesn't this fire via `simp`?
6574
simp [List.foldrM_filterMap]
6675
rfl
6776

68-
theorem foldlM_filter [Monad m] [LawfulMonad m] (p : α → Bool) (g : β → α → m β) (l : Array α) (init : β) :
69-
(l.filter p).foldlM g init =
77+
theorem foldlM_filter [Monad m] [LawfulMonad m] (p : α → Bool) (g : β → α → m β)
78+
(l : Array α) (init : β) (w : stop = (l.filter p).size) :
79+
(l.filter p).foldlM g init 0 stop =
7080
l.foldlM (fun x y => if p y then g x y else pure x) init := by
81+
subst w
7182
cases l
72-
rw [List.filter_toArray] -- Why doesn't this fire via `simp`?
7383
simp [List.foldlM_filter]
7484

75-
theorem foldrM_filter [Monad m] [LawfulMonad m] (p : α → Bool) (g : α → β → m β) (l : Array α) (init : β) :
76-
(l.filter p).foldrM g init =
85+
theorem foldrM_filter [Monad m] [LawfulMonad m] (p : α → Bool) (g : α → β → m β)
86+
(l : Array α) (init : β) (w : start = (l.filter p).size) :
87+
(l.filter p).foldrM g init start 0 =
7788
l.foldrM (fun x y => if p x then g x y else pure y) init := by
89+
subst w
7890
cases l
79-
rw [List.filter_toArray] -- Why doesn't this fire via `simp`?
8091
simp [List.foldrM_filter]
8192

93+
@[simp] theorem foldlM_attachWith [Monad m]
94+
(l : Array α) {q : α → Prop} (H : ∀ a, a ∈ l → q a) {f : β → { x // q x} → m β} {b} (w : stop = l.size):
95+
(l.attachWith q H).foldlM f b 0 stop =
96+
l.attach.foldlM (fun b ⟨a, h⟩ => f b ⟨a, H _ h⟩) b := by
97+
subst w
98+
rcases l with ⟨l⟩
99+
simp [List.foldlM_map]
100+
101+
@[simp] theorem foldrM_attachWith [Monad m] [LawfulMonad m]
102+
(l : Array α) {q : α → Prop} (H : ∀ a, a ∈ l → q a) {f : { x // q x} → β → m β} {b} (w : start = l.size):
103+
(l.attachWith q H).foldrM f b start 0 =
104+
l.attach.foldrM (fun a acc => f ⟨a.1, H _ a.2⟩ acc) b := by
105+
subst w
106+
rcases l with ⟨l⟩
107+
simp [List.foldrM_map]
108+
82109
/-! ### forM -/
83110

84111
@[congr] theorem forM_congr [Monad m] {as bs : Array α} (w : as = bs)
85112
{f : α → m PUnit} :
86-
as.forM f = bs.forM f := by
113+
forM as f = forM bs f := by
87114
cases as <;> cases bs
88115
simp_all
89116

117+
@[simp] theorem forM_append [Monad m] [LawfulMonad m] (l₁ l₂ : Array α) (f : α → m PUnit) :
118+
forM (l₁ ++ l₂) f = (do forM l₁ f; forM l₂ f) := by
119+
rcases l₁ with ⟨l₁⟩
120+
rcases l₂ with ⟨l₂⟩
121+
simp
122+
90123
@[simp] theorem forM_map [Monad m] [LawfulMonad m] (l : Array α) (g : α → β) (f : β → m PUnit) :
91-
(l.map g).forM f = l.forM (fun a => f (g a)) := by
124+
forM (l.map g) f = forM l (fun a => f (g a)) := by
92125
cases l
93126
simp
94127

@@ -115,9 +148,7 @@ theorem forIn'_eq_foldlM [Monad m] [LawfulMonad m]
115148
| .yield b => f a m b
116149
| .done b => pure (.done b)) (ForInStep.yield init) := by
117150
cases l
118-
rw [List.attach_toArray] -- Why doesn't this fire via `simp`?
119-
simp only [List.forIn'_toArray, List.forIn'_eq_foldlM, List.attachWith_mem_toArray, size_toArray,
120-
List.length_map, List.length_attach, List.foldlM_toArray', List.foldlM_map]
151+
simp [List.forIn'_eq_foldlM, List.foldlM_map]
121152
congr
122153

123154
/-- We can express a for loop over an array which always yields as a fold. -/
@@ -126,7 +157,6 @@ theorem forIn'_eq_foldlM [Monad m] [LawfulMonad m]
126157
forIn' l init (fun a m b => (fun c => .yield (g a m b c)) <$> f a m b) =
127158
l.attach.foldlM (fun b ⟨a, m⟩ => g a m b <$> f a m b) init := by
128159
cases l
129-
rw [List.attach_toArray] -- Why doesn't this fire via `simp`?
130160
simp [List.foldlM_map]
131161

132162
theorem forIn'_pure_yield_eq_foldl [Monad m] [LawfulMonad m]
@@ -191,4 +221,59 @@ theorem forIn_pure_yield_eq_foldl [Monad m] [LawfulMonad m]
191221
cases l
192222
simp
193223

224+
/-! ### Recognizing higher order functions using a function that only depends on the value. -/
225+
226+
/--
227+
This lemma identifies monadic folds over lists of subtypes, where the function only depends on the value, not the proposition,
228+
and simplifies these to the function directly taking the value.
229+
-/
230+
@[simp] theorem foldlM_subtype [Monad m] {p : α → Prop} {l : Array { x // p x }}
231+
{f : β → { x // p x } → m β} {g : β → α → m β} {x : β}
232+
(hf : ∀ b x h, f b ⟨x, h⟩ = g b x) (w : stop = l.size) :
233+
l.foldlM f x 0 stop = l.unattach.foldlM g x 0 stop := by
234+
subst w
235+
rcases l with ⟨l⟩
236+
simp
237+
rw [List.foldlM_subtype hf]
238+
239+
/--
240+
This lemma identifies monadic folds over lists of subtypes, where the function only depends on the value, not the proposition,
241+
and simplifies these to the function directly taking the value.
242+
-/
243+
@[simp] theorem foldrM_subtype [Monad m] [LawfulMonad m] {p : α → Prop} {l : Array { x // p x }}
244+
{f : { x // p x } → β → m β} {g : α → β → m β} {x : β}
245+
(hf : ∀ x h b, f ⟨x, h⟩ b = g x b) (w : start = l.size) :
246+
l.foldrM f x start 0 = l.unattach.foldrM g x start 0:= by
247+
subst w
248+
rcases l with ⟨l⟩
249+
simp
250+
rw [List.foldrM_subtype hf]
251+
252+
/--
253+
This lemma identifies monadic maps over lists of subtypes, where the function only depends on the value, not the proposition,
254+
and simplifies these to the function directly taking the value.
255+
-/
256+
@[simp] theorem mapM_subtype [Monad m] [LawfulMonad m] {p : α → Prop} {l : Array { x // p x }}
257+
{f : { x // p x } → m β} {g : α → m β} (hf : ∀ x h, f ⟨x, h⟩ = g x) :
258+
l.mapM f = l.unattach.mapM g := by
259+
rcases l with ⟨l⟩
260+
simp
261+
rw [List.mapM_subtype hf]
262+
263+
-- Without `filterMapM_toArray` relating `filterMapM` on `List` and `Array` we can't prove this yet:
264+
-- @[simp] theorem filterMapM_subtype [Monad m] [LawfulMonad m] {p : α → Prop} {l : Array { x // p x }}
265+
-- {f : { x // p x } → m (Option β)} {g : α → m (Option β)} (hf : ∀ x h, f ⟨x, h⟩ = g x) :
266+
-- l.filterMapM f = l.unattach.filterMapM g := by
267+
-- rcases l with ⟨l⟩
268+
-- simp
269+
-- rw [List.filterMapM_subtype hf]
270+
271+
-- Without `flatMapM_toArray` relating `flatMapM` on `List` and `Array` we can't prove this yet:
272+
-- @[simp] theorem flatMapM_subtype [Monad m] [LawfulMonad m] {p : α → Prop} {l : Array { x // p x }}
273+
-- {f : { x // p x } → m (Array β)} {g : α → m (Array β)} (hf : ∀ x h, f ⟨x, h⟩ = g x) :
274+
-- (l.flatMapM f) = l.unattach.flatMapM g := by
275+
-- rcases l with ⟨l⟩
276+
-- simp
277+
-- rw [List.flatMapM_subtype hf]
278+
194279
end Array

0 commit comments

Comments
 (0)