diff --git a/.golangci.yaml b/.golangci.yaml index a23b0843..ef70c4fb 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -90,6 +90,9 @@ issues: - gosec text: "G404:" # warning about insecure math/rand. We dont care about this in tests! path: "\\w*_test.go" + - linters: + - gosimple + text: "S1023:" # allow redundant return statements. They can be nice for readability. # Enable default excludes, for common sense values. exclude-use-default: true diff --git a/ext/common_test.go b/ext/common_test.go index 2ca8594c..c82affd6 100644 --- a/ext/common_test.go +++ b/ext/common_test.go @@ -6,6 +6,7 @@ import ( type DummyHandler struct { F func(b *gotgbot.Bot, ctx *Context) error + N string } func (d DummyHandler) CheckUpdate(b *gotgbot.Bot, ctx *Context) bool { @@ -17,7 +18,7 @@ func (d DummyHandler) HandleUpdate(b *gotgbot.Bot, ctx *Context) error { } func (d DummyHandler) Name() string { - return "dummy" + return "dummy" + d.N } func (u *Updater) InjectUpdate(token string, upd gotgbot.Update) error { diff --git a/ext/dispatcher.go b/ext/dispatcher.go index 8913f4c4..f64e32ed 100644 --- a/ext/dispatcher.go +++ b/ext/dispatcher.go @@ -6,7 +6,6 @@ import ( "fmt" "log" "runtime/debug" - "sort" "strings" "sync" @@ -71,10 +70,8 @@ type Dispatcher struct { // If nil, logging is done via the log package's standard logger. ErrorLog *log.Logger - // handlerGroups represents the list of available handler groups, numerically sorted. - handlerGroups []int - // handlers represents all available handles, split into groups (see handlerGroups). - handlers map[int][]Handler + // handlers represents all available handlers. + handlers handlerMappings // limiter is how we limit the maximum number of goroutines for handling updates. // if nil, this is a limitless dispatcher. @@ -152,7 +149,7 @@ func NewDispatcher(opts *DispatcherOpts) *Dispatcher { Panic: panicHandler, UnhandledErrFunc: unhandledErrFunc, ErrorLog: errLog, - handlers: make(map[int][]Handler), + handlers: handlerMappings{}, limiter: limiter, waitGroup: sync.WaitGroup{}, } @@ -228,13 +225,21 @@ func (d *Dispatcher) AddHandler(handler Handler) { } // AddHandlerToGroup adds a handler to a specific group; lowest number will be processed first. -func (d *Dispatcher) AddHandlerToGroup(handler Handler, group int) { - currHandlers, ok := d.handlers[group] - if !ok { - d.handlerGroups = append(d.handlerGroups, group) - sort.Ints(d.handlerGroups) - } - d.handlers[group] = append(currHandlers, handler) +func (d *Dispatcher) AddHandlerToGroup(h Handler, group int) { + d.handlers.add(h, group) +} + +// RemoveHandlerFromGroup removes a handler by name from the specified group. +// If multiple handlers have the same name, only the first one is removed. +// Returns true if the handler was successfully removed. +func (d *Dispatcher) RemoveHandlerFromGroup(handlerName string, group int) bool { + return d.handlers.remove(handlerName, group) +} + +// RemoveGroup removes an entire group from the dispatcher's processing. +// If group can't be found, this is a noop. +func (d *Dispatcher) RemoveGroup(group int) bool { + return d.handlers.removeGroup(group) } // processRawUpdate takes a JSON update to be unmarshalled and processed by Dispatcher.ProcessUpdate. @@ -274,8 +279,8 @@ func (d *Dispatcher) ProcessUpdate(b *gotgbot.Bot, u *gotgbot.Update, data map[s } func (d *Dispatcher) iterateOverHandlerGroups(b *gotgbot.Bot, ctx *Context) error { - for _, groupNum := range d.handlerGroups { - for _, handler := range d.handlers[groupNum] { + for _, groups := range d.handlers.getGroups() { + for _, handler := range groups { if !handler.CheckUpdate(b, ctx) { // Handler filter doesn't match this update; continue. continue diff --git a/ext/dispatcher_ext_test.go b/ext/dispatcher_ext_test.go new file mode 100644 index 00000000..6a122e8f --- /dev/null +++ b/ext/dispatcher_ext_test.go @@ -0,0 +1,182 @@ +package ext_test + +import ( + "sort" + "testing" + + "github.com/PaulSonOfLars/gotgbot/v2" + "github.com/PaulSonOfLars/gotgbot/v2/ext" + "github.com/PaulSonOfLars/gotgbot/v2/ext/handlers" + "github.com/PaulSonOfLars/gotgbot/v2/ext/handlers/filters/message" +) + +func TestDispatcher(t *testing.T) { + type testHandler struct { + group int + shouldRun bool + returnVal error + } + + for name, testParams := range map[string]struct { + handlers []testHandler + numMatches int + }{ + "one group two handlers": { + handlers: []testHandler{ + { + group: 0, + shouldRun: true, + returnVal: nil, + }, { + group: 0, + shouldRun: false, // same group, so doesnt run + returnVal: nil, + }, + }, + numMatches: 1, + }, + "two handlers two groups": { + handlers: []testHandler{ + { + group: 0, + shouldRun: true, + returnVal: nil, + }, { + group: 1, + shouldRun: true, // second group, so also runs + returnVal: nil, + }, + }, + numMatches: 2, + }, + "end groups": { + handlers: []testHandler{ + { + group: 0, + shouldRun: true, + returnVal: ext.EndGroups, + }, { + group: 1, + shouldRun: false, // ended, so second group doesnt run + returnVal: nil, + }, + }, + numMatches: 1, + }, + "continue groups": { + handlers: []testHandler{ + { + group: 0, + shouldRun: true, + returnVal: ext.ContinueGroups, + }, { + group: 0, + shouldRun: true, // continued, so second item in same group runs + returnVal: nil, + }, + }, + numMatches: 2, + }, + } { + name, testParams := name, testParams + + t.Run(name, func(t *testing.T) { + d := ext.NewDispatcher(nil) + var events []int + for idx, h := range testParams.handlers { + idx, h := idx, h + + t.Logf("Loading handler %d in group %d", idx, h.group) + d.AddHandlerToGroup(handlers.NewMessage(message.All, func(b *gotgbot.Bot, ctx *ext.Context) error { + if !h.shouldRun { + t.Errorf("handler %d in group %d should not have run", idx, h.group) + t.FailNow() + } + + t.Logf("handler %d in group %d has run, as expected", idx, h.group) + events = append(events, idx) + return h.returnVal + }), h.group) + } + + t.Log("Processing one update...") + err := d.ProcessUpdate(nil, &gotgbot.Update{ + Message: &gotgbot.Message{Text: "test text"}, + }, nil) + if err != nil { + t.Errorf("Unexpected error while processing updates: %s", err.Error()) + } + + // ensure events handled in order + if !sort.IntsAreSorted(events) { + t.Errorf("order of events is not sorted: %v", events) + } + if len(events) != testParams.numMatches { + t.Errorf("got %d matches, expected %d ", len(events), testParams.numMatches) + } + }) + } +} + +func TestDispatcher_RemoveHandlerFromGroup(t *testing.T) { + d := ext.NewDispatcher(nil) + + const removeMe = "remove_me" + const group = 0 + + d.AddHandlerToGroup(handlers.NewNamedhandler(removeMe, handlers.NewMessage(message.All, nil)), group) + + if found := d.RemoveHandlerFromGroup(removeMe, group); !found { + t.Errorf("RemoveHandlerFromGroup() = %v, want true", found) + } +} + +func TestDispatcher_RemoveOneHandlerFromGroup(t *testing.T) { + d := ext.NewDispatcher(nil) + + const removeMe = "remove_me" + const group = 0 + + // Load handler twice. + d.AddHandlerToGroup(handlers.NewNamedhandler(removeMe, handlers.NewMessage(message.All, nil)), group) + d.AddHandlerToGroup(handlers.NewNamedhandler(removeMe, handlers.NewMessage(message.All, nil)), group) + + // Remove handler twice. + if found := d.RemoveHandlerFromGroup(removeMe, group); !found { + t.Errorf("RemoveHandlerFromGroup() = %v, want true", found) + } + if found := d.RemoveHandlerFromGroup(removeMe, group); !found { + t.Errorf("RemoveHandlerFromGroup() = %v, want true", found) + } + // fail! only 2 in there. + if found := d.RemoveHandlerFromGroup(removeMe, group); found { + t.Errorf("RemoveHandlerFromGroup() = %v, want false", found) + } +} + +func TestDispatcher_RemoveHandlerNonExistingHandlerFromGroup(t *testing.T) { + d := ext.NewDispatcher(nil) + + const keepMe = "keep_me" + const removeMe = "remove_me" + const group = 0 + + d.AddHandlerToGroup(handlers.NewNamedhandler(keepMe, handlers.NewMessage(message.All, nil)), group) + + if found := d.RemoveHandlerFromGroup(removeMe, group); found { + t.Errorf("RemoveHandlerFromGroup() = %v, want false", found) + } +} + +func TestDispatcher_RemoveHandlerHandlerFromNonExistingGroup(t *testing.T) { + d := ext.NewDispatcher(nil) + + const removeMe = "remove_me" + const group = 0 + const wrongGroup = 1 + d.AddHandlerToGroup(handlers.NewNamedhandler(removeMe, handlers.NewMessage(message.All, nil)), group) + + if found := d.RemoveHandlerFromGroup(removeMe, wrongGroup); found { + t.Errorf("RemoveHandlerFromGroup() = %v, want false", found) + } +} diff --git a/ext/handler_mapping.go b/ext/handler_mapping.go new file mode 100644 index 00000000..c25994e5 --- /dev/null +++ b/ext/handler_mapping.go @@ -0,0 +1,112 @@ +package ext + +import ( + "sort" + "sync" +) + +type handlerMappings struct { + // mutex is used to ensure everything threadsafe. + mutex sync.RWMutex + + // handlerGroups represents the list of available handler groups, numerically sorted. + handlerGroups []int + // handlers represents all available handlers, split into groups (see handlerGroups). + handlers map[int][]Handler +} + +func (m *handlerMappings) add(h Handler, group int) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.handlers == nil { + m.handlers = map[int][]Handler{} + } + currHandlers, ok := m.handlers[group] + if !ok { + m.handlerGroups = append(m.handlerGroups, group) + sort.Ints(m.handlerGroups) + } + m.handlers[group] = append(currHandlers, h) +} + +func (m *handlerMappings) remove(name string, group int) bool { + m.mutex.Lock() + defer m.mutex.Unlock() + + currHandlers, ok := m.handlers[group] + if !ok { + // group does not exist; removal failed. + return false + } + + for idx, handler := range currHandlers { + if handler.Name() != name { + continue + } + + // Only one item left, so just delete the group entirely. + if len(currHandlers) == 1 { + // get index of the current group to remove it from the list of handlergroups + gIdx := getIndex(group, m.handlerGroups) + if gIdx != -1 { + m.handlerGroups = append(m.handlerGroups[:gIdx], m.handlerGroups[gIdx+1:]...) + } + delete(m.handlers, group) + return true + } + + // Make sure to copy the handler list to ensure we don't change the values of the underlying arrays, which + // could cause slice access issues when used concurrently. + newHandlers := make([]Handler, len(m.handlers[group])) + copy(newHandlers, m.handlers[group]) + + m.handlers[group] = append(newHandlers[:idx], newHandlers[idx+1:]...) + return true + } + // handler not found - removal failed. + return false +} + +func getIndex(find int, is []int) int { + for i, v := range is { + if v == find { + return i + } + } + return -1 +} + +func (m *handlerMappings) removeGroup(group int) bool { + m.mutex.Lock() + defer m.mutex.Unlock() + + if _, ok := m.handlers[group]; !ok { + // Group doesn't exist in map, so already removed. + return false + } + + for idx, handlerGroup := range m.handlerGroups { + if handlerGroup != group { + continue + } + + m.handlerGroups = append(m.handlerGroups[:idx], m.handlerGroups[idx+1:]...) + delete(m.handlers, group) + // Group found, and deleted. Success! + return true + } + // Group not found in list - so already removed. + return false +} + +func (m *handlerMappings) getGroups() [][]Handler { + m.mutex.RLock() + defer m.mutex.RUnlock() + + allHandlers := make([][]Handler, len(m.handlerGroups)) + for idx, num := range m.handlerGroups { + allHandlers[idx] = m.handlers[num] + } + return allHandlers +} diff --git a/ext/handler_mapping_test.go b/ext/handler_mapping_test.go new file mode 100644 index 00000000..77d5cee1 --- /dev/null +++ b/ext/handler_mapping_test.go @@ -0,0 +1,117 @@ +package ext + +import ( + "testing" +) + +// This test should demonstrate that once obtained, a list will not be changed by any additions/removals to that list by another call. +func Test_handlerMappings_getGroupsConcurrentSafe(t *testing.T) { + m := handlerMappings{} + firstHandler := DummyHandler{N: "first"} + secondHandler := DummyHandler{N: "second"} + + // We expect 0 groups at the start + startGroups := m.getGroups() + if len(startGroups) != 0 { + t.Errorf("failed predicate group layout") + } + + // Add one handler. + m.add(firstHandler, 0) + currGroups := m.getGroups() + if len(currGroups) != 1 && len(startGroups) != 0 { + t.Errorf("Start groups should be 0, curr groups should be 1; got %d and %d", len(startGroups), len(currGroups)) + } + checkList(t, "currGroups", currGroups[0], firstHandler) + + // Add a second handler. + m.add(secondHandler, 0) + newGroups := m.getGroups() + checkList(t, "newgroups;currGroups", currGroups[0], firstHandler) + checkList(t, "newgroups;newGroups", newGroups[0], firstHandler, secondHandler) + + // Remove second handler.. + ok := m.remove(secondHandler.Name(), 0) + if !ok { + t.Errorf("failed to remove second handler") + } + delGroups := m.getGroups() + checkList(t, "delgroups;currGroups", currGroups[0], firstHandler) + checkList(t, "delgroups;newGroups", newGroups[0], firstHandler, secondHandler) + checkList(t, "delgroups;delGroups", delGroups[0], firstHandler) + + // Re-add second handler. + m.add(secondHandler, 0) + reAddedGroups := m.getGroups() + checkList(t, "readded;currGroups", currGroups[0], firstHandler) + checkList(t, "readded;newGroups", newGroups[0], firstHandler, secondHandler) + checkList(t, "readded;delGroups", delGroups[0], firstHandler) + checkList(t, "readded;reAddedGroups", reAddedGroups[0], firstHandler, secondHandler) + + // Remove first handler. + ok = m.remove(firstHandler.Name(), 0) + if !ok { + t.Errorf("failed to remove second handler") + } + noFirstGroups := m.getGroups() + checkList(t, "nofirst;currGroups", currGroups[0], firstHandler) + checkList(t, "nofirst;newGroups", newGroups[0], firstHandler, secondHandler) + checkList(t, "nofirst;delGroups", delGroups[0], firstHandler) + checkList(t, "nofirst;reAddedGroups", reAddedGroups[0], firstHandler, secondHandler) + checkList(t, "nofirst;noFirstGroups", noFirstGroups[0], secondHandler) +} + +func checkList(t *testing.T, name string, got []Handler, expected ...Handler) { + if len(got) != len(expected) { + t.Errorf("mismatch on length of expected outputs for %s - got %d, expected %d", name, len(got), len(expected)) + } + for idx, v := range got { + if v.Name() != expected[idx].Name() { + t.Errorf("unexpected output name for %s - IDX %d got %s, expected %s", name, idx, v.Name(), expected[idx].Name()) + } + } +} + +func Test_handlerMappings_remove(t *testing.T) { + m := &handlerMappings{} + handler := DummyHandler{N: "test"} + + t.Run("nonExistent", func(t *testing.T) { + // removing an item that doesnt exist returns "false" + if got := m.remove(handler.Name(), 0); got { + t.Errorf("remove() = %v, want false", got) + } + }) + + t.Run("removalSuccess", func(t *testing.T) { + m.add(handler, 0) + // removing an item that DOES exist, returns true + if got := m.remove(handler.Name(), 0); !got { + t.Errorf("remove() = %v, want true", got) + } + // And so the second time, it returns false + if got := m.remove(handler.Name(), 0); got { + t.Errorf("remove() = %v, want false", got) + } + }) + + t.Run("removalSuccess", func(t *testing.T) { + m.add(handler, 0) + // removing an item that DOES exist, returns true + if got := m.remove(handler.Name(), 0); !got { + t.Errorf("remove() = %v, want true", got) + } + // And so the second time, it returns false + if got := m.remove(handler.Name(), 0); got { + t.Errorf("remove() = %v, want false", got) + } + }) + + t.Run("removalDifferentIndexes", func(t *testing.T) { + m.add(handler, 1) + m.add(handler, 2) + if got := m.remove(handler.Name(), 2); !got { + t.Errorf("remove() = %v, want true", got) + } + }) +} diff --git a/ext/handlers/named.go b/ext/handlers/named.go new file mode 100644 index 00000000..ca91ca68 --- /dev/null +++ b/ext/handlers/named.go @@ -0,0 +1,23 @@ +package handlers + +import ( + "github.com/PaulSonOfLars/gotgbot/v2/ext" +) + +type Named struct { + // Custom name to identify handler by + CustomName string + // Inlined version of parent handler to inherit methods. + ext.Handler +} + +func (n Named) Name() string { + return n.CustomName +} + +func NewNamedhandler(name string, handler ext.Handler) Named { + return Named{ + CustomName: name, + Handler: handler, + } +}