From 4d9f0da8da9f7f38f5fde489c4dbca66ec39444a Mon Sep 17 00:00:00 2001 From: ikawaha Date: Sat, 16 Oct 2021 14:16:17 +0900 Subject: [PATCH] Add a heap of nodes --- tokenizer/lattice/lattice.go | 8 +++-- tokenizer/lattice/node.go | 58 ++++++++++++++++++++++++++++++++++ tokenizer/lattice/node_test.go | 52 ++++++++++++++++++++++++++++++ 3 files changed, 116 insertions(+), 2 deletions(-) diff --git a/tokenizer/lattice/lattice.go b/tokenizer/lattice/lattice.go index fe61ae6..4247ef8 100644 --- a/tokenizer/lattice/lattice.go +++ b/tokenizer/lattice/lattice.go @@ -212,7 +212,7 @@ func additionalCost(n *Node) int { return 0 } -// Forward runs forward algorithm of the Viterbi. +// Forward is the forward algorithm of the Viterbi. func (la *Lattice) Forward(m TokenizeMode) { for i, size := 1, len(la.list); i < size; i++ { currentList := la.list[i] @@ -243,7 +243,7 @@ func (la *Lattice) Forward(m TokenizeMode) { } } -// Backward runs backward algorithm of the Viterbi. +// Backward is the backward algorithm of the Viterbi. func (la *Lattice) Backward(m TokenizeMode) { const bufferExpandRatio = 2 size := len(la.list) @@ -279,6 +279,10 @@ func (la *Lattice) Backward(m TokenizeMode) { } } +func (la *Lattice) NBestBackward(m TokenizeMode) { + /** TODO **/ +} + func posFeature(d *dict.Dict, u *dict.UserDict, t *Node) string { var ret []string switch t.Class { diff --git a/tokenizer/lattice/node.go b/tokenizer/lattice/node.go index fbd5f29..448fbd4 100644 --- a/tokenizer/lattice/node.go +++ b/tokenizer/lattice/node.go @@ -50,3 +50,61 @@ var nodePool = sync.Pool{ return new(Node) }, } + +type NodeHeap struct { + list []*Node + less func(x, y *Node) bool +} + +// Push adds a node to the heap. +func (h *NodeHeap) Push(n *Node) { + i := len(h.list) + h.list = append(h.list, n) + for i != 0 { + p := (i - 1) / 2 + if !h.less(h.list[p], h.list[i]) { + h.list[p], h.list[i] = h.list[i], h.list[p] + } + i = p + } +} + +// Pop returns the highest priority node of the heap. If the heap is empty, Pop returns nil. +func (h *NodeHeap) Pop() *Node { + if len(h.list) < 1 { + return nil + } + ret := h.list[0] + if len(h.list) > 1 { + h.list[0] = h.list[len(h.list)-1] + } + h.list[len(h.list)-1] = nil + h.list = h.list[:len(h.list)-1] + + for i := 0; ; { + min := i + if left := (i+1)*2 - 1; left < len(h.list) && !h.less(h.list[min], h.list[left]) { + min = left + } + if right := (i + 1) * 2; right < len(h.list) && !h.less(h.list[min], h.list[right]) { + min = right + } + if min == i { + break + } + h.list[i], h.list[min] = h.list[min], h.list[i] + i = min + } + + return ret +} + +// Empty returns true if the heap is empty. +func (h NodeHeap) Empty() bool { + return len(h.list) == 0 +} + +// Size returns the size of the heap. +func (h NodeHeap) Size() int { + return len(h.list) +} diff --git a/tokenizer/lattice/node_test.go b/tokenizer/lattice/node_test.go index 72a5c98..5ca7063 100644 --- a/tokenizer/lattice/node_test.go +++ b/tokenizer/lattice/node_test.go @@ -1,6 +1,7 @@ package lattice import ( + "reflect" "testing" ) @@ -21,3 +22,54 @@ func Test_NodeClassString(t *testing.T) { } } } + +func TestNodeHeap_PushPop(t *testing.T) { + idSorter := func(x, y *Node) bool { + return x.ID < y.ID + } + heap := NodeHeap{ + less: idSorter, + } + testdata := []struct { + name string + ids []int + want []int + }{ + { + name: "ascending order", + ids: []int{1, 2, 3, 4, 5, 6, 7}, + want: []int{1, 2, 3, 4, 5, 6, 7}, + }, + { + name: "descending order", + ids: []int{7, 6, 5, 4, 3, 2, 1}, + want: []int{1, 2, 3, 4, 5, 6, 7}, + }, + { + name: "random order", + ids: []int{3, 6, 4, 1, 7, 5, 2}, + want: []int{1, 2, 3, 4, 5, 6, 7}, + }, + { + name: "list /w duplicate items", + ids: []int{3, 6, 3, 4, 1, 3, 6, 2, 7, 5, 2}, + want: []int{1, 2, 2, 3, 3, 3, 4, 5, 6, 6, 7}, + }, + } + for _, data := range testdata { + for _, v := range data.ids { + heap.Push(&Node{ID: v}) + } + got := make([]int, 0, heap.Size()) + for !heap.Empty() { + n := heap.Pop() + if n == nil { + t.Fatalf("unexpected nil node, heap=%+v", heap) + } + got = append(got, n.ID) + } + if !reflect.DeepEqual(got, data.want) { + t.Errorf("got %+v, want %+v", got, data.want) + } + } +}