diff --git a/cmd/evm/internal/t8ntool/execution.go b/cmd/evm/internal/t8ntool/execution.go
index 5fd1d6a4a6ad..ee72db68ec94 100644
--- a/cmd/evm/internal/t8ntool/execution.go
+++ b/cmd/evm/internal/t8ntool/execution.go
@@ -271,6 +271,7 @@ func (pre *Prestate) Apply(vmConfig vm.Config, chainConfig *params.ChainConfig,
}
continue
}
+ statedb.DiscardSnapshot(snapshot)
includedTxs = append(includedTxs, tx)
if hashError != nil {
return nil, nil, nil, NewError(ErrorMissingBlockhash, hashError)
diff --git a/cmd/evm/runner.go b/cmd/evm/runner.go
index 235fed66302a..69a52abe804a 100644
--- a/cmd/evm/runner.go
+++ b/cmd/evm/runner.go
@@ -158,7 +158,8 @@ func runCmd(ctx *cli.Context) error {
sdb := state.NewDatabase(triedb, nil)
statedb, _ = state.New(genesis.Root(), sdb)
chainConfig = genesisConfig.Config
-
+ id := statedb.Snapshot()
+ defer statedb.DiscardSnapshot(id)
if ctx.String(SenderFlag.Name) != "" {
sender = common.HexToAddress(ctx.String(SenderFlag.Name))
}
diff --git a/core/genesis.go b/core/genesis.go
index 31db49f527e4..3a776c99a02e 100644
--- a/core/genesis.go
+++ b/core/genesis.go
@@ -152,6 +152,7 @@ func flushAlloc(ga *types.GenesisAlloc, triedb *triedb.Database) (common.Hash, e
if err != nil {
return common.Hash{}, err
}
+
for addr, account := range *ga {
if account.Balance != nil {
// This is not actually logged via tracer because OnGenesisBlock
diff --git a/core/state/journal.go b/core/state/journal.go
index a2fea6b6ecc5..f6e3f5a3ed30 100644
--- a/core/state/journal.go
+++ b/core/state/journal.go
@@ -32,33 +32,36 @@ type revision struct {
journalIndex int
}
-// journalEntry is a modification entry in the state change journal that can be
+// journalEntry is a modification entry in the state change linear journal that can be
// reverted on demand.
type journalEntry interface {
- // revert undoes the changes introduced by this journal entry.
+ // revert undoes the changes introduced by this entry.
revert(*StateDB)
- // dirtied returns the Ethereum address modified by this journal entry.
+ // dirtied returns the Ethereum address modified by this entry.
dirtied() *common.Address
- // copy returns a deep-copied journal entry.
+ // copy returns a deep-copied entry.
copy() journalEntry
}
-// journal contains the list of state modifications applied since the last state
+// linearJournal contains the list of state modifications applied since the last state
// commit. These are tracked to be able to be reverted in the case of an execution
// exception or request for reversal.
-type journal struct {
- entries []journalEntry // Current changes tracked by the journal
+type linearJournal struct {
+ entries []journalEntry // Current changes tracked by the linearJournal
dirties map[common.Address]int // Dirty accounts and the number of changes
validRevisions []revision
nextRevisionId int
}
-// newJournal creates a new initialized journal.
-func newJournal() *journal {
- return &journal{
+// compile-time interface check
+var _ journal = (*linearJournal)(nil)
+
+// newLinearJournal creates a new initialized linearJournal.
+func newLinearJournal() *linearJournal {
+ return &linearJournal{
dirties: make(map[common.Address]int),
}
}
@@ -66,39 +69,54 @@ func newJournal() *journal {
// reset clears the journal, after this operation the journal can be used anew.
// It is semantically similar to calling 'newJournal', but the underlying slices
// can be reused.
-func (j *journal) reset() {
+func (j *linearJournal) reset() {
j.entries = j.entries[:0]
j.validRevisions = j.validRevisions[:0]
clear(j.dirties)
j.nextRevisionId = 0
}
+func (j linearJournal) dirtyAccounts() []common.Address {
+ dirty := make([]common.Address, 0, len(j.dirties))
+ // flatten into list
+ for addr := range j.dirties {
+ dirty = append(dirty, addr)
+ }
+ return dirty
+}
+
// snapshot returns an identifier for the current revision of the state.
-func (j *journal) snapshot() int {
+func (j *linearJournal) snapshot() int {
id := j.nextRevisionId
j.nextRevisionId++
j.validRevisions = append(j.validRevisions, revision{id, j.length()})
return id
}
-// revertToSnapshot reverts all state changes made since the given revision.
-func (j *journal) revertToSnapshot(revid int, s *StateDB) {
+func (j *linearJournal) revertToSnapshot(revid int, s *StateDB) {
// Find the snapshot in the stack of valid snapshots.
idx := sort.Search(len(j.validRevisions), func(i int) bool {
return j.validRevisions[i].id >= revid
})
if idx == len(j.validRevisions) || j.validRevisions[idx].id != revid {
- panic(fmt.Errorf("revision id %v cannot be reverted", revid))
+ panic(fmt.Errorf("revision id %v cannot be reverted (valid revisions: %d)", revid, len(j.validRevisions)))
}
snapshot := j.validRevisions[idx].journalIndex
- // Replay the journal to undo changes and remove invalidated snapshots
+ // Replay the linearJournal to undo changes and remove invalidated snapshots
j.revert(s, snapshot)
j.validRevisions = j.validRevisions[:idx]
}
-// append inserts a new modification entry to the end of the change journal.
-func (j *journal) append(entry journalEntry) {
+// DiscardSnapshot removes the snapshot with the given id; after calling this
+// method, it is no longer possible to revert to that particular snapshot, the
+// changes are considered part of the parent scope.
+func (j *linearJournal) DiscardSnapshot(id int) {
+ //
+}
+
+// append inserts a new modification entry to the end of the change linearJournal.
+func (j *linearJournal) append(entry journalEntry) {
j.entries = append(j.entries, entry)
if addr := entry.dirtied(); addr != nil {
j.dirties[*addr]++
@@ -107,7 +125,7 @@ func (j *journal) append(entry journalEntry) {
// revert undoes a batch of journalled modifications along with any reverted
// dirty handling too.
-func (j *journal) revert(statedb *StateDB, snapshot int) {
+func (j *linearJournal) revert(statedb *StateDB, snapshot int) {
for i := len(j.entries) - 1; i >= snapshot; i-- {
// Undo the changes made by the operation
j.entries[i].revert(statedb)
@@ -125,22 +143,22 @@ func (j *journal) revert(statedb *StateDB, snapshot int) {
// dirty explicitly sets an address to dirty, even if the change entries would
// otherwise suggest it as clean. This method is an ugly hack to handle the RIPEMD
// precompile consensus exception.
-func (j *journal) dirty(addr common.Address) {
+func (j *linearJournal) dirty(addr common.Address) {
j.dirties[addr]++
}
-// length returns the current number of entries in the journal.
-func (j *journal) length() int {
+// length returns the current number of entries in the linearJournal.
+func (j *linearJournal) length() int {
return len(j.entries)
}
// copy returns a deep-copied journal.
-func (j *journal) copy() *journal {
+func (j *linearJournal) copy() journal {
entries := make([]journalEntry, 0, j.length())
for i := 0; i < j.length(); i++ {
entries = append(entries, j.entries[i].copy())
}
- return &journal{
+ return &linearJournal{
entries: entries,
dirties: maps.Clone(j.dirties),
validRevisions: slices.Clone(j.validRevisions),
@@ -148,23 +166,23 @@ func (j *journal) copy() *journal {
}
}
-func (j *journal) logChange(txHash common.Hash) {
+func (j *linearJournal) logChange(txHash common.Hash) {
j.append(addLogChange{txhash: txHash})
}
-func (j *journal) createObject(addr common.Address) {
+func (j *linearJournal) createObject(addr common.Address) {
j.append(createObjectChange{account: addr})
}
-func (j *journal) createContract(addr common.Address) {
+func (j *linearJournal) createContract(addr common.Address, account *types.StateAccount) {
j.append(createContractChange{account: addr})
}
-func (j *journal) destruct(addr common.Address) {
+func (j *linearJournal) destruct(addr common.Address, account *types.StateAccount) {
j.append(selfDestructChange{account: addr})
}
-func (j *journal) storageChange(addr common.Address, key, prev, origin common.Hash) {
+func (j *linearJournal) storageChange(addr common.Address, key, prev, origin common.Hash) {
j.append(storageChange{
account: addr,
key: key,
@@ -173,7 +191,7 @@ func (j *journal) storageChange(addr common.Address, key, prev, origin common.Ha
})
}
-func (j *journal) transientStateChange(addr common.Address, key, prev common.Hash) {
+func (j *linearJournal) transientStateChange(addr common.Address, key, prev common.Hash) {
j.append(transientStorageChange{
account: addr,
key: key,
@@ -181,29 +199,29 @@ func (j *journal) transientStateChange(addr common.Address, key, prev common.Has
})
}
-func (j *journal) refundChange(previous uint64) {
+func (j *linearJournal) refundChange(previous uint64) {
j.append(refundChange{prev: previous})
}
-func (j *journal) balanceChange(addr common.Address, previous *uint256.Int) {
+func (j *linearJournal) balanceChange(addr common.Address, account *types.StateAccount, destructed, newContract bool) {
j.append(balanceChange{
account: addr,
- prev: previous.Clone(),
+ prev: account.Balance.Clone(),
})
}
-func (j *journal) setCode(address common.Address) {
+func (j *linearJournal) setCode(address common.Address, account *types.StateAccount) {
j.append(codeChange{account: address})
}
-func (j *journal) nonceChange(address common.Address, prev uint64) {
+func (j *linearJournal) nonceChange(address common.Address, account *types.StateAccount, destructed, newContract bool) {
j.append(nonceChange{
account: address,
- prev: prev,
+ prev: account.Nonce,
})
}
-func (j *journal) touchChange(address common.Address) {
+func (j *linearJournal) touchChange(address common.Address, account *types.StateAccount, destructed, newContract bool) {
j.append(touchChange{
account: address,
})
@@ -214,11 +232,11 @@ func (j *journal) touchChange(address common.Address) {
}
}
-func (j *journal) accessListAddAccount(addr common.Address) {
+func (j *linearJournal) accessListAddAccount(addr common.Address) {
j.append(accessListAddAccountChange{addr})
}
-func (j *journal) accessListAddSlot(addr common.Address, slot common.Hash) {
+func (j *linearJournal) accessListAddSlot(addr common.Address, slot common.Hash) {
j.append(accessListAddSlotChange{
address: addr,
slot: slot,
@@ -231,7 +249,7 @@ type (
account common.Address
}
// createContractChange represents an account becoming a contract-account.
- // This event happens prior to executing initcode. The journal-event simply
+ // This event happens prior to executing initcode. The linearJournal-event simply
// manages the created-flag, in order to allow same-tx destruction.
createContractChange struct {
account common.Address
@@ -457,7 +475,7 @@ func (ch addLogChange) copy() journalEntry {
func (ch accessListAddAccountChange) revert(s *StateDB) {
/*
One important invariant here, is that whenever a (addr, slot) is added, if the
- addr is not already present, the add causes two journal entries:
+ addr is not already present, the add causes two linearJournal entries:
- one for the address,
- one for the (address,slot)
Therefore, when unrolling the change, we can always blindly delete the
diff --git a/core/state/journal_api.go b/core/state/journal_api.go
new file mode 100644
index 000000000000..d88cd853135d
--- /dev/null
+++ b/core/state/journal_api.go
@@ -0,0 +1,86 @@
+package state
+
+import (
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/types"
+)
+
+type journal interface {
+ // snapshot returns an identifier for the current revision of the state.
+ // The lifeycle of journalling is as follows:
+ // - snapshot() starts a 'scope'.
+ // - Tee method snapshot() may be called any number of times.
+ // - For each call to snapshot, there should be a corresponding call to end
+ // the scope via either of:
+ // - revertToSnapshot, which undoes the changes in the scope, or
+ // - discardSnapshot, which discards the ability to revert the changes in the scope.
+ // - This operation might merge the changes into the parent scope.
+ // If it does not merge the changes into the parent scope, it must create
+ // a new snapshot internally, in order to ensure that order of changes
+ // remains intact.
+ snapshot() int
+
+ // revertToSnapshot reverts all state changes made since the given revision.
+ revertToSnapshot(revid int, s *StateDB)
+
+ // reset clears the journal so it can be reused.
+ reset()
+
+ // DiscardSnapshot removes the snapshot with the given id; after calling this
+ // method, it is no longer possible to revert to that particular snapshot, the
+ // changes are considered part of the parent scope.
+ DiscardSnapshot(revid int)
+
+ // dirtyAccounts returns a list of all accounts modified in this journal
+ dirtyAccounts() []common.Address
+
+ // accessListAddAccount journals the adding of addr to the access list
+ accessListAddAccount(addr common.Address)
+
+ // accessListAddSlot journals the adding of addr/slot to the access list
+ accessListAddSlot(addr common.Address, slot common.Hash)
+
+ // logChange journals the adding of a log related to the txHash
+ logChange(txHash common.Hash)
+
+ // createObject journals the event of a new account created in the trie.
+ createObject(addr common.Address)
+
+ // createContract journals the creation of a new contract at addr.
+ // OBS: This method must not be applied twice, it assumes that the pre-state
+ // (i.e the rollback-state) is non-created.
+ createContract(addr common.Address, account *types.StateAccount)
+
+ // destruct journals the destruction of an account in the trie.
+ // pre-state (i.e the rollback-state) is non-destructed (and, for the purpose
+ // of EIP-XXX (TODO lookup), created in this tx).
+ destruct(addr common.Address, account *types.StateAccount)
+
+ // storageChange journals a change in the storage data related to addr.
+ // It records the key and previous value of the slot.
+ storageChange(addr common.Address, key, prev, origin common.Hash)
+
+ // transientStateChange journals a change in the t-storage data related to addr.
+ // It records the key and previous value of the slot.
+ transientStateChange(addr common.Address, key, prev common.Hash)
+
+ // refundChange journals that the refund has been changed, recording the previous value.
+ refundChange(previous uint64)
+
+ // balanceChange journals that the balance of addr has been changed, recording the previous value
+ balanceChange(addr common.Address, account *types.StateAccount, destructed, newContract bool)
+
+ // setCode journals that the code of addr has been set.
+ // OBS: This method must not be applied twice -- it always assumes that the
+ // pre-state (i.e the rollback-state) is "no code".
+ setCode(addr common.Address, account *types.StateAccount)
+
+ // nonceChange journals that the nonce of addr was changed, recording the previous value.
+ nonceChange(addr common.Address, account *types.StateAccount, destructed, newContract bool)
+
+ // touchChange journals that the account at addr was touched during execution.
+ touchChange(addr common.Address, account *types.StateAccount, destructed, newContract bool)
+
+ // copy returns a deep-copied journal.
+ copy() journal
+}
diff --git a/core/state/journal_test.go b/core/state/journal_test.go
new file mode 100644
index 000000000000..76d1c936afc6
--- /dev/null
+++ b/core/state/journal_test.go
@@ -0,0 +1,132 @@
+// Copyright 2024 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+// Package state provides a caching layer atop the Ethereum state trie.
+package state
+
+import (
+ "testing"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/holiman/uint256"
+)
+
+func TestLinearJournalDirty(t *testing.T) {
+ testJournalDirty(t, newLinearJournal())
+}
+
+func TestSparseJournalDirty(t *testing.T) {
+ testJournalDirty(t, newSparseJournal())
+}
+
+// This test verifies some basics around journalling: the ability to
+// deliver a dirty-set.
+func testJournalDirty(t *testing.T, j journal) {
+ acc := &types.StateAccount{
+ Nonce: 1,
+ Balance: new(uint256.Int),
+ Root: common.Hash{},
+ CodeHash: nil,
+ }
+ {
+ j.nonceChange(common.Address{0x1}, acc, false, false)
+ if have, want := len(j.dirtyAccounts()), 1; have != want {
+ t.Errorf("wrong size of dirty accounts, have %v want %v", have, want)
+ }
+ }
+ {
+ j.storageChange(common.Address{0x2}, common.Hash{0x1}, common.Hash{0x1}, common.Hash{})
+ if have, want := len(j.dirtyAccounts()), 2; have != want {
+ t.Errorf("wrong size of dirty accounts, have %v want %v", have, want)
+ }
+ }
+ { // The previous scopes should also be accounted for
+ j.snapshot()
+ if have, want := len(j.dirtyAccounts()), 2; have != want {
+ t.Errorf("wrong size of dirty accounts, have %v want %v", have, want)
+ }
+ }
+}
+
+func TestLinearJournalAccessList(t *testing.T) {
+ testJournalAccessList(t, newLinearJournal())
+}
+
+func TestSparseJournalAccessList(t *testing.T) {
+ testJournalAccessList(t, newSparseJournal())
+}
+
+func testJournalAccessList(t *testing.T, j journal) {
+ var statedb = &StateDB{}
+ statedb.accessList = newAccessList()
+ statedb.journal = j
+
+ {
+ // If the journal performs the rollback in the wrong order, this
+ // will cause a panic.
+ id := j.snapshot()
+ statedb.AddSlotToAccessList(common.Address{0x1}, common.Hash{0x4})
+ statedb.AddSlotToAccessList(common.Address{0x3}, common.Hash{0x4})
+ statedb.RevertToSnapshot(id)
+ }
+ {
+ id := j.snapshot()
+ statedb.AddAddressToAccessList(common.Address{0x2})
+ statedb.AddAddressToAccessList(common.Address{0x3})
+ statedb.AddAddressToAccessList(common.Address{0x4})
+ statedb.RevertToSnapshot(id)
+ if statedb.accessList.ContainsAddress(common.Address{0x2}) {
+ t.Fatal("should be missing")
+ }
+ }
+}
+
+func TestLinearJournalRefunds(t *testing.T) {
+ testJournalRefunds(t, newLinearJournal())
+}
+
+func TestSparseJournalRefunds(t *testing.T) {
+ testJournalRefunds(t, newSparseJournal())
+}
+
+func testJournalRefunds(t *testing.T, j journal) {
+ var statedb = &StateDB{}
+ statedb.accessList = newAccessList()
+ statedb.journal = j
+ zero := j.snapshot()
+ j.refundChange(0)
+ j.refundChange(1)
+ {
+ id := j.snapshot()
+ j.refundChange(2)
+ j.refundChange(3)
+ j.revertToSnapshot(id, statedb)
+ if have, want := statedb.refund, uint64(2); have != want {
+ t.Fatalf("have %d want %d", have, want)
+ }
+ }
+ {
+ id := j.snapshot()
+ j.refundChange(2)
+ j.refundChange(3)
+ j.DiscardSnapshot(id)
+ }
+ j.revertToSnapshot(zero, statedb)
+ if have, want := statedb.refund, uint64(0); have != want {
+ t.Fatalf("have %d want %d", have, want)
+ }
+}
diff --git a/core/state/setjournal.go b/core/state/setjournal.go
new file mode 100644
index 000000000000..63ba2eb9dab3
--- /dev/null
+++ b/core/state/setjournal.go
@@ -0,0 +1,486 @@
+// Copyright 2024 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package state
+
+import (
+ "bytes"
+ "fmt"
+ "maps"
+ "slices"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/types"
+ "github.com/ethereum/go-ethereum/log"
+ "github.com/holiman/uint256"
+)
+
+var (
+ _ journal = (*sparseJournal)(nil)
+)
+
+// journalAccount represents the 'journable state' of a types.Account.
+// Which means, all the normal fields except storage root, but also with a
+// destruction-flag.
+type journalAccount struct {
+ nonce uint64
+ balance uint256.Int
+ codeHash []byte // nil == emptyCodeHAsh
+ destructed bool
+ newContract bool
+}
+
+type addrSlot struct {
+ addr common.Address
+ slot common.Hash
+}
+
+type doubleHash struct {
+ origin common.Hash
+ prev common.Hash
+}
+
+// scopedJournal represents all changes within a single callscope. These changes
+// are either all reverted, or all committed -- they cannot be partially applied.
+type scopedJournal struct {
+ accountChanges map[common.Address]*journalAccount
+ refund int64
+ logs []common.Hash
+
+ accessListAddresses []common.Address
+ accessListAddrSlots []addrSlot
+
+ storageChanges map[common.Address]map[common.Hash]doubleHash
+ tStorageChanges map[common.Address]map[common.Hash]common.Hash
+}
+
+func newScopedJournal() *scopedJournal {
+ return &scopedJournal{
+ refund: -1,
+ }
+}
+
+func (j *scopedJournal) deepCopy() *scopedJournal {
+ var cpy = &scopedJournal{
+ // The accountChanges copy will copy the pointers to
+ // journalAccount objects: thus not actually deep copy those
+ // objects. That is fine: we never mutate journalAccount.
+ accountChanges: maps.Clone(j.accountChanges),
+ refund: j.refund,
+ logs: slices.Clone(j.logs),
+ accessListAddresses: slices.Clone(j.accessListAddresses),
+ accessListAddrSlots: slices.Clone(j.accessListAddrSlots),
+ }
+ if j.storageChanges != nil {
+ cpy.storageChanges = make(map[common.Address]map[common.Hash]doubleHash)
+ for addr, changes := range j.storageChanges {
+ cpy.storageChanges[addr] = maps.Clone(changes)
+ }
+ }
+ if j.tStorageChanges != nil {
+ cpy.tStorageChanges = make(map[common.Address]map[common.Hash]common.Hash)
+ for addr, changes := range j.tStorageChanges {
+ cpy.tStorageChanges[addr] = maps.Clone(changes)
+ }
+ }
+ return cpy
+}
+
+func (j *scopedJournal) journalRefundChange(prev uint64) {
+ if j.refund == -1 {
+ // We convert from uint64 to int64 here, so that we can use -1
+ // to represent "no previous value set".
+ // Treating refund as int64 is fine, there's no possibility for
+ // refund to ever exceed maxInt64.
+ j.refund = int64(prev)
+ }
+}
+
+// journalAccountChange is the common shared implementation for all account-changes.
+// These changes all fall back to this method:
+// - balance change
+// - nonce change
+// - destruct-change
+// - code change
+// - touch change
+// - creation change (in this case, the account is nil)
+func (j *scopedJournal) journalAccountChange(address common.Address, account *types.StateAccount, destructed, newContract bool) {
+ if j.accountChanges == nil {
+ j.accountChanges = make(map[common.Address]*journalAccount)
+ }
+ // If the account has already been journalled, we're done here
+ if _, ok := j.accountChanges[address]; ok {
+ return
+ }
+ if account == nil {
+ j.accountChanges[address] = nil // created now, previously non-existent
+ return
+ }
+ ja := &journalAccount{
+ nonce: account.Nonce,
+ balance: *account.Balance,
+ destructed: destructed,
+ newContract: newContract,
+ }
+ if !bytes.Equal(account.CodeHash, types.EmptyCodeHash[:]) {
+ ja.codeHash = account.CodeHash
+ }
+ j.accountChanges[address] = ja
+}
+
+func (j *scopedJournal) journalLog(txHash common.Hash) {
+ j.logs = append(j.logs, txHash)
+}
+
+func (j *scopedJournal) journalAccessListAddAccount(addr common.Address) {
+ j.accessListAddresses = append(j.accessListAddresses, addr)
+}
+
+func (j *scopedJournal) journalAccessListAddSlot(addr common.Address, slot common.Hash) {
+ j.accessListAddrSlots = append(j.accessListAddrSlots, addrSlot{addr, slot})
+}
+
+func (j *scopedJournal) journalSetState(addr common.Address, key, prev, origin common.Hash) {
+ if j.storageChanges == nil {
+ j.storageChanges = make(map[common.Address]map[common.Hash]doubleHash)
+ }
+ changes, ok := j.storageChanges[addr]
+ if !ok {
+ changes = make(map[common.Hash]doubleHash)
+ j.storageChanges[addr] = changes
+ }
+ // Do not overwrite a previous value!
+ if _, ok := changes[key]; !ok {
+ changes[key] = doubleHash{origin: origin, prev: prev}
+ }
+}
+
+func (j *scopedJournal) journalSetTransientState(addr common.Address, key, prev common.Hash) {
+ if j.tStorageChanges == nil {
+ j.tStorageChanges = make(map[common.Address]map[common.Hash]common.Hash)
+ }
+ changes, ok := j.tStorageChanges[addr]
+ if !ok {
+ changes = make(map[common.Hash]common.Hash)
+ j.tStorageChanges[addr] = changes
+ }
+ // Do not overwrite a previous value!
+ if _, ok := changes[key]; !ok {
+ changes[key] = prev
+ }
+}
+
+func (j *scopedJournal) revert(s *StateDB) {
+ // Revert refund
+ if j.refund != -1 {
+ s.refund = uint64(j.refund)
+ }
+ // Revert storage changes
+ for addr, changes := range j.storageChanges {
+ obj := s.getStateObject(addr)
+ for key, val := range changes {
+ obj.setState(key, val.prev, val.origin)
+ }
+ }
+ // Revert t-store changes
+ for addr, changes := range j.tStorageChanges {
+ for key, val := range changes {
+ s.setTransientState(addr, key, val)
+ }
+ }
+
+ // Revert changes to accounts
+ for addr, data := range j.accountChanges {
+ if data == nil { // Reverting a create
+ delete(s.stateObjects, addr)
+ continue
+ }
+ obj := s.getStateObject(addr)
+ obj.setNonce(data.nonce)
+ // Setting 'code' to nil means it will be loaded from disk
+ // next time it is needed. We avoid nilling it unless required
+ journalHash := data.codeHash
+ if data.codeHash == nil {
+ if !bytes.Equal(obj.CodeHash(), types.EmptyCodeHash[:]) {
+ obj.setCode(types.EmptyCodeHash, nil)
+ }
+ } else {
+ if !bytes.Equal(obj.CodeHash(), journalHash) {
+ obj.setCode(common.BytesToHash(data.codeHash), nil)
+ }
+ }
+ obj.setBalance(&data.balance)
+ obj.selfDestructed = data.destructed
+ obj.newContract = data.newContract
+ }
+ // Revert logs
+ for _, txhash := range j.logs {
+ logs := s.logs[txhash]
+ if len(logs) == 1 {
+ delete(s.logs, txhash)
+ } else {
+ s.logs[txhash] = logs[:len(logs)-1]
+ }
+ s.logSize--
+ }
+ // Revert access list additions
+ for i := len(j.accessListAddrSlots) - 1; i >= 0; i-- {
+ item := j.accessListAddrSlots[i]
+ s.accessList.DeleteSlot(item.addr, item.slot)
+ }
+ for i := len(j.accessListAddresses) - 1; i >= 0; i-- {
+ s.accessList.DeleteAddress(j.accessListAddresses[i])
+ }
+}
+
+func (j *scopedJournal) merge(parent *scopedJournal) {
+ if parent.refund == -1 {
+ parent.refund = j.refund
+ }
+ // Revert changes to accounts
+ if parent.accountChanges == nil {
+ parent.accountChanges = j.accountChanges
+ } else {
+ for addr, data := range j.accountChanges {
+ if _, present := parent.accountChanges[addr]; present {
+ // Nothing to do here, it's already stored in parent scope
+ continue
+ }
+ parent.accountChanges[addr] = data
+ }
+ }
+ // Revert logs
+ parent.logs = append(parent.logs, j.logs...)
+
+ // Revert access list additions
+ parent.accessListAddrSlots = append(parent.accessListAddrSlots, j.accessListAddrSlots...)
+ parent.accessListAddresses = append(parent.accessListAddresses, j.accessListAddresses...)
+
+ if parent.storageChanges == nil {
+ parent.storageChanges = j.storageChanges
+ } else {
+ // Merge storage changes
+ for addr, changes := range j.storageChanges {
+ prevChanges, ok := parent.storageChanges[addr]
+ if !ok {
+ parent.storageChanges[addr] = changes
+ continue
+ }
+ for k, v := range changes {
+ if _, ok := prevChanges[k]; !ok {
+ prevChanges[k] = v
+ }
+ }
+ }
+ }
+ if parent.tStorageChanges == nil {
+ parent.tStorageChanges = j.tStorageChanges
+ } else {
+ // Revert t-store changes
+ for addr, changes := range j.tStorageChanges {
+ prevChanges, ok := parent.tStorageChanges[addr]
+ if !ok {
+ parent.tStorageChanges[addr] = changes
+ continue
+ }
+ for k, v := range changes {
+ if _, ok := prevChanges[k]; !ok {
+ prevChanges[k] = v
+ }
+ }
+ }
+ }
+}
+
+func (j *scopedJournal) addDirtyAccounts(set map[common.Address]any) {
+ // Changes due to account changes
+ for addr := range j.accountChanges {
+ set[addr] = []interface{}{}
+ }
+ // Changes due to storage changes
+ for addr := range j.storageChanges {
+ set[addr] = []interface{}{}
+ }
+}
+
+// sparseJournal contains the list of state modifications applied since the last state
+// commit. These are tracked to be able to be reverted in the case of an execution
+// exception or request for reversal.
+type sparseJournal struct {
+ entries []*scopedJournal // Current changes tracked by the journal
+ ripeMagic bool
+}
+
+// newJournal creates a new initialized journal.
+func newSparseJournal() *sparseJournal {
+ s := new(sparseJournal)
+ s.snapshot() // create snaphot zero
+ return s
+}
+
+// reset clears the journal, after this operation the journal can be used
+// anew. It is semantically similar to calling 'newJournal', but the underlying
+// slices can be reused
+func (j *sparseJournal) reset() {
+ j.entries = j.entries[:0]
+ j.snapshot()
+}
+
+func (j *sparseJournal) copy() journal {
+ cp := &sparseJournal{
+ entries: make([]*scopedJournal, 0, len(j.entries)),
+ }
+ for _, entry := range j.entries {
+ cp.entries = append(cp.entries, entry.deepCopy())
+ }
+ return cp
+}
+
+// snapshot returns an identifier for the current revision of the state.
+// OBS: A call to Snapshot is _required_ in order to initialize the journalling,
+// invoking the journal-methods without having invoked Snapshot will lead to
+// panic.
+func (j *sparseJournal) snapshot() int {
+ id := len(j.entries)
+ j.entries = append(j.entries, newScopedJournal())
+ return id
+}
+
+// revertToSnapshot reverts all state changes made since the given revision.
+func (j *sparseJournal) revertToSnapshot(id int, s *StateDB) {
+ if id >= len(j.entries) {
+ panic(fmt.Errorf("revision id %v cannot be reverted", id))
+ }
+ // Revert the entries sequentially
+ for i := len(j.entries) - 1; i >= id; i-- {
+ entry := j.entries[i]
+ entry.revert(s)
+ }
+ j.entries = j.entries[:id]
+}
+
+func (j *sparseJournal) DiscardSnapshot(id int) {
+ if id == 0 {
+ return
+ }
+ // here we must merge the 'id' with it's parent.
+ want := len(j.entries) - 1
+ have := id
+ if want != have {
+ if want == 0 && id == 1 {
+ // If a transcation is applied successfully, the statedb.Finalize will
+ // end by clearing and resetting the journal. Invoking a DiscardSnapshot
+ // afterwards will lead us here.
+ // Let's not panic, but it's ok to complain a bit
+ log.Error("Extraneous invocation to discard snapshot")
+ return
+ } else {
+ panic(fmt.Sprintf("journalling error, want discard(%d), have discard(%d)", want, have))
+ }
+ }
+ entry := j.entries[id]
+ parent := j.entries[id-1]
+ entry.merge(parent)
+ j.entries = j.entries[:id]
+}
+
+func (j *sparseJournal) journalAccountChange(addr common.Address, account *types.StateAccount, destructed, newContract bool) {
+ j.entries[len(j.entries)-1].journalAccountChange(addr, account, destructed, newContract)
+}
+
+func (j *sparseJournal) nonceChange(addr common.Address, account *types.StateAccount, destructed, newContract bool) {
+ j.journalAccountChange(addr, account, destructed, newContract)
+}
+
+func (j *sparseJournal) balanceChange(addr common.Address, account *types.StateAccount, destructed, newContract bool) {
+ j.journalAccountChange(addr, account, destructed, newContract)
+}
+
+func (j *sparseJournal) setCode(addr common.Address, account *types.StateAccount) {
+ j.journalAccountChange(addr, account, false, true)
+}
+
+func (j *sparseJournal) createObject(addr common.Address) {
+ // Creating an account which is destructed, hence already exists, is not
+ // allowed, hence we know destructed == 'false'.
+ // Also, if we are creating the account now, it cannot yet be a
+ // newContract (that might come later)
+ j.journalAccountChange(addr, nil, false, false)
+}
+
+func (j *sparseJournal) createContract(addr common.Address, account *types.StateAccount) {
+ // Creating an account which is destructed, hence already exists, is not
+ // allowed, hence we know it to be 'false'.
+ // Also: if we create the contract now, it cannot be previously created
+ j.journalAccountChange(addr, account, false, false)
+}
+
+func (j *sparseJournal) destruct(addr common.Address, account *types.StateAccount) {
+ // destructing an already destructed account must not be journalled. Hence we
+ // know it to be 'false'.
+ // Also: if we're allowed to destruct it, it must be `newContract:true`, OR
+ // the concept of newContract is unused and moot.
+ j.journalAccountChange(addr, account, false, true)
+}
+
+// var ripemd = common.HexToAddress("0000000000000000000000000000000000000003")
+func (j *sparseJournal) touchChange(addr common.Address, account *types.StateAccount, destructed, newContract bool) {
+ j.journalAccountChange(addr, account, destructed, newContract)
+ if addr == ripemd {
+ // Explicitly put it in the dirty-cache one extra time. Ripe magic.
+ j.ripeMagic = true
+ }
+}
+
+func (j *sparseJournal) logChange(txHash common.Hash) {
+ j.entries[len(j.entries)-1].journalLog(txHash)
+}
+
+func (j *sparseJournal) refundChange(prev uint64) {
+ j.entries[len(j.entries)-1].journalRefundChange(prev)
+}
+
+func (j *sparseJournal) accessListAddAccount(addr common.Address) {
+ j.entries[len(j.entries)-1].journalAccessListAddAccount(addr)
+}
+
+func (j *sparseJournal) accessListAddSlot(addr common.Address, slot common.Hash) {
+ j.entries[len(j.entries)-1].journalAccessListAddSlot(addr, slot)
+}
+
+func (j *sparseJournal) storageChange(addr common.Address, key, prev, origin common.Hash) {
+ j.entries[len(j.entries)-1].journalSetState(addr, key, prev, origin)
+}
+
+func (j *sparseJournal) transientStateChange(addr common.Address, key, prev common.Hash) {
+ j.entries[len(j.entries)-1].journalSetTransientState(addr, key, prev)
+}
+
+func (j *sparseJournal) dirtyAccounts() []common.Address {
+ // The dirty-set should encompass all layers
+ var dirty = make(map[common.Address]any)
+ for _, scope := range j.entries {
+ scope.addDirtyAccounts(dirty)
+ }
+ if j.ripeMagic {
+ dirty[ripemd] = []interface{}{}
+ }
+ var dirtyList = make([]common.Address, 0, len(dirty))
+ for addr := range dirty {
+ dirtyList = append(dirtyList, addr)
+ }
+ return dirtyList
+}
diff --git a/core/state/state_object.go b/core/state/state_object.go
index 422badb19bc3..817a79c1e984 100644
--- a/core/state/state_object.go
+++ b/core/state/state_object.go
@@ -114,7 +114,7 @@ func (s *stateObject) markSelfdestructed() {
}
func (s *stateObject) touch() {
- s.db.journal.touchChange(s.address)
+ s.db.journal.touchChange(s.address, &s.data, s.selfDestructed, s.newContract)
}
// getTrie returns the associated storage trie. The trie will be opened if it's
@@ -470,7 +470,7 @@ func (s *stateObject) SubBalance(amount *uint256.Int, reason tracing.BalanceChan
}
func (s *stateObject) SetBalance(amount *uint256.Int, reason tracing.BalanceChangeReason) {
- s.db.journal.balanceChange(s.address, s.data.Balance)
+ s.db.journal.balanceChange(s.address, &s.data, s.selfDestructed, s.newContract)
if s.db.logger != nil && s.db.logger.OnBalanceChange != nil {
s.db.logger.OnBalanceChange(s.address, s.Balance().ToBig(), amount.ToBig(), reason)
}
@@ -546,7 +546,7 @@ func (s *stateObject) CodeSize() int {
}
func (s *stateObject) SetCode(codeHash common.Hash, code []byte) {
- s.db.journal.setCode(s.address)
+ s.db.journal.setCode(s.address, &s.data)
if s.db.logger != nil && s.db.logger.OnCodeChange != nil {
// TODO remove prevcode from this callback
s.db.logger.OnCodeChange(s.address, common.BytesToHash(s.CodeHash()), nil, codeHash, code)
@@ -561,7 +561,7 @@ func (s *stateObject) setCode(codeHash common.Hash, code []byte) {
}
func (s *stateObject) SetNonce(nonce uint64) {
- s.db.journal.nonceChange(s.address, s.data.Nonce)
+ s.db.journal.nonceChange(s.address, &s.data, s.selfDestructed, s.newContract)
if s.db.logger != nil && s.db.logger.OnNonceChange != nil {
s.db.logger.OnNonceChange(s.address, s.data.Nonce, nonce)
}
diff --git a/core/state/statedb.go b/core/state/statedb.go
index b2b4f8fb97b1..acf40d3d7358 100644
--- a/core/state/statedb.go
+++ b/core/state/statedb.go
@@ -136,7 +136,7 @@ type StateDB struct {
// Journal of state modifications. This is the backbone of
// Snapshot and RevertToSnapshot.
- journal *journal
+ journal journal
// State witness if cross validation is needed
witness *stateless.Witness
@@ -180,7 +180,7 @@ func New(root common.Hash, db Database) (*StateDB, error) {
mutations: make(map[common.Address]*mutation),
logs: make(map[common.Hash][]*types.Log),
preimages: make(map[common.Hash][]byte),
- journal: newJournal(),
+ journal: newSparseJournal(),
accessList: newAccessList(),
transientStorage: newTransientStorage(),
}
@@ -500,7 +500,7 @@ func (s *StateDB) SelfDestruct(addr common.Address) {
// If it is already marked as self-destructed, we do not need to add it
// for journalling a second time.
if !stateObject.selfDestructed {
- s.journal.destruct(addr)
+ s.journal.destruct(addr, &stateObject.data)
stateObject.markSelfdestructed()
}
}
@@ -638,7 +638,7 @@ func (s *StateDB) CreateContract(addr common.Address) {
obj := s.getStateObject(addr)
if !obj.newContract {
obj.newContract = true
- s.journal.createContract(addr)
+ s.journal.createContract(addr, &obj.data)
}
}
@@ -707,6 +707,13 @@ func (s *StateDB) Snapshot() int {
return s.journal.snapshot()
}
+// DiscardSnapshot removes the snapshot with the given id; after calling this
+// method, it is no longer possible to revert to that particular snapshot, the
+// changes are considered part of the parent scope.
+func (s *StateDB) DiscardSnapshot(id int) {
+ s.journal.DiscardSnapshot(id)
+}
+
// RevertToSnapshot reverts all state changes made since the given revision.
func (s *StateDB) RevertToSnapshot(revid int) {
s.journal.revertToSnapshot(revid, s)
@@ -721,8 +728,9 @@ func (s *StateDB) GetRefund() uint64 {
// the journal as well as the refunds. Finalise, however, will not push any updates
// into the tries just yet. Only IntermediateRoot or Commit will do that.
func (s *StateDB) Finalise(deleteEmptyObjects bool) {
- addressesToPrefetch := make([][]byte, 0, len(s.journal.dirties))
- for addr := range s.journal.dirties {
+ dirties := s.journal.dirtyAccounts()
+ addressesToPrefetch := make([][]byte, 0, len(dirties))
+ for _, addr := range dirties {
obj, exist := s.stateObjects[addr]
if !exist {
// ripeMD is 'touched' at block 1714175, in tx 0x1237f737031e40bcde4a8b7e717b2d15e3ecadfe49bb1bbc71ee9deb09c6fcf2
diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go
index 9441834c6a24..e29221c3bd96 100644
--- a/core/state/statedb_test.go
+++ b/core/state/statedb_test.go
@@ -55,7 +55,7 @@ func TestUpdateLeaks(t *testing.T) {
sdb = NewDatabase(tdb, nil)
)
state, _ := New(types.EmptyRootHash, sdb)
-
+ state.Snapshot()
// Update it with some accounts
for i := byte(0); i < 255; i++ {
addr := common.BytesToAddress([]byte{i})
@@ -111,7 +111,7 @@ func TestIntermediateLeaks(t *testing.T) {
}
// Write modifications to trie.
transState.IntermediateRoot(false)
-
+ transState.journal.snapshot()
// Overwrite all the data with new values in the transient database.
for i := byte(0); i < 255; i++ {
modify(transState, common.Address{i}, i, 99)
@@ -228,7 +228,7 @@ func TestCopy(t *testing.T) {
}
// TestCopyWithDirtyJournal tests if Copy can correct create a equal copied
-// stateDB with dirty journal present.
+// stateDB with dirty linearJournal present.
func TestCopyWithDirtyJournal(t *testing.T) {
db := NewDatabaseForTesting()
orig, _ := New(types.EmptyRootHash, db)
@@ -362,6 +362,12 @@ func newTestAction(addr common.Address, r *rand.Rand) testAction {
{
name: "SetStorage",
fn: func(a testAction, s *StateDB) {
+ contractHash := s.GetCodeHash(addr)
+ emptyCode := contractHash == (common.Hash{}) || contractHash == types.EmptyCodeHash
+ if emptyCode {
+ // no-op
+ return
+ }
var key, val common.Hash
binary.BigEndian.PutUint16(key[:], uint16(a.args[0]))
binary.BigEndian.PutUint16(val[:], uint16(a.args[1]))
@@ -372,12 +378,26 @@ func newTestAction(addr common.Address, r *rand.Rand) testAction {
{
name: "SetCode",
fn: func(a testAction, s *StateDB) {
- // SetCode can only be performed in case the addr does
- // not already hold code
+ // SetCode cannot be performed if the addr already has code
if c := s.GetCode(addr); len(c) > 0 {
// no-op
return
}
+ // SetCode cannot be performed if the addr has just selfdestructed
+ if obj := s.getStateObject(addr); obj != nil {
+ if obj.selfDestructed {
+ // If it's selfdestructed, we cannot create into it
+ return
+ }
+ }
+ // SetCode requires the contract to be account + contract to be created first
+ if obj := s.getStateObject(addr); obj == nil {
+ s.createObject(addr)
+ }
+ obj := s.getStateObject(addr)
+ if !obj.newContract {
+ s.CreateContract(addr)
+ }
code := make([]byte, 16)
binary.BigEndian.PutUint64(code, uint64(a.args[0]))
binary.BigEndian.PutUint64(code[8:], uint64(a.args[1]))
@@ -403,13 +423,20 @@ func newTestAction(addr common.Address, r *rand.Rand) testAction {
emptyCode := contractHash == (common.Hash{}) || contractHash == types.EmptyCodeHash
storageRoot := s.GetStorageRoot(addr)
emptyStorage := storageRoot == (common.Hash{}) || storageRoot == types.EmptyRootHash
+
+ if obj := s.getStateObject(addr); obj != nil {
+ if obj.selfDestructed {
+ // If it's selfdestructed, we cannot create into it
+ return
+ }
+ }
if s.GetNonce(addr) == 0 && emptyCode && emptyStorage {
s.CreateContract(addr)
// We also set some code here, to prevent the
// CreateContract action from being performed twice in a row,
// which would cause a difference in state when unrolling
- // the journal. (CreateContact assumes created was false prior to
- // invocation, and the journal rollback sets it to false).
+ // the linearJournal. (CreateContact assumes created was false prior to
+ // invocation, and the linearJournal rollback sets it to false).
s.SetCode(addr, []byte{1})
}
},
@@ -417,6 +444,15 @@ func newTestAction(addr common.Address, r *rand.Rand) testAction {
{
name: "SelfDestruct",
fn: func(a testAction, s *StateDB) {
+ obj := s.getStateObject(addr)
+ // SelfDestruct requires the object to first exist
+ if obj == nil {
+ s.createObject(addr)
+ }
+ obj = s.getStateObject(addr)
+ if !obj.newContract {
+ s.CreateContract(addr)
+ }
s.SelfDestruct(addr)
},
},
@@ -437,15 +473,6 @@ func newTestAction(addr common.Address, r *rand.Rand) testAction {
},
args: make([]int64, 1),
},
- {
- name: "AddPreimage",
- fn: func(a testAction, s *StateDB) {
- preimage := []byte{1}
- hash := common.BytesToHash(preimage)
- s.AddPreimage(hash, preimage)
- },
- args: make([]int64, 1),
- },
{
name: "AddAddressToAccessList",
fn: func(a testAction, s *StateDB) {
@@ -463,6 +490,13 @@ func newTestAction(addr common.Address, r *rand.Rand) testAction {
{
name: "SetTransientState",
fn: func(a testAction, s *StateDB) {
+ contractHash := s.GetCodeHash(addr)
+ emptyCode := contractHash == (common.Hash{}) || contractHash == types.EmptyCodeHash
+ if emptyCode {
+ // no-op
+ return
+ }
+
var key, val common.Hash
binary.BigEndian.PutUint16(key[:], uint16(a.args[0]))
binary.BigEndian.PutUint16(val[:], uint16(a.args[1]))
@@ -675,22 +709,23 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v",
state.GetLogs(common.Hash{}, 0, common.Hash{}), checkstate.GetLogs(common.Hash{}, 0, common.Hash{}))
}
- if !maps.Equal(state.journal.dirties, checkstate.journal.dirties) {
- getKeys := func(dirty map[common.Address]int) string {
- var keys []common.Address
- out := new(strings.Builder)
- for key := range dirty {
- keys = append(keys, key)
- }
- slices.SortFunc(keys, common.Address.Cmp)
- for i, key := range keys {
- fmt.Fprintf(out, " %d. %v\n", i, key)
+ { // Check the dirty-accounts
+ have := state.journal.dirtyAccounts()
+ want := checkstate.journal.dirtyAccounts()
+ slices.SortFunc(have, common.Address.Cmp)
+ slices.SortFunc(want, common.Address.Cmp)
+ if !slices.Equal(have, want) {
+ getKeys := func(keys []common.Address) string {
+ out := new(strings.Builder)
+ for i, key := range keys {
+ fmt.Fprintf(out, " %d. %v\n", i, key)
+ }
+ return out.String()
}
- return out.String()
+ haveK := getKeys(have)
+ wantK := getKeys(want)
+ return fmt.Errorf("dirty-journal set mismatch.\nhave:\n%v\nwant:\n%v\n", haveK, wantK)
}
- have := getKeys(state.journal.dirties)
- want := getKeys(checkstate.journal.dirties)
- return fmt.Errorf("dirty-journal set mismatch.\nhave:\n%v\nwant:\n%v\n", have, want)
}
return nil
}
@@ -704,11 +739,11 @@ func TestTouchDelete(t *testing.T) {
snapshot := s.state.Snapshot()
s.state.AddBalance(common.Address{}, new(uint256.Int), tracing.BalanceChangeUnspecified)
- if len(s.state.journal.dirties) != 1 {
+ if len(s.state.journal.dirtyAccounts()) != 1 {
t.Fatal("expected one dirty state object")
}
s.state.RevertToSnapshot(snapshot)
- if len(s.state.journal.dirties) != 0 {
+ if len(s.state.journal.dirtyAccounts()) != 0 {
t.Fatal("expected no dirty state object")
}
}
@@ -1097,33 +1132,44 @@ func TestStateDBAccessList(t *testing.T) {
}
}
+ var ids []int
+ push := func(id int) {
+ ids = append(ids, id)
+ }
+ pop := func() int {
+ id := ids[len(ids)-1]
+ ids = ids[:len(ids)-1]
+ return id
+ }
+
+ push(state.journal.snapshot()) // journal id 0
state.AddAddressToAccessList(addr("aa")) // 1
- state.AddSlotToAccessList(addr("bb"), slot("01")) // 2,3
+ push(state.journal.snapshot()) // journal id 1
+ state.AddAddressToAccessList(addr("bb")) // 2
+ push(state.journal.snapshot()) // journal id 2
+ state.AddSlotToAccessList(addr("bb"), slot("01")) // 3
+ push(state.journal.snapshot()) // journal id 3
state.AddSlotToAccessList(addr("bb"), slot("02")) // 4
+ push(state.journal.snapshot()) // journal id 4
verifyAddrs("aa", "bb")
verifySlots("bb", "01", "02")
// Make a copy
stateCopy1 := state.Copy()
- if exp, got := 4, state.journal.length(); exp != got {
- t.Fatalf("journal length mismatch: have %d, want %d", got, exp)
- }
- // same again, should cause no journal entries
+ // same again, should cause no linearJournal entries
state.AddSlotToAccessList(addr("bb"), slot("01"))
state.AddSlotToAccessList(addr("bb"), slot("02"))
state.AddAddressToAccessList(addr("aa"))
- if exp, got := 4, state.journal.length(); exp != got {
- t.Fatalf("journal length mismatch: have %d, want %d", got, exp)
- }
+
// some new ones
state.AddSlotToAccessList(addr("bb"), slot("03")) // 5
+ push(state.journal.snapshot()) // journal id 5
state.AddSlotToAccessList(addr("aa"), slot("01")) // 6
- state.AddSlotToAccessList(addr("cc"), slot("01")) // 7,8
- state.AddAddressToAccessList(addr("cc"))
- if exp, got := 8, state.journal.length(); exp != got {
- t.Fatalf("journal length mismatch: have %d, want %d", got, exp)
- }
+ push(state.journal.snapshot()) // journal id 6
+ state.AddAddressToAccessList(addr("cc")) // 7
+ push(state.journal.snapshot()) // journal id 7
+ state.AddSlotToAccessList(addr("cc"), slot("01")) // 8
verifyAddrs("aa", "bb", "cc")
verifySlots("aa", "01")
@@ -1131,7 +1177,7 @@ func TestStateDBAccessList(t *testing.T) {
verifySlots("cc", "01")
// now start rolling back changes
- state.journal.revert(state, 7)
+ state.journal.revertToSnapshot(pop(), state) // revert to 6
if _, ok := state.SlotInAccessList(addr("cc"), slot("01")); ok {
t.Fatalf("slot present, expected missing")
}
@@ -1139,7 +1185,7 @@ func TestStateDBAccessList(t *testing.T) {
verifySlots("aa", "01")
verifySlots("bb", "01", "02", "03")
- state.journal.revert(state, 6)
+ state.journal.revertToSnapshot(pop(), state) // revert to 5
if state.AddressInAccessList(addr("cc")) {
t.Fatalf("addr present, expected missing")
}
@@ -1147,40 +1193,40 @@ func TestStateDBAccessList(t *testing.T) {
verifySlots("aa", "01")
verifySlots("bb", "01", "02", "03")
- state.journal.revert(state, 5)
+ state.journal.revertToSnapshot(pop(), state) // revert to 4
if _, ok := state.SlotInAccessList(addr("aa"), slot("01")); ok {
t.Fatalf("slot present, expected missing")
}
verifyAddrs("aa", "bb")
verifySlots("bb", "01", "02", "03")
- state.journal.revert(state, 4)
+ state.journal.revertToSnapshot(pop(), state) // revert to 3
if _, ok := state.SlotInAccessList(addr("bb"), slot("03")); ok {
t.Fatalf("slot present, expected missing")
}
verifyAddrs("aa", "bb")
verifySlots("bb", "01", "02")
- state.journal.revert(state, 3)
+ state.journal.revertToSnapshot(pop(), state) // revert to 2
if _, ok := state.SlotInAccessList(addr("bb"), slot("02")); ok {
t.Fatalf("slot present, expected missing")
}
verifyAddrs("aa", "bb")
verifySlots("bb", "01")
- state.journal.revert(state, 2)
+ state.journal.revertToSnapshot(pop(), state) // revert to 1
if _, ok := state.SlotInAccessList(addr("bb"), slot("01")); ok {
t.Fatalf("slot present, expected missing")
}
verifyAddrs("aa", "bb")
- state.journal.revert(state, 1)
+ state.journal.revertToSnapshot(pop(), state) // revert to 0
if state.AddressInAccessList(addr("bb")) {
t.Fatalf("addr present, expected missing")
}
verifyAddrs("aa")
- state.journal.revert(state, 0)
+ state.journal.revertToSnapshot(0, state)
if state.AddressInAccessList(addr("aa")) {
t.Fatalf("addr present, expected missing")
}
@@ -1251,11 +1297,9 @@ func TestStateDBTransientStorage(t *testing.T) {
key := common.Hash{0x01}
value := common.Hash{0x02}
addr := common.Address{}
-
+ revision := state.journal.snapshot()
state.SetTransientState(addr, key, value)
- if exp, got := 1, state.journal.length(); exp != got {
- t.Fatalf("journal length mismatch: have %d, want %d", got, exp)
- }
+
// the retrieved value should equal what was set
if got := state.GetTransientState(addr, key); got != value {
t.Fatalf("transient storage mismatch: have %x, want %x", got, value)
@@ -1263,7 +1307,7 @@ func TestStateDBTransientStorage(t *testing.T) {
// revert the transient state being set and then check that the
// value is now the empty hash
- state.journal.revert(state, 0)
+ state.journal.revertToSnapshot(revision, state)
if got, exp := state.GetTransientState(addr, key), (common.Hash{}); exp != got {
t.Fatalf("transient storage mismatch: have %x, want %x", got, exp)
}
@@ -1375,3 +1419,53 @@ func TestStorageDirtiness(t *testing.T) {
state.RevertToSnapshot(snap)
checkDirty(common.Hash{0x1}, common.Hash{0x1}, true)
}
+
+func TestStorageDirtiness2(t *testing.T) {
+ var (
+ disk = rawdb.NewMemoryDatabase()
+ tdb = triedb.NewDatabase(disk, nil)
+ db = NewDatabase(tdb, nil)
+ state, _ = New(types.EmptyRootHash, db)
+ addr = common.HexToAddress("0x1")
+ checkDirty = func(key common.Hash, value common.Hash, dirty bool) {
+ t.Helper()
+ obj := state.getStateObject(addr)
+ v, exist := obj.dirtyStorage[key]
+ if exist != dirty {
+ t.Fatalf("unexpected dirty marker, want: %v, have: %v", dirty, exist)
+ }
+ if !exist {
+ return
+ }
+ if v != value {
+ t.Fatalf("unexpected storage slot, want: %x, have: %x", value, v)
+ }
+ }
+ )
+
+ { // Initiate a state, where an account has SLOT(1) = 0xA, +nonzero balance
+ state.CreateAccount(addr)
+ state.SetBalance(addr, uint256.NewInt(1), tracing.BalanceChangeUnspecified) // Prevent empty-delete
+ state.SetState(addr, common.Hash{0x1}, common.Hash{0xa})
+ root, err := state.Commit(0, true)
+ if err != nil {
+ t.Fatal(err)
+ }
+ // Init phase done, load it again
+ if state, err = New(root, NewDatabase(tdb, nil)); err != nil {
+ t.Fatal(err)
+ }
+ }
+ // A no-op storage change, no dirty marker
+ state.SetState(addr, common.Hash{0x1}, common.Hash{0xa})
+ checkDirty(common.Hash{0x1}, common.Hash{0xa}, false)
+
+ // Enter new scope
+ snap := state.Snapshot()
+ state.SetState(addr, common.Hash{0x1}, common.Hash{0xb}) // SLOT(1) = 0xB
+ checkDirty(common.Hash{0x1}, common.Hash{0xb}, true) // Should be flagged dirty
+ state.RevertToSnapshot(snap) // Revert scope
+
+ // the storage change has been set back to original, dirtiness should be revoked
+ checkDirty(common.Hash{0x1}, common.Hash{0x1}, false)
+}
diff --git a/core/vm/evm.go b/core/vm/evm.go
index 616668d565cc..ed9de46bca6d 100644
--- a/core/vm/evm.go
+++ b/core/vm/evm.go
@@ -201,6 +201,7 @@ func (evm *EVM) Call(caller ContractRef, addr common.Address, input []byte, gas
if !isPrecompile && evm.chainRules.IsEIP158 && value.IsZero() {
// Calling a non-existing account, don't do anything.
+ evm.StateDB.DiscardSnapshot(snapshot)
return nil, gas, nil
}
evm.StateDB.CreateAccount(addr)
@@ -240,9 +241,8 @@ func (evm *EVM) Call(caller ContractRef, addr common.Address, input []byte, gas
gas = 0
}
- // TODO: consider clearing up unused snapshots:
- //} else {
- // evm.StateDB.DiscardSnapshot(snapshot)
+ } else {
+ evm.StateDB.DiscardSnapshot(snapshot)
}
return ret, gas, err
}
@@ -299,6 +299,8 @@ func (evm *EVM) CallCode(caller ContractRef, addr common.Address, input []byte,
gas = 0
}
+ } else {
+ evm.StateDB.DiscardSnapshot(snapshot)
}
return ret, gas, err
}
@@ -348,6 +350,8 @@ func (evm *EVM) DelegateCall(caller ContractRef, addr common.Address, input []by
}
gas = 0
}
+ } else {
+ evm.StateDB.DiscardSnapshot(snapshot)
}
return ret, gas, err
}
@@ -410,6 +414,8 @@ func (evm *EVM) StaticCall(caller ContractRef, addr common.Address, input []byte
gas = 0
}
+ } else {
+ evm.StateDB.DiscardSnapshot(snapshot)
}
return ret, gas, err
}
@@ -521,6 +527,8 @@ func (evm *EVM) create(caller ContractRef, codeAndHash *codeAndHash, gas uint64,
if err != ErrExecutionReverted {
contract.UseGas(contract.Gas, evm.Config.Tracer, tracing.GasChangeCallFailedExecution)
}
+ } else {
+ evm.StateDB.DiscardSnapshot(snapshot)
}
return ret, address, contract.Gas, err
}
diff --git a/core/vm/interface.go b/core/vm/interface.go
index 5f426435650d..58e2075ae55a 100644
--- a/core/vm/interface.go
+++ b/core/vm/interface.go
@@ -83,7 +83,14 @@ type StateDB interface {
Prepare(rules params.Rules, sender, coinbase common.Address, dest *common.Address, precompiles []common.Address, txAccesses types.AccessList)
+ // RevertToSnapshot reverts all state changes made since the given revision.
RevertToSnapshot(int)
+
+ // DiscardSnapshot removes the snapshot with the given id; after calling this
+ // method, it is no longer possible to revert to that particular snapshot, the
+ // changes are considered part of the parent scope.
+ DiscardSnapshot(int)
+ // Snapshot returns an identifier for the current scope of the state.
Snapshot() int
AddLog(*types.Log)
diff --git a/tests/state_test_util.go b/tests/state_test_util.go
index cf0ce9777f67..b562a197956e 100644
--- a/tests/state_test_util.go
+++ b/tests/state_test_util.go
@@ -308,6 +308,8 @@ func (t *StateTest) RunNoVerify(subtest StateSubtest, vmconfig vm.Config, snapsh
if tracer := evm.Config.Tracer; tracer != nil && tracer.OnTxEnd != nil {
evm.Config.Tracer.OnTxEnd(nil, err)
}
+ } else {
+ st.StateDB.DiscardSnapshot(snapshot)
}
// Add 0-value mining reward. This only makes a difference in the cases
// where