Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: add lean_instantiate_level_mvars #4910

Merged
merged 3 commits into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions src/Lean/MetavarContext.lean
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,8 @@ structure MetavarContext where
For more information about delayed abstraction, see the docstring for `DelayedMetavarAssignment`. -/
dAssignment : PersistentHashMap MVarId DelayedMetavarAssignment := {}

instance : Inhabited MetavarContext := ⟨{}⟩

/-- A monad with a stateful metavariable context, defining `getMCtx` and `modifyMCtx`. -/
class MonadMCtx (m : Type → Type) where
getMCtx : m MetavarContext
Expand All @@ -358,15 +360,27 @@ abbrev setMCtx [MonadMCtx m] (mctx : MetavarContext) : m Unit :=
abbrev getLevelMVarAssignment? [Monad m] [MonadMCtx m] (mvarId : LMVarId) : m (Option Level) :=
return (← getMCtx).lAssignment.find? mvarId

@[export lean_get_lmvar_assignment]
def getLevelMVarAssignmentExp (m : MetavarContext) (mvarId : LMVarId) : Option Level :=
m.lAssignment.find? mvarId

def MetavarContext.getExprAssignmentCore? (m : MetavarContext) (mvarId : MVarId) : Option Expr :=
m.eAssignment.find? mvarId

@[export lean_get_mvar_assignment]
def MetavarContext.getExprAssignmentExp (m : MetavarContext) (mvarId : MVarId) : Option Expr :=
m.eAssignment.find? mvarId

def getExprMVarAssignment? [Monad m] [MonadMCtx m] (mvarId : MVarId) : m (Option Expr) :=
return (← getMCtx).getExprAssignmentCore? mvarId

def MetavarContext.getDelayedMVarAssignmentCore? (mctx : MetavarContext) (mvarId : MVarId) : Option DelayedMetavarAssignment :=
mctx.dAssignment.find? mvarId

@[export lean_get_delayed_mvar_assignment]
def MetavarContext.getDelayedMVarAssignmentExp (mctx : MetavarContext) (mvarId : MVarId) : Option DelayedMetavarAssignment :=
mctx.dAssignment.find? mvarId

def getDelayedMVarAssignment? [Monad m] [MonadMCtx m] (mvarId : MVarId) : m (Option DelayedMetavarAssignment) :=
return (← getMCtx).getDelayedMVarAssignmentCore? mvarId

Expand Down Expand Up @@ -478,6 +492,10 @@ def hasAssignableMVar [Monad m] [MonadMCtx m] : Expr → m Bool
def assignLevelMVar [MonadMCtx m] (mvarId : LMVarId) (val : Level) : m Unit :=
modifyMCtx fun m => { m with lAssignment := m.lAssignment.insert mvarId val }

@[export lean_assign_lmvar]
def assignLevelMVarExp (m : MetavarContext) (mvarId : LMVarId) (val : Level) : MetavarContext :=
{ m with lAssignment := m.lAssignment.insert mvarId val }

/--
Add `mvarId := x` to the metavariable assignment.
This method does not check whether `mvarId` is already assigned, nor it checks whether
Expand All @@ -487,6 +505,10 @@ This is a low-level API, and it is safer to use `isDefEq (mkMVar mvarId) x`.
def _root_.Lean.MVarId.assign [MonadMCtx m] (mvarId : MVarId) (val : Expr) : m Unit :=
modifyMCtx fun m => { m with eAssignment := m.eAssignment.insert mvarId val }

@[export lean_assign_mvar]
def assignExp (m : MetavarContext) (mvarId : MVarId) (val : Expr) : MetavarContext :=
{ m with eAssignment := m.eAssignment.insert mvarId val }

/--
Add a delayed assignment for the given metavariable. You must make sure that
the metavariable is not already assigned or delayed-assigned.
Expand Down Expand Up @@ -516,6 +538,9 @@ To avoid this term eta-expanded term, we apply beta-reduction when instantiating
This operation is performed at `instantiateExprMVars`, `elimMVarDeps`, and `levelMVarToParam`.
-/

@[extern "lean_instantiate_level_mvars"]
opaque instantiateLevelMVarsImp (mctx : MetavarContext) (l : Level) : MetavarContext × Level

partial def instantiateLevelMVars [Monad m] [MonadMCtx m] : Level → m Level
| lvl@(Level.succ lvl₁) => return Level.updateSucc! lvl (← instantiateLevelMVars lvl₁)
| lvl@(Level.max lvl₁ lvl₂) => return Level.updateMax! lvl (← instantiateLevelMVars lvl₁) (← instantiateLevelMVars lvl₂)
Expand All @@ -531,6 +556,9 @@ partial def instantiateLevelMVars [Monad m] [MonadMCtx m] : Level → m Level
| none => pure lvl
| lvl => pure lvl

@[extern "lean_instantiate_expr_mvars"]
opaque instantiateExprMVarsImp (mctx : MetavarContext) (e : Expr) : MetavarContext × Expr

/-- instantiateExprMVars main function -/
partial def instantiateExprMVars [Monad m] [MonadMCtx m] [STWorld ω m] [MonadLiftT (ST ω) m] (e : Expr) : MonadCacheT ExprStructEq Expr m Expr :=
if !e.hasMVar then
Expand Down Expand Up @@ -792,8 +820,6 @@ def localDeclDependsOnPred [Monad m] [MonadMCtx m] (localDecl : LocalDecl) (pf :

namespace MetavarContext

instance : Inhabited MetavarContext := ⟨{}⟩

@[export lean_mk_metavar_ctx]
def mkMetavarContext : Unit → MetavarContext := fun _ => {}

Expand Down
2 changes: 1 addition & 1 deletion src/kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ add_library(kernel OBJECT level.cpp expr.cpp expr_eq_fn.cpp
for_each_fn.cpp replace_fn.cpp abstract.cpp instantiate.cpp
local_ctx.cpp declaration.cpp environment.cpp type_checker.cpp
init_module.cpp expr_cache.cpp equiv_manager.cpp quot.cpp
inductive.cpp trace.cpp)
inductive.cpp trace.cpp instantiate_mvars.cpp)
95 changes: 95 additions & 0 deletions src/kernel/instantiate_mvars.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.

Authors: Leonardo de Moura
*/
#include <unordered_map>
#include "runtime/option_ref.h"
#include "kernel/instantiate.h"
#include "kernel/abstract.h"

/*
This module is not used by the kernel. It just provides an efficient implementation of
`instantiateExprMVars`
*/

namespace lean {

extern "C" object * lean_get_lmvar_assignment(obj_arg mctx, obj_arg mid);
extern "C" object * lean_assign_lmvar(obj_arg mctx, obj_arg mid, obj_arg val);

typedef object_ref metavar_ctx;
void assign_lmvar(metavar_ctx & mctx, name const & mid, level const & l) {
object * r = lean_assign_lmvar(mctx.steal(), mid.to_obj_arg(), l.to_obj_arg());
mctx.set_box(r);
}

option_ref<level> get_lmvar_assignment(metavar_ctx & mctx, name const & mid) {
return option_ref<level>(lean_get_lmvar_assignment(mctx.to_obj_arg(), mid.to_obj_arg()));
}

class instantiate_lmvar_fn {
metavar_ctx & m_mctx;
std::unordered_map<lean_object *, lean_object *> m_cache;

inline level cache(level const & l, level && r, bool shared) {
if (shared) {
m_cache.insert(mk_pair(l.raw(), r.raw()));
}
return r;
}
public:
instantiate_lmvar_fn(metavar_ctx & mctx):m_mctx(mctx) {}
level visit(level const & l) {
if (!has_mvar(l))
return l;
bool shared = false;
if (is_shared(l)) {
auto it = m_cache.find(l.raw());
if (it != m_cache.end()) {
return level(it->second, true);
}
shared = true;
}
switch (l.kind()) {
case level_kind::Succ:
return cache(l, update_succ(l, visit(succ_of(l))), shared);
case level_kind::Max: case level_kind::IMax:
return cache(l, update_max(l, visit(level_lhs(l)), visit(level_rhs(l))), shared);
case level_kind::Zero: case level_kind::Param:
lean_unreachable();
case level_kind::MVar: {
option_ref<level> r = get_lmvar_assignment(m_mctx, mvar_id(l));
if (!r) {
return l;
} else {
level a(r.get_val());
if (!has_mvar(a)) {
return a;
} else {
level a_new = visit(a);
if (!is_eqp(a, a_new)) {
assign_lmvar(m_mctx, mvar_id(l), a_new);
}
return a_new;
}
}
}}
}
level operator()(level const & l) { return visit(l); }
};

extern "C" LEAN_EXPORT object * lean_instantiate_level_mvars(object * m, object * l) {
metavar_ctx mctx(m);
level l_new = instantiate_lmvar_fn(mctx)(level(l));
object * r = alloc_cnstr(0, 2, 0);
cnstr_set(r, 0, mctx.steal());
cnstr_set(r, 1, l_new.steal());
return r;
}

extern "C" LEAN_EXPORT object * lean_instantiate_expr_mvars(object *, object *) {
lean_internal_panic("not implemented yet");
}
}
2 changes: 2 additions & 0 deletions src/kernel/level.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ inline bool operator!=(level const & l1, level const & l2) { return !operator==(
struct level_hash { unsigned operator()(level const & n) const { return n.hash(); } };
struct level_eq { bool operator()(level const & n1, level const & n2) const { return n1 == n2; } };

inline bool is_shared(level const & l) { return !is_exclusive(l.raw()); }

inline optional<level> none_level() { return optional<level>(); }
inline optional<level> some_level(level const & e) { return optional<level>(e); }
inline optional<level> some_level(level && e) { return optional<level>(std::forward<level>(e)); }
Expand Down
4 changes: 4 additions & 0 deletions src/runtime/object_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ class object_ref {
s.m_obj = box(0);
return *this;
}
void set_box(object * o) {
lean_assert(is_scalar(m_obj));
m_obj = o;
}
object * raw() const { return m_obj; }
object * steal() { object * r = m_obj; m_obj = box(0); return r; }
object * to_obj_arg() const { inc(m_obj); return m_obj; }
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/option_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class option_ref : public object_ref {
explicit operator bool() const { return !is_scalar(raw()); }
optional<T> get() const { return *this ? some(static_cast<T const &>(cnstr_get_ref(*this, 0))) : optional<T>(); }

T get_val() const { lean_assert(*this); return static_cast<T const &>(cnstr_get_ref(*this, 0)); }

/** \brief Structural equality. */
friend bool operator==(option_ref const & o1, option_ref const & o2) {
return o1.get() == o2.get();
Expand Down
Loading