Skip to content

Commit 3723341

Browse files
kim-emluisacicolini
authored andcommitted
feat: align take/drop/extract across List/Array/Vector (leanprover#6860)
This PR makes `take`/`drop`/`extract` available for each of `List`/`Array`/`Vector`. The simp normal forms differ, however: in `List`, we simplify `extract` to `take+drop`, while in `Array` and `Vector` we simplify `take` and `drop` to `extract`. We also provide `Array/Vector.shrink`, which simplifies to `take`, but is implemented by repeatedly popping. Verification lemmas for `Array/Vector.extract` to follow in a subsequent PR.
1 parent 404a910 commit 3723341

File tree

15 files changed

+110
-76
lines changed

15 files changed

+110
-76
lines changed

src/Init/Data/Array/Basic.lean

+11-3
Original file line numberDiff line numberDiff line change
@@ -270,14 +270,22 @@ def swapAt! (a : Array α) (i : Nat) (v : α) : α × Array α :=
270270
have : Inhabited (α × Array α) := ⟨(v, a)⟩
271271
panic! ("index " ++ toString i ++ " out of bounds")
272272

273-
/-- `take a n` returns the first `n` elements of `a`. -/
274-
def take (a : Array α) (n : Nat) : Array α :=
273+
/-- `shrink a n` returns the first `n` elements of `a`, implemented by repeatedly popping the last element. -/
274+
def shrink (a : Array α) (n : Nat) : Array α :=
275275
let rec loop
276276
| 0, a => a
277277
| n+1, a => loop n a.pop
278278
loop (a.size - n) a
279279

280-
@[deprecated take (since := "2024-10-22")] abbrev shrink := @take
280+
/-- `take a n` returns the first `n` elements of `a`, implemented by copying the first `n` elements. -/
281+
abbrev take (a : Array α) (n : Nat) : Array α := extract a 0 n
282+
283+
@[simp] theorem take_eq_extract (a : Array α) (n : Nat) : a.take n = a.extract 0 n := rfl
284+
285+
/-- `drop a n` removes the first `n` elements of `a`, implemented by copying the remaining elements. -/
286+
abbrev drop (a : Array α) (n : Nat) : Array α := extract a n a.size
287+
288+
@[simp] theorem drop_eq_extract (a : Array α) (n : Nat) : a.drop n = a.extract n a.size := rfl
281289

282290
@[inline]
283291
unsafe def modifyMUnsafe [Monad m] (a : Array α) (i : Nat) (f : α → m α) : m (Array α) := do

src/Init/Data/Array/Lemmas.lean

+25-19
Original file line numberDiff line numberDiff line change
@@ -2565,8 +2565,14 @@ theorem getElem?_extract {as : Array α} {start stop : Nat} :
25652565
· omega
25662566
· rfl
25672567

2568+
@[congr] theorem extract_congr {as bs : Array α}
2569+
(w : as = bs) (h : start = start') (h' : stop = stop') :
2570+
as.extract start stop = bs.extract start' stop' := by
2571+
subst w h h'
2572+
rfl
2573+
25682574
@[simp] theorem toList_extract (as : Array α) (start stop : Nat) :
2569-
(as.extract start stop).toList = (as.toList.drop start).take (stop - start) := by
2575+
(as.extract start stop).toList = as.toList.extract start stop := by
25702576
apply List.ext_getElem
25712577
· simp only [length_toList, size_extract, List.length_take, List.length_drop]
25722578
omega
@@ -2595,7 +2601,7 @@ theorem extract_empty_of_size_le_start (as : Array α) {start stop : Nat} (h : a
25952601
extract_empty_of_size_le_start _ (Nat.zero_le _)
25962602

25972603
@[simp] theorem _root_.List.extract_toArray (l : List α) (start stop : Nat) :
2598-
l.toArray.extract start stop = ((l.drop start).take (stop - start)).toArray := by
2604+
l.toArray.extract start stop = (l.extract start stop).toArray := by
25992605
apply ext'
26002606
simp
26012607

@@ -3363,36 +3369,36 @@ theorem size_eq_length_toList (as : Array α) : as.size = as.toList.length := rf
33633369
theorem getElem_range {n : Nat} {x : Nat} (h : x < (Array.range n).size) : (Array.range n)[x] = x := by
33643370
simp [← getElem_toList]
33653371

3372+
/-! ### shrink -/
33663373

3367-
3368-
3369-
/-! ### take -/
3370-
3371-
@[simp] theorem size_take_loop (a : Array α) (n : Nat) : (take.loop n a).size = a.size - n := by
3374+
@[simp] theorem size_shrink_loop (a : Array α) (n : Nat) : (shrink.loop n a).size = a.size - n := by
33723375
induction n generalizing a with
3373-
| zero => simp [take.loop]
3376+
| zero => simp [shrink.loop]
33743377
| succ n ih =>
3375-
simp [take.loop, ih]
3378+
simp [shrink.loop, ih]
33763379
omega
33773380

3378-
@[simp] theorem getElem_take_loop (a : Array α) (n : Nat) (i : Nat) (h : i < (take.loop n a).size) :
3379-
(take.loop n a)[i] = a[i]'(by simp at h; omega) := by
3381+
@[simp] theorem getElem_shrink_loop (a : Array α) (n : Nat) (i : Nat) (h : i < (shrink.loop n a).size) :
3382+
(shrink.loop n a)[i] = a[i]'(by simp at h; omega) := by
33803383
induction n generalizing a i with
3381-
| zero => simp [take.loop]
3384+
| zero => simp [shrink.loop]
33823385
| succ n ih =>
3383-
simp [take.loop, ih]
3386+
simp [shrink.loop, ih]
33843387

3385-
@[simp] theorem size_take (a : Array α) (n : Nat) : (a.take n).size = min n a.size := by
3386-
simp [take]
3388+
@[simp] theorem size_shrink (a : Array α) (n : Nat) : (a.shrink n).size = min n a.size := by
3389+
simp [shrink]
33873390
omega
33883391

3389-
@[simp] theorem getElem_take (a : Array α) (n : Nat) (i : Nat) (h : i < (a.take n).size) :
3390-
(a.take n)[i] = a[i]'(by simp at h; omega) := by
3391-
simp [take]
3392+
@[simp] theorem getElem_shrink (a : Array α) (n : Nat) (i : Nat) (h : i < (a.shrink n).size) :
3393+
(a.shrink n)[i] = a[i]'(by simp at h; omega) := by
3394+
simp [shrink]
33923395

3393-
@[simp] theorem toList_take (a : Array α) (n : Nat) : (a.take n).toList = a.toList.take n := by
3396+
@[simp] theorem toList_shrink (a : Array α) (n : Nat) : (a.shrink n).toList = a.toList.take n := by
33943397
apply List.ext_getElem <;> simp
33953398

3399+
@[simp] theorem shrink_eq_take (a : Array α) (n : Nat) : a.shrink n = a.take n := by
3400+
ext <;> simp
3401+
33963402
/-! ### forIn -/
33973403

33983404
@[simp] theorem forIn_toList [Monad m] (as : Array α) (b : β) (f : α → β → m (ForInStep β)) :

src/Init/Data/List/Basic.lean

+11
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,17 @@ theorem drop_eq_nil_of_le {as : List α} {i : Nat} (h : as.length ≤ i) : as.dr
823823
| _::_, 0 => simp at h
824824
| _::as, i+1 => simp only [length_cons] at h; exact @drop_eq_nil_of_le as i (Nat.le_of_succ_le_succ h)
825825

826+
/-! ### extract -/
827+
828+
/-- `extract l start stop` returns the slice of `l` from indices `start` to `stop` (exclusive). -/
829+
-- This is only an abbreviation for the operation in terms of `drop` and `take`.
830+
-- We do not prove properties of extract itself.
831+
abbrev extract (l : List α) (start : Nat := 0) (stop : Nat := l.length) : List α :=
832+
(l.drop start).take (stop - start)
833+
834+
@[simp] theorem extract_eq_drop_take (l : List α) (start stop : Nat) :
835+
l.extract start stop = (l.drop start).take (stop - start) := rfl
836+
826837
/-! ### takeWhile -/
827838

828839
/--

src/Init/Data/Vector/Basic.lean

+28-15
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,36 @@ instance : HAppend (Vector α n) (Vector α m) (Vector α (n + m)) where
163163
Extracts the slice of a vector from indices `start` to `stop` (exclusive). If `start ≥ stop`, the
164164
result is empty. If `stop` is greater than the size of the vector, the size is used instead.
165165
-/
166-
@[inline] def extract (v : Vector α n) (start stop : Nat) : Vector α (min stop n - start) :=
166+
@[inline] def extract (v : Vector α n) (start : Nat := 0) (stop : Nat := n) : Vector α (min stop n - start) :=
167167
⟨v.toArray.extract start stop, by simp⟩
168168

169+
/--
170+
Extract the first `m` elements of a vector. If `m` is greater than or equal to the size of the
171+
vector then the vector is returned unchanged.
172+
-/
173+
@[inline] def take (v : Vector α n) (m : Nat) : Vector α (min m n) :=
174+
⟨v.toArray.take m, by simp⟩
175+
176+
@[simp] theorem take_eq_extract (v : Vector α n) (m : Nat) : v.take m = v.extract 0 m := rfl
177+
178+
/--
179+
Deletes the first `m` elements of a vector. If `m` is greater than or equal to the size of the
180+
vector then the empty vector is returned.
181+
-/
182+
@[inline] def drop (v : Vector α n) (m : Nat) : Vector α (n - m) :=
183+
⟨v.toArray.drop m, by simp⟩
184+
185+
@[simp] theorem drop_eq_cast_extract (v : Vector α n) (m : Nat) :
186+
v.drop m = (v.extract m n).cast (by simp) := by
187+
simp [drop, extract, Vector.cast]
188+
189+
/-- Shrinks a vector to the first `m` elements, by repeatedly popping the last element. -/
190+
@[inline] def shrink (v : Vector α n) (m : Nat) : Vector α (min m n) :=
191+
⟨v.toArray.shrink m, by simp⟩
192+
193+
@[simp] theorem shrink_eq_take (v : Vector α n) (m : Nat) : v.shrink m = v.take m := by
194+
simp [shrink, take]
195+
169196
/-- Maps elements of a vector using the function `f`. -/
170197
@[inline] def map (f : α → β) (v : Vector α n) : Vector β n :=
171198
⟨v.toArray.map f, by simp⟩
@@ -291,20 +318,6 @@ This will perform the update destructively provided that the vector has a refere
291318
/-- The vector `#v[0,1,2,...,n-1]`. -/
292319
@[inline] def range (n : Nat) : Vector Nat n := ⟨Array.range n, by simp⟩
293320

294-
/--
295-
Extract the first `m` elements of a vector. If `m` is greater than or equal to the size of the
296-
vector then the vector is returned unchanged.
297-
-/
298-
@[inline] def take (v : Vector α n) (m : Nat) : Vector α (min m n) :=
299-
⟨v.toArray.take m, by simp⟩
300-
301-
/--
302-
Deletes the first `m` elements of a vector. If `m` is greater than or equal to the size of the
303-
vector then the empty vector is returned.
304-
-/
305-
@[inline] def drop (v : Vector α n) (m : Nat) : Vector α (n - m) :=
306-
⟨v.toArray.extract m v.size, by simp⟩
307-
308321
/--
309322
Compares two vectors of the same size using a given boolean relation `r`. `isEqv v w r` returns
310323
`true` if and only if `r v[i] w[i]` is true for all indices `i`.

src/Init/Prelude.lean

+1-1
Original file line numberDiff line numberDiff line change
@@ -2706,7 +2706,7 @@ protected def Array.appendCore {α : Type u} (as : Array α) (bs : Array α) :
27062706
If `start` is greater or equal to `stop`, the result is empty.
27072707
If `stop` is greater than the length of `as`, the length is used instead. -/
27082708
-- NOTE: used in the quotation elaborator output
2709-
def Array.extract (as : Array α) (start stop : Nat) : Array α :=
2709+
def Array.extract (as : Array α) (start : Nat := 0) (stop : Nat := as.size) : Array α :=
27102710
let rec loop (i : Nat) (j : Nat) (bs : Array α) : Array α :=
27112711
dite (LT.lt j as.size)
27122712
(fun hlt =>

src/Lean/Compiler/LCNF/PullLetDecls.lean

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ partial def withCheckpoint (x : PullM Code) : PullM Code := do
4646
else
4747
return c
4848
let (c, keep) := go toPullSizeSaved (← read).included |>.run #[]
49-
modify fun s => { s with toPull := s.toPull.take toPullSizeSaved ++ keep }
49+
modify fun s => { s with toPull := s.toPull.shrink toPullSizeSaved ++ keep }
5050
return c
5151

5252
def attachToPull (c : Code) : PullM Code := do

src/Lean/Elab/ParseImportsFast.lean

+1-1
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ partial def moduleIdent (runtimeOnly : Bool) : Parser := fun input s =>
182182
let s := p input s
183183
match s.error? with
184184
| none => many p input s
185-
| some _ => { pos, error? := none, imports := s.imports.take size }
185+
| some _ => { pos, error? := none, imports := s.imports.shrink size }
186186

187187
@[inline] partial def preludeOpt (k : String) : Parser :=
188188
keywordCore k (fun _ s => s.pushModule `Init false) (fun _ s => s)

src/Lean/Meta/Tactic/LinearArith/Solver.lean

+3-3
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ abbrev Assignment.get? (a : Assignment) (x : Var) : Option Rat :=
3737
abbrev Assignment.push (a : Assignment) (v : Rat) : Assignment :=
3838
{ a with val := a.val.push v }
3939

40-
abbrev Assignment.take (a : Assignment) (newSize : Nat) : Assignment :=
41-
{ a with val := a.val.take newSize }
40+
abbrev Assignment.shrink (a : Assignment) (newSize : Nat) : Assignment :=
41+
{ a with val := a.val.shrink newSize }
4242

4343
structure Poly where
4444
val : Array (Int × Var)
@@ -243,7 +243,7 @@ def resolve (s : State) (cl : Cnstr) (cu : Cnstr) : Sum Result State :=
243243
let maxVarIdx := c.lhs.getMaxVar.id
244244
match s with -- Hack: we avoid { s with ... } to make sure we get a destructive update
245245
| { lowers, uppers, int, assignment, } =>
246-
let assignment := assignment.take maxVarIdx
246+
let assignment := assignment.shrink maxVarIdx
247247
if c.lhs.getMaxVarCoeff < 0 then
248248
let lowers := lowers.modify maxVarIdx (·.push c)
249249
Sum.inr { lowers, uppers, int, assignment }

src/Lean/Meta/WHNF.lean

+2-2
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ private def mkNullaryCtor (type : Expr) (nparams : Nat) : MetaM (Option Expr) :=
112112
let .const d lvls := type.getAppFn
113113
| return none
114114
let (some ctor) ← getFirstCtor d | pure none
115-
return mkAppN (mkConst ctor lvls) (type.getAppArgs.take nparams)
115+
return mkAppN (mkConst ctor lvls) (type.getAppArgs.shrink nparams)
116116

117117
private def getRecRuleFor (recVal : RecursorVal) (major : Expr) : Option RecursorRule :=
118118
match major.getAppFn with
@@ -180,7 +180,7 @@ private def toCtorWhenStructure (inductName : Name) (major : Expr) : MetaM Expr
180180
else
181181
let some ctorName ← getFirstCtor d | pure major
182182
let ctorInfo ← getConstInfoCtor ctorName
183-
let params := majorType.getAppArgs.take ctorInfo.numParams
183+
let params := majorType.getAppArgs.shrink ctorInfo.numParams
184184
let mut result := mkAppN (mkConst ctorName us) params
185185
for i in [:ctorInfo.numFields] do
186186
result := mkApp result (← mkProjFn ctorInfo us params i major)

src/Lean/Parser/Basic.lean

+4-4
Original file line numberDiff line numberDiff line change
@@ -1332,7 +1332,7 @@ namespace ParserState
13321332

13331333
def keepTop (s : SyntaxStack) (startStackSize : Nat) : SyntaxStack :=
13341334
let node := s.back
1335-
s.take startStackSize |>.push node
1335+
s.shrink startStackSize |>.push node
13361336

13371337
def keepNewError (s : ParserState) (oldStackSize : Nat) : ParserState :=
13381338
match s with
@@ -1341,13 +1341,13 @@ def keepNewError (s : ParserState) (oldStackSize : Nat) : ParserState :=
13411341
def keepPrevError (s : ParserState) (oldStackSize : Nat) (oldStopPos : String.Pos) (oldError : Option Error) (oldLhsPrec : Nat) : ParserState :=
13421342
match s with
13431343
| ⟨stack, _, _, cache, _, errs⟩ =>
1344-
⟨stack.take oldStackSize, oldLhsPrec, oldStopPos, cache, oldError, errs⟩
1344+
⟨stack.shrink oldStackSize, oldLhsPrec, oldStopPos, cache, oldError, errs⟩
13451345

13461346
def mergeErrors (s : ParserState) (oldStackSize : Nat) (oldError : Error) : ParserState :=
13471347
match s with
13481348
| ⟨stack, lhsPrec, pos, cache, some err, errs⟩ =>
13491349
let newError := if oldError == err then err else oldError.merge err
1350-
⟨stack.take oldStackSize, lhsPrec, pos, cache, some newError, errs⟩
1350+
⟨stack.shrink oldStackSize, lhsPrec, pos, cache, some newError, errs⟩
13511351
| other => other
13521352

13531353
def keepLatest (s : ParserState) (startStackSize : Nat) : ParserState :=
@@ -1390,7 +1390,7 @@ def runLongestMatchParser (left? : Option Syntax) (startLhsPrec : Nat) (p : Pars
13901390
s -- success or error with the expected number of nodes
13911391
else if s.hasError then
13921392
-- error with an unexpected number of nodes.
1393-
s.takeStack startSize |>.pushSyntax Syntax.missing
1393+
s.shrinkStack startSize |>.pushSyntax Syntax.missing
13941394
else
13951395
-- parser succeeded with incorrect number of nodes
13961396
invalidLongestMatchParser s

src/Lean/Parser/Types.lean

+8-12
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,8 @@ def size (stack : SyntaxStack) : Nat :=
158158
def isEmpty (stack : SyntaxStack) : Bool :=
159159
stack.size == 0
160160

161-
def take (stack : SyntaxStack) (n : Nat) : SyntaxStack :=
162-
{ stack with raw := stack.raw.take (stack.drop + n) }
163-
164-
@[deprecated take (since := "2024-10-22")] abbrev shrink := @take
161+
def shrink (stack : SyntaxStack) (n : Nat) : SyntaxStack :=
162+
{ stack with raw := stack.raw.shrink (stack.drop + n) }
165163

166164
def push (stack : SyntaxStack) (a : Syntax) : SyntaxStack :=
167165
{ stack with raw := stack.raw.push a }
@@ -214,7 +212,7 @@ def stackSize (s : ParserState) : Nat :=
214212
s.stxStack.size
215213

216214
def restore (s : ParserState) (iniStackSz : Nat) (iniPos : String.Pos) : ParserState :=
217-
{ s with stxStack := s.stxStack.take iniStackSz, errorMsg := none, pos := iniPos }
215+
{ s with stxStack := s.stxStack.shrink iniStackSz, errorMsg := none, pos := iniPos }
218216

219217
def setPos (s : ParserState) (pos : String.Pos) : ParserState :=
220218
{ s with pos := pos }
@@ -228,10 +226,8 @@ def pushSyntax (s : ParserState) (n : Syntax) : ParserState :=
228226
def popSyntax (s : ParserState) : ParserState :=
229227
{ s with stxStack := s.stxStack.pop }
230228

231-
def takeStack (s : ParserState) (iniStackSz : Nat) : ParserState :=
232-
{ s with stxStack := s.stxStack.take iniStackSz }
233-
234-
@[deprecated takeStack (since := "2024-10-22")] abbrev shrinkStack := @takeStack
229+
def shrinkStack (s : ParserState) (iniStackSz : Nat) : ParserState :=
230+
{ s with stxStack := s.stxStack.shrink iniStackSz }
235231

236232
def next (s : ParserState) (input : String) (pos : String.Pos) : ParserState :=
237233
{ s with pos := input.next pos }
@@ -254,15 +250,15 @@ def mkNode (s : ParserState) (k : SyntaxNodeKind) (iniStackSz : Nat) : ParserSta
254250
⟨stack, lhsPrec, pos, cache, err, recovered⟩
255251
else
256252
let newNode := Syntax.node SourceInfo.none k (stack.extract iniStackSz stack.size)
257-
let stack := stack.take iniStackSz
253+
let stack := stack.shrink iniStackSz
258254
let stack := stack.push newNode
259255
⟨stack, lhsPrec, pos, cache, err, recovered⟩
260256

261257
def mkTrailingNode (s : ParserState) (k : SyntaxNodeKind) (iniStackSz : Nat) : ParserState :=
262258
match s with
263259
| ⟨stack, lhsPrec, pos, cache, err, errs⟩ =>
264260
let newNode := Syntax.node SourceInfo.none k (stack.extract (iniStackSz - 1) stack.size)
265-
let stack := stack.take (iniStackSz - 1)
261+
let stack := stack.shrink (iniStackSz - 1)
266262
let stack := stack.push newNode
267263
⟨stack, lhsPrec, pos, cache, err, errs⟩
268264

@@ -287,7 +283,7 @@ def mkEOIError (s : ParserState) (expected : List String := []) : ParserState :=
287283
def mkErrorsAt (s : ParserState) (ex : List String) (pos : String.Pos) (initStackSz? : Option Nat := none) : ParserState := Id.run do
288284
let mut s := s.setPos pos
289285
if let some sz := initStackSz? then
290-
s := s.takeStack sz
286+
s := s.shrinkStack sz
291287
s := s.setError { expected := ex }
292288
s.pushSyntax .missing
293289

src/Lean/PrettyPrinter/Delaborator/TopDownAnalyze.lean

+1-1
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ mutual
398398
let fType ← replaceLPsWithVars (← inferType f)
399399
let (mvars, bInfos, resultType) ← forallMetaBoundedTelescope fType args.size
400400
let rest := args.extract mvars.size args.size
401-
let args := args.take mvars.size
401+
let args := args.shrink mvars.size
402402

403403
-- Unify with the expected type
404404
if (← read).knowsType then tryUnify (← inferType (mkAppN f args)) resultType

src/Lean/PrettyPrinter/Formatter.lean

+1-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def fold (fn : Array Format → Format) (x : FormatterM Unit) : FormatterM Unit
146146
x
147147
let stack ← getStack
148148
let f := fn $ stack.extract sp stack.size
149-
setStack $ (stack.take sp).push f
149+
setStack $ (stack.shrink sp).push f
150150

151151
/-- Execute `x` and concatenate generated Format objects. -/
152152
def concat (x : FormatterM Unit) : FormatterM Unit := do

src/lake/Lake/Util/Log.lean

+1-1
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ instance : Append Log := ⟨Log.append⟩
321321

322322
/-- Removes log entries after `pos` (inclusive). -/
323323
@[inline] def dropFrom (log : Log) (pos : Log.Pos) : Log :=
324-
.mk <| log.entries.take pos.val
324+
.mk <| log.entries.shrink pos.val
325325

326326
/-- Takes log entries before `pos` (exclusive). -/
327327
@[inline] def takeFrom (log : Log) (pos : Log.Pos) : Log :=

0 commit comments

Comments
 (0)