diff --git a/src/Init/Grind/Util.lean b/src/Init/Grind/Util.lean index 34bc7ec2df73..0a58ba5979cc 100644 --- a/src/Init/Grind/Util.lean +++ b/src/Init/Grind/Util.lean @@ -31,6 +31,14 @@ When `EqMatch a b origin` is `True`, we mark `origin` as a resolved case-split. -/ def EqMatch (a b : α) {_origin : α} : Prop := a = b +/-- +Gadget for annotating conditions of `match` equational lemmas. +We use this annotation for two different reasons: +- We don't want to normalize them. +- We have a propagator for them. +-/ +def MatchCond (p : Prop) : Prop := p + theorem nestedProof_congr (p q : Prop) (h : p = q) (hp : p) (hq : q) : HEq (@nestedProof p hp) (@nestedProof q hq) := by subst h; apply HEq.refl diff --git a/src/Lean/Meta/Tactic/Grind.lean b/src/Lean/Meta/Tactic/Grind.lean index 0bb04bafe04c..53f5373b46ba 100644 --- a/src/Lean/Meta/Tactic/Grind.lean +++ b/src/Lean/Meta/Tactic/Grind.lean @@ -26,6 +26,8 @@ import Lean.Meta.Tactic.Grind.Main import Lean.Meta.Tactic.Grind.CasesMatch import Lean.Meta.Tactic.Grind.Arith import Lean.Meta.Tactic.Grind.Ext +import Lean.Meta.Tactic.Grind.MatchCond +import Lean.Meta.Tactic.Grind.DoNotSimp namespace Lean @@ -70,5 +72,6 @@ builtin_initialize registerTraceClass `grind.debug.offset builtin_initialize registerTraceClass `grind.debug.offset.proof builtin_initialize registerTraceClass `grind.debug.ematch.pattern builtin_initialize registerTraceClass `grind.debug.beta +builtin_initialize registerTraceClass `grind.debug.matchCond end Lean diff --git a/src/Lean/Meta/Tactic/Grind/EMatch.lean b/src/Lean/Meta/Tactic/Grind/EMatch.lean index 8341d44add46..bef989aaa6f7 100644 --- a/src/Lean/Meta/Tactic/Grind/EMatch.lean +++ b/src/Lean/Meta/Tactic/Grind/EMatch.lean @@ -7,6 +7,7 @@ prelude import Lean.Meta.Tactic.Grind.Types import Lean.Meta.Tactic.Grind.Intro import Lean.Meta.Tactic.Grind.DoNotSimp +import Lean.Meta.Tactic.Grind.MatchCond namespace Lean.Meta.Grind namespace EMatch @@ -215,7 +216,11 @@ Helper function for marking parts of `match`-equation theorem as "do-not-simplif -/ private partial def annotateMatchEqnType (prop : Expr) (initApp : Expr) : M Expr := do if let .forallE n d b bi := prop then - withLocalDecl n bi (← markAsDoNotSimp d) fun x => do + let d ← if (← isProp d) then + markAsMatchCond d + else + pure d + withLocalDecl n bi d fun x => do mkForallFVars #[x] (← annotateMatchEqnType (b.instantiate1 x) initApp) else let_expr f@Eq α lhs rhs := prop | return prop diff --git a/src/Lean/Meta/Tactic/Grind/Internalize.lean b/src/Lean/Meta/Tactic/Grind/Internalize.lean index 8e6c85b1f761..e4ce20405d2f 100644 --- a/src/Lean/Meta/Tactic/Grind/Internalize.lean +++ b/src/Lean/Meta/Tactic/Grind/Internalize.lean @@ -110,6 +110,9 @@ private def pushCastHEqs (e : Expr) : GoalM Unit := do private def preprocessGroundPattern (e : Expr) : GoalM Expr := do shareCommon (← canon (← normalizeLevels (← unfoldReducible e))) +private def mkENode' (e : Expr) (generation : Nat) : GoalM Unit := + mkENodeCore e (ctor := false) (interpreted := false) (generation := generation) + mutual /-- Internalizes the nested ground terms in the given pattern. -/ private partial def internalizePattern (pattern : Expr) (generation : Nat) : GoalM Expr := do @@ -122,6 +125,24 @@ private partial def internalizePattern (pattern : Expr) (generation : Nat) : Goa else pattern.withApp fun f args => do return mkAppN f (← args.mapM (internalizePattern · generation)) +/-- Internalizes the `MatchCond` gadget. -/ +private partial def internalizeMatchCond (matchCond : Expr) (generation : Nat) : GoalM Unit := do + let_expr Grind.MatchCond e ← matchCond | return () + mkENode' matchCond generation + let mut e := e + repeat + let .forallE _ d b _ := e | break + let internalizeLhs (lhs : Expr) : GoalM Unit := do + unless lhs.hasLooseBVars do + internalize lhs generation + registerParent matchCond lhs + match_expr d with + | Eq _ lhs _ => internalizeLhs lhs + | HEq _ lhs _ _ => internalizeLhs lhs + | _ => pure () + e := b + propagateUp matchCond + partial def activateTheorem (thm : EMatchTheorem) (generation : Nat) : GoalM Unit := do -- Recall that we use the proof as part of the key for a set of instances found so far. -- We don't want to use structural equality when comparing keys. @@ -164,10 +185,9 @@ partial def internalize (e : Expr) (generation : Nat) (parent? : Option Expr := match e with | .bvar .. => unreachable! | .sort .. => return () - | .fvar .. | .letE .. | .lam .. => - mkENodeCore e (ctor := false) (interpreted := false) (generation := generation) + | .fvar .. | .letE .. | .lam .. => mkENode' e generation | .forallE _ d b _ => - mkENodeCore e (ctor := false) (interpreted := false) (generation := generation) + mkENode' e generation if (← isProp d <&&> isProp e) then internalize d generation e registerParent e d @@ -181,12 +201,14 @@ partial def internalize (e : Expr) (generation : Nat) (parent? : Option Expr := | .mdata .. | .proj .. => reportIssue m!"unexpected kernel projection term during internalization{indentExpr e}\n`grind` uses a pre-processing step that folds them as projection applications, the pre-processor should have failed to fold this term" - mkENodeCore e (ctor := false) (interpreted := false) (generation := generation) + mkENode' e generation | .app .. => if (← isLitValue e) then -- We do not want to internalize the components of a literal value. mkENode e generation Arith.internalize e parent? + else if e.isAppOfArity ``Grind.MatchCond 1 then + internalizeMatchCond e generation else e.withApp fun f args => do checkAndAddSplitCandidate e pushCastHEqs e diff --git a/src/Lean/Meta/Tactic/Grind/MatchCond.lean b/src/Lean/Meta/Tactic/Grind/MatchCond.lean new file mode 100644 index 000000000000..da3b990c7c9e --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/MatchCond.lean @@ -0,0 +1,155 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Init.Grind +import Init.Simproc +import Lean.Meta.Tactic.Simp.Simproc +import Lean.Meta.Tactic.Grind.PropagatorAttr + +namespace Lean.Meta.Grind + +/-- +Returns `Grind.MatchCond e`. +Recall that `Grind.MatchCond` is an identity function, +but the following simproc is used to prevent the term `e` from being simplified, +and we have special support for propagating is truth value. +-/ +def markAsMatchCond (e : Expr) : MetaM Expr := + mkAppM ``Grind.MatchCond #[e] + +builtin_dsimproc_decl reduceMatchCond (Grind.MatchCond _) := fun e => do + let_expr Grind.MatchCond _ ← e | return .continue + return .done e + +/-- Adds `reduceMatchCond` to `s` -/ +def addMatchCond (s : Simprocs) : CoreM Simprocs := do + s.add ``reduceMatchCond (post := false) + +/-- +Helper function for `isSatisfied`. +See `isSatisfied`. +-/ +private partial def isMathCondFalseHyp (e : Expr) : GoalM Bool := do + match_expr e with + | Eq _ lhs rhs => isFalse lhs rhs + | HEq _ lhs _ rhs => isFalse lhs rhs + | _ => return false +where + isFalse (lhs rhs : Expr) : GoalM Bool := do + if lhs.hasLooseBVars then return false + let root ← getRootENode lhs + if root.ctor then + let some ctorLhs ← isConstructorApp? root.self | return false + let some ctorRhs ← isConstructorApp? rhs | return false + if ctorLhs.name ≠ ctorRhs.name then return true + let lhsArgs := root.self.getAppArgs + let rhsArgs := rhs.getAppArgs + for i in [ctorLhs.numParams : ctorLhs.numParams + ctorLhs.numFields] do + if (← isFalse lhsArgs[i]! rhsArgs[i]!) then + return true + return false + else if root.interpreted then + if rhs.hasLooseBVars then return false + unless (← isLitValue rhs) do return false + return (← normLitValue root.self) != (← normLitValue rhs) + else + return false + +/-- +Returns `true` if `e` is a `Grind.MatchCond`, and it has been satifisfied. +Recall that we use `Grind.MatchCond` to annotate conditional `match`-equations. +Consider the following example: +``` +inductive S where + | mk1 (n : Nat) + | mk2 (n : Nat) (s : S) + | mk3 (n : Bool) + | mk4 (s1 s2 : S) + +def f (x y : S) := + match x, y with + | .mk1 _, _ => 2 + | _, .mk2 1 (.mk4 _ _) => 3 + | .mk3 _, _ => 4 + | _, _ => 5 +``` +The `match`-expression in the example above has overlapping patterns and +consequently produces conditional `match` equations. Thus, `grind` generates +the following auxiliary `Grind.MatchCond` terms for an application `f a b`: +- `Grind.MatchCond (∀ (n : Nat), a = S.mk1 n → False)` +- `Grind.MatchCond (∀ (s1 s2 : S), b = S.mk2 1 (s1.mk4 s2) → False)` +- `Grind.MatchCond (∀ (n : Bool), a = S.mk3 n → False)` + +`isSatisfied` uses the fact that constructor applications and literal values +are always the root of their equivalence classes. +-/ +private partial def isStatisfied (e : Expr) : GoalM Bool := do + let_expr Grind.MatchCond e ← e | return false + let mut e := e + repeat + let .forallE _ d b _ := e | break + if (← isMathCondFalseHyp d) then + trace[grind.debug.matchCond] "satifised{indentExpr e}\nthe following equality is false{indentExpr d}" + return true + e := b + return false + +private partial def mkMathCondProof? (e : Expr) : GoalM (Option Expr) := do + let_expr Grind.MatchCond f ← e | return none + forallTelescopeReducing f fun xs _ => do + for x in xs do + let type ← inferType x + if (← isMathCondFalseHyp type) then + trace[grind.debug.matchCond] ">>> {type}" + let some h ← go? x | pure () + return some (← mkLambdaFVars xs h) + return none +where + go? (h : Expr) : GoalM (Option Expr) := do + trace[grind.debug.matchCond] "go?: {← inferType h}" + let (lhs, rhs, isHeq) ← match_expr (← inferType h) with + | Eq _ lhs rhs => pure (lhs, rhs, false) + | HEq _ lhs _ rhs => pure (lhs, rhs, true) + | _ => return none + let target ← (← get).mvarId.getType + let root ← getRootENode lhs + let h ← if isHeq then + mkEqOfHEq (← mkHEqTrans (← mkHEqProof root.self lhs) h) + else + mkEqTrans (← mkEqProof root.self lhs) h + if root.ctor then + let some ctorLhs ← isConstructorApp? root.self | return none + let some ctorRhs ← isConstructorApp? rhs | return none + let h ← mkNoConfusion target h + if ctorLhs.name ≠ ctorRhs.name then + return some h + else + let .forallE _ k _ _ ← whnfD (← inferType h) + | return none + forallTelescopeReducing k fun xs _ => do + for x in xs do + let some hx ← go? x | pure () + return some (mkApp h (← mkLambdaFVars xs hx)) + return none + else if root.interpreted then + if (← normLitValue root.self) != (← normLitValue rhs) then + let hne ← mkDecideProof (mkNot (← mkEq root.self rhs)) + return some (mkApp hne h) + else + return none + else + return none + +/-- Propagates `MatchCond` upwards -/ +builtin_grind_propagator propagateMatchCond ↑Grind.MatchCond := fun e => do + trace[grind.debug.matchCond] "visiting{indentExpr e}" + if !(← isStatisfied e) then return () + let some h ← mkMathCondProof? e + | reportIssue m!"failed to construct proof for{indentExpr e}"; return () + trace[grind.debug.matchCond] "{← inferType h}" + pushEqTrue e <| mkEqTrueCore e h + +end Lean.Meta.Grind diff --git a/src/Lean/Meta/Tactic/Grind/SimpUtil.lean b/src/Lean/Meta/Tactic/Grind/SimpUtil.lean index 4473ebdfd245..d08f8d609b0c 100644 --- a/src/Lean/Meta/Tactic/Grind/SimpUtil.lean +++ b/src/Lean/Meta/Tactic/Grind/SimpUtil.lean @@ -7,6 +7,7 @@ prelude import Lean.Meta.Tactic.Simp.Simproc import Lean.Meta.Tactic.Grind.Simp import Lean.Meta.Tactic.Grind.DoNotSimp +import Lean.Meta.Tactic.Grind.MatchCond import Lean.Meta.Tactic.Simp.BuiltinSimprocs.List namespace Lean.Meta.Grind @@ -24,7 +25,7 @@ def registerNormTheorems (preDeclNames : Array Name) (postDeclNames : Array Name /-- Returns the array of simprocs used by `grind`. -/ protected def getSimprocs : MetaM (Array Simprocs) := do - let e ← Simp.getSEvalSimprocs + let s ← Simp.getSEvalSimprocs /- We don't want to apply `List.reduceReplicate` as a normalization operation in `grind`. Consider the following example: @@ -38,9 +39,10 @@ protected def getSimprocs : MetaM (Array Simprocs) := do ``` We don't want it to be simplified to `[] = []`. -/ - let e := e.erase ``List.reduceReplicate - let e ← addDoNotSimp e - return #[e] + let s := s.erase ``List.reduceReplicate + let s ← addDoNotSimp s + let s ← addMatchCond s + return #[s] /-- Returns the simplification context used by `grind`. -/ protected def getSimpContext : MetaM Simp.Context := do diff --git a/tests/lean/run/grind_match_eq_propagation.lean b/tests/lean/run/grind_match_eq_propagation.lean new file mode 100644 index 000000000000..05d364fb6b5a --- /dev/null +++ b/tests/lean/run/grind_match_eq_propagation.lean @@ -0,0 +1,47 @@ +inductive S where + | mk1 (n : Nat) + | mk2 (n : Nat) (s : S) + | mk3 (n : Bool) + | mk4 (s1 s2 : S) + +def f (x y : S) := + match x, y with + | .mk1 _, _ => 2 + | _, .mk2 1 (.mk4 _ _) => 3 + | .mk3 _, _ => 4 + | _, _ => 5 + +example : f a b < 2 → b = .mk2 y1 y2 → y1 = 2 → a = .mk4 y3 y4 → False := by + unfold f + grind (splits := 0) + +example : b = .mk2 y1 y2 → y1 = 2 → a = .mk4 y3 y4 → f a b = 5 := by + unfold f + grind (splits := 0) + +example : b = .mk2 y1 y2 → y1 = 2 → a = .mk3 n → f a b = 4 := by + unfold f + grind (splits := 0) + +example : b = .mk2 y1 y2 → y1 = 1 → y2 = .mk4 s1 s2 → a = .mk3 n → f a b = 3 := by + unfold f + grind (splits := 0) + +example : b = .mk2 y1 y2 → y1 = 1 → y2 = .mk4 s1 s2 → a = .mk2 s3 s4 → f a b = 3 := by + unfold f + grind (splits := 0) + +inductive Vec (α : Type u) : Nat → Type u + | nil : Vec α 0 + | cons : α → Vec α n → Vec α (n+1) + +def g (v w : Vec α n) : Nat := + match v, w with + | _, .cons _ (.cons _ _) => 20 + | .nil, _ => 30 + | _, _ => 40 + +-- TODO: introduce casts while instantiating equation theorems for `g.match_1` +-- example (a b : Vec α 2) : g a b = 20 := by +-- unfold g +-- grind