Skip to content

Commit

Permalink
feat: simp +arith for integers (leanprover#7011)
Browse files Browse the repository at this point in the history
This PR adds `simp +arith` for integers. It uses the new `grind`
normalizer for linear integer arithmetic. We still need to implement
support for dividing the coefficients by their GCD. It also fixes
several bugs in the normalizer.
  • Loading branch information
leodemoura authored and tobiasgrosser committed Feb 16, 2025
1 parent dc8073e commit 9769781
Show file tree
Hide file tree
Showing 9 changed files with 194 additions and 39 deletions.
14 changes: 14 additions & 0 deletions src/Init/Data/Int/Linear.lean
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,20 @@ theorem ExprCnstr.eq_of_toPoly_eq (ctx : Context) (c c' : ExprCnstr) (h : c.toPo
rw [denote_toPoly, denote_toPoly] at h
assumption

theorem ExprCnstr.eq_of_toPoly_eq_var (ctx : Context) (x y : Var) (c : ExprCnstr) (h : c.toPoly == .eq (.add 1 x (.add (-1) y (.num 0))))
: c.denote ctx = (x.denote ctx = y.denote ctx) := by
have h := congrArg (PolyCnstr.denote ctx) (eq_of_beq h)
rw [denote_toPoly] at h
rw [h]; simp
rw [← Int.sub_eq_add_neg, Int.sub_eq_zero]

theorem ExprCnstr.eq_of_toPoly_eq_const (ctx : Context) (x : Var) (k : Int) (c : ExprCnstr) (h : c.toPoly == .eq (.add 1 x (.num (-k))))
: c.denote ctx = (x.denote ctx = k) := by
have h := congrArg (PolyCnstr.denote ctx) (eq_of_beq h)
rw [denote_toPoly] at h
rw [h]; simp
rw [Int.add_comm, ← Int.sub_eq_add_neg, Int.sub_eq_zero]

def PolyCnstr.isUnsat : PolyCnstr → Bool
| .eq (.num k) => k != 0
| .eq _ => false
Expand Down
3 changes: 2 additions & 1 deletion src/Lean/Meta/Tactic/LinearArith/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def isLinearTerm (e : Expr) : Bool :=
false
else
let n := f.constName!
n == ``HAdd.hAdd || n == ``HMul.hMul || n == ``HSub.hSub || n == ``Nat.succ
n == ``HAdd.hAdd || n == ``HMul.hMul || n == ``HSub.hSub || n == ``Neg.neg || n == ``Nat.succ
|| n == ``Add.add || n == ``Mul.mul || n == ``Sub.sub

/-- Quick filter for linear constraints. -/
partial def isLinearCnstr (e : Expr) : Bool :=
Expand Down
29 changes: 16 additions & 13 deletions src/Lean/Meta/Tactic/LinearArith/Int/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,6 @@ def addAsVar (e : Expr) : M LinearExpr := do
set { varMap := (← s.varMap.insert e x), vars := s.vars.push e : State }
return var x

private def toInt? (e : Expr) : MetaM (Option Int) := do
let_expr OfNat.ofNat _ n i ← e | return none
unless (← isInstOfNatInt i) do return none
let some n ← evalNat n |>.run | return none
return some (Int.ofNat n)

partial def toLinearExpr (e : Expr) : M LinearExpr := do
match e with
| .mdata _ e => toLinearExpr e
Expand All @@ -119,14 +113,14 @@ partial def toLinearExpr (e : Expr) : M LinearExpr := do
where
visit (e : Expr) : M LinearExpr := do
let mul (a b : Expr) := do
match (← toInt? a) with
match (← getIntValue? a) with
| some k => return .mulL k (← toLinearExpr b)
| none => match (← toInt? b) with
| none => match (← getIntValue? b) with
| some k => return .mulR (← toLinearExpr a) k
| none => addAsVar e
match_expr e with
| OfNat.ofNat _ n i =>
if (← isInstOfNatInt i) then toLinearExpr n
| OfNat.ofNat _ _ _ =>
if let some n ← getIntValue? e then return .num n
else addAsVar e
| Int.neg a => return .neg (← toLinearExpr a)
| Neg.neg _ i a =>
Expand All @@ -144,7 +138,7 @@ where
if (← isInstSubInt i) then return .sub (← toLinearExpr a) (← toLinearExpr b)
else addAsVar e
| HSub.hSub _ _ _ i a b =>
if (← isInstSubInt i) then return .sub (← toLinearExpr a) (← toLinearExpr b)
if (← isInstHSubInt i) then return .sub (← toLinearExpr a) (← toLinearExpr b)
else addAsVar e
| Int.mul a b => mul a b
| Mul.mul _ i a b =>
Expand All @@ -159,13 +153,22 @@ partial def toLinearCnstr? (e : Expr) : M (Option LinearCnstr) := OptionT.run do
match_expr e with
| Eq α a b =>
let_expr Int ← α | failure
return .eq (← toLinearExpr a) (← toLinearExpr b)
let a ← toLinearExpr a
let b ← toLinearExpr b
match a, b with
/-
We do not want to convert `x = y` into `x + -1*y = 0`.
Similarly, we don't want to convert `x = 3` into `x + -3 = 0`.
`grind` and other tactics have better support for this kind of equalities.
-/
| .var _, .var _ | .var _, .num _ | .num _, .var _ => failure
| _, _ => return .eq a b
| Int.le a b =>
return .le (← toLinearExpr a) (← toLinearExpr b)
| Int.lt a b =>
return .le (.add (← toLinearExpr a) (.num 1)) (← toLinearExpr b)
| LE.le _ i a b =>
guard (← isInstLENat i)
guard (← isInstLEInt i)
return .le (← toLinearExpr a) (← toLinearExpr b)
| LT.lt _ i a b =>
guard (← isInstLTInt i)
Expand Down
16 changes: 13 additions & 3 deletions src/Lean/Meta/Tactic/LinearArith/Int/Simp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,19 @@ def simpCnstrPos? (e : Expr) : MetaM (Option (Expr × Expr)) := do
else
let c' : LinearCnstr := p.toExprCnstr
if c != c' then
let r ← c'.toArith ctx
let p := mkApp4 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq) (toContextExpr ctx) (toExpr c) (toExpr c') reflBoolTrue
return some (r, ← mkExpectedTypeHint p (← mkEq lhs r))
match p with
| .eq (.add 1 x (.add (-1) y (.num 0))) =>
let r := mkIntEq ctx[x]! ctx[y]!
let p := mkApp5 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq_var) (toContextExpr ctx) (toExpr x) (toExpr y) (toExpr c) reflBoolTrue
return some (r, ← mkExpectedTypeHint p (← mkEq lhs r))
| .eq (.add 1 x (.num k)) =>
let r := mkIntEq ctx[x]! (toExpr (-k))
let p := mkApp5 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq_const) (toContextExpr ctx) (toExpr x) (toExpr (-k)) (toExpr c) reflBoolTrue
return some (r, ← mkExpectedTypeHint p (← mkEq lhs r))
| _ =>
let r ← c'.toArith ctx
let p := mkApp4 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq) (toContextExpr ctx) (toExpr c) (toExpr c') reflBoolTrue
return some (r, ← mkExpectedTypeHint p (← mkEq lhs r))
else
return none

Expand Down
1 change: 1 addition & 0 deletions src/Lean/Meta/Tactic/LinearArith/Simp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Authors: Leonardo de Moura
prelude
import Lean.Meta.Tactic.LinearArith.Basic
import Lean.Meta.Tactic.LinearArith.Nat.Simp
import Lean.Meta.Tactic.LinearArith.Int.Simp

namespace Lean.Meta.Linear

Expand Down
18 changes: 12 additions & 6 deletions src/Lean/Meta/Tactic/Simp/Rewrite.lean
Original file line number Diff line number Diff line change
Expand Up @@ -284,16 +284,22 @@ def simpArith (e : Expr) : SimpM Step := do
unless (← getConfig).arith do
return .continue
if Linear.isLinearCnstr e then
let some (e', h) ← Linear.Nat.simpCnstr? e
| return .continue
return .visit { expr := e', proof? := h }
if let some (e', h) ← Linear.Nat.simpCnstr? e then
return .visit { expr := e', proof? := h }
else if let some (e', h) ← Linear.Int.simpCnstr? e then
return .visit { expr := e', proof? := h }
else
return .continue
else if Linear.isLinearTerm e then
if Linear.parentIsTarget (← getContext).parent? then
-- We mark `cache := false` to ensure we do not miss simplifications.
return .continue (some { expr := e, cache := false })
let some (e', h) ← Linear.Nat.simpExpr? e
| return .continue
return .visit { expr := e', proof? := h }
else if let some (e', h) ← Linear.Nat.simpExpr? e then
return .visit { expr := e', proof? := h }
else if let some (e', h) ← Linear.Int.simpExpr? e then
return .visit { expr := e', proof? := h }
else
return .continue
else
return .continue

Expand Down
7 changes: 1 addition & 6 deletions tests/lean/run/2615.lean
Original file line number Diff line number Diff line change
@@ -1,7 +1,2 @@
-- `simp_arith` does not support `Int` yet.
-- But, the weird error message at #2615 is not generated anymore
/--
error: simp made no progress
-/
#guard_msgs (error) in
-- `simp +arith` supports integers now
theorem huh (x : Int) : x + 1 = 1 + x := by simp_arith
10 changes: 0 additions & 10 deletions tests/lean/run/grind_regression.lean
Original file line number Diff line number Diff line change
Expand Up @@ -321,13 +321,6 @@ end

section

example {a : Int} : ((a * b) - (2 * c)) * d - (a * b) = (d - 1) * (a * b) - (2 * c * d) := by
grind only [Int.sub_mul, Int.sub_sub, Int.add_comm, Int.mul_comm, Int.one_mul]

end

section

example : Nat → (x : Nat) → x = x := by
intro x
grind
Expand Down Expand Up @@ -548,9 +541,6 @@ example (as bs : List α) : reverse (as ++ bs) = (reverse bs) ++ (reverse as) :=

variable (a b c d : Int)

example : ((a * b) - (2 * c)) * d - (a * b) = (d - 1) * (a * b) - (2 * c * d) := by
grind only [Int.sub_mul, Int.sub_sub, Int.add_comm, Int.mul_comm, Int.one_mul]

example {p q r : Prop} (h₁ : p) (h₂ : p ↔ q) (h₃ : q → (p ↔ r)) : p ↔ r := by
grind

Expand Down
135 changes: 135 additions & 0 deletions tests/lean/run/simp_int_arith.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
example (x y : Int) : x + y + 2 + y = y + 1 + 1 + x + y := by
simp +arith only

example (x y : Int) (h : x + y + 2 + y < y + 1 + 1 + x + y) : False := by
simp +arith only at h

example (x y : Int) (h : x + y + 2 + y > y + 1 + 1 + x + y) : False := by
simp +arith only at h

example (x y : Int) (_h : x + y + 3 + y > y + 1 + 1 + x + y) : True := by
simp +arith only at _h
guard_hyp _h : True
constructor

example (x y : Int) (h : x + y + 2 + y > 1 + 1 + x + x + y + 2*x) : 3*x + -1*y+10 := by
simp +arith only at h
guard_hyp h : 3 * x + -1*y + 10
assumption

example (x y : Int) (h : 6*x + y + 3 + y + 1 < y + 1 + 1 + x + 5*y) : 5*x + -4*y + 30 := by
simp +arith only at h
guard_hyp h : 5*x + -4*y + 30
assumption

example (x y : Int) : x + y + 2 + y ≤ y + 1 + 1 + x + y := by
simp +arith only

example (x y : Int) : x + y + 2 + y ≤ y + 1 + 1 + 5 + x + y := by
simp +arith only

example (x y z : Int) : x + y + 2 + y + z + z ≤ y + 3*z + 1 + 1 + x + y - z := by
simp +arith only

example (x y : Int) (h : False) : x + y + 20 + y ≤ y + 1 + 1 + 5 + x + y := by
simp +arith only
guard_target = False
assumption

example (x y : Int) (h : False) : x = y := by
fail_if_success simp +arith only
guard_target = x = y
contradiction

example (x : Int) (h : False) : x = 3 := by
fail_if_success simp +arith only
guard_target = x = 3
contradiction

example (x : Int) (h : False) : 3 = x := by
fail_if_success simp +arith only
guard_target = 3 = x
contradiction

example (x : Int) (h : False) : 2*x = x + 3 := by
simp +arith only
guard_target = x = 3
contradiction

example (x y : Int) (h : False) : 2*x = x + y := by
simp +arith only
guard_target = x = y
contradiction

example (x : Int) (h : False) : 2*x + 1 = x := by
simp +arith only
guard_target = x = -1
contradiction

example (x : Int) (h : False) (f : Int → Int) : f (0 + x + x) = 3 := by
simp +arith only
guard_target = f (2*x) = 3
contradiction

example (x y : Int) (h : False) (f : Int → Int) : f (x + y - x) = 3 := by
simp +arith only
guard_target = f y = 3
contradiction

example (x y : Int) (h : False) (f : Int → Int) : f (x + y - x + 2 - y + 1) = 3 := by
simp +arith only
guard_target = f 3 = 3
contradiction

example (x y : Int) (h : False) (f : Int → Int) : f (x + y - x + (2 - y)*2 + 1) = 3 := by
simp +arith only
guard_target = f (-1*y + 5) = 3
contradiction

example (x y : Int) (h : False) (f : Int → Int) : f (x + y - x + (2 - y)*(1+1) + 1) = 3 := by
simp +arith only [Int.reduceAdd]
guard_target = f (-1*y + 5) = 3
contradiction

example (x y : Int) (h : False) (f : Int → Int) : f (x + y - x + (1-1+2)*(2 - y) + 1) = 3 := by
simp +arith only [Int.reduceAdd, Int.reduceSub]
guard_target = f (-1*y + 5) = 3
contradiction

example (x y : Int) (h : False) (f : Int → Int) : f (Int.add x y - x + (2 - y)*2 + 1) = 3 := by
simp +arith only
guard_target = f (-1*y + 5) = 3
contradiction

example (x y : Int) (h : False) (f : Int → Int) : f (x + y - x + Int.mul (1-1+2) (2 - y) + 1) = 3 := by
simp +arith only [Int.reduceAdd, Int.reduceSub]
guard_target = f (-1*y + 5) = 3
contradiction

example (x y : Int) (h : False) (f : Int → Int) : f (Int.add x y - x + (2 - y)*(-2) + 1) = 3 := by
simp +arith only
guard_target = f (3*y + -3) = 3
contradiction

example (x : Int) : x > x - 1 := by
simp +arith only

example (x : Int) : x - 1 < x := by
simp +arith only

example (x : Int) : x < x + 1 := by
simp +arith only

example (x : Int) : x ≥ x - 1 := by
simp +arith only

example (x : Int) : x ≤ x := by
simp +arith only

example (x : Int) : x ≤ x + 1 := by
simp +arith only

example (x : Int) (h : False) : x > x := by
simp +arith only
guard_target = False
assumption

0 comments on commit 9769781

Please sign in to comment.