Skip to content

Commit

Permalink
fix: unfair lock behavior in copied stream (#71)
Browse files Browse the repository at this point in the history
* fix: unfair lock behavior in copied stream

* feat: modify struct and add comments for stream copy

* feat: add ErrRecvAfterClosed error
  • Loading branch information
N3kox authored Feb 20, 2025
1 parent 4bfe584 commit 5e5dc1f
Show file tree
Hide file tree
Showing 3 changed files with 226 additions and 143 deletions.
194 changes: 68 additions & 126 deletions schema/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
package schema

import (
"container/list"
"errors"
"io"
"reflect"
"runtime/debug"
"sync"
"sync/atomic"

"github.com/cloudwego/eino/utils/safe"
)
Expand All @@ -45,6 +45,10 @@ import (
// DO NOT use it under other circumstances.
var ErrNoValue = errors.New("no value")

// ErrRecvAfterClosed indicates that StreamReader.Recv was unexpectedly called after StreamReader.Close.
// This error should not occur during normal use of StreamReader.Recv. If it does, please check your application code.
var ErrRecvAfterClosed = errors.New("recv after stream closed")

// Pipe creates a new stream with the given capacity that represented with StreamWriter and StreamReader.
// The capacity is the maximum number of items that can be buffered in the stream.
// e.g.
Expand Down Expand Up @@ -531,23 +535,29 @@ func (srw *streamReaderWithConvert[T]) toStream() *stream[T] {
return ret
}

type listElement[T any] struct {
item streamItem[T]
refCount int
type cpStreamElement[T any] struct {
once sync.Once
next *cpStreamElement[T]
item streamItem[T]
}

// copyStreamReaders creates multiple independent StreamReaders from a single StreamReader.
// Each child StreamReader can read from the original stream independently.
func copyStreamReaders[T any](sr *StreamReader[T], n int) []*StreamReader[T] {
cpsr := &parentStreamReader[T]{
sr: sr,
recvMu: sync.Mutex{},
mem: &cpStreamMem[T]{
mu: sync.Mutex{},
buf: list.New(),
subStreamList: make([]*list.Element, n),
closedNum: 0,
closedList: make([]bool, n),
hasFinished: false,
},
sr: sr,
subStreamList: make([]*cpStreamElement[T], n),
closedNum: 0,
}

// Initialize subStreamList with an empty element, which acts like a tail node.
// A nil element (used for dereference) represents that the child has been closed.
// It is challenging to link the previous and current elements when the length of the original channel is unknown.
// Additionally, using a previous pointer complicates dereferencing elements, possibly requiring reference counting.
elem := &cpStreamElement[T]{}

for i := range cpsr.subStreamList {
cpsr.subStreamList[i] = elem
}

ret := make([]*StreamReader[T], n)
Expand All @@ -565,131 +575,63 @@ func copyStreamReaders[T any](sr *StreamReader[T], n int) []*StreamReader[T] {
}

type parentStreamReader[T any] struct {
// sr is the original StreamReader.
sr *StreamReader[T]

recvMu sync.Mutex

mem *cpStreamMem[T]
}

type cpStreamMem[T any] struct {
mu sync.Mutex

buf *list.List
subStreamList []*list.Element

closedNum int
closedList []bool

hasFinished bool
}

func (c *parentStreamReader[T]) peek(idx int) (T, error) {
if t, err, ok := c.mem.peek(idx); ok {
return t, err
}

c.recvMu.Lock()
defer c.recvMu.Unlock()

// retry read from buffer
if t, err, ok := c.mem.peek(idx); ok {
return t, err
}

// get value from StreamReader
nChunk, err := c.sr.Recv()

c.mem.set(idx, nChunk, err)

return nChunk, err
}

func (c *parentStreamReader[T]) close(idx int) {
if allClosed := c.mem.close(idx); allClosed {
c.sr.Close()
}
}

func (m *cpStreamMem[T]) peek(idx int) (T, error, bool) {
m.mu.Lock()
defer m.mu.Unlock()

if elem := m.subStreamList[idx]; elem != nil {
next := elem.Next()
cElem := elem.Value.(*listElement[T]) // nolint: byted_interface_check_golintx
cElem.refCount--
if cElem.refCount == 0 {
m.buf.Remove(elem)
// subStreamList maps each child's index to its latest read chunk.
// Each value comes from a hidden linked list of cpStreamElement.
subStreamList []*cpStreamElement[T]

// closedNum is the count of closed children.
closedNum uint32
}

// peek is not safe for concurrent use with the same idx but is safe for different idx.
// Ensure that each child StreamReader uses a for-loop in a single goroutine.
func (p *parentStreamReader[T]) peek(idx int) (t T, err error) {
elem := p.subStreamList[idx]
if elem == nil {
// Unexpected call to receive after the child has been closed.
return t, ErrRecvAfterClosed
}

// The sync.Once here is used to:
// 1. Write the content of this cpStreamElement.
// 2. Initialize the 'next' field of this cpStreamElement with an empty cpStreamElement,
// similar to the initialization in copyStreamReaders.
elem.once.Do(func() {
t, err = p.sr.Recv()
elem.item = streamItem[T]{chunk: t, err: err}
if err != io.EOF {
elem.next = &cpStreamElement[T]{}
p.subStreamList[idx] = elem.next
}
})

m.subStreamList[idx] = next
return cElem.item.chunk, cElem.item.err, true
}

var t T

if m.hasFinished {
return t, io.EOF, true
}

return t, nil, false
}

func (m *cpStreamMem[T]) set(idx int, nChunk T, err error) {
m.mu.Lock()
defer m.mu.Unlock()

if err == io.EOF { // nolint: byted_s_error_binary
m.hasFinished = true
return
}

nElem := &listElement[T]{
item: streamItem[T]{chunk: nChunk, err: err},
refCount: len(m.subStreamList) - m.closedNum - 1, // except chan receiver
// The element has been set and will not be modified again.
// Therefore, children can read this element's content and 'next' pointer concurrently.
t = elem.item.chunk
err = elem.item.err
if err != io.EOF {
p.subStreamList[idx] = elem.next
}

if nElem.refCount == 0 {
// no need to set buffer when there's no other receivers
return
}

elem := m.buf.PushBack(nElem)
for i := range m.subStreamList {
if m.subStreamList[i] == nil && i != idx && !m.closedList[i] {
m.subStreamList[i] = elem
}
}
return t, err
}

func (m *cpStreamMem[T]) close(idx int) (allClosed bool) {
m.mu.Lock()
defer m.mu.Unlock()

if m.closedList[idx] {
return false // avoid close multiple times
func (p *parentStreamReader[T]) close(idx int) {
if p.subStreamList[idx] == nil {
return // avoid close multiple times
}

m.closedList[idx] = true
m.closedNum++
if m.closedNum == len(m.subStreamList) {
allClosed = true
}
p.subStreamList[idx] = nil

p := m.subStreamList[idx]
for p != nil {
next := p.Next()
ptr := p.Value.(*listElement[T]) // nolint: byted_interface_check_golintx
ptr.refCount--
if ptr.refCount == 0 {
m.buf.Remove(p)
}
curClosedNum := atomic.AddUint32(&p.closedNum, 1)

p = next
allClosed := int(curClosedNum) == len(p.subStreamList)
if allClosed {
p.sr.Close()
}

return allClosed
}

type childStreamReader[T any] struct {
Expand Down
151 changes: 151 additions & 0 deletions schema/stream_copy_external_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package schema

import (
"fmt"
"io"
"runtime"
"sort"
"sync"
"sync/atomic"
"testing"
"time"
)

func TestStream1(t *testing.T) {
runtime.GOMAXPROCS(1)

sr, sw := Pipe[int](0)
go func() {
for i := 0; i < 100; i++ {
sw.Send(i, nil)
time.Sleep(3 * time.Millisecond)
}
sw.Close()
}()
copied := sr.Copy(2)
var (
now = time.Now().UnixMilli()
ts = []int64{now, now}
tsOld = []int64{now, now}
)
var count int32
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
s := copied[0]
for {
n, e := s.Recv()
if e != nil {
if e == io.EOF {
break
}
}
tsOld[0] = ts[0]
ts[0] = time.Now().UnixMilli()
interval := ts[0] - tsOld[0]
if interval >= 6 {
atomic.AddInt32(&count, 1)
}
t.Logf("reader= 0, index= %d, interval= %v", n, interval)
}
wg.Done()
}()
go func() {
s := copied[1]
for {
n, e := s.Recv()
if e != nil {
if e == io.EOF {
break
}
}
tsOld[1] = ts[1]
ts[1] = time.Now().UnixMilli()
interval := ts[1] - tsOld[1]
if interval >= 6 {
atomic.AddInt32(&count, 1)
}
t.Logf("reader= 1, index= %d, interval= %v", n, interval)
}
wg.Done()
}()
wg.Wait()
t.Logf("count= %d", count)
}

type info struct {
idx int
ts int64
after int64
content string
}

func TestCopyDelay(t *testing.T) {
runtime.GOMAXPROCS(10)
n := 3
//m := 100
s := newStream[string](0)
scp := s.asReader().Copy(n)
go func() {
s.send("1", nil)
s.send("2", nil)
time.Sleep(time.Second)
s.send("3", nil)
s.closeSend()
}()
wg := sync.WaitGroup{}
wg.Add(n)
infoList := make([][]info, n)
for i := 0; i < n; i++ {
j := i
go func() {
defer func() {
scp[j].Close()
wg.Done()
}()
for {
lastTime := time.Now()
str, err := scp[j].Recv()
if err == io.EOF {
break
}
now := time.Now()
infoList[j] = append(infoList[j], info{
idx: j,
ts: now.UnixMicro(),
after: now.Sub(lastTime).Milliseconds(),
content: str,
})
}
}()
}
wg.Wait()
infos := make([]info, 0)
for _, infoL := range infoList {
for _, info := range infoL {
infos = append(infos, info)
}
}
sort.Slice(infos, func(i, j int) bool {
return infos[i].ts < infos[j].ts
})
for _, info := range infos {
fmt.Printf("child[%d] ts[%d] after[%5dms] content[%s]\n", info.idx, info.ts, info.after, info.content)
}
}
Loading

0 comments on commit 5e5dc1f

Please sign in to comment.