Skip to content

Commit

Permalink
Merge pull request #585 from felixwellen/more-flexibel-reflection-sol…
Browse files Browse the repository at this point in the history
…ving

More flexibel reflection solving
  • Loading branch information
ecavallo authored Oct 22, 2021
2 parents 9f57beb + f4e2d7d commit d2b4b60
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 40 deletions.
35 changes: 35 additions & 0 deletions Cubical/Algebra/RingSolver/Examples.agda
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module Cubical.Algebra.RingSolver.Examples where

open import Cubical.Foundations.Prelude
open import Cubical.Data.Int.Base hiding (_+_ ; _·_ ; _-_)
open import Cubical.Data.List

open import Cubical.Algebra.CommRing
open import Cubical.Algebra.RingSolver.ReflectionSolving
Expand All @@ -11,9 +12,13 @@ private
variable
: Level


module Test (R : CommRing ℓ) where
open CommRingStr (snd R)

_ : 0r ≡ 0r
_ = solve R

_ : 1r · (1r + 0r)
≡ (1r · 0r) + 1r
_ = solve R
Expand Down Expand Up @@ -60,3 +65,33 @@ module Test (R : CommRing ℓ) where
_ : (x y : (fst R)) → x ≡ y
_ = solve R
-}

module TestInPlaceSolving (R : CommRing ℓ) where
open CommRingStr (snd R)

testWithOneVariabl : (x : fst R) x + 0r ≡ 0r + x
testWithOneVariabl x = solveInPlace R (x ∷ [])

testEquationalReasoning : (x : fst R) x + 0r ≡ 0r + x
testEquationalReasoning x =
x + 0r ≡⟨solveIn R withVars (x ∷ []) ⟩
0r + x ∎

testWithTwoVariables : (x y : fst R) x + y ≡ y + x
testWithTwoVariables x y =
x + y ≡⟨solveIn R withVars (x ∷ y ∷ []) ⟩
y + x ∎

{-
So far, solving during equational reasoning has a serious
restriction:
The solver identifies variables by deBruijn indices and the variables
appearing in the equations to solve need to have indices 0,...,n. This
entails that in the following code, the order of 'p' and 'x' cannot be
switched.
-}
testEquationalReasoning' : (p : (y : fst R) 0r + y ≡ 1r) (x : fst R) x + 0r ≡ 1r
testEquationalReasoning' p x =
x + 0r ≡⟨solveIn R withVars (x ∷ []) ⟩
0r + x ≡⟨ p x ⟩
1r ∎
205 changes: 165 additions & 40 deletions Cubical/Algebra/RingSolver/ReflectionSolving.agda
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
{-
This is inspired by/copied from:
https://github.com/agda/agda-stdlib/blob/master/src/Tactic/MonoidSolver.agda
Boilerplate code for calling the ring solver is constructed automatically
with agda's reflection features.
-}
module Cubical.Algebra.RingSolver.ReflectionSolving where

Expand Down Expand Up @@ -44,21 +47,65 @@ private
varType : Arg Term
index :

getArgs : Term Maybe (Term × Term)
getArgs (def n xs) =
if n == (quote PathP)
{-
`getLastTwoArgsOf` maps a term 'def n (z₁ ∷ … ∷ zₙ ∷ x ∷ y ∷ [])' to the pair '(x,y)'
non-visible arguments are ignored.
-}
getLastTwoArgsOf : Name Term Maybe (Term × Term)
getLastTwoArgsOf n' (def n xs) =
if n == n'
then go xs
else nothing
where
go : List (Arg Term) Maybe (Term × Term)
go (varg x ∷ varg y ∷ []) = just (x , y)
go (x ∷ xs) = go xs
go _ = nothing
getArgs _ = nothing
getLastTwoArgsOf n' _ = nothing

{-
`getArgs` maps a term 'x ≡ y' to the pair '(x,y)'
-}
getArgs : Term Maybe (Term × Term)
getArgs = getLastTwoArgsOf (quote PathP)


firstVisibleArg : List (Arg Term) Maybe Term
firstVisibleArg [] = nothing
firstVisibleArg (varg x ∷ l) = just x
firstVisibleArg (x ∷ l) = firstVisibleArg l

{-
If the solver needs to be applied during equational reasoning,
the right hand side of the equation to solve cannot be given to
the solver directly. The folllowing function extracts this term y
from a more complex expression as in:
x ≡⟨ solve ... ⟩ (y ≡⟨ ... ⟩ z ∎)
-}
getRhs : Term Maybe Term
getRhs reasoningToTheRight@(def n xs) =
if n == (quote _∎)
then firstVisibleArg xs
else (if n == (quote _≡⟨_⟩_)
then firstVisibleArg xs
else nothing)
getRhs _ = nothing

constructSolution : List VarInfo Term Term Term Term
constructSolution n varInfos R lhs rhs =
encloseWithIteratedLambda (map VarInfo.varName varInfos) solverCall

private
solverCallAsTerm : Term Arg Term Term Term Term
solverCallAsTerm R varList lhs rhs =
def
(quote ringSolve)
(varg R ∷ varg lhs ∷ varg rhs
∷ varList
∷ varg (def (quote refl) []) ∷ [])

solverCallWithLambdas : List VarInfo Term Term Term Term
solverCallWithLambdas n varInfos R lhs rhs =
encloseWithIteratedLambda
(map VarInfo.varName varInfos)
(solverCallAsTerm R (variableList (rev varInfos)) lhs rhs)
where
encloseWithIteratedLambda : List String Term Term
encloseWithIteratedLambda (varName ∷ xs) t = lam visible (abs varName (encloseWithIteratedLambda xs t))
Expand All @@ -69,11 +116,16 @@ private
variableList (varInfo ∷ varInfos)
= varg (con (quote _∷vec_) (varg (var (VarInfo.index varInfo) []) ∷ (variableList varInfos) ∷ []))

solverCall = def
(quote ringSolve)
(varg R ∷ varg lhs ∷ varg rhs
∷ variableList (rev varInfos)
∷ varg (def (quote refl) []) ∷ [])
solverCallByVarIndices : List ℕ Term Term Term Term
solverCallByVarIndices n varIndices R lhs rhs =
solverCallAsTerm R (variableList (rev varIndices)) lhs rhs
where
variableList : List ℕ Arg Term
variableList [] = varg (con (quote emptyVec) [])
variableList (varIndex ∷ varIndices)
= varg (con (quote _∷vec_) (varg (var (varIndex) []) ∷ (variableList varIndices) ∷ []))



module pr (R : CommRing ℓ) {n : ℕ} where
private
Expand All @@ -93,18 +145,19 @@ module _ (cring : Term) where

open pr

`0` : List (Arg Term) Term
`0` [] = def (quote 0') (varg cring ∷ [])
`0` (varg fstcring ∷ xs) = `0` xs
`0` (harg _ ∷ xs) = `0` xs
`0` _ = unknown

`1` : List (Arg Term) Term
`1` [] = def (quote 1') (varg cring ∷ [])
`1` (varg fstcring ∷ xs) = `1` xs
`1` (harg _ ∷ xs) = `1` xs
`1` _ = unknown

mutual
`0` : List (Arg Term) Term
`0` [] = def (quote 0') (varg cring ∷ [])
`0` (varg fstcring ∷ xs) = `0` xs
`0` (harg _ ∷ xs) = `0` xs
`0` _ = unknown

`1` : List (Arg Term) Term
`1` [] = def (quote 1') (varg cring ∷ [])
`1` (varg fstcring ∷ xs) = `1` xs
`1` (harg _ ∷ xs) = `1` xs
`1` _ = unknown

`_·_` : List (Arg Term) Term
`_·_` (harg _ ∷ xs) = `_·_` xs
Expand Down Expand Up @@ -163,6 +216,13 @@ private
adjustDeBruijnIndex n (var k args) = var (k +ℕ n) args
adjustDeBruijnIndex n _ = unknown

extractVarIndices : Maybe (List Term) Maybe (List ℕ)
extractVarIndices (just ((var index _) ∷ l)) with extractVarIndices (just l)
... | just indices = just (index ∷ indices)
... | nothing = nothing
extractVarIndices (just []) = just []
extractVarIndices _ = nothing

getVarsAndEquation : Term Maybe (List VarInfo × Term)
getVarsAndEquation t =
let
Expand All @@ -172,7 +232,8 @@ private
where
extractVars : Term List (String × Arg Term) × Term
extractVars (pi argType (abs varName t)) with extractVars t
... | xs , equation = (varName , argType) ∷ xs , equation
... | xs , equation
= (varName , argType) ∷ xs , equation
extractVars equation = [] , equation

addIndices : List (String × Arg Term) Maybe (List VarInfo)
Expand All @@ -183,27 +244,91 @@ private
(addIndices countVar list)
addIndices _ _ = nothing

toListOfTerms : Term Maybe (List Term)
toListOfTerms (con c []) = if (c == (quote [])) then just [] else nothing
toListOfTerms (con c (varg t ∷ varg s ∷ args)) with toListOfTerms s
... | just terms = if (c == (quote _∷_)) then just (t ∷ terms) else nothing
... | nothing = nothing
toListOfTerms (con c (harg t ∷ args)) = toListOfTerms (con c args)
toListOfTerms _ = nothing

solve-macro : Term Term TC Unit
solve-macro cring hole = do
hole′ inferType hole >>= normalise
just (varInfos , equation) returnTC (getVarsAndEquation hole′)
where
nothing
typeError (strErr "Something went wrong when getting the variable names in "
∷ termErr hole′ ∷ [])
adjustedCring returnTC (adjustDeBruijnIndex (length varInfos) cring)
just (lhs , rhs) returnTC (toAlgebraExpression adjustedCring (getArgs equation))
where
nothing
typeError(
strErr "Error while trying to buils ASTs for the equation "
termErr equation ∷ [])
let solution = constructSolution (length varInfos) varInfos adjustedCring lhs rhs
unify hole solution
solve-macro cring hole =
do
hole′ inferType hole >>= normalise
just (varInfos , equation) returnTC (getVarsAndEquation hole′)
where
nothing
typeError (strErr "Something went wrong when getting the variable names in "
∷ termErr hole′ ∷ [])
{-
The call to the ring solver will be inside a lamba-expression.
That means, that we have to adjust the deBruijn-indices of the variables in 'cring'
-}
adjustedCring returnTC (adjustDeBruijnIndex (length varInfos) cring)
just (lhs , rhs) returnTC (toAlgebraExpression adjustedCring (getArgs equation))
where
nothing
typeError(
strErr "Error while trying to build ASTs for the equation "
termErr equation ∷ [])
let solution = solverCallWithLambdas (length varInfos) varInfos adjustedCring lhs rhs
unify hole solution

solveInPlace-macro : Term Term Term TC Unit
solveInPlace-macro cring varsToSolve hole =
do
equation inferType hole >>= normalise
just varIndices returnTC (extractVarIndices (toListOfTerms varsToSolve))
where
nothing
typeError(
strErr "Error reading variables to solve "
termErr varsToSolve ∷ [])
just (lhs , rhs) returnTC (toAlgebraExpression cring (getArgs equation))
where
nothing
typeError(
strErr "Error while trying to build ASTs for the equation "
termErr equation ∷ [])
let solution = solverCallByVarIndices (length varIndices) varIndices cring lhs rhs
unify hole solution

solveEqReasoning-macro : Term Term Term Term Term TC Unit
solveEqReasoning-macro lhs cring varsToSolve reasoningToTheRight hole =
do
just varIndices returnTC (extractVarIndices (toListOfTerms varsToSolve))
where
nothing
typeError(
strErr "Error reading variables to solve "
termErr varsToSolve ∷ [])
just rhs returnTC (getRhs reasoningToTheRight)
where
nothing
typeError(
strErr "Failed to extract right hand side of equation to solve from "
termErr reasoningToTheRight ∷ [])
just (lhsAST , rhsAST) returnTC (toAlgebraExpression cring (just (lhs , rhs)))
where
nothing
typeError(
strErr "Error while trying to build ASTs from "
termErr lhs ∷ strErr " and " ∷ termErr rhs ∷ [])
let solverCall = solverCallByVarIndices (length varIndices) varIndices cring lhsAST rhsAST
unify hole (def (quote _≡⟨_⟩_) (varg lhs ∷ varg solverCall ∷ varg reasoningToTheRight ∷ []))

macro
solve : Term Term TC _
solve = solve-macro

solveInPlace : Term Term Term TC _
solveInPlace = solveInPlace-macro

infixr 2 _≡⟨solveIn_withVars_⟩_
_≡⟨solveIn_withVars_⟩_ : Term Term Term Term Term TC Unit
_≡⟨solveIn_withVars_⟩_ = solveEqReasoning-macro


fromℤ : (R : CommRing ℓ) fst R
fromℤ = scalar

0 comments on commit d2b4b60

Please sign in to comment.