Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Passing struct to new functions #503

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 46 additions & 51 deletions act/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,91 +20,86 @@ type IRegistry interface {

// Registry keeps track of all policies and actions.
type Registry struct {
logger zerolog.Logger
Logger zerolog.Logger
// Timeout for policy evaluation.
policyTimeout time.Duration
PolicyTimeout time.Duration
// Default timeout for running actions
defaultActionTimeout time.Duration

Signals map[string]*sdkAct.Signal
Policies map[string]*sdkAct.Policy
Actions map[string]*sdkAct.Action
DefaultPolicy *sdkAct.Policy
DefaultSignal *sdkAct.Signal
DefaultActionTimeout time.Duration

Signals map[string]*sdkAct.Signal
Policies map[string]*sdkAct.Policy
Actions map[string]*sdkAct.Action
DefaultPolicyName string
DefaultPolicy *sdkAct.Policy
DefaultSignal *sdkAct.Signal
}

var _ IRegistry = (*Registry)(nil)

// NewActRegistry creates a new act registry with the specified default policy and timeout
// and the builtin signals, policies, and actions.
func NewActRegistry(
builtinSignals map[string]*sdkAct.Signal,
builtinsPolicies map[string]*sdkAct.Policy,
builtinActions map[string]*sdkAct.Action,
defaultPolicy string,
policyTimeout time.Duration,
defaultActionTimeout time.Duration,
logger zerolog.Logger,
registry Registry,
) *Registry {
if builtinSignals == nil || builtinsPolicies == nil || builtinActions == nil {
logger.Warn().Msg("Builtin signals, policies, or actions are nil, not adding")
if registry.Signals == nil || registry.Policies == nil || registry.Actions == nil {
registry.Logger.Warn().Msg("Builtin signals, policies, or actions are nil, not adding")
return nil
}

for _, signal := range builtinSignals {
for _, signal := range registry.Signals {
if signal == nil {
logger.Warn().Msg("Signal is nil, not adding")
registry.Logger.Warn().Msg("Signal is nil, not adding")
return nil
}
logger.Debug().Str("name", signal.Name).Msg("Registered builtin signal")
registry.Logger.Debug().Str("name", signal.Name).Msg("Registered builtin signal")
}

for _, policy := range builtinsPolicies {
for _, policy := range registry.Policies {
if policy == nil {
logger.Warn().Msg("Policy is nil, not adding")
registry.Logger.Warn().Msg("Policy is nil, not adding")
return nil
}
logger.Debug().Str("name", policy.Name).Msg("Registered builtin policy")
registry.Logger.Debug().Str("name", policy.Name).Msg("Registered builtin policy")
}

for _, action := range builtinActions {
for _, action := range registry.Actions {
if action == nil {
logger.Warn().Msg("Action is nil, not adding")
registry.Logger.Warn().Msg("Action is nil, not adding")
return nil
}
logger.Debug().Str("name", action.Name).Msg("Registered builtin action")
registry.Logger.Debug().Str("name", action.Name).Msg("Registered builtin action")
}

// The default policy must exist, otherwise use passthrough.
if _, exists := builtinsPolicies[defaultPolicy]; !exists || defaultPolicy == "" {
logger.Warn().Str("name", defaultPolicy).Msgf(
if _, exists := registry.Policies[registry.DefaultPolicyName]; !exists || registry.DefaultPolicyName == "" {
registry.Logger.Warn().Str("name", registry.DefaultPolicyName).Msgf(
"The specified default policy does not exist, using %s", config.DefaultPolicy)
defaultPolicy = config.DefaultPolicy
registry.DefaultPolicyName = config.DefaultPolicy
}

logger.Debug().Str("name", defaultPolicy).Msg("Using default policy")
registry.Logger.Debug().Str("name", registry.DefaultPolicyName).Msg("Using default policy")

return &Registry{
logger: logger,
policyTimeout: policyTimeout,
defaultActionTimeout: defaultActionTimeout,
Signals: builtinSignals,
Policies: builtinsPolicies,
Actions: builtinActions,
DefaultPolicy: builtinsPolicies[defaultPolicy],
DefaultSignal: builtinSignals[defaultPolicy],
Logger: registry.Logger,
PolicyTimeout: registry.PolicyTimeout,
DefaultActionTimeout: registry.DefaultActionTimeout,
Signals: registry.Signals,
Policies: registry.Policies,
Actions: registry.Actions,
DefaultPolicy: registry.Policies[registry.DefaultPolicyName],
DefaultSignal: registry.Signals[registry.DefaultPolicyName],
}
}

// Add adds a policy to the registry.
func (r *Registry) Add(policy *sdkAct.Policy) {
if policy == nil {
r.logger.Warn().Msg("Policy is nil, not adding")
r.Logger.Warn().Msg("Policy is nil, not adding")
return
}

if _, exists := r.Policies[policy.Name]; exists {
r.logger.Warn().Str("name", policy.Name).Msg("Policy already exists, overwriting")
r.Logger.Warn().Str("name", policy.Name).Msg("Policy already exists, overwriting")
}

// Builtin policies are can be overwritten by user-defined policies.
Expand All @@ -115,7 +110,7 @@ func (r *Registry) Add(policy *sdkAct.Policy) {
func (r *Registry) Apply(signals []sdkAct.Signal) []*sdkAct.Output {
// If there are no signals, apply the default policy.
if len(signals) == 0 {
r.logger.Debug().Msg("No signals provided, applying default signal")
r.Logger.Debug().Msg("No signals provided, applying default signal")
return r.Apply([]sdkAct.Signal{*r.DefaultSignal})
}

Expand All @@ -138,15 +133,15 @@ func (r *Registry) Apply(signals []sdkAct.Signal) []*sdkAct.Output {
// If the signal is terminal, all non-terminal signals are ignored. Also, it only
// makes sense to have a terminal signal if the action is synchronous and terminal.
if len(terminal) > 0 && slices.Contains(nonTerminal, signal.Name) {
r.logger.Warn().Str("name", signal.Name).Msg(
r.Logger.Warn().Str("name", signal.Name).Msg(
"Terminal signal takes precedence, ignoring non-terminal signals")
continue
}

// Apply the signal and append the output to the list of outputs.
output, err := r.apply(signal)
if err != nil {
r.logger.Error().Err(err).Str("name", signal.Name).Msg("Error applying signal")
r.Logger.Error().Err(err).Str("name", signal.Name).Msg("Error applying signal")
// If there is an error evaluating the policy, continue to the next signal.
// This also prevents stack overflows from infinite loops of the external
// if condition below.
Expand Down Expand Up @@ -179,7 +174,7 @@ func (r *Registry) apply(signal sdkAct.Signal) (*sdkAct.Output, *gerr.GatewayDEr
}

// Create a context with a timeout for policy evaluation.
ctx, cancel := context.WithTimeout(context.Background(), r.policyTimeout)
ctx, cancel := context.WithTimeout(context.Background(), r.PolicyTimeout)
defer cancel()

// Evaluate the policy.
Expand Down Expand Up @@ -219,21 +214,21 @@ func (r *Registry) Run(
if output == nil {
// This should never happen, since the output is always set by the registry
// to be the default policy if no signals are provided.
r.logger.Debug().Msg("Output is nil, run aborted")
r.Logger.Debug().Msg("Output is nil, run aborted")
return nil, gerr.ErrNilPointer
}

action, ok := r.Actions[output.MatchedPolicy]
if !ok {
r.logger.Warn().Str("matchedPolicy", output.MatchedPolicy).Msg(
r.Logger.Warn().Str("matchedPolicy", output.MatchedPolicy).Msg(
"Action does not exist, run aborted")
return nil, gerr.ErrActionNotExist
}

// Prepend the logger to the parameters.
params = append([]sdkAct.Parameter{WithLogger(r.logger)}, params...)
params = append([]sdkAct.Parameter{WithLogger(r.Logger)}, params...)

timeout := r.defaultActionTimeout
timeout := r.DefaultActionTimeout
if action.Timeout > 0 {
timeout = time.Duration(action.Timeout) * time.Second
}
Expand All @@ -248,13 +243,13 @@ func (r *Registry) Run(
// If the action is synchronous, run it and return the result immediately.
if action.Sync {
defer cancel()
return runActionWithTimeout(ctx, action, output, params, r.logger)
return runActionWithTimeout(ctx, action, output, params, r.Logger)
}

// Run the action asynchronously.
go func() {
defer cancel()
_, _ = runActionWithTimeout(ctx, action, output, params, r.logger)
_, _ = runActionWithTimeout(ctx, action, output, params, r.Logger)
}()
return nil, gerr.ErrAsyncAction
}
Expand Down
Loading