From 6b81a5d0b9b320d4c6433bfd57e9faa3d0991084 Mon Sep 17 00:00:00 2001 From: Martin Holst Swende Date: Fri, 19 Apr 2024 09:08:05 +0200 Subject: [PATCH 1/3] core/state: separate journal-implementation behind interface, implement createaccount --- core/state/journal.go | 88 ++++++++++++++++++-------------- core/state/journal_api.go | 71 ++++++++++++++++++++++++++ core/state/statedb.go | 9 ++-- core/state/statedb_test.go | 102 ++++++++++++++++++++++--------------- 4 files changed, 187 insertions(+), 83 deletions(-) create mode 100644 core/state/journal_api.go diff --git a/core/state/journal.go b/core/state/journal.go index a2fea6b6ecc5..f96936268a90 100644 --- a/core/state/journal.go +++ b/core/state/journal.go @@ -32,33 +32,36 @@ type revision struct { journalIndex int } -// journalEntry is a modification entry in the state change journal that can be +// journalEntry is a modification entry in the state change linear journal that can be // reverted on demand. type journalEntry interface { - // revert undoes the changes introduced by this journal entry. + // revert undoes the changes introduced by this entry. revert(*StateDB) - // dirtied returns the Ethereum address modified by this journal entry. + // dirtied returns the Ethereum address modified by this entry. dirtied() *common.Address - // copy returns a deep-copied journal entry. + // copy returns a deep-copied entry. copy() journalEntry } -// journal contains the list of state modifications applied since the last state +// linearJournal contains the list of state modifications applied since the last state // commit. These are tracked to be able to be reverted in the case of an execution // exception or request for reversal. -type journal struct { - entries []journalEntry // Current changes tracked by the journal +type linearJournal struct { + entries []journalEntry // Current changes tracked by the linearJournal dirties map[common.Address]int // Dirty accounts and the number of changes validRevisions []revision nextRevisionId int } -// newJournal creates a new initialized journal. -func newJournal() *journal { - return &journal{ +// compile-time interface check +var _ journal = (*linearJournal)(nil) + +// newLinearJournal creates a new initialized linearJournal. +func newLinearJournal() *linearJournal { + return &linearJournal{ dirties: make(map[common.Address]int), } } @@ -66,15 +69,24 @@ func newJournal() *journal { // reset clears the journal, after this operation the journal can be used anew. // It is semantically similar to calling 'newJournal', but the underlying slices // can be reused. -func (j *journal) reset() { +func (j *linearJournal) reset() { j.entries = j.entries[:0] j.validRevisions = j.validRevisions[:0] clear(j.dirties) j.nextRevisionId = 0 } +func (j linearJournal) dirtyAccounts() []common.Address { + dirty := make([]common.Address, 0, len(j.dirties)) + // flatten into list + for addr := range j.dirties { + dirty = append(dirty, addr) + } + return dirty +} + // snapshot returns an identifier for the current revision of the state. -func (j *journal) snapshot() int { +func (j *linearJournal) snapshot() int { id := j.nextRevisionId j.nextRevisionId++ j.validRevisions = append(j.validRevisions, revision{id, j.length()}) @@ -82,23 +94,23 @@ func (j *journal) snapshot() int { } // revertToSnapshot reverts all state changes made since the given revision. -func (j *journal) revertToSnapshot(revid int, s *StateDB) { +func (j *linearJournal) revertToSnapshot(revid int, s *StateDB) { // Find the snapshot in the stack of valid snapshots. idx := sort.Search(len(j.validRevisions), func(i int) bool { return j.validRevisions[i].id >= revid }) if idx == len(j.validRevisions) || j.validRevisions[idx].id != revid { - panic(fmt.Errorf("revision id %v cannot be reverted", revid)) + panic(fmt.Errorf("revision id %v cannot be reverted (valid revisions: %d)", revid, len(j.validRevisions))) } snapshot := j.validRevisions[idx].journalIndex - // Replay the journal to undo changes and remove invalidated snapshots + // Replay the linearJournal to undo changes and remove invalidated snapshots j.revert(s, snapshot) j.validRevisions = j.validRevisions[:idx] } -// append inserts a new modification entry to the end of the change journal. -func (j *journal) append(entry journalEntry) { +// append inserts a new modification entry to the end of the change linearJournal. +func (j *linearJournal) append(entry journalEntry) { j.entries = append(j.entries, entry) if addr := entry.dirtied(); addr != nil { j.dirties[*addr]++ @@ -107,7 +119,7 @@ func (j *journal) append(entry journalEntry) { // revert undoes a batch of journalled modifications along with any reverted // dirty handling too. -func (j *journal) revert(statedb *StateDB, snapshot int) { +func (j *linearJournal) revert(statedb *StateDB, snapshot int) { for i := len(j.entries) - 1; i >= snapshot; i-- { // Undo the changes made by the operation j.entries[i].revert(statedb) @@ -125,22 +137,22 @@ func (j *journal) revert(statedb *StateDB, snapshot int) { // dirty explicitly sets an address to dirty, even if the change entries would // otherwise suggest it as clean. This method is an ugly hack to handle the RIPEMD // precompile consensus exception. -func (j *journal) dirty(addr common.Address) { +func (j *linearJournal) dirty(addr common.Address) { j.dirties[addr]++ } -// length returns the current number of entries in the journal. -func (j *journal) length() int { +// length returns the current number of entries in the linearJournal. +func (j *linearJournal) length() int { return len(j.entries) } // copy returns a deep-copied journal. -func (j *journal) copy() *journal { +func (j *linearJournal) copy() journal { entries := make([]journalEntry, 0, j.length()) for i := 0; i < j.length(); i++ { entries = append(entries, j.entries[i].copy()) } - return &journal{ + return &linearJournal{ entries: entries, dirties: maps.Clone(j.dirties), validRevisions: slices.Clone(j.validRevisions), @@ -148,23 +160,23 @@ func (j *journal) copy() *journal { } } -func (j *journal) logChange(txHash common.Hash) { +func (j *linearJournal) logChange(txHash common.Hash) { j.append(addLogChange{txhash: txHash}) } -func (j *journal) createObject(addr common.Address) { +func (j *linearJournal) createObject(addr common.Address) { j.append(createObjectChange{account: addr}) } -func (j *journal) createContract(addr common.Address) { +func (j *linearJournal) createContract(addr common.Address) { j.append(createContractChange{account: addr}) } -func (j *journal) destruct(addr common.Address) { +func (j *linearJournal) destruct(addr common.Address) { j.append(selfDestructChange{account: addr}) } -func (j *journal) storageChange(addr common.Address, key, prev, origin common.Hash) { +func (j *linearJournal) storageChange(addr common.Address, key, prev, origin common.Hash) { j.append(storageChange{ account: addr, key: key, @@ -173,7 +185,7 @@ func (j *journal) storageChange(addr common.Address, key, prev, origin common.Ha }) } -func (j *journal) transientStateChange(addr common.Address, key, prev common.Hash) { +func (j *linearJournal) transientStateChange(addr common.Address, key, prev common.Hash) { j.append(transientStorageChange{ account: addr, key: key, @@ -181,29 +193,29 @@ func (j *journal) transientStateChange(addr common.Address, key, prev common.Has }) } -func (j *journal) refundChange(previous uint64) { +func (j *linearJournal) refundChange(previous uint64) { j.append(refundChange{prev: previous}) } -func (j *journal) balanceChange(addr common.Address, previous *uint256.Int) { +func (j *linearJournal) balanceChange(addr common.Address, previous *uint256.Int) { j.append(balanceChange{ account: addr, prev: previous.Clone(), }) } -func (j *journal) setCode(address common.Address) { +func (j *linearJournal) setCode(address common.Address) { j.append(codeChange{account: address}) } -func (j *journal) nonceChange(address common.Address, prev uint64) { +func (j *linearJournal) nonceChange(address common.Address, prev uint64) { j.append(nonceChange{ account: address, prev: prev, }) } -func (j *journal) touchChange(address common.Address) { +func (j *linearJournal) touchChange(address common.Address) { j.append(touchChange{ account: address, }) @@ -214,11 +226,11 @@ func (j *journal) touchChange(address common.Address) { } } -func (j *journal) accessListAddAccount(addr common.Address) { +func (j *linearJournal) accessListAddAccount(addr common.Address) { j.append(accessListAddAccountChange{addr}) } -func (j *journal) accessListAddSlot(addr common.Address, slot common.Hash) { +func (j *linearJournal) accessListAddSlot(addr common.Address, slot common.Hash) { j.append(accessListAddSlotChange{ address: addr, slot: slot, @@ -231,7 +243,7 @@ type ( account common.Address } // createContractChange represents an account becoming a contract-account. - // This event happens prior to executing initcode. The journal-event simply + // This event happens prior to executing initcode. The linearJournal-event simply // manages the created-flag, in order to allow same-tx destruction. createContractChange struct { account common.Address @@ -457,7 +469,7 @@ func (ch addLogChange) copy() journalEntry { func (ch accessListAddAccountChange) revert(s *StateDB) { /* One important invariant here, is that whenever a (addr, slot) is added, if the - addr is not already present, the add causes two journal entries: + addr is not already present, the add causes two linearJournal entries: - one for the address, - one for the (address,slot) Therefore, when unrolling the change, we can always blindly delete the diff --git a/core/state/journal_api.go b/core/state/journal_api.go new file mode 100644 index 000000000000..0dd879fd8a9a --- /dev/null +++ b/core/state/journal_api.go @@ -0,0 +1,71 @@ +package state + +import ( + "github.com/ethereum/go-ethereum/common" + "github.com/holiman/uint256" +) + +type journal interface { + + // snapshot returns an identifier for the current revision of the state. + snapshot() int + + // revertToSnapshot reverts all state changes made since the given revision. + revertToSnapshot(revid int, s *StateDB) + + // reset clears the journal so it can be reused. + reset() + + // dirtyAccounts returns a list of all accounts modified in this journal + dirtyAccounts() []common.Address + + // accessListAddAccount journals the adding of addr to the access list + accessListAddAccount(addr common.Address) + + // accessListAddSlot journals the adding of addr/slot to the access list + accessListAddSlot(addr common.Address, slot common.Hash) + + // logChange journals the adding of a log related to the txHash + logChange(txHash common.Hash) + + // createObject journals the event of a new account created in the trie. + createObject(addr common.Address) + + // createContract journals the creation of a new contract at addr. + // OBS: This method must not be applied twice, it assumes that the pre-state + // (i.e the rollback-state) is non-created. + createContract(addr common.Address) + + // destruct journals the destruction of an account in the trie. + // OBS: This method must not be applied twice -- it always assumes that the + // pre-state (i.e the rollback-state) is non-destructed. + destruct(addr common.Address) + + // storageChange journals a change in the storage data related to addr. + // It records the key and previous value of the slot. + storageChange(addr common.Address, key, prev, origin common.Hash) + + // transientStateChange journals a change in the t-storage data related to addr. + // It records the key and previous value of the slot. + transientStateChange(addr common.Address, key, prev common.Hash) + + // refundChange journals that the refund has been changed, recording the previous value. + refundChange(previous uint64) + + // balanceChange journals tha the balance of addr has been changed, recording the previous value + balanceChange(addr common.Address, previous *uint256.Int) + + // JournalSetCode journals that the code of addr has been set. + // OBS: This method must not be applied twice -- it always assumes that the + // pre-state (i.e the rollback-state) is "no code". + setCode(addr common.Address) + + // nonceChange journals that the nonce of addr was changed, recording the previous value. + nonceChange(addr common.Address, prev uint64) + + // touchChange journals that the account at addr was touched during execution. + touchChange(addr common.Address) + + // copy returns a deep-copied journal. + copy() journal +} diff --git a/core/state/statedb.go b/core/state/statedb.go index b2b4f8fb97b1..3738dc2f1540 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -136,7 +136,7 @@ type StateDB struct { // Journal of state modifications. This is the backbone of // Snapshot and RevertToSnapshot. - journal *journal + journal journal // State witness if cross validation is needed witness *stateless.Witness @@ -180,7 +180,7 @@ func New(root common.Hash, db Database) (*StateDB, error) { mutations: make(map[common.Address]*mutation), logs: make(map[common.Hash][]*types.Log), preimages: make(map[common.Hash][]byte), - journal: newJournal(), + journal: newLinearJournal(), accessList: newAccessList(), transientStorage: newTransientStorage(), } @@ -721,8 +721,9 @@ func (s *StateDB) GetRefund() uint64 { // the journal as well as the refunds. Finalise, however, will not push any updates // into the tries just yet. Only IntermediateRoot or Commit will do that. func (s *StateDB) Finalise(deleteEmptyObjects bool) { - addressesToPrefetch := make([][]byte, 0, len(s.journal.dirties)) - for addr := range s.journal.dirties { + dirties := s.journal.dirtyAccounts() + addressesToPrefetch := make([][]byte, 0, len(dirties)) + for _, addr := range dirties { obj, exist := s.stateObjects[addr] if !exist { // ripeMD is 'touched' at block 1714175, in tx 0x1237f737031e40bcde4a8b7e717b2d15e3ecadfe49bb1bbc71ee9deb09c6fcf2 diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index 9441834c6a24..919c0a525928 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -228,7 +228,7 @@ func TestCopy(t *testing.T) { } // TestCopyWithDirtyJournal tests if Copy can correct create a equal copied -// stateDB with dirty journal present. +// stateDB with dirty linearJournal present. func TestCopyWithDirtyJournal(t *testing.T) { db := NewDatabaseForTesting() orig, _ := New(types.EmptyRootHash, db) @@ -408,8 +408,8 @@ func newTestAction(addr common.Address, r *rand.Rand) testAction { // We also set some code here, to prevent the // CreateContract action from being performed twice in a row, // which would cause a difference in state when unrolling - // the journal. (CreateContact assumes created was false prior to - // invocation, and the journal rollback sets it to false). + // the linearJournal. (CreateContact assumes created was false prior to + // invocation, and the linearJournal rollback sets it to false). s.SetCode(addr, []byte{1}) } }, @@ -675,22 +675,23 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v", state.GetLogs(common.Hash{}, 0, common.Hash{}), checkstate.GetLogs(common.Hash{}, 0, common.Hash{})) } - if !maps.Equal(state.journal.dirties, checkstate.journal.dirties) { - getKeys := func(dirty map[common.Address]int) string { - var keys []common.Address - out := new(strings.Builder) - for key := range dirty { - keys = append(keys, key) - } - slices.SortFunc(keys, common.Address.Cmp) - for i, key := range keys { - fmt.Fprintf(out, " %d. %v\n", i, key) + { // Check the dirty-accounts + have := state.journal.dirtyAccounts() + want := checkstate.journal.dirtyAccounts() + slices.SortFunc(have, common.Address.Cmp) + slices.SortFunc(want, common.Address.Cmp) + if !slices.Equal(have, want) { + getKeys := func(keys []common.Address) string { + out := new(strings.Builder) + for i, key := range keys { + fmt.Fprintf(out, " %d. %v\n", i, key) + } + return out.String() } - return out.String() + haveK := getKeys(state.journal.dirtyAccounts()) + wantK := getKeys(checkstate.journal.dirtyAccounts()) + return fmt.Errorf("dirty-journal set mismatch.\nhave:\n%v\nwant:\n%v\n", haveK, wantK) } - have := getKeys(state.journal.dirties) - want := getKeys(checkstate.journal.dirties) - return fmt.Errorf("dirty-journal set mismatch.\nhave:\n%v\nwant:\n%v\n", have, want) } return nil } @@ -704,11 +705,11 @@ func TestTouchDelete(t *testing.T) { snapshot := s.state.Snapshot() s.state.AddBalance(common.Address{}, new(uint256.Int), tracing.BalanceChangeUnspecified) - if len(s.state.journal.dirties) != 1 { + if len(s.state.journal.dirtyAccounts()) != 1 { t.Fatal("expected one dirty state object") } s.state.RevertToSnapshot(snapshot) - if len(s.state.journal.dirties) != 0 { + if len(s.state.journal.dirtyAccounts()) != 0 { t.Fatal("expected no dirty state object") } } @@ -1097,32 +1098,51 @@ func TestStateDBAccessList(t *testing.T) { } } + var ids []int + push := func(id int) { + ids = append(ids, id) + } + pop := func() int { + id := ids[len(ids)-1] + ids = ids[:len(ids)-1] + return id + } + + push(state.journal.snapshot()) // journal id 0 state.AddAddressToAccessList(addr("aa")) // 1 - state.AddSlotToAccessList(addr("bb"), slot("01")) // 2,3 + push(state.journal.snapshot()) // journal id 1 + state.AddAddressToAccessList(addr("bb")) // 2 + push(state.journal.snapshot()) // journal id 2 + state.AddSlotToAccessList(addr("bb"), slot("01")) // 3 + push(state.journal.snapshot()) // journal id 3 state.AddSlotToAccessList(addr("bb"), slot("02")) // 4 + push(state.journal.snapshot()) // journal id 4 verifyAddrs("aa", "bb") verifySlots("bb", "01", "02") // Make a copy stateCopy1 := state.Copy() - if exp, got := 4, state.journal.length(); exp != got { - t.Fatalf("journal length mismatch: have %d, want %d", got, exp) + if exp, got := 4, state.journal.(*linearJournal).length(); exp != got { + t.Fatalf("linearJournal length mismatch: have %d, want %d", got, exp) } - // same again, should cause no journal entries + // same again, should cause no linearJournal entries state.AddSlotToAccessList(addr("bb"), slot("01")) state.AddSlotToAccessList(addr("bb"), slot("02")) state.AddAddressToAccessList(addr("aa")) - if exp, got := 4, state.journal.length(); exp != got { - t.Fatalf("journal length mismatch: have %d, want %d", got, exp) + if exp, got := 4, state.journal.(*linearJournal).length(); exp != got { + t.Fatalf("linearJournal length mismatch: have %d, want %d", got, exp) } // some new ones state.AddSlotToAccessList(addr("bb"), slot("03")) // 5 + push(state.journal.snapshot()) // journal id 5 state.AddSlotToAccessList(addr("aa"), slot("01")) // 6 - state.AddSlotToAccessList(addr("cc"), slot("01")) // 7,8 - state.AddAddressToAccessList(addr("cc")) - if exp, got := 8, state.journal.length(); exp != got { - t.Fatalf("journal length mismatch: have %d, want %d", got, exp) + push(state.journal.snapshot()) // journal id 6 + state.AddAddressToAccessList(addr("cc")) // 7 + push(state.journal.snapshot()) // journal id 7 + state.AddSlotToAccessList(addr("cc"), slot("01")) // 8 + if exp, got := 8, state.journal.(*linearJournal).length(); exp != got { + t.Fatalf("linearJournal length mismatch: have %d, want %d", got, exp) } verifyAddrs("aa", "bb", "cc") @@ -1131,7 +1151,7 @@ func TestStateDBAccessList(t *testing.T) { verifySlots("cc", "01") // now start rolling back changes - state.journal.revert(state, 7) + state.journal.revertToSnapshot(pop(), state) // revert to 6 if _, ok := state.SlotInAccessList(addr("cc"), slot("01")); ok { t.Fatalf("slot present, expected missing") } @@ -1139,7 +1159,7 @@ func TestStateDBAccessList(t *testing.T) { verifySlots("aa", "01") verifySlots("bb", "01", "02", "03") - state.journal.revert(state, 6) + state.journal.revertToSnapshot(pop(), state) // revert to 5 if state.AddressInAccessList(addr("cc")) { t.Fatalf("addr present, expected missing") } @@ -1147,40 +1167,40 @@ func TestStateDBAccessList(t *testing.T) { verifySlots("aa", "01") verifySlots("bb", "01", "02", "03") - state.journal.revert(state, 5) + state.journal.revertToSnapshot(pop(), state) // revert to 4 if _, ok := state.SlotInAccessList(addr("aa"), slot("01")); ok { t.Fatalf("slot present, expected missing") } verifyAddrs("aa", "bb") verifySlots("bb", "01", "02", "03") - state.journal.revert(state, 4) + state.journal.revertToSnapshot(pop(), state) // revert to 3 if _, ok := state.SlotInAccessList(addr("bb"), slot("03")); ok { t.Fatalf("slot present, expected missing") } verifyAddrs("aa", "bb") verifySlots("bb", "01", "02") - state.journal.revert(state, 3) + state.journal.revertToSnapshot(pop(), state) // revert to 2 if _, ok := state.SlotInAccessList(addr("bb"), slot("02")); ok { t.Fatalf("slot present, expected missing") } verifyAddrs("aa", "bb") verifySlots("bb", "01") - state.journal.revert(state, 2) + state.journal.revertToSnapshot(pop(), state) // revert to 1 if _, ok := state.SlotInAccessList(addr("bb"), slot("01")); ok { t.Fatalf("slot present, expected missing") } verifyAddrs("aa", "bb") - state.journal.revert(state, 1) + state.journal.revertToSnapshot(pop(), state) // revert to 0 if state.AddressInAccessList(addr("bb")) { t.Fatalf("addr present, expected missing") } verifyAddrs("aa") - state.journal.revert(state, 0) + state.journal.revertToSnapshot(0, state) if state.AddressInAccessList(addr("aa")) { t.Fatalf("addr present, expected missing") } @@ -1251,10 +1271,10 @@ func TestStateDBTransientStorage(t *testing.T) { key := common.Hash{0x01} value := common.Hash{0x02} addr := common.Address{} - + revision := state.journal.snapshot() state.SetTransientState(addr, key, value) - if exp, got := 1, state.journal.length(); exp != got { - t.Fatalf("journal length mismatch: have %d, want %d", got, exp) + if exp, got := 1, state.journal.(*linearJournal).length(); exp != got { + t.Fatalf("linearJournal length mismatch: have %d, want %d", got, exp) } // the retrieved value should equal what was set if got := state.GetTransientState(addr, key); got != value { @@ -1263,7 +1283,7 @@ func TestStateDBTransientStorage(t *testing.T) { // revert the transient state being set and then check that the // value is now the empty hash - state.journal.revert(state, 0) + state.journal.revertToSnapshot(revision, state) if got, exp := state.GetTransientState(addr, key), (common.Hash{}); exp != got { t.Fatalf("transient storage mismatch: have %x, want %x", got, exp) } From 725031943576275c720b2c3971a6adde1ba20162 Mon Sep 17 00:00:00 2001 From: Martin Holst Swende Date: Wed, 25 Sep 2024 10:12:33 +0200 Subject: [PATCH 2/3] core/state: test to demonstrate flaw in dirty-handling --- core/state/statedb_test.go | 50 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index 919c0a525928..fd8464d28383 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -1395,3 +1395,53 @@ func TestStorageDirtiness(t *testing.T) { state.RevertToSnapshot(snap) checkDirty(common.Hash{0x1}, common.Hash{0x1}, true) } + +func TestStorageDirtiness2(t *testing.T) { + var ( + disk = rawdb.NewMemoryDatabase() + tdb = triedb.NewDatabase(disk, nil) + db = NewDatabase(tdb, nil) + state, _ = New(types.EmptyRootHash, db) + addr = common.HexToAddress("0x1") + checkDirty = func(key common.Hash, value common.Hash, dirty bool) { + t.Helper() + obj := state.getStateObject(addr) + v, exist := obj.dirtyStorage[key] + if exist != dirty { + t.Fatalf("unexpected dirty marker, want: %v, have: %v", dirty, exist) + } + if !exist { + return + } + if v != value { + t.Fatalf("unexpected storage slot, want: %x, have: %x", value, v) + } + } + ) + + { // Initiate a state, where an account has SLOT(1) = 0xA, +nonzero balance + state.CreateAccount(addr) + state.SetBalance(addr, uint256.NewInt(1), tracing.BalanceChangeUnspecified) // Prevent empty-delete + state.SetState(addr, common.Hash{0x1}, common.Hash{0xa}) + root, err := state.Commit(0, true) + if err != nil { + t.Fatal(err) + } + // Init phase done, load it again + if state, err = New(root, NewDatabase(tdb, nil)); err != nil { + t.Fatal(err) + } + } + // A no-op storage change, no dirty marker + state.SetState(addr, common.Hash{0x1}, common.Hash{0xa}) + checkDirty(common.Hash{0x1}, common.Hash{0xa}, false) + + // Enter new scope + snap := state.Snapshot() + state.SetState(addr, common.Hash{0x1}, common.Hash{0xb}) // SLOT(1) = 0xB + checkDirty(common.Hash{0x1}, common.Hash{0xb}, true) // Should be flagged dirty + state.RevertToSnapshot(snap) // Revert scope + + // the storage change has been set back to original, dirtiness should be revoked + checkDirty(common.Hash{0x1}, common.Hash{0x1}, false) +} From b37e8ff4d5dd3f5747a2ca5b1ad2d7fbb934a944 Mon Sep 17 00:00:00 2001 From: Martin Holst Swende Date: Mon, 29 Jan 2024 10:18:46 +0100 Subject: [PATCH 3/3] core/state: make journalling set-based core/state: add handling for DiscardSnapshot core/state: use new journal core/state, genesis: fix flaw re discard/commit. In case the state is committed, the journal is reset, thus it is not correct to Discard/Revert snapshots at that point. core/state: fix nil defer in merge core/state: fix bugs in setjournal core/state: journal api changes core/state: bugfixes in sparse journal core/state: journal tests core/state: improve post-state check in journal-fuzzing test core/state: post-rebase fixups miner: remove discard-snapshot call, it's not needed since journal will be reset in Finalize core/state: fix tests core/state: lint core/state: supply origin-value when reverting storage change --- cmd/evm/internal/t8ntool/execution.go | 1 + cmd/evm/runner.go | 3 +- core/genesis.go | 1 + core/state/journal.go | 24 +- core/state/journal_api.go | 39 ++- core/state/journal_test.go | 132 +++++++ core/state/setjournal.go | 486 ++++++++++++++++++++++++++ core/state/state_object.go | 8 +- core/state/statedb.go | 13 +- core/state/statedb_test.go | 78 +++-- core/vm/evm.go | 14 +- core/vm/interface.go | 7 + tests/state_test_util.go | 2 + 13 files changed, 749 insertions(+), 59 deletions(-) create mode 100644 core/state/journal_test.go create mode 100644 core/state/setjournal.go diff --git a/cmd/evm/internal/t8ntool/execution.go b/cmd/evm/internal/t8ntool/execution.go index 5fd1d6a4a6ad..ee72db68ec94 100644 --- a/cmd/evm/internal/t8ntool/execution.go +++ b/cmd/evm/internal/t8ntool/execution.go @@ -271,6 +271,7 @@ func (pre *Prestate) Apply(vmConfig vm.Config, chainConfig *params.ChainConfig, } continue } + statedb.DiscardSnapshot(snapshot) includedTxs = append(includedTxs, tx) if hashError != nil { return nil, nil, nil, NewError(ErrorMissingBlockhash, hashError) diff --git a/cmd/evm/runner.go b/cmd/evm/runner.go index 235fed66302a..69a52abe804a 100644 --- a/cmd/evm/runner.go +++ b/cmd/evm/runner.go @@ -158,7 +158,8 @@ func runCmd(ctx *cli.Context) error { sdb := state.NewDatabase(triedb, nil) statedb, _ = state.New(genesis.Root(), sdb) chainConfig = genesisConfig.Config - + id := statedb.Snapshot() + defer statedb.DiscardSnapshot(id) if ctx.String(SenderFlag.Name) != "" { sender = common.HexToAddress(ctx.String(SenderFlag.Name)) } diff --git a/core/genesis.go b/core/genesis.go index 31db49f527e4..3a776c99a02e 100644 --- a/core/genesis.go +++ b/core/genesis.go @@ -152,6 +152,7 @@ func flushAlloc(ga *types.GenesisAlloc, triedb *triedb.Database) (common.Hash, e if err != nil { return common.Hash{}, err } + for addr, account := range *ga { if account.Balance != nil { // This is not actually logged via tracer because OnGenesisBlock diff --git a/core/state/journal.go b/core/state/journal.go index f96936268a90..f6e3f5a3ed30 100644 --- a/core/state/journal.go +++ b/core/state/journal.go @@ -93,7 +93,6 @@ func (j *linearJournal) snapshot() int { return id } -// revertToSnapshot reverts all state changes made since the given revision. func (j *linearJournal) revertToSnapshot(revid int, s *StateDB) { // Find the snapshot in the stack of valid snapshots. idx := sort.Search(len(j.validRevisions), func(i int) bool { @@ -109,6 +108,13 @@ func (j *linearJournal) revertToSnapshot(revid int, s *StateDB) { j.validRevisions = j.validRevisions[:idx] } +// DiscardSnapshot removes the snapshot with the given id; after calling this +// method, it is no longer possible to revert to that particular snapshot, the +// changes are considered part of the parent scope. +func (j *linearJournal) DiscardSnapshot(id int) { + // +} + // append inserts a new modification entry to the end of the change linearJournal. func (j *linearJournal) append(entry journalEntry) { j.entries = append(j.entries, entry) @@ -168,11 +174,11 @@ func (j *linearJournal) createObject(addr common.Address) { j.append(createObjectChange{account: addr}) } -func (j *linearJournal) createContract(addr common.Address) { +func (j *linearJournal) createContract(addr common.Address, account *types.StateAccount) { j.append(createContractChange{account: addr}) } -func (j *linearJournal) destruct(addr common.Address) { +func (j *linearJournal) destruct(addr common.Address, account *types.StateAccount) { j.append(selfDestructChange{account: addr}) } @@ -197,25 +203,25 @@ func (j *linearJournal) refundChange(previous uint64) { j.append(refundChange{prev: previous}) } -func (j *linearJournal) balanceChange(addr common.Address, previous *uint256.Int) { +func (j *linearJournal) balanceChange(addr common.Address, account *types.StateAccount, destructed, newContract bool) { j.append(balanceChange{ account: addr, - prev: previous.Clone(), + prev: account.Balance.Clone(), }) } -func (j *linearJournal) setCode(address common.Address) { +func (j *linearJournal) setCode(address common.Address, account *types.StateAccount) { j.append(codeChange{account: address}) } -func (j *linearJournal) nonceChange(address common.Address, prev uint64) { +func (j *linearJournal) nonceChange(address common.Address, account *types.StateAccount, destructed, newContract bool) { j.append(nonceChange{ account: address, - prev: prev, + prev: account.Nonce, }) } -func (j *linearJournal) touchChange(address common.Address) { +func (j *linearJournal) touchChange(address common.Address, account *types.StateAccount, destructed, newContract bool) { j.append(touchChange{ account: address, }) diff --git a/core/state/journal_api.go b/core/state/journal_api.go index 0dd879fd8a9a..d88cd853135d 100644 --- a/core/state/journal_api.go +++ b/core/state/journal_api.go @@ -2,12 +2,22 @@ package state import ( "github.com/ethereum/go-ethereum/common" - "github.com/holiman/uint256" + "github.com/ethereum/go-ethereum/core/types" ) type journal interface { - // snapshot returns an identifier for the current revision of the state. + // The lifeycle of journalling is as follows: + // - snapshot() starts a 'scope'. + // - Tee method snapshot() may be called any number of times. + // - For each call to snapshot, there should be a corresponding call to end + // the scope via either of: + // - revertToSnapshot, which undoes the changes in the scope, or + // - discardSnapshot, which discards the ability to revert the changes in the scope. + // - This operation might merge the changes into the parent scope. + // If it does not merge the changes into the parent scope, it must create + // a new snapshot internally, in order to ensure that order of changes + // remains intact. snapshot() int // revertToSnapshot reverts all state changes made since the given revision. @@ -16,6 +26,11 @@ type journal interface { // reset clears the journal so it can be reused. reset() + // DiscardSnapshot removes the snapshot with the given id; after calling this + // method, it is no longer possible to revert to that particular snapshot, the + // changes are considered part of the parent scope. + DiscardSnapshot(revid int) + // dirtyAccounts returns a list of all accounts modified in this journal dirtyAccounts() []common.Address @@ -34,12 +49,12 @@ type journal interface { // createContract journals the creation of a new contract at addr. // OBS: This method must not be applied twice, it assumes that the pre-state // (i.e the rollback-state) is non-created. - createContract(addr common.Address) + createContract(addr common.Address, account *types.StateAccount) // destruct journals the destruction of an account in the trie. - // OBS: This method must not be applied twice -- it always assumes that the - // pre-state (i.e the rollback-state) is non-destructed. - destruct(addr common.Address) + // pre-state (i.e the rollback-state) is non-destructed (and, for the purpose + // of EIP-XXX (TODO lookup), created in this tx). + destruct(addr common.Address, account *types.StateAccount) // storageChange journals a change in the storage data related to addr. // It records the key and previous value of the slot. @@ -52,19 +67,19 @@ type journal interface { // refundChange journals that the refund has been changed, recording the previous value. refundChange(previous uint64) - // balanceChange journals tha the balance of addr has been changed, recording the previous value - balanceChange(addr common.Address, previous *uint256.Int) + // balanceChange journals that the balance of addr has been changed, recording the previous value + balanceChange(addr common.Address, account *types.StateAccount, destructed, newContract bool) - // JournalSetCode journals that the code of addr has been set. + // setCode journals that the code of addr has been set. // OBS: This method must not be applied twice -- it always assumes that the // pre-state (i.e the rollback-state) is "no code". - setCode(addr common.Address) + setCode(addr common.Address, account *types.StateAccount) // nonceChange journals that the nonce of addr was changed, recording the previous value. - nonceChange(addr common.Address, prev uint64) + nonceChange(addr common.Address, account *types.StateAccount, destructed, newContract bool) // touchChange journals that the account at addr was touched during execution. - touchChange(addr common.Address) + touchChange(addr common.Address, account *types.StateAccount, destructed, newContract bool) // copy returns a deep-copied journal. copy() journal diff --git a/core/state/journal_test.go b/core/state/journal_test.go new file mode 100644 index 000000000000..76d1c936afc6 --- /dev/null +++ b/core/state/journal_test.go @@ -0,0 +1,132 @@ +// Copyright 2024 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Package state provides a caching layer atop the Ethereum state trie. +package state + +import ( + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/holiman/uint256" +) + +func TestLinearJournalDirty(t *testing.T) { + testJournalDirty(t, newLinearJournal()) +} + +func TestSparseJournalDirty(t *testing.T) { + testJournalDirty(t, newSparseJournal()) +} + +// This test verifies some basics around journalling: the ability to +// deliver a dirty-set. +func testJournalDirty(t *testing.T, j journal) { + acc := &types.StateAccount{ + Nonce: 1, + Balance: new(uint256.Int), + Root: common.Hash{}, + CodeHash: nil, + } + { + j.nonceChange(common.Address{0x1}, acc, false, false) + if have, want := len(j.dirtyAccounts()), 1; have != want { + t.Errorf("wrong size of dirty accounts, have %v want %v", have, want) + } + } + { + j.storageChange(common.Address{0x2}, common.Hash{0x1}, common.Hash{0x1}, common.Hash{}) + if have, want := len(j.dirtyAccounts()), 2; have != want { + t.Errorf("wrong size of dirty accounts, have %v want %v", have, want) + } + } + { // The previous scopes should also be accounted for + j.snapshot() + if have, want := len(j.dirtyAccounts()), 2; have != want { + t.Errorf("wrong size of dirty accounts, have %v want %v", have, want) + } + } +} + +func TestLinearJournalAccessList(t *testing.T) { + testJournalAccessList(t, newLinearJournal()) +} + +func TestSparseJournalAccessList(t *testing.T) { + testJournalAccessList(t, newSparseJournal()) +} + +func testJournalAccessList(t *testing.T, j journal) { + var statedb = &StateDB{} + statedb.accessList = newAccessList() + statedb.journal = j + + { + // If the journal performs the rollback in the wrong order, this + // will cause a panic. + id := j.snapshot() + statedb.AddSlotToAccessList(common.Address{0x1}, common.Hash{0x4}) + statedb.AddSlotToAccessList(common.Address{0x3}, common.Hash{0x4}) + statedb.RevertToSnapshot(id) + } + { + id := j.snapshot() + statedb.AddAddressToAccessList(common.Address{0x2}) + statedb.AddAddressToAccessList(common.Address{0x3}) + statedb.AddAddressToAccessList(common.Address{0x4}) + statedb.RevertToSnapshot(id) + if statedb.accessList.ContainsAddress(common.Address{0x2}) { + t.Fatal("should be missing") + } + } +} + +func TestLinearJournalRefunds(t *testing.T) { + testJournalRefunds(t, newLinearJournal()) +} + +func TestSparseJournalRefunds(t *testing.T) { + testJournalRefunds(t, newSparseJournal()) +} + +func testJournalRefunds(t *testing.T, j journal) { + var statedb = &StateDB{} + statedb.accessList = newAccessList() + statedb.journal = j + zero := j.snapshot() + j.refundChange(0) + j.refundChange(1) + { + id := j.snapshot() + j.refundChange(2) + j.refundChange(3) + j.revertToSnapshot(id, statedb) + if have, want := statedb.refund, uint64(2); have != want { + t.Fatalf("have %d want %d", have, want) + } + } + { + id := j.snapshot() + j.refundChange(2) + j.refundChange(3) + j.DiscardSnapshot(id) + } + j.revertToSnapshot(zero, statedb) + if have, want := statedb.refund, uint64(0); have != want { + t.Fatalf("have %d want %d", have, want) + } +} diff --git a/core/state/setjournal.go b/core/state/setjournal.go new file mode 100644 index 000000000000..63ba2eb9dab3 --- /dev/null +++ b/core/state/setjournal.go @@ -0,0 +1,486 @@ +// Copyright 2024 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package state + +import ( + "bytes" + "fmt" + "maps" + "slices" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/log" + "github.com/holiman/uint256" +) + +var ( + _ journal = (*sparseJournal)(nil) +) + +// journalAccount represents the 'journable state' of a types.Account. +// Which means, all the normal fields except storage root, but also with a +// destruction-flag. +type journalAccount struct { + nonce uint64 + balance uint256.Int + codeHash []byte // nil == emptyCodeHAsh + destructed bool + newContract bool +} + +type addrSlot struct { + addr common.Address + slot common.Hash +} + +type doubleHash struct { + origin common.Hash + prev common.Hash +} + +// scopedJournal represents all changes within a single callscope. These changes +// are either all reverted, or all committed -- they cannot be partially applied. +type scopedJournal struct { + accountChanges map[common.Address]*journalAccount + refund int64 + logs []common.Hash + + accessListAddresses []common.Address + accessListAddrSlots []addrSlot + + storageChanges map[common.Address]map[common.Hash]doubleHash + tStorageChanges map[common.Address]map[common.Hash]common.Hash +} + +func newScopedJournal() *scopedJournal { + return &scopedJournal{ + refund: -1, + } +} + +func (j *scopedJournal) deepCopy() *scopedJournal { + var cpy = &scopedJournal{ + // The accountChanges copy will copy the pointers to + // journalAccount objects: thus not actually deep copy those + // objects. That is fine: we never mutate journalAccount. + accountChanges: maps.Clone(j.accountChanges), + refund: j.refund, + logs: slices.Clone(j.logs), + accessListAddresses: slices.Clone(j.accessListAddresses), + accessListAddrSlots: slices.Clone(j.accessListAddrSlots), + } + if j.storageChanges != nil { + cpy.storageChanges = make(map[common.Address]map[common.Hash]doubleHash) + for addr, changes := range j.storageChanges { + cpy.storageChanges[addr] = maps.Clone(changes) + } + } + if j.tStorageChanges != nil { + cpy.tStorageChanges = make(map[common.Address]map[common.Hash]common.Hash) + for addr, changes := range j.tStorageChanges { + cpy.tStorageChanges[addr] = maps.Clone(changes) + } + } + return cpy +} + +func (j *scopedJournal) journalRefundChange(prev uint64) { + if j.refund == -1 { + // We convert from uint64 to int64 here, so that we can use -1 + // to represent "no previous value set". + // Treating refund as int64 is fine, there's no possibility for + // refund to ever exceed maxInt64. + j.refund = int64(prev) + } +} + +// journalAccountChange is the common shared implementation for all account-changes. +// These changes all fall back to this method: +// - balance change +// - nonce change +// - destruct-change +// - code change +// - touch change +// - creation change (in this case, the account is nil) +func (j *scopedJournal) journalAccountChange(address common.Address, account *types.StateAccount, destructed, newContract bool) { + if j.accountChanges == nil { + j.accountChanges = make(map[common.Address]*journalAccount) + } + // If the account has already been journalled, we're done here + if _, ok := j.accountChanges[address]; ok { + return + } + if account == nil { + j.accountChanges[address] = nil // created now, previously non-existent + return + } + ja := &journalAccount{ + nonce: account.Nonce, + balance: *account.Balance, + destructed: destructed, + newContract: newContract, + } + if !bytes.Equal(account.CodeHash, types.EmptyCodeHash[:]) { + ja.codeHash = account.CodeHash + } + j.accountChanges[address] = ja +} + +func (j *scopedJournal) journalLog(txHash common.Hash) { + j.logs = append(j.logs, txHash) +} + +func (j *scopedJournal) journalAccessListAddAccount(addr common.Address) { + j.accessListAddresses = append(j.accessListAddresses, addr) +} + +func (j *scopedJournal) journalAccessListAddSlot(addr common.Address, slot common.Hash) { + j.accessListAddrSlots = append(j.accessListAddrSlots, addrSlot{addr, slot}) +} + +func (j *scopedJournal) journalSetState(addr common.Address, key, prev, origin common.Hash) { + if j.storageChanges == nil { + j.storageChanges = make(map[common.Address]map[common.Hash]doubleHash) + } + changes, ok := j.storageChanges[addr] + if !ok { + changes = make(map[common.Hash]doubleHash) + j.storageChanges[addr] = changes + } + // Do not overwrite a previous value! + if _, ok := changes[key]; !ok { + changes[key] = doubleHash{origin: origin, prev: prev} + } +} + +func (j *scopedJournal) journalSetTransientState(addr common.Address, key, prev common.Hash) { + if j.tStorageChanges == nil { + j.tStorageChanges = make(map[common.Address]map[common.Hash]common.Hash) + } + changes, ok := j.tStorageChanges[addr] + if !ok { + changes = make(map[common.Hash]common.Hash) + j.tStorageChanges[addr] = changes + } + // Do not overwrite a previous value! + if _, ok := changes[key]; !ok { + changes[key] = prev + } +} + +func (j *scopedJournal) revert(s *StateDB) { + // Revert refund + if j.refund != -1 { + s.refund = uint64(j.refund) + } + // Revert storage changes + for addr, changes := range j.storageChanges { + obj := s.getStateObject(addr) + for key, val := range changes { + obj.setState(key, val.prev, val.origin) + } + } + // Revert t-store changes + for addr, changes := range j.tStorageChanges { + for key, val := range changes { + s.setTransientState(addr, key, val) + } + } + + // Revert changes to accounts + for addr, data := range j.accountChanges { + if data == nil { // Reverting a create + delete(s.stateObjects, addr) + continue + } + obj := s.getStateObject(addr) + obj.setNonce(data.nonce) + // Setting 'code' to nil means it will be loaded from disk + // next time it is needed. We avoid nilling it unless required + journalHash := data.codeHash + if data.codeHash == nil { + if !bytes.Equal(obj.CodeHash(), types.EmptyCodeHash[:]) { + obj.setCode(types.EmptyCodeHash, nil) + } + } else { + if !bytes.Equal(obj.CodeHash(), journalHash) { + obj.setCode(common.BytesToHash(data.codeHash), nil) + } + } + obj.setBalance(&data.balance) + obj.selfDestructed = data.destructed + obj.newContract = data.newContract + } + // Revert logs + for _, txhash := range j.logs { + logs := s.logs[txhash] + if len(logs) == 1 { + delete(s.logs, txhash) + } else { + s.logs[txhash] = logs[:len(logs)-1] + } + s.logSize-- + } + // Revert access list additions + for i := len(j.accessListAddrSlots) - 1; i >= 0; i-- { + item := j.accessListAddrSlots[i] + s.accessList.DeleteSlot(item.addr, item.slot) + } + for i := len(j.accessListAddresses) - 1; i >= 0; i-- { + s.accessList.DeleteAddress(j.accessListAddresses[i]) + } +} + +func (j *scopedJournal) merge(parent *scopedJournal) { + if parent.refund == -1 { + parent.refund = j.refund + } + // Revert changes to accounts + if parent.accountChanges == nil { + parent.accountChanges = j.accountChanges + } else { + for addr, data := range j.accountChanges { + if _, present := parent.accountChanges[addr]; present { + // Nothing to do here, it's already stored in parent scope + continue + } + parent.accountChanges[addr] = data + } + } + // Revert logs + parent.logs = append(parent.logs, j.logs...) + + // Revert access list additions + parent.accessListAddrSlots = append(parent.accessListAddrSlots, j.accessListAddrSlots...) + parent.accessListAddresses = append(parent.accessListAddresses, j.accessListAddresses...) + + if parent.storageChanges == nil { + parent.storageChanges = j.storageChanges + } else { + // Merge storage changes + for addr, changes := range j.storageChanges { + prevChanges, ok := parent.storageChanges[addr] + if !ok { + parent.storageChanges[addr] = changes + continue + } + for k, v := range changes { + if _, ok := prevChanges[k]; !ok { + prevChanges[k] = v + } + } + } + } + if parent.tStorageChanges == nil { + parent.tStorageChanges = j.tStorageChanges + } else { + // Revert t-store changes + for addr, changes := range j.tStorageChanges { + prevChanges, ok := parent.tStorageChanges[addr] + if !ok { + parent.tStorageChanges[addr] = changes + continue + } + for k, v := range changes { + if _, ok := prevChanges[k]; !ok { + prevChanges[k] = v + } + } + } + } +} + +func (j *scopedJournal) addDirtyAccounts(set map[common.Address]any) { + // Changes due to account changes + for addr := range j.accountChanges { + set[addr] = []interface{}{} + } + // Changes due to storage changes + for addr := range j.storageChanges { + set[addr] = []interface{}{} + } +} + +// sparseJournal contains the list of state modifications applied since the last state +// commit. These are tracked to be able to be reverted in the case of an execution +// exception or request for reversal. +type sparseJournal struct { + entries []*scopedJournal // Current changes tracked by the journal + ripeMagic bool +} + +// newJournal creates a new initialized journal. +func newSparseJournal() *sparseJournal { + s := new(sparseJournal) + s.snapshot() // create snaphot zero + return s +} + +// reset clears the journal, after this operation the journal can be used +// anew. It is semantically similar to calling 'newJournal', but the underlying +// slices can be reused +func (j *sparseJournal) reset() { + j.entries = j.entries[:0] + j.snapshot() +} + +func (j *sparseJournal) copy() journal { + cp := &sparseJournal{ + entries: make([]*scopedJournal, 0, len(j.entries)), + } + for _, entry := range j.entries { + cp.entries = append(cp.entries, entry.deepCopy()) + } + return cp +} + +// snapshot returns an identifier for the current revision of the state. +// OBS: A call to Snapshot is _required_ in order to initialize the journalling, +// invoking the journal-methods without having invoked Snapshot will lead to +// panic. +func (j *sparseJournal) snapshot() int { + id := len(j.entries) + j.entries = append(j.entries, newScopedJournal()) + return id +} + +// revertToSnapshot reverts all state changes made since the given revision. +func (j *sparseJournal) revertToSnapshot(id int, s *StateDB) { + if id >= len(j.entries) { + panic(fmt.Errorf("revision id %v cannot be reverted", id)) + } + // Revert the entries sequentially + for i := len(j.entries) - 1; i >= id; i-- { + entry := j.entries[i] + entry.revert(s) + } + j.entries = j.entries[:id] +} + +func (j *sparseJournal) DiscardSnapshot(id int) { + if id == 0 { + return + } + // here we must merge the 'id' with it's parent. + want := len(j.entries) - 1 + have := id + if want != have { + if want == 0 && id == 1 { + // If a transcation is applied successfully, the statedb.Finalize will + // end by clearing and resetting the journal. Invoking a DiscardSnapshot + // afterwards will lead us here. + // Let's not panic, but it's ok to complain a bit + log.Error("Extraneous invocation to discard snapshot") + return + } else { + panic(fmt.Sprintf("journalling error, want discard(%d), have discard(%d)", want, have)) + } + } + entry := j.entries[id] + parent := j.entries[id-1] + entry.merge(parent) + j.entries = j.entries[:id] +} + +func (j *sparseJournal) journalAccountChange(addr common.Address, account *types.StateAccount, destructed, newContract bool) { + j.entries[len(j.entries)-1].journalAccountChange(addr, account, destructed, newContract) +} + +func (j *sparseJournal) nonceChange(addr common.Address, account *types.StateAccount, destructed, newContract bool) { + j.journalAccountChange(addr, account, destructed, newContract) +} + +func (j *sparseJournal) balanceChange(addr common.Address, account *types.StateAccount, destructed, newContract bool) { + j.journalAccountChange(addr, account, destructed, newContract) +} + +func (j *sparseJournal) setCode(addr common.Address, account *types.StateAccount) { + j.journalAccountChange(addr, account, false, true) +} + +func (j *sparseJournal) createObject(addr common.Address) { + // Creating an account which is destructed, hence already exists, is not + // allowed, hence we know destructed == 'false'. + // Also, if we are creating the account now, it cannot yet be a + // newContract (that might come later) + j.journalAccountChange(addr, nil, false, false) +} + +func (j *sparseJournal) createContract(addr common.Address, account *types.StateAccount) { + // Creating an account which is destructed, hence already exists, is not + // allowed, hence we know it to be 'false'. + // Also: if we create the contract now, it cannot be previously created + j.journalAccountChange(addr, account, false, false) +} + +func (j *sparseJournal) destruct(addr common.Address, account *types.StateAccount) { + // destructing an already destructed account must not be journalled. Hence we + // know it to be 'false'. + // Also: if we're allowed to destruct it, it must be `newContract:true`, OR + // the concept of newContract is unused and moot. + j.journalAccountChange(addr, account, false, true) +} + +// var ripemd = common.HexToAddress("0000000000000000000000000000000000000003") +func (j *sparseJournal) touchChange(addr common.Address, account *types.StateAccount, destructed, newContract bool) { + j.journalAccountChange(addr, account, destructed, newContract) + if addr == ripemd { + // Explicitly put it in the dirty-cache one extra time. Ripe magic. + j.ripeMagic = true + } +} + +func (j *sparseJournal) logChange(txHash common.Hash) { + j.entries[len(j.entries)-1].journalLog(txHash) +} + +func (j *sparseJournal) refundChange(prev uint64) { + j.entries[len(j.entries)-1].journalRefundChange(prev) +} + +func (j *sparseJournal) accessListAddAccount(addr common.Address) { + j.entries[len(j.entries)-1].journalAccessListAddAccount(addr) +} + +func (j *sparseJournal) accessListAddSlot(addr common.Address, slot common.Hash) { + j.entries[len(j.entries)-1].journalAccessListAddSlot(addr, slot) +} + +func (j *sparseJournal) storageChange(addr common.Address, key, prev, origin common.Hash) { + j.entries[len(j.entries)-1].journalSetState(addr, key, prev, origin) +} + +func (j *sparseJournal) transientStateChange(addr common.Address, key, prev common.Hash) { + j.entries[len(j.entries)-1].journalSetTransientState(addr, key, prev) +} + +func (j *sparseJournal) dirtyAccounts() []common.Address { + // The dirty-set should encompass all layers + var dirty = make(map[common.Address]any) + for _, scope := range j.entries { + scope.addDirtyAccounts(dirty) + } + if j.ripeMagic { + dirty[ripemd] = []interface{}{} + } + var dirtyList = make([]common.Address, 0, len(dirty)) + for addr := range dirty { + dirtyList = append(dirtyList, addr) + } + return dirtyList +} diff --git a/core/state/state_object.go b/core/state/state_object.go index 422badb19bc3..817a79c1e984 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -114,7 +114,7 @@ func (s *stateObject) markSelfdestructed() { } func (s *stateObject) touch() { - s.db.journal.touchChange(s.address) + s.db.journal.touchChange(s.address, &s.data, s.selfDestructed, s.newContract) } // getTrie returns the associated storage trie. The trie will be opened if it's @@ -470,7 +470,7 @@ func (s *stateObject) SubBalance(amount *uint256.Int, reason tracing.BalanceChan } func (s *stateObject) SetBalance(amount *uint256.Int, reason tracing.BalanceChangeReason) { - s.db.journal.balanceChange(s.address, s.data.Balance) + s.db.journal.balanceChange(s.address, &s.data, s.selfDestructed, s.newContract) if s.db.logger != nil && s.db.logger.OnBalanceChange != nil { s.db.logger.OnBalanceChange(s.address, s.Balance().ToBig(), amount.ToBig(), reason) } @@ -546,7 +546,7 @@ func (s *stateObject) CodeSize() int { } func (s *stateObject) SetCode(codeHash common.Hash, code []byte) { - s.db.journal.setCode(s.address) + s.db.journal.setCode(s.address, &s.data) if s.db.logger != nil && s.db.logger.OnCodeChange != nil { // TODO remove prevcode from this callback s.db.logger.OnCodeChange(s.address, common.BytesToHash(s.CodeHash()), nil, codeHash, code) @@ -561,7 +561,7 @@ func (s *stateObject) setCode(codeHash common.Hash, code []byte) { } func (s *stateObject) SetNonce(nonce uint64) { - s.db.journal.nonceChange(s.address, s.data.Nonce) + s.db.journal.nonceChange(s.address, &s.data, s.selfDestructed, s.newContract) if s.db.logger != nil && s.db.logger.OnNonceChange != nil { s.db.logger.OnNonceChange(s.address, s.data.Nonce, nonce) } diff --git a/core/state/statedb.go b/core/state/statedb.go index 3738dc2f1540..acf40d3d7358 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -180,7 +180,7 @@ func New(root common.Hash, db Database) (*StateDB, error) { mutations: make(map[common.Address]*mutation), logs: make(map[common.Hash][]*types.Log), preimages: make(map[common.Hash][]byte), - journal: newLinearJournal(), + journal: newSparseJournal(), accessList: newAccessList(), transientStorage: newTransientStorage(), } @@ -500,7 +500,7 @@ func (s *StateDB) SelfDestruct(addr common.Address) { // If it is already marked as self-destructed, we do not need to add it // for journalling a second time. if !stateObject.selfDestructed { - s.journal.destruct(addr) + s.journal.destruct(addr, &stateObject.data) stateObject.markSelfdestructed() } } @@ -638,7 +638,7 @@ func (s *StateDB) CreateContract(addr common.Address) { obj := s.getStateObject(addr) if !obj.newContract { obj.newContract = true - s.journal.createContract(addr) + s.journal.createContract(addr, &obj.data) } } @@ -707,6 +707,13 @@ func (s *StateDB) Snapshot() int { return s.journal.snapshot() } +// DiscardSnapshot removes the snapshot with the given id; after calling this +// method, it is no longer possible to revert to that particular snapshot, the +// changes are considered part of the parent scope. +func (s *StateDB) DiscardSnapshot(id int) { + s.journal.DiscardSnapshot(id) +} + // RevertToSnapshot reverts all state changes made since the given revision. func (s *StateDB) RevertToSnapshot(revid int) { s.journal.revertToSnapshot(revid, s) diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index fd8464d28383..e29221c3bd96 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -55,7 +55,7 @@ func TestUpdateLeaks(t *testing.T) { sdb = NewDatabase(tdb, nil) ) state, _ := New(types.EmptyRootHash, sdb) - + state.Snapshot() // Update it with some accounts for i := byte(0); i < 255; i++ { addr := common.BytesToAddress([]byte{i}) @@ -111,7 +111,7 @@ func TestIntermediateLeaks(t *testing.T) { } // Write modifications to trie. transState.IntermediateRoot(false) - + transState.journal.snapshot() // Overwrite all the data with new values in the transient database. for i := byte(0); i < 255; i++ { modify(transState, common.Address{i}, i, 99) @@ -362,6 +362,12 @@ func newTestAction(addr common.Address, r *rand.Rand) testAction { { name: "SetStorage", fn: func(a testAction, s *StateDB) { + contractHash := s.GetCodeHash(addr) + emptyCode := contractHash == (common.Hash{}) || contractHash == types.EmptyCodeHash + if emptyCode { + // no-op + return + } var key, val common.Hash binary.BigEndian.PutUint16(key[:], uint16(a.args[0])) binary.BigEndian.PutUint16(val[:], uint16(a.args[1])) @@ -372,12 +378,26 @@ func newTestAction(addr common.Address, r *rand.Rand) testAction { { name: "SetCode", fn: func(a testAction, s *StateDB) { - // SetCode can only be performed in case the addr does - // not already hold code + // SetCode cannot be performed if the addr already has code if c := s.GetCode(addr); len(c) > 0 { // no-op return } + // SetCode cannot be performed if the addr has just selfdestructed + if obj := s.getStateObject(addr); obj != nil { + if obj.selfDestructed { + // If it's selfdestructed, we cannot create into it + return + } + } + // SetCode requires the contract to be account + contract to be created first + if obj := s.getStateObject(addr); obj == nil { + s.createObject(addr) + } + obj := s.getStateObject(addr) + if !obj.newContract { + s.CreateContract(addr) + } code := make([]byte, 16) binary.BigEndian.PutUint64(code, uint64(a.args[0])) binary.BigEndian.PutUint64(code[8:], uint64(a.args[1])) @@ -403,6 +423,13 @@ func newTestAction(addr common.Address, r *rand.Rand) testAction { emptyCode := contractHash == (common.Hash{}) || contractHash == types.EmptyCodeHash storageRoot := s.GetStorageRoot(addr) emptyStorage := storageRoot == (common.Hash{}) || storageRoot == types.EmptyRootHash + + if obj := s.getStateObject(addr); obj != nil { + if obj.selfDestructed { + // If it's selfdestructed, we cannot create into it + return + } + } if s.GetNonce(addr) == 0 && emptyCode && emptyStorage { s.CreateContract(addr) // We also set some code here, to prevent the @@ -417,6 +444,15 @@ func newTestAction(addr common.Address, r *rand.Rand) testAction { { name: "SelfDestruct", fn: func(a testAction, s *StateDB) { + obj := s.getStateObject(addr) + // SelfDestruct requires the object to first exist + if obj == nil { + s.createObject(addr) + } + obj = s.getStateObject(addr) + if !obj.newContract { + s.CreateContract(addr) + } s.SelfDestruct(addr) }, }, @@ -437,15 +473,6 @@ func newTestAction(addr common.Address, r *rand.Rand) testAction { }, args: make([]int64, 1), }, - { - name: "AddPreimage", - fn: func(a testAction, s *StateDB) { - preimage := []byte{1} - hash := common.BytesToHash(preimage) - s.AddPreimage(hash, preimage) - }, - args: make([]int64, 1), - }, { name: "AddAddressToAccessList", fn: func(a testAction, s *StateDB) { @@ -463,6 +490,13 @@ func newTestAction(addr common.Address, r *rand.Rand) testAction { { name: "SetTransientState", fn: func(a testAction, s *StateDB) { + contractHash := s.GetCodeHash(addr) + emptyCode := contractHash == (common.Hash{}) || contractHash == types.EmptyCodeHash + if emptyCode { + // no-op + return + } + var key, val common.Hash binary.BigEndian.PutUint16(key[:], uint16(a.args[0])) binary.BigEndian.PutUint16(val[:], uint16(a.args[1])) @@ -688,8 +722,8 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { } return out.String() } - haveK := getKeys(state.journal.dirtyAccounts()) - wantK := getKeys(checkstate.journal.dirtyAccounts()) + haveK := getKeys(have) + wantK := getKeys(want) return fmt.Errorf("dirty-journal set mismatch.\nhave:\n%v\nwant:\n%v\n", haveK, wantK) } } @@ -1122,17 +1156,12 @@ func TestStateDBAccessList(t *testing.T) { // Make a copy stateCopy1 := state.Copy() - if exp, got := 4, state.journal.(*linearJournal).length(); exp != got { - t.Fatalf("linearJournal length mismatch: have %d, want %d", got, exp) - } // same again, should cause no linearJournal entries state.AddSlotToAccessList(addr("bb"), slot("01")) state.AddSlotToAccessList(addr("bb"), slot("02")) state.AddAddressToAccessList(addr("aa")) - if exp, got := 4, state.journal.(*linearJournal).length(); exp != got { - t.Fatalf("linearJournal length mismatch: have %d, want %d", got, exp) - } + // some new ones state.AddSlotToAccessList(addr("bb"), slot("03")) // 5 push(state.journal.snapshot()) // journal id 5 @@ -1141,9 +1170,6 @@ func TestStateDBAccessList(t *testing.T) { state.AddAddressToAccessList(addr("cc")) // 7 push(state.journal.snapshot()) // journal id 7 state.AddSlotToAccessList(addr("cc"), slot("01")) // 8 - if exp, got := 8, state.journal.(*linearJournal).length(); exp != got { - t.Fatalf("linearJournal length mismatch: have %d, want %d", got, exp) - } verifyAddrs("aa", "bb", "cc") verifySlots("aa", "01") @@ -1273,9 +1299,7 @@ func TestStateDBTransientStorage(t *testing.T) { addr := common.Address{} revision := state.journal.snapshot() state.SetTransientState(addr, key, value) - if exp, got := 1, state.journal.(*linearJournal).length(); exp != got { - t.Fatalf("linearJournal length mismatch: have %d, want %d", got, exp) - } + // the retrieved value should equal what was set if got := state.GetTransientState(addr, key); got != value { t.Fatalf("transient storage mismatch: have %x, want %x", got, value) diff --git a/core/vm/evm.go b/core/vm/evm.go index 616668d565cc..ed9de46bca6d 100644 --- a/core/vm/evm.go +++ b/core/vm/evm.go @@ -201,6 +201,7 @@ func (evm *EVM) Call(caller ContractRef, addr common.Address, input []byte, gas if !isPrecompile && evm.chainRules.IsEIP158 && value.IsZero() { // Calling a non-existing account, don't do anything. + evm.StateDB.DiscardSnapshot(snapshot) return nil, gas, nil } evm.StateDB.CreateAccount(addr) @@ -240,9 +241,8 @@ func (evm *EVM) Call(caller ContractRef, addr common.Address, input []byte, gas gas = 0 } - // TODO: consider clearing up unused snapshots: - //} else { - // evm.StateDB.DiscardSnapshot(snapshot) + } else { + evm.StateDB.DiscardSnapshot(snapshot) } return ret, gas, err } @@ -299,6 +299,8 @@ func (evm *EVM) CallCode(caller ContractRef, addr common.Address, input []byte, gas = 0 } + } else { + evm.StateDB.DiscardSnapshot(snapshot) } return ret, gas, err } @@ -348,6 +350,8 @@ func (evm *EVM) DelegateCall(caller ContractRef, addr common.Address, input []by } gas = 0 } + } else { + evm.StateDB.DiscardSnapshot(snapshot) } return ret, gas, err } @@ -410,6 +414,8 @@ func (evm *EVM) StaticCall(caller ContractRef, addr common.Address, input []byte gas = 0 } + } else { + evm.StateDB.DiscardSnapshot(snapshot) } return ret, gas, err } @@ -521,6 +527,8 @@ func (evm *EVM) create(caller ContractRef, codeAndHash *codeAndHash, gas uint64, if err != ErrExecutionReverted { contract.UseGas(contract.Gas, evm.Config.Tracer, tracing.GasChangeCallFailedExecution) } + } else { + evm.StateDB.DiscardSnapshot(snapshot) } return ret, address, contract.Gas, err } diff --git a/core/vm/interface.go b/core/vm/interface.go index 5f426435650d..58e2075ae55a 100644 --- a/core/vm/interface.go +++ b/core/vm/interface.go @@ -83,7 +83,14 @@ type StateDB interface { Prepare(rules params.Rules, sender, coinbase common.Address, dest *common.Address, precompiles []common.Address, txAccesses types.AccessList) + // RevertToSnapshot reverts all state changes made since the given revision. RevertToSnapshot(int) + + // DiscardSnapshot removes the snapshot with the given id; after calling this + // method, it is no longer possible to revert to that particular snapshot, the + // changes are considered part of the parent scope. + DiscardSnapshot(int) + // Snapshot returns an identifier for the current scope of the state. Snapshot() int AddLog(*types.Log) diff --git a/tests/state_test_util.go b/tests/state_test_util.go index cf0ce9777f67..b562a197956e 100644 --- a/tests/state_test_util.go +++ b/tests/state_test_util.go @@ -308,6 +308,8 @@ func (t *StateTest) RunNoVerify(subtest StateSubtest, vmconfig vm.Config, snapsh if tracer := evm.Config.Tracer; tracer != nil && tracer.OnTxEnd != nil { evm.Config.Tracer.OnTxEnd(nil, err) } + } else { + st.StateDB.DiscardSnapshot(snapshot) } // Add 0-value mining reward. This only makes a difference in the cases // where