Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: overlapping match patterns in grind #6733

Merged
merged 6 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/Init/Grind/Util.lean
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ When `EqMatch a b origin` is `True`, we mark `origin` as a resolved case-split.
-/
def EqMatch (a b : α) {_origin : α} : Prop := a = b

/--
Gadget for annotating conditions of `match` equational lemmas.
We use this annotation for two different reasons:
- We don't want to normalize them.
- We have a propagator for them.
-/
def MatchCond (p : Prop) : Prop := p

theorem nestedProof_congr (p q : Prop) (h : p = q) (hp : p) (hq : q) : HEq (@nestedProof p hp) (@nestedProof q hq) := by
subst h; apply HEq.refl

Expand Down
3 changes: 3 additions & 0 deletions src/Lean/Meta/Tactic/Grind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import Lean.Meta.Tactic.Grind.Main
import Lean.Meta.Tactic.Grind.CasesMatch
import Lean.Meta.Tactic.Grind.Arith
import Lean.Meta.Tactic.Grind.Ext
import Lean.Meta.Tactic.Grind.MatchCond
import Lean.Meta.Tactic.Grind.DoNotSimp

namespace Lean

Expand Down Expand Up @@ -70,5 +72,6 @@ builtin_initialize registerTraceClass `grind.debug.offset
builtin_initialize registerTraceClass `grind.debug.offset.proof
builtin_initialize registerTraceClass `grind.debug.ematch.pattern
builtin_initialize registerTraceClass `grind.debug.beta
builtin_initialize registerTraceClass `grind.debug.matchCond

end Lean
7 changes: 6 additions & 1 deletion src/Lean/Meta/Tactic/Grind/EMatch.lean
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ prelude
import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.Intro
import Lean.Meta.Tactic.Grind.DoNotSimp
import Lean.Meta.Tactic.Grind.MatchCond

namespace Lean.Meta.Grind
namespace EMatch
Expand Down Expand Up @@ -215,7 +216,11 @@ Helper function for marking parts of `match`-equation theorem as "do-not-simplif
-/
private partial def annotateMatchEqnType (prop : Expr) (initApp : Expr) : M Expr := do
if let .forallE n d b bi := prop then
withLocalDecl n bi (← markAsDoNotSimp d) fun x => do
let d ← if (← isProp d) then
markAsMatchCond d
else
pure d
withLocalDecl n bi d fun x => do
mkForallFVars #[x] (← annotateMatchEqnType (b.instantiate1 x) initApp)
else
let_expr f@Eq α lhs rhs := prop | return prop
Expand Down
30 changes: 26 additions & 4 deletions src/Lean/Meta/Tactic/Grind/Internalize.lean
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ private def pushCastHEqs (e : Expr) : GoalM Unit := do
private def preprocessGroundPattern (e : Expr) : GoalM Expr := do
shareCommon (← canon (← normalizeLevels (← unfoldReducible e)))

private def mkENode' (e : Expr) (generation : Nat) : GoalM Unit :=
mkENodeCore e (ctor := false) (interpreted := false) (generation := generation)

mutual
/-- Internalizes the nested ground terms in the given pattern. -/
private partial def internalizePattern (pattern : Expr) (generation : Nat) : GoalM Expr := do
Expand All @@ -122,6 +125,24 @@ private partial def internalizePattern (pattern : Expr) (generation : Nat) : Goa
else pattern.withApp fun f args => do
return mkAppN f (← args.mapM (internalizePattern · generation))

/-- Internalizes the `MatchCond` gadget. -/
private partial def internalizeMatchCond (matchCond : Expr) (generation : Nat) : GoalM Unit := do
let_expr Grind.MatchCond e ← matchCond | return ()
mkENode' matchCond generation
let mut e := e
repeat
let .forallE _ d b _ := e | break
let internalizeLhs (lhs : Expr) : GoalM Unit := do
unless lhs.hasLooseBVars do
internalize lhs generation
registerParent matchCond lhs
match_expr d with
| Eq _ lhs _ => internalizeLhs lhs
| HEq _ lhs _ _ => internalizeLhs lhs
| _ => pure ()
e := b
propagateUp matchCond

partial def activateTheorem (thm : EMatchTheorem) (generation : Nat) : GoalM Unit := do
-- Recall that we use the proof as part of the key for a set of instances found so far.
-- We don't want to use structural equality when comparing keys.
Expand Down Expand Up @@ -164,10 +185,9 @@ partial def internalize (e : Expr) (generation : Nat) (parent? : Option Expr :=
match e with
| .bvar .. => unreachable!
| .sort .. => return ()
| .fvar .. | .letE .. | .lam .. =>
mkENodeCore e (ctor := false) (interpreted := false) (generation := generation)
| .fvar .. | .letE .. | .lam .. => mkENode' e generation
| .forallE _ d b _ =>
mkENodeCore e (ctor := false) (interpreted := false) (generation := generation)
mkENode' e generation
if (← isProp d <&&> isProp e) then
internalize d generation e
registerParent e d
Expand All @@ -181,12 +201,14 @@ partial def internalize (e : Expr) (generation : Nat) (parent? : Option Expr :=
| .mdata ..
| .proj .. =>
reportIssue m!"unexpected kernel projection term during internalization{indentExpr e}\n`grind` uses a pre-processing step that folds them as projection applications, the pre-processor should have failed to fold this term"
mkENodeCore e (ctor := false) (interpreted := false) (generation := generation)
mkENode' e generation
| .app .. =>
if (← isLitValue e) then
-- We do not want to internalize the components of a literal value.
mkENode e generation
Arith.internalize e parent?
else if e.isAppOfArity ``Grind.MatchCond 1 then
internalizeMatchCond e generation
else e.withApp fun f args => do
checkAndAddSplitCandidate e
pushCastHEqs e
Expand Down
155 changes: 155 additions & 0 deletions src/Lean/Meta/Tactic/Grind/MatchCond.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Init.Grind
import Init.Simproc
import Lean.Meta.Tactic.Simp.Simproc
import Lean.Meta.Tactic.Grind.PropagatorAttr

namespace Lean.Meta.Grind

/--
Returns `Grind.MatchCond e`.
Recall that `Grind.MatchCond` is an identity function,
but the following simproc is used to prevent the term `e` from being simplified,
and we have special support for propagating is truth value.
-/
def markAsMatchCond (e : Expr) : MetaM Expr :=
mkAppM ``Grind.MatchCond #[e]

builtin_dsimproc_decl reduceMatchCond (Grind.MatchCond _) := fun e => do
let_expr Grind.MatchCond _ ← e | return .continue
return .done e

/-- Adds `reduceMatchCond` to `s` -/
def addMatchCond (s : Simprocs) : CoreM Simprocs := do
s.add ``reduceMatchCond (post := false)

/--
Helper function for `isSatisfied`.
See `isSatisfied`.
-/
private partial def isMathCondFalseHyp (e : Expr) : GoalM Bool := do
match_expr e with
| Eq _ lhs rhs => isFalse lhs rhs
| HEq _ lhs _ rhs => isFalse lhs rhs
| _ => return false
where
isFalse (lhs rhs : Expr) : GoalM Bool := do
if lhs.hasLooseBVars then return false
let root ← getRootENode lhs
if root.ctor then
let some ctorLhs ← isConstructorApp? root.self | return false
let some ctorRhs ← isConstructorApp? rhs | return false
if ctorLhs.name ≠ ctorRhs.name then return true
let lhsArgs := root.self.getAppArgs
let rhsArgs := rhs.getAppArgs
for i in [ctorLhs.numParams : ctorLhs.numParams + ctorLhs.numFields] do
if (← isFalse lhsArgs[i]! rhsArgs[i]!) then
return true
return false
else if root.interpreted then
if rhs.hasLooseBVars then return false
unless (← isLitValue rhs) do return false
return (← normLitValue root.self) != (← normLitValue rhs)
else
return false

/--
Returns `true` if `e` is a `Grind.MatchCond`, and it has been satifisfied.
Recall that we use `Grind.MatchCond` to annotate conditional `match`-equations.
Consider the following example:
```
inductive S where
| mk1 (n : Nat)
| mk2 (n : Nat) (s : S)
| mk3 (n : Bool)
| mk4 (s1 s2 : S)

def f (x y : S) :=
match x, y with
| .mk1 _, _ => 2
| _, .mk2 1 (.mk4 _ _) => 3
| .mk3 _, _ => 4
| _, _ => 5
```
The `match`-expression in the example above has overlapping patterns and
consequently produces conditional `match` equations. Thus, `grind` generates
the following auxiliary `Grind.MatchCond` terms for an application `f a b`:
- `Grind.MatchCond (∀ (n : Nat), a = S.mk1 n → False)`
- `Grind.MatchCond (∀ (s1 s2 : S), b = S.mk2 1 (s1.mk4 s2) → False)`
- `Grind.MatchCond (∀ (n : Bool), a = S.mk3 n → False)`

`isSatisfied` uses the fact that constructor applications and literal values
are always the root of their equivalence classes.
-/
private partial def isStatisfied (e : Expr) : GoalM Bool := do
let_expr Grind.MatchCond e ← e | return false
let mut e := e
repeat
let .forallE _ d b _ := e | break
if (← isMathCondFalseHyp d) then
trace[grind.debug.matchCond] "satifised{indentExpr e}\nthe following equality is false{indentExpr d}"
return true
e := b
return false

private partial def mkMathCondProof? (e : Expr) : GoalM (Option Expr) := do
let_expr Grind.MatchCond f ← e | return none
forallTelescopeReducing f fun xs _ => do
for x in xs do
let type ← inferType x
if (← isMathCondFalseHyp type) then
trace[grind.debug.matchCond] ">>> {type}"
let some h ← go? x | pure ()
return some (← mkLambdaFVars xs h)
return none
where
go? (h : Expr) : GoalM (Option Expr) := do
trace[grind.debug.matchCond] "go?: {← inferType h}"
let (lhs, rhs, isHeq) ← match_expr (← inferType h) with
| Eq _ lhs rhs => pure (lhs, rhs, false)
| HEq _ lhs _ rhs => pure (lhs, rhs, true)
| _ => return none
let target ← (← get).mvarId.getType
let root ← getRootENode lhs
let h ← if isHeq then
mkEqOfHEq (← mkHEqTrans (← mkHEqProof root.self lhs) h)
else
mkEqTrans (← mkEqProof root.self lhs) h
if root.ctor then
let some ctorLhs ← isConstructorApp? root.self | return none
let some ctorRhs ← isConstructorApp? rhs | return none
let h ← mkNoConfusion target h
if ctorLhs.name ≠ ctorRhs.name then
return some h
else
let .forallE _ k _ _ ← whnfD (← inferType h)
| return none
forallTelescopeReducing k fun xs _ => do
for x in xs do
let some hx ← go? x | pure ()
return some (mkApp h (← mkLambdaFVars xs hx))
return none
else if root.interpreted then
if (← normLitValue root.self) != (← normLitValue rhs) then
let hne ← mkDecideProof (mkNot (← mkEq root.self rhs))
return some (mkApp hne h)
else
return none
else
return none

/-- Propagates `MatchCond` upwards -/
builtin_grind_propagator propagateMatchCond ↑Grind.MatchCond := fun e => do
trace[grind.debug.matchCond] "visiting{indentExpr e}"
if !(← isStatisfied e) then return ()
let some h ← mkMathCondProof? e
| reportIssue m!"failed to construct proof for{indentExpr e}"; return ()
trace[grind.debug.matchCond] "{← inferType h}"
pushEqTrue e <| mkEqTrueCore e h

end Lean.Meta.Grind
10 changes: 6 additions & 4 deletions src/Lean/Meta/Tactic/Grind/SimpUtil.lean
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ prelude
import Lean.Meta.Tactic.Simp.Simproc
import Lean.Meta.Tactic.Grind.Simp
import Lean.Meta.Tactic.Grind.DoNotSimp
import Lean.Meta.Tactic.Grind.MatchCond
import Lean.Meta.Tactic.Simp.BuiltinSimprocs.List

namespace Lean.Meta.Grind
Expand All @@ -24,7 +25,7 @@ def registerNormTheorems (preDeclNames : Array Name) (postDeclNames : Array Name

/-- Returns the array of simprocs used by `grind`. -/
protected def getSimprocs : MetaM (Array Simprocs) := do
let e ← Simp.getSEvalSimprocs
let s ← Simp.getSEvalSimprocs
/-
We don't want to apply `List.reduceReplicate` as a normalization operation in
`grind`. Consider the following example:
Expand All @@ -38,9 +39,10 @@ protected def getSimprocs : MetaM (Array Simprocs) := do
```
We don't want it to be simplified to `[] = []`.
-/
let e := e.erase ``List.reduceReplicate
let e ← addDoNotSimp e
return #[e]
let s := s.erase ``List.reduceReplicate
let s ← addDoNotSimp s
let s ← addMatchCond s
return #[s]

/-- Returns the simplification context used by `grind`. -/
protected def getSimpContext : MetaM Simp.Context := do
Expand Down
47 changes: 47 additions & 0 deletions tests/lean/run/grind_match_eq_propagation.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
inductive S where
| mk1 (n : Nat)
| mk2 (n : Nat) (s : S)
| mk3 (n : Bool)
| mk4 (s1 s2 : S)

def f (x y : S) :=
match x, y with
| .mk1 _, _ => 2
| _, .mk2 1 (.mk4 _ _) => 3
| .mk3 _, _ => 4
| _, _ => 5

example : f a b < 2 → b = .mk2 y1 y2 → y1 = 2 → a = .mk4 y3 y4 → False := by
unfold f
grind (splits := 0)

example : b = .mk2 y1 y2 → y1 = 2 → a = .mk4 y3 y4 → f a b = 5 := by
unfold f
grind (splits := 0)

example : b = .mk2 y1 y2 → y1 = 2 → a = .mk3 n → f a b = 4 := by
unfold f
grind (splits := 0)

example : b = .mk2 y1 y2 → y1 = 1 → y2 = .mk4 s1 s2 → a = .mk3 n → f a b = 3 := by
unfold f
grind (splits := 0)

example : b = .mk2 y1 y2 → y1 = 1 → y2 = .mk4 s1 s2 → a = .mk2 s3 s4 → f a b = 3 := by
unfold f
grind (splits := 0)

inductive Vec (α : Type u) : Nat → Type u
| nil : Vec α 0
| cons : α → Vec α n → Vec α (n+1)

def g (v w : Vec α n) : Nat :=
match v, w with
| _, .cons _ (.cons _ _) => 20
| .nil, _ => 30
| _, _ => 40

-- TODO: introduce casts while instantiating equation theorems for `g.match_1`
-- example (a b : Vec α 2) : g a b = 20 := by
-- unfold g
-- grind
Loading