From 1659eed6ec77f27b857b1b33e009397cc163883e Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 7 Apr 2025 10:47:14 +0200 Subject: [PATCH 01/41] macaroons: remove context.TODO() in tests We want `context.TODO()` to be high signal in the code-base. It should signal clearly that work is required to thread parent context through to the call-site. So to keep the signal-to-noise ratio high, we remove any context.TODO() calls from tests since these will never need to be replace by a parent context. --- macaroons/service_test.go | 15 +++++++++------ macaroons/store_test.go | 13 ++++++++----- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/macaroons/service_test.go b/macaroons/service_test.go index b7694ab3e5..40d164dd43 100644 --- a/macaroons/service_test.go +++ b/macaroons/service_test.go @@ -55,6 +55,7 @@ func setupTestRootKeyStorage(t *testing.T) kvdb.Backend { // TestNewService tests the creation of the macaroon service. func TestNewService(t *testing.T) { t.Parallel() + ctx := context.Background() // First, initialize a dummy DB file with a store that the service // can read from. Make sure the file is removed in the end. @@ -74,13 +75,13 @@ func TestNewService(t *testing.T) { require.NoError(t, err, "Error unlocking root key storage") // Third, check if the created service can bake macaroons. - _, err = service.NewMacaroon(context.TODO(), nil, testOperation) + _, err = service.NewMacaroon(ctx, nil, testOperation) if err != macaroons.ErrMissingRootKeyID { t.Fatalf("Received %v instead of ErrMissingRootKeyID", err) } macaroon, err := service.NewMacaroon( - context.TODO(), macaroons.DefaultRootKeyID, testOperation, + ctx, macaroons.DefaultRootKeyID, testOperation, ) require.NoError(t, err, "Error creating macaroon from service") if macaroon.Namespace().String() != "std:" { @@ -108,6 +109,7 @@ func TestNewService(t *testing.T) { // incoming context. func TestValidateMacaroon(t *testing.T) { t.Parallel() + ctx := context.Background() // First, initialize the service and unlock it. db := setupTestRootKeyStorage(t) @@ -124,7 +126,7 @@ func TestValidateMacaroon(t *testing.T) { // Then, create a new macaroon that we can serialize. macaroon, err := service.NewMacaroon( - context.TODO(), macaroons.DefaultRootKeyID, testOperation, + ctx, macaroons.DefaultRootKeyID, testOperation, testOperationURI, ) require.NoError(t, err, "Error creating macaroon from service") @@ -136,7 +138,7 @@ func TestValidateMacaroon(t *testing.T) { md := metadata.New(map[string]string{ "macaroon": hex.EncodeToString(macaroonBinary), }) - mockContext := metadata.NewIncomingContext(context.Background(), md) + mockContext := metadata.NewIncomingContext(ctx, md) // Finally, validate the macaroon against the required permissions. err = service.ValidateMacaroon( @@ -155,6 +157,7 @@ func TestValidateMacaroon(t *testing.T) { // TestListMacaroonIDs checks that ListMacaroonIDs returns the expected result. func TestListMacaroonIDs(t *testing.T) { t.Parallel() + ctx := context.Background() // First, initialize a dummy DB file with a store that the service // can read from. Make sure the file is removed in the end. @@ -176,12 +179,12 @@ func TestListMacaroonIDs(t *testing.T) { // Third, make 3 new macaroons with different root key IDs. expectedIDs := [][]byte{{1}, {2}, {3}} for _, v := range expectedIDs { - _, err := service.NewMacaroon(context.TODO(), v, testOperation) + _, err := service.NewMacaroon(ctx, v, testOperation) require.NoError(t, err, "Error creating macaroon from service") } // Finally, check that calling List return the expected values. - ids, _ := service.ListMacaroonIDs(context.TODO()) + ids, _ := service.ListMacaroonIDs(ctx) require.Equal(t, expectedIDs, ids, "root key IDs mismatch") } diff --git a/macaroons/store_test.go b/macaroons/store_test.go index 37dce71100..7a25621408 100644 --- a/macaroons/store_test.go +++ b/macaroons/store_test.go @@ -58,12 +58,15 @@ func openTestStore(t *testing.T, tempDir string) *macaroons.RootKeyStorage { // TestStore tests the normal use cases of the store like creating, unlocking, // reading keys and closing it. func TestStore(t *testing.T) { + t.Parallel() + ctx := context.Background() + tempDir, store := newTestStore(t) - _, _, err := store.RootKey(context.TODO()) + _, _, err := store.RootKey(ctx) require.Equal(t, macaroons.ErrStoreLocked, err) - _, err = store.Get(context.TODO(), nil) + _, err = store.Get(ctx, nil) require.Equal(t, macaroons.ErrStoreLocked, err) pw := []byte("weks") @@ -72,18 +75,18 @@ func TestStore(t *testing.T) { // Check ErrContextRootKeyID is returned when no root key ID found in // context. - _, _, err = store.RootKey(context.TODO()) + _, _, err = store.RootKey(ctx) require.Equal(t, macaroons.ErrContextRootKeyID, err) // Check ErrMissingRootKeyID is returned when empty root key ID is used. emptyKeyID := make([]byte, 0) - badCtx := macaroons.ContextWithRootKeyID(context.TODO(), emptyKeyID) + badCtx := macaroons.ContextWithRootKeyID(ctx, emptyKeyID) _, _, err = store.RootKey(badCtx) require.Equal(t, macaroons.ErrMissingRootKeyID, err) // Create a context with illegal root key ID value. encryptedKeyID := []byte("enckey") - badCtx = macaroons.ContextWithRootKeyID(context.TODO(), encryptedKeyID) + badCtx = macaroons.ContextWithRootKeyID(ctx, encryptedKeyID) _, _, err = store.RootKey(badCtx) require.Equal(t, macaroons.ErrKeyValueForbidden, err) From d52f7299619bb9c867b4d019a6a8e4c56b03f623 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 7 Apr 2025 10:49:41 +0200 Subject: [PATCH 02/41] kvdb/etcd: remove context.TODO() from test helpers We want `context.TODO()` to be high signal in the code-base. It should signal clearly that work is required to thread parent context through to the call-site. So to keep the signal-to-noise ratio high, we remove any context.TODO() calls from tests since these will never need to be replace by a parent context. After this commit, there is only a single context.TODO() left in the code-base. --- kvdb/backend.go | 2 +- kvdb/etcd/db_test.go | 2 +- kvdb/etcd/fixture.go | 14 ++++++++++---- kvdb/etcd/readwrite_tx_test.go | 4 ++-- kvdb/etcd/walletdb_interface_test.go | 2 +- 5 files changed, 15 insertions(+), 9 deletions(-) diff --git a/kvdb/backend.go b/kvdb/backend.go index 269d5d8215..3354cc8210 100644 --- a/kvdb/backend.go +++ b/kvdb/backend.go @@ -274,7 +274,7 @@ func GetTestBackend(path, name string) (Backend, func(), error) { return nil, empty, err } backend, err := Open( - EtcdBackendName, context.TODO(), etcdConfig, + EtcdBackendName, context.Background(), etcdConfig, ) return backend, cancel, err diff --git a/kvdb/etcd/db_test.go b/kvdb/etcd/db_test.go index 59e29fc94a..9ef68d9fe5 100644 --- a/kvdb/etcd/db_test.go +++ b/kvdb/etcd/db_test.go @@ -19,7 +19,7 @@ func TestDump(t *testing.T) { f := NewEtcdTestFixture(t) - db, err := newEtcdBackend(context.TODO(), f.BackendConfig()) + db, err := newEtcdBackend(context.Background(), f.BackendConfig()) require.NoError(t, err) err = db.Update(func(tx walletdb.ReadWriteTx) error { diff --git a/kvdb/etcd/fixture.go b/kvdb/etcd/fixture.go index b7a697fad4..c19047b39c 100644 --- a/kvdb/etcd/fixture.go +++ b/kvdb/etcd/fixture.go @@ -76,7 +76,7 @@ func (f *EtcdTestFixture) NewBackend(singleWriter bool) walletdb.DB { cfg.SingleWriter = true } - db, err := newEtcdBackend(context.TODO(), cfg) + db, err := newEtcdBackend(context.Background(), cfg) require.NoError(f.t, err) return db @@ -84,7 +84,9 @@ func (f *EtcdTestFixture) NewBackend(singleWriter bool) walletdb.DB { // Put puts a string key/value into the test etcd database. func (f *EtcdTestFixture) Put(key, value string) { - ctx, cancel := context.WithTimeout(context.TODO(), testEtcdTimeout) + ctx, cancel := context.WithTimeout( + context.Background(), testEtcdTimeout, + ) defer cancel() _, err := f.cli.Put(ctx, key, value) @@ -95,7 +97,9 @@ func (f *EtcdTestFixture) Put(key, value string) { // Get queries a key and returns the stored value from the test etcd database. func (f *EtcdTestFixture) Get(key string) string { - ctx, cancel := context.WithTimeout(context.TODO(), testEtcdTimeout) + ctx, cancel := context.WithTimeout( + context.Background(), testEtcdTimeout, + ) defer cancel() resp, err := f.cli.Get(ctx, key) @@ -112,7 +116,9 @@ func (f *EtcdTestFixture) Get(key string) string { // Dump scans and returns all key/values from the test etcd database. func (f *EtcdTestFixture) Dump() map[string]string { - ctx, cancel := context.WithTimeout(context.TODO(), testEtcdTimeout) + ctx, cancel := context.WithTimeout( + context.Background(), testEtcdTimeout, + ) defer cancel() resp, err := f.cli.Get(ctx, "\x00", clientv3.WithFromKey()) diff --git a/kvdb/etcd/readwrite_tx_test.go b/kvdb/etcd/readwrite_tx_test.go index b5758f7a3f..d66b2e2512 100644 --- a/kvdb/etcd/readwrite_tx_test.go +++ b/kvdb/etcd/readwrite_tx_test.go @@ -16,7 +16,7 @@ func TestChangeDuringManualTx(t *testing.T) { f := NewEtcdTestFixture(t) - db, err := newEtcdBackend(context.TODO(), f.BackendConfig()) + db, err := newEtcdBackend(context.Background(), f.BackendConfig()) require.NoError(t, err) tx, err := db.BeginReadWriteTx() @@ -44,7 +44,7 @@ func TestChangeDuringUpdate(t *testing.T) { f := NewEtcdTestFixture(t) - db, err := newEtcdBackend(context.TODO(), f.BackendConfig()) + db, err := newEtcdBackend(context.Background(), f.BackendConfig()) require.NoError(t, err) count := 0 diff --git a/kvdb/etcd/walletdb_interface_test.go b/kvdb/etcd/walletdb_interface_test.go index 13c57e337d..483becbb2f 100644 --- a/kvdb/etcd/walletdb_interface_test.go +++ b/kvdb/etcd/walletdb_interface_test.go @@ -15,5 +15,5 @@ import ( func TestWalletDBInterface(t *testing.T) { f := NewEtcdTestFixture(t) cfg := f.BackendConfig() - walletdbtest.TestInterface(t, dbType, context.TODO(), &cfg) + walletdbtest.TestInterface(t, dbType, context.Background(), &cfg) } From 62db6e2a98eeb2508e831c23a3677f2832314321 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 7 Apr 2025 09:34:26 +0200 Subject: [PATCH 03/41] lnd: pass context to `newServer` and `server.Start` In preparation for starting to thread a single parent context through LND, we update the main `server.Start` method to take a context so that it can later pass it to any subsytem's Start method it calls. We also pass the context to `newServer` since it makes some calls that will eventually reach the DB (for example the graph db). --- lnd.go | 4 ++-- server.go | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/lnd.go b/lnd.go index e63ddc1965..3afa8c2fba 100644 --- a/lnd.go +++ b/lnd.go @@ -626,7 +626,7 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, // Set up the core server which will listen for incoming peer // connections. server, err := newServer( - cfg, cfg.Listeners, dbs, activeChainControl, &idKeyDesc, + ctx, cfg, cfg.Listeners, dbs, activeChainControl, &idKeyDesc, activeChainControl.Cfg.WalletUnlockParams.ChansToRestore, multiAcceptor, torController, tlsManager, leaderElector, implCfg, @@ -758,7 +758,7 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, // case the startup of the subservers do not behave as expected. errChan := make(chan error) go func() { - errChan <- server.Start() + errChan <- server.Start(ctx) }() defer func() { diff --git a/server.go b/server.go index e816c3ca4c..911c57e427 100644 --- a/server.go +++ b/server.go @@ -554,7 +554,9 @@ func noiseDial(idKey keychain.SingleKeyECDH, // newServer creates a new instance of the server which is to listen using the // passed listener address. -func newServer(cfg *Config, listenAddrs []net.Addr, +// +//nolint:funlen +func newServer(_ context.Context, cfg *Config, listenAddrs []net.Addr, dbs *DatabaseInstances, cc *chainreg.ChainControl, nodeKeyDesc *keychain.KeyDescriptor, chansToRestore walletunlocker.ChannelsToRecover, @@ -2217,7 +2219,7 @@ func (s *server) startLowLevelServices() error { // NOTE: This function is safe for concurrent access. // //nolint:funlen -func (s *server) Start() error { +func (s *server) Start(_ context.Context) error { // Get the current blockbeat. beat, err := s.getStartingBeat() if err != nil { From 9f6740e638b7cad5459e04be3cc8c60ae25d4077 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 7 Apr 2025 09:54:20 +0200 Subject: [PATCH 04/41] discovery: thread context through to gossiper Pass the parent LND context to the gossiper, let it derive a child context that gets cancelled on Stop. Pass the context through to any methods that will eventually thread it through to any graph DB calls. One `context.TODO()` is added here - this will be removed in the next commit. NOTE: for any internal methods that the context gets passed to, if those methods already listen on the gossiper's `quit` channel, then then don't need to also listen on the passed context's Done() channel because the quit channel is closed at the same time that the context is cancelled. --- discovery/gossiper.go | 130 ++++++++++++++++++++++--------------- discovery/gossiper_test.go | 9 ++- server.go | 4 +- 3 files changed, 85 insertions(+), 58 deletions(-) diff --git a/discovery/gossiper.go b/discovery/gossiper.go index cdeaa29426..ee74a59381 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -2,6 +2,7 @@ package discovery import ( "bytes" + "context" "errors" "fmt" "strings" @@ -475,9 +476,6 @@ type AuthenticatedGossiper struct { // held. bestHeight uint32 - quit chan struct{} - wg sync.WaitGroup - // cfg is a copy of the configuration struct that the gossiper service // was initialized with. cfg *Config @@ -555,6 +553,10 @@ type AuthenticatedGossiper struct { vb *ValidationBarrier sync.Mutex + + cancel fn.Option[context.CancelFunc] + quit chan struct{} + wg sync.WaitGroup } // New creates a new AuthenticatedGossiper instance, initialized with the @@ -600,7 +602,11 @@ func New(cfg Config, selfKeyDesc *keychain.KeyDescriptor) *AuthenticatedGossiper NotifyWhenOnline: cfg.NotifyWhenOnline, NotifyWhenOffline: cfg.NotifyWhenOffline, MessageStore: cfg.MessageStore, - IsMsgStale: gossiper.isMsgStale, + IsMsgStale: func(message lnwire.Message) bool { + ctx := context.TODO() + + return gossiper.isMsgStale(ctx, message) + }, }) return gossiper @@ -641,16 +647,19 @@ func (d *AuthenticatedGossiper) PropagateChanPolicyUpdate( // Start spawns network messages handler goroutine and registers on new block // notifications in order to properly handle the premature announcements. -func (d *AuthenticatedGossiper) Start() error { +func (d *AuthenticatedGossiper) Start(ctx context.Context) error { var err error d.started.Do(func() { + ctx, cancel := context.WithCancel(ctx) + d.cancel = fn.Some(cancel) + log.Info("Authenticated Gossiper starting") - err = d.start() + err = d.start(ctx) }) return err } -func (d *AuthenticatedGossiper) start() error { +func (d *AuthenticatedGossiper) start(ctx context.Context) error { // First we register for new notifications of newly discovered blocks. // We do this immediately so we'll later be able to consume any/all // blocks which were discovered. @@ -680,7 +689,7 @@ func (d *AuthenticatedGossiper) start() error { // Start receiving blocks in its dedicated goroutine. d.wg.Add(2) go d.syncBlockHeight() - go d.networkHandler() + go d.networkHandler(ctx) return nil } @@ -835,6 +844,7 @@ func (d *AuthenticatedGossiper) stop() { d.banman.stop() + d.cancel.WhenSome(func(fn context.CancelFunc) { fn() }) close(d.quit) d.wg.Wait() @@ -1337,7 +1347,7 @@ func (d *AuthenticatedGossiper) splitAnnouncementBatches( // split size, and then sends out all items to the set of target peers. Locally // generated announcements are always sent before remotely generated // announcements. -func (d *AuthenticatedGossiper) splitAndSendAnnBatch( +func (d *AuthenticatedGossiper) splitAndSendAnnBatch(ctx context.Context, annBatch msgsToBroadcast) { // delayNextBatch is a helper closure that blocks for `SubBatchDelay` @@ -1374,7 +1384,7 @@ func (d *AuthenticatedGossiper) splitAndSendAnnBatch( // Now send the remote announcements. for _, annBatch := range remoteBatches { - d.sendRemoteBatch(annBatch) + d.sendRemoteBatch(ctx, annBatch) delayNextBatch() } }() @@ -1398,7 +1408,9 @@ func (d *AuthenticatedGossiper) sendLocalBatch(annBatch []msgWithSenders) { // sendRemoteBatch broadcasts a list of remotely generated announcements to our // peers. -func (d *AuthenticatedGossiper) sendRemoteBatch(annBatch []msgWithSenders) { +func (d *AuthenticatedGossiper) sendRemoteBatch(_ context.Context, + annBatch []msgWithSenders) { + syncerPeers := d.syncMgr.GossipSyncers() // We'll first attempt to filter out this new message for all peers @@ -1431,7 +1443,7 @@ func (d *AuthenticatedGossiper) sendRemoteBatch(annBatch []msgWithSenders) { // broadcasting our latest topology state to all connected peers. // // NOTE: This MUST be run as a goroutine. -func (d *AuthenticatedGossiper) networkHandler() { +func (d *AuthenticatedGossiper) networkHandler(ctx context.Context) { defer d.wg.Done() // Initialize empty deDupedAnnouncements to store announcement batch. @@ -1446,7 +1458,7 @@ func (d *AuthenticatedGossiper) networkHandler() { // To start, we'll first check to see if there are any stale channel or // node announcements that we need to re-transmit. - if err := d.retransmitStaleAnns(time.Now()); err != nil { + if err := d.retransmitStaleAnns(ctx, time.Now()); err != nil { log.Errorf("Unable to rebroadcast stale announcements: %v", err) } @@ -1463,7 +1475,7 @@ func (d *AuthenticatedGossiper) networkHandler() { // the affected channels and also update the underlying // graph with the new state. newChanUpdates, err := d.processChanPolicyUpdate( - policyUpdate.edgesToUpdate, + ctx, policyUpdate.edgesToUpdate, ) policyUpdate.errChan <- err if err != nil { @@ -1488,7 +1500,7 @@ func (d *AuthenticatedGossiper) networkHandler() { // messages that we'll process serially. case *lnwire.AnnounceSignatures1: emittedAnnouncements, _ := d.processNetworkAnnouncement( - announcement, + ctx, announcement, ) log.Debugf("Processed network message %s, "+ "returned len(announcements)=%v", @@ -1528,7 +1540,7 @@ func (d *AuthenticatedGossiper) networkHandler() { d.wg.Add(1) go d.handleNetworkMessages( - announcement, &announcements, annJobID, + ctx, announcement, &announcements, annJobID, ) // The trickle timer has ticked, which indicates we should @@ -1551,7 +1563,7 @@ func (d *AuthenticatedGossiper) networkHandler() { // announcements, we'll blast them out w/o regard for // our peer's policies so we ensure they propagate // properly. - d.splitAndSendAnnBatch(announcementBatch) + d.splitAndSendAnnBatch(ctx, announcementBatch) // The retransmission timer has ticked which indicates that we // should check if we need to prune or re-broadcast any of our @@ -1560,7 +1572,7 @@ func (d *AuthenticatedGossiper) networkHandler() { // have been dropped, or not properly propagated through the // network. case tick := <-d.cfg.RetransmitTicker.Ticks(): - if err := d.retransmitStaleAnns(tick); err != nil { + if err := d.retransmitStaleAnns(ctx, tick); err != nil { log.Errorf("unable to rebroadcast stale "+ "announcements: %v", err) } @@ -1578,8 +1590,8 @@ func (d *AuthenticatedGossiper) networkHandler() { // signal its dependants and add the new announcements to the announce batch. // // NOTE: must be run as a goroutine. -func (d *AuthenticatedGossiper) handleNetworkMessages(nMsg *networkMsg, - deDuped *deDupedAnnouncements, jobID JobID) { +func (d *AuthenticatedGossiper) handleNetworkMessages(ctx context.Context, + nMsg *networkMsg, deDuped *deDupedAnnouncements, jobID JobID) { defer d.wg.Done() defer d.vb.CompleteJob() @@ -1607,7 +1619,7 @@ func (d *AuthenticatedGossiper) handleNetworkMessages(nMsg *networkMsg, // Process the network announcement to determine if this is either a // new announcement from our PoV or an edges to a prior vertex/edge we // previously proceeded. - newAnns, allow := d.processNetworkAnnouncement(nMsg) + newAnns, allow := d.processNetworkAnnouncement(ctx, nMsg) log.Tracef("Processed network message %s, returned "+ "len(announcements)=%v, allowDependents=%v", @@ -1681,7 +1693,9 @@ func (d *AuthenticatedGossiper) isRecentlyRejectedMsg(msg lnwire.Message, // stale iff, the last timestamp of its rebroadcast is older than the // RebroadcastInterval. We also check if a refreshed node announcement should // be resent. -func (d *AuthenticatedGossiper) retransmitStaleAnns(now time.Time) error { +func (d *AuthenticatedGossiper) retransmitStaleAnns(ctx context.Context, + now time.Time) error { + // Iterate over all of our channels and check if any of them fall // within the prune interval or re-broadcast interval. type updateTuple struct { @@ -1753,7 +1767,7 @@ func (d *AuthenticatedGossiper) retransmitStaleAnns(now time.Time) error { // Re-sign and update the channel on disk and retrieve our // ChannelUpdate to broadcast. chanAnn, chanUpdate, err := d.updateChannel( - chanToUpdate.info, chanToUpdate.edge, + ctx, chanToUpdate.info, chanToUpdate.edge, ) if err != nil { return fmt.Errorf("unable to update channel: %w", err) @@ -1794,7 +1808,7 @@ func (d *AuthenticatedGossiper) retransmitStaleAnns(now time.Time) error { // Before broadcasting the refreshed node announcement, add it // to our own graph. - if err := d.addNode(&newNodeAnn); err != nil { + if err := d.addNode(ctx, &newNodeAnn); err != nil { log.Errorf("Unable to add refreshed node announcement "+ "to graph: %v", err) } @@ -1820,7 +1834,7 @@ func (d *AuthenticatedGossiper) retransmitStaleAnns(now time.Time) error { // processChanPolicyUpdate generates a new set of channel updates for the // provided list of edges and updates the backing ChannelGraphSource. -func (d *AuthenticatedGossiper) processChanPolicyUpdate( +func (d *AuthenticatedGossiper) processChanPolicyUpdate(ctx context.Context, edgesToUpdate []EdgeWithInfo) ([]networkMsg, error) { var chanUpdates []networkMsg @@ -1829,7 +1843,7 @@ func (d *AuthenticatedGossiper) processChanPolicyUpdate( // we'll re-sign and update the backing ChannelGraphSource, and // retrieve our ChannelUpdate to broadcast. _, chanUpdate, err := d.updateChannel( - edgeInfo.Info, edgeInfo.Edge, + ctx, edgeInfo.Info, edgeInfo.Edge, ) if err != nil { return nil, err @@ -1922,7 +1936,7 @@ func remotePubFromChanInfo(chanInfo *models.ChannelEdgeInfo, // situation in the case where we create a channel, but for some reason fail // to receive the remote peer's proof, while the remote peer is able to fully // assemble the proof and craft the ChannelAnnouncement. -func (d *AuthenticatedGossiper) processRejectedEdge( +func (d *AuthenticatedGossiper) processRejectedEdge(_ context.Context, chanAnnMsg *lnwire.ChannelAnnouncement1, proof *models.ChannelAuthProof) ([]networkMsg, error) { @@ -2010,8 +2024,8 @@ func (d *AuthenticatedGossiper) fetchPKScript(chanID *lnwire.ShortChannelID) ( // addNode processes the given node announcement, and adds it to our channel // graph. -func (d *AuthenticatedGossiper) addNode(msg *lnwire.NodeAnnouncement, - op ...batch.SchedulerOption) error { +func (d *AuthenticatedGossiper) addNode(_ context.Context, + msg *lnwire.NodeAnnouncement, op ...batch.SchedulerOption) error { if err := netann.ValidateNodeAnn(msg); err != nil { return fmt.Errorf("unable to validate node announcement: %w", @@ -2086,7 +2100,7 @@ func (d *AuthenticatedGossiper) isPremature(chanID lnwire.ShortChannelID, // be returned which should be broadcasted to the rest of the network. The // boolean returned indicates whether any dependents of the announcement should // attempt to be processed as well. -func (d *AuthenticatedGossiper) processNetworkAnnouncement( +func (d *AuthenticatedGossiper) processNetworkAnnouncement(ctx context.Context, nMsg *networkMsg) ([]networkMsg, bool) { // If this is a remote update, we set the scheduler option to lazily @@ -2101,26 +2115,26 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( // information about a node in one of the channels we know about, or a // updating previously advertised information. case *lnwire.NodeAnnouncement: - return d.handleNodeAnnouncement(nMsg, msg, schedulerOp) + return d.handleNodeAnnouncement(ctx, nMsg, msg, schedulerOp) // A new channel announcement has arrived, this indicates the // *creation* of a new channel within the network. This only advertises // the existence of a channel and not yet the routing policies in // either direction of the channel. case *lnwire.ChannelAnnouncement1: - return d.handleChanAnnouncement(nMsg, msg, schedulerOp...) + return d.handleChanAnnouncement(ctx, nMsg, msg, schedulerOp...) // A new authenticated channel edge update has arrived. This indicates // that the directional information for an already known channel has // been updated. case *lnwire.ChannelUpdate1: - return d.handleChanUpdate(nMsg, msg, schedulerOp) + return d.handleChanUpdate(ctx, nMsg, msg, schedulerOp) // A new signature announcement has been received. This indicates // willingness of nodes involved in the funding of a channel to // announce this new channel to the rest of the world. case *lnwire.AnnounceSignatures1: - return d.handleAnnSig(nMsg, msg) + return d.handleAnnSig(ctx, nMsg, msg) default: err := errors.New("wrong type of the announcement") @@ -2134,7 +2148,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( // // NOTE: only the NodeKey1Bytes and NodeKey2Bytes members of the ChannelEdgeInfo // should be inspected. -func (d *AuthenticatedGossiper) processZombieUpdate( +func (d *AuthenticatedGossiper) processZombieUpdate(_ context.Context, chanInfo *models.ChannelEdgeInfo, scid lnwire.ShortChannelID, msg *lnwire.ChannelUpdate1) error { @@ -2192,7 +2206,7 @@ func (d *AuthenticatedGossiper) processZombieUpdate( // fetchNodeAnn fetches the latest signed node announcement from our point of // view for the node with the given public key. -func (d *AuthenticatedGossiper) fetchNodeAnn( +func (d *AuthenticatedGossiper) fetchNodeAnn(_ context.Context, pubKey [33]byte) (*lnwire.NodeAnnouncement, error) { node, err := d.cfg.Graph.FetchLightningNode(pubKey) @@ -2205,7 +2219,9 @@ func (d *AuthenticatedGossiper) fetchNodeAnn( // isMsgStale determines whether a message retrieved from the backing // MessageStore is seen as stale by the current graph. -func (d *AuthenticatedGossiper) isMsgStale(msg lnwire.Message) bool { +func (d *AuthenticatedGossiper) isMsgStale(_ context.Context, + msg lnwire.Message) bool { + switch msg := msg.(type) { case *lnwire.AnnounceSignatures1: chanInfo, _, _, err := d.cfg.Graph.GetChannelByID( @@ -2272,7 +2288,8 @@ func (d *AuthenticatedGossiper) isMsgStale(msg lnwire.Message) bool { // updateChannel creates a new fully signed update for the channel, and updates // the underlying graph with the new state. -func (d *AuthenticatedGossiper) updateChannel(info *models.ChannelEdgeInfo, +func (d *AuthenticatedGossiper) updateChannel(_ context.Context, + info *models.ChannelEdgeInfo, edge *models.ChannelEdgePolicy) (*lnwire.ChannelAnnouncement1, *lnwire.ChannelUpdate1, error) { @@ -2414,8 +2431,8 @@ func (d *AuthenticatedGossiper) latestHeight() uint32 { } // handleNodeAnnouncement processes a new node announcement. -func (d *AuthenticatedGossiper) handleNodeAnnouncement(nMsg *networkMsg, - nodeAnn *lnwire.NodeAnnouncement, +func (d *AuthenticatedGossiper) handleNodeAnnouncement(ctx context.Context, + nMsg *networkMsg, nodeAnn *lnwire.NodeAnnouncement, ops []batch.SchedulerOption) ([]networkMsg, bool) { timestamp := time.Unix(int64(nodeAnn.Timestamp), 0) @@ -2432,7 +2449,7 @@ func (d *AuthenticatedGossiper) handleNodeAnnouncement(nMsg *networkMsg, return nil, true } - if err := d.addNode(nodeAnn, ops...); err != nil { + if err := d.addNode(ctx, nodeAnn, ops...); err != nil { log.Debugf("Adding node: %x got error: %v", nodeAnn.NodeID, err) @@ -2487,8 +2504,10 @@ func (d *AuthenticatedGossiper) handleNodeAnnouncement(nMsg *networkMsg, } // handleChanAnnouncement processes a new channel announcement. -func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, - ann *lnwire.ChannelAnnouncement1, +// +//nolint:funlen +func (d *AuthenticatedGossiper) handleChanAnnouncement(ctx context.Context, + nMsg *networkMsg, ann *lnwire.ChannelAnnouncement1, ops ...batch.SchedulerOption) ([]networkMsg, bool) { scid := ann.ShortChannelID @@ -2680,7 +2699,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // add an alias ChannelAnnouncement from the gossiper. if !(d.cfg.AssumeChannelValid || d.cfg.IsAlias(scid)) { //nolint:nestif op, capacity, script, err := d.validateFundingTransaction( - ann, tapscriptRoot, + ctx, ann, tapscriptRoot, ) if err != nil { defer d.channelMtx.Unlock(scid.ToUint64()) @@ -2802,7 +2821,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, if graph.IsError(err, graph.ErrIgnored) { // Attempt to process the rejected message to see if we // get any new announcements. - anns, rErr := d.processRejectedEdge(ann, proof) + anns, rErr := d.processRejectedEdge(ctx, ann, proof) if rErr != nil { key := newRejectCacheKey( scid.ToUint64(), @@ -2945,8 +2964,10 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, } // handleChanUpdate processes a new channel update. -func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, - upd *lnwire.ChannelUpdate1, +// +//nolint:funlen +func (d *AuthenticatedGossiper) handleChanUpdate(ctx context.Context, + nMsg *networkMsg, upd *lnwire.ChannelUpdate1, ops []batch.SchedulerOption) ([]networkMsg, bool) { log.Debugf("Processing ChannelUpdate: peer=%v, short_chan_id=%v, ", @@ -3052,7 +3073,7 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, break case errors.Is(err, graphdb.ErrZombieEdge): - err = d.processZombieUpdate(chanInfo, graphScid, upd) + err = d.processZombieUpdate(ctx, chanInfo, graphScid, upd) if err != nil { log.Debug(err) nMsg.err <- err @@ -3346,8 +3367,11 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, } // handleAnnSig processes a new announcement signatures message. -func (d *AuthenticatedGossiper) handleAnnSig(nMsg *networkMsg, - ann *lnwire.AnnounceSignatures1) ([]networkMsg, bool) { +// +//nolint:funlen +func (d *AuthenticatedGossiper) handleAnnSig(ctx context.Context, + nMsg *networkMsg, ann *lnwire.AnnounceSignatures1) ([]networkMsg, + bool) { needBlockHeight := ann.ShortChannelID.BlockHeight + d.cfg.ProofMatureDelta @@ -3631,7 +3655,7 @@ func (d *AuthenticatedGossiper) handleAnnSig(nMsg *networkMsg, // it since the source gets skipped. This isn't necessary for channel // updates and announcement signatures since we send those directly to // our channel counterparty through the gossiper's reliable sender. - node1Ann, err := d.fetchNodeAnn(chanInfo.NodeKey1Bytes) + node1Ann, err := d.fetchNodeAnn(ctx, chanInfo.NodeKey1Bytes) if err != nil { log.Debugf("Unable to fetch node announcement for %x: %v", chanInfo.NodeKey1Bytes, err) @@ -3645,7 +3669,7 @@ func (d *AuthenticatedGossiper) handleAnnSig(nMsg *networkMsg, } } - node2Ann, err := d.fetchNodeAnn(chanInfo.NodeKey2Bytes) + node2Ann, err := d.fetchNodeAnn(ctx, chanInfo.NodeKey2Bytes) if err != nil { log.Debugf("Unable to fetch node announcement for %x: %v", chanInfo.NodeKey2Bytes, err) @@ -3700,7 +3724,7 @@ func (d *AuthenticatedGossiper) ShouldDisconnect(pubkey *btcec.PublicKey) ( // transaction from chain to ensure that it exists, is not spent and matches // the channel announcement proof. The transaction's outpoint and value are // returned if we can glean them from the work done in this method. -func (d *AuthenticatedGossiper) validateFundingTransaction( +func (d *AuthenticatedGossiper) validateFundingTransaction(_ context.Context, ann *lnwire.ChannelAnnouncement1, tapscriptRoot fn.Option[chainhash.Hash]) (wire.OutPoint, btcutil.Amount, []byte, error) { diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index b48ea8dbdf..b3ca738ebf 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -2,6 +2,7 @@ package discovery import ( "bytes" + "context" "encoding/hex" "fmt" prand "math/rand" @@ -993,7 +994,7 @@ func createTestCtx(t *testing.T, startHeight uint32, isChanPeer bool) ( ScidCloser: newMockScidCloser(isChanPeer), }, selfKeyDesc) - if err := gossiper.Start(); err != nil { + if err := gossiper.Start(context.Background()); err != nil { return nil, fmt.Errorf("unable to start router: %w", err) } @@ -1680,7 +1681,7 @@ func TestSignatureAnnouncementRetryAtStartup(t *testing.T) { KeyLocator: ctx.gossiper.selfKeyLoc, }) require.NoError(t, err, "unable to recreate gossiper") - if err := gossiper.Start(); err != nil { + if err := gossiper.Start(context.Background()); err != nil { t.Fatalf("unable to start recreated gossiper: %v", err) } defer gossiper.Stop() @@ -4756,7 +4757,9 @@ func assertChanChainRejection(t *testing.T, ctx *testCtx, err: errChan, } - _, added := ctx.gossiper.handleChanAnnouncement(nMsg, edge) + _, added := ctx.gossiper.handleChanAnnouncement( + context.Background(), nMsg, edge, + ) require.False(t, added) select { diff --git a/server.go b/server.go index 911c57e427..f8beba1dd0 100644 --- a/server.go +++ b/server.go @@ -2219,7 +2219,7 @@ func (s *server) startLowLevelServices() error { // NOTE: This function is safe for concurrent access. // //nolint:funlen -func (s *server) Start(_ context.Context) error { +func (s *server) Start(ctx context.Context) error { // Get the current blockbeat. beat, err := s.getStartingBeat() if err != nil { @@ -2390,7 +2390,7 @@ func (s *server) Start(_ context.Context) error { // The authGossiper depends on the chanRouter and therefore // should be started after it. cleanup = cleanup.add(s.authGossiper.Stop) - if err := s.authGossiper.Start(); err != nil { + if err := s.authGossiper.Start(ctx); err != nil { startErr = err return } From 1c9c9d8224790ac7f724ad9c79ba3fa841710fe8 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 7 Apr 2025 10:03:29 +0200 Subject: [PATCH 05/41] discovery: pass context through to reliable sender And remove a context.TODO() that was added in the previous commit. --- discovery/gossiper.go | 14 ++++------- discovery/reliable_sender.go | 37 ++++++++++++++++++---------- discovery/reliable_sender_test.go | 41 +++++++++++++++++-------------- 3 files changed, 51 insertions(+), 41 deletions(-) diff --git a/discovery/gossiper.go b/discovery/gossiper.go index ee74a59381..ac95d55ba9 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -602,11 +602,7 @@ func New(cfg Config, selfKeyDesc *keychain.KeyDescriptor) *AuthenticatedGossiper NotifyWhenOnline: cfg.NotifyWhenOnline, NotifyWhenOffline: cfg.NotifyWhenOffline, MessageStore: cfg.MessageStore, - IsMsgStale: func(message lnwire.Message) bool { - ctx := context.TODO() - - return gossiper.isMsgStale(ctx, message) - }, + IsMsgStale: gossiper.isMsgStale, }) return gossiper @@ -678,7 +674,7 @@ func (d *AuthenticatedGossiper) start(ctx context.Context) error { // Start the reliable sender. In case we had any pending messages ready // to be sent when the gossiper was last shut down, we must continue on // our quest to deliver them to their respective peers. - if err := d.reliableSender.Start(); err != nil { + if err := d.reliableSender.Start(ctx); err != nil { return err } @@ -1889,7 +1885,7 @@ func (d *AuthenticatedGossiper) processChanPolicyUpdate(ctx context.Context, edgeInfo.Info, chanUpdate.ChannelFlags, ) err := d.reliableSender.sendMessage( - chanUpdate, remotePubKey, + ctx, chanUpdate, remotePubKey, ) if err != nil { log.Errorf("Unable to reliably send %v for "+ @@ -3333,7 +3329,7 @@ func (d *AuthenticatedGossiper) handleChanUpdate(ctx context.Context, // Now we'll attempt to send the channel update message // reliably to the remote peer in the background, so that we // don't block if the peer happens to be offline at the moment. - err := d.reliableSender.sendMessage(upd, remotePubKey) + err := d.reliableSender.sendMessage(ctx, upd, remotePubKey) if err != nil { err := fmt.Errorf("unable to reliably send %v for "+ "channel=%v to peer=%x: %v", upd.MsgType(), @@ -3470,7 +3466,7 @@ func (d *AuthenticatedGossiper) handleAnnSig(ctx context.Context, // Since the remote peer might not be online we'll call a // method that will attempt to deliver the proof when it comes // online. - err := d.reliableSender.sendMessage(ann, remotePubKey) + err := d.reliableSender.sendMessage(ctx, ann, remotePubKey) if err != nil { err := fmt.Errorf("unable to reliably send %v for "+ "channel=%v to peer=%x: %v", ann.MsgType(), diff --git a/discovery/reliable_sender.go b/discovery/reliable_sender.go index b4d32e73fd..57f2f28ff3 100644 --- a/discovery/reliable_sender.go +++ b/discovery/reliable_sender.go @@ -1,8 +1,10 @@ package discovery import ( + "context" "sync" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lnwire" ) @@ -28,7 +30,7 @@ type reliableSenderCfg struct { // IsMsgStale determines whether a message retrieved from the backing // MessageStore is seen as stale by the current graph. - IsMsgStale func(lnwire.Message) bool + IsMsgStale func(context.Context, lnwire.Message) bool } // peerManager contains the set of channels required for the peerHandler to @@ -59,8 +61,9 @@ type reliableSender struct { activePeers map[[33]byte]peerManager activePeersMtx sync.Mutex - wg sync.WaitGroup - quit chan struct{} + wg sync.WaitGroup + quit chan struct{} + cancel fn.Option[context.CancelFunc] } // newReliableSender returns a new reliableSender backed by the given config. @@ -73,10 +76,13 @@ func newReliableSender(cfg *reliableSenderCfg) *reliableSender { } // Start spawns message handlers for any peers with pending messages. -func (s *reliableSender) Start() error { +func (s *reliableSender) Start(ctx context.Context) error { var err error s.start.Do(func() { - err = s.resendPendingMsgs() + ctx, cancel := context.WithCancel(ctx) + s.cancel = fn.Some(cancel) + + err = s.resendPendingMsgs(ctx) }) return err } @@ -87,6 +93,7 @@ func (s *reliableSender) Stop() { log.Debugf("reliableSender is stopping") defer log.Debugf("reliableSender stopped") + s.cancel.WhenSome(func(fn context.CancelFunc) { fn() }) close(s.quit) s.wg.Wait() }) @@ -96,7 +103,9 @@ func (s *reliableSender) Stop() { // event that the peer is currently offline, this will only write the message to // disk. Once the peer reconnects, this message, along with any others pending, // will be sent to the peer. -func (s *reliableSender) sendMessage(msg lnwire.Message, peerPubKey [33]byte) error { +func (s *reliableSender) sendMessage(ctx context.Context, msg lnwire.Message, + peerPubKey [33]byte) error { + // We'll start by persisting the message to disk. This allows us to // resend the message upon restarts and peer reconnections. if err := s.cfg.MessageStore.AddMessage(msg, peerPubKey); err != nil { @@ -106,7 +115,7 @@ func (s *reliableSender) sendMessage(msg lnwire.Message, peerPubKey [33]byte) er // Then, we'll spawn a peerHandler for this peer to handle resending its // pending messages while taking into account its connection lifecycle. spawnHandler: - msgHandler, ok := s.spawnPeerHandler(peerPubKey) + msgHandler, ok := s.spawnPeerHandler(ctx, peerPubKey) // If the handler wasn't previously active, we can exit now as we know // that the message will be sent once the peer online notification is @@ -134,7 +143,7 @@ spawnHandler: // spawnPeerMsgHandler spawns a peerHandler for the given peer if there isn't // one already active. The boolean returned signals whether there was already // one active or not. -func (s *reliableSender) spawnPeerHandler( +func (s *reliableSender) spawnPeerHandler(ctx context.Context, peerPubKey [33]byte) (peerManager, bool) { s.activePeersMtx.Lock() @@ -152,7 +161,7 @@ func (s *reliableSender) spawnPeerHandler( // peerHandler. if !ok { s.wg.Add(1) - go s.peerHandler(msgHandler, peerPubKey) + go s.peerHandler(ctx, msgHandler, peerPubKey) } return msgHandler, ok @@ -164,7 +173,9 @@ func (s *reliableSender) spawnPeerHandler( // offline will be queued and sent once the peer reconnects. // // NOTE: This must be run as a goroutine. -func (s *reliableSender) peerHandler(peerMgr peerManager, peerPubKey [33]byte) { +func (s *reliableSender) peerHandler(ctx context.Context, peerMgr peerManager, + peerPubKey [33]byte) { + defer s.wg.Done() // We'll start by requesting a notification for when the peer @@ -252,7 +263,7 @@ out: // check whether it's stale. This guarantees that // AnnounceSignatures are sent at least once if we happen to // already have signatures for both parties. - if s.cfg.IsMsgStale(msg) { + if s.cfg.IsMsgStale(ctx, msg) { err := s.cfg.MessageStore.DeleteMessage(msg, peerPubKey) if err != nil { log.Errorf("Unable to remove stale %v message "+ @@ -321,7 +332,7 @@ out: // resendPendingMsgs retrieves and sends all of the messages within the message // store that should be reliably sent to their respective peers. -func (s *reliableSender) resendPendingMsgs() error { +func (s *reliableSender) resendPendingMsgs(ctx context.Context) error { // Fetch all of the peers for which we have pending messages for and // spawn a peerMsgHandler for each. Once the peer is seen as online, all // of the pending messages will be sent. @@ -331,7 +342,7 @@ func (s *reliableSender) resendPendingMsgs() error { } for peer := range peers { - s.spawnPeerHandler(peer) + s.spawnPeerHandler(ctx, peer) } return nil diff --git a/discovery/reliable_sender_test.go b/discovery/reliable_sender_test.go index 19fdaa1cad..fc94d57f35 100644 --- a/discovery/reliable_sender_test.go +++ b/discovery/reliable_sender_test.go @@ -1,6 +1,7 @@ package discovery import ( + "context" "fmt" "sync/atomic" "testing" @@ -11,6 +12,7 @@ import ( "github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/require" ) // newTestReliableSender creates a new reliable sender instance used for @@ -32,7 +34,7 @@ func newTestReliableSender(t *testing.T) *reliableSender { return c }, MessageStore: newMockMessageStore(), - IsMsgStale: func(lnwire.Message) bool { + IsMsgStale: func(context.Context, lnwire.Message) bool { return false }, } @@ -69,6 +71,7 @@ func assertMsgsSent(t *testing.T, msgChan chan lnwire.Message, // a peer while taking into account its connection lifecycle works as expected. func TestReliableSenderFlow(t *testing.T) { t.Parallel() + ctx := context.Background() reliableSender := newTestReliableSender(t) @@ -98,9 +101,8 @@ func TestReliableSenderFlow(t *testing.T) { msg1 := randChannelUpdate() var peerPubKey [33]byte copy(peerPubKey[:], pubKey.SerializeCompressed()) - if err := reliableSender.sendMessage(msg1, peerPubKey); err != nil { - t.Fatalf("unable to reliably send message: %v", err) - } + err := reliableSender.sendMessage(ctx, msg1, peerPubKey) + require.NoError(t, err) // Since there isn't a peerHandler for this peer currently active due to // this being the first message being sent reliably, we should expect to @@ -114,9 +116,8 @@ func TestReliableSenderFlow(t *testing.T) { // We'll then attempt to send another additional message reliably. msg2 := randAnnounceSignatures() - if err := reliableSender.sendMessage(msg2, peerPubKey); err != nil { - t.Fatalf("unable to reliably send message: %v", err) - } + err = reliableSender.sendMessage(ctx, msg2, peerPubKey) + require.NoError(t, err) // This should not however request another peer online notification as // the peerHandler has already been started and is waiting for the @@ -145,9 +146,8 @@ func TestReliableSenderFlow(t *testing.T) { // Then, we'll send one more message reliably. msg3 := randChannelUpdate() - if err := reliableSender.sendMessage(msg3, peerPubKey); err != nil { - t.Fatalf("unable to reliably send message: %v", err) - } + err = reliableSender.sendMessage(ctx, msg3, peerPubKey) + require.NoError(t, err) // Again, this should not request another peer online notification // request since we are currently waiting for the peer to be offline. @@ -188,6 +188,7 @@ func TestReliableSenderFlow(t *testing.T) { // them as stale. func TestReliableSenderStaleMessages(t *testing.T) { t.Parallel() + ctx := context.Background() reliableSender := newTestReliableSender(t) @@ -206,7 +207,9 @@ func TestReliableSenderStaleMessages(t *testing.T) { // We'll also override IsMsgStale to mark all messages as stale as we're // interested in testing the stale message behavior. - reliableSender.cfg.IsMsgStale = func(_ lnwire.Message) bool { + reliableSender.cfg.IsMsgStale = func(_ context.Context, + _ lnwire.Message) bool { + return true } @@ -215,9 +218,8 @@ func TestReliableSenderStaleMessages(t *testing.T) { msg1 := randAnnounceSignatures() var peerPubKey [33]byte copy(peerPubKey[:], pubKey.SerializeCompressed()) - if err := reliableSender.sendMessage(msg1, peerPubKey); err != nil { - t.Fatalf("unable to reliably send message: %v", err) - } + err := reliableSender.sendMessage(ctx, msg1, peerPubKey) + require.NoError(t, err) // Since there isn't a peerHandler for this peer currently active due to // this being the first message being sent reliably, we should expect to @@ -245,7 +247,7 @@ func TestReliableSenderStaleMessages(t *testing.T) { // message store since it is seen as stale and has been sent at least // once. Once the message is removed, the peerHandler should be torn // down as there are no longer any pending messages within the store. - err := wait.NoError(func() error { + err = wait.NoError(func() error { msgs, err := reliableSender.cfg.MessageStore.MessagesForPeer( peerPubKey, ) @@ -265,14 +267,15 @@ func TestReliableSenderStaleMessages(t *testing.T) { } // Override IsMsgStale to no longer mark messages as stale. - reliableSender.cfg.IsMsgStale = func(_ lnwire.Message) bool { + reliableSender.cfg.IsMsgStale = func(_ context.Context, + _ lnwire.Message) bool { + return false } // We'll request the message to be sent reliably. - if err := reliableSender.sendMessage(msg2, peerPubKey); err != nil { - t.Fatalf("unable to reliably send message: %v", err) - } + err = reliableSender.sendMessage(ctx, msg2, peerPubKey) + require.NoError(t, err) // We should see an online notification request indicating that a new // peerHandler has been spawned since it was previously torn down. From 5193a9f82c4ebfca63897e3177839f60fdae0810 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 7 Apr 2025 10:17:35 +0200 Subject: [PATCH 06/41] discovery: thread contexts to syncer The `GossiperSyncer` makes various calls to the `ChannelGraphTimeSeries` interface which threads through to the graph DB. So in preparation for threading context through to all the methods on that interface, we update the GossipSyncer accordingly by passing contexts through. Two `context.TODO()`s are added in this commit. They will be removed in the upcoming commits. --- discovery/gossiper.go | 8 +++-- discovery/sync_manager.go | 2 +- discovery/syncer.go | 61 +++++++++++++++++++------------- discovery/syncer_test.go | 73 ++++++++++++++++++++++++--------------- 4 files changed, 89 insertions(+), 55 deletions(-) diff --git a/discovery/gossiper.go b/discovery/gossiper.go index ac95d55ba9..c1f81df534 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -860,6 +860,8 @@ func (d *AuthenticatedGossiper) stop() { func (d *AuthenticatedGossiper) ProcessRemoteAnnouncement(msg lnwire.Message, peer lnpeer.Peer) chan error { + ctx := context.TODO() + log.Debugf("Processing remote msg %T from peer=%x", msg, peer.PubKey()) errChan := make(chan error, 1) @@ -907,7 +909,7 @@ func (d *AuthenticatedGossiper) ProcessRemoteAnnouncement(msg lnwire.Message, // If we've found the message target, then we'll dispatch the // message directly to it. - if err := syncer.ApplyGossipFilter(m); err != nil { + if err := syncer.ApplyGossipFilter(ctx, m); err != nil { log.Warnf("Unable to apply gossip filter for peer=%x: "+ "%v", peer.PubKey(), err) @@ -1404,7 +1406,7 @@ func (d *AuthenticatedGossiper) sendLocalBatch(annBatch []msgWithSenders) { // sendRemoteBatch broadcasts a list of remotely generated announcements to our // peers. -func (d *AuthenticatedGossiper) sendRemoteBatch(_ context.Context, +func (d *AuthenticatedGossiper) sendRemoteBatch(ctx context.Context, annBatch []msgWithSenders) { syncerPeers := d.syncMgr.GossipSyncers() @@ -1413,7 +1415,7 @@ func (d *AuthenticatedGossiper) sendRemoteBatch(_ context.Context, // that have active gossip syncers active. for pub, syncer := range syncerPeers { log.Tracef("Sending messages batch to GossipSyncer(%s)", pub) - syncer.FilterGossipMsgs(annBatch...) + syncer.FilterGossipMsgs(ctx, annBatch...) } for _, msgChunk := range annBatch { diff --git a/discovery/sync_manager.go b/discovery/sync_manager.go index b1cb208cd9..a38065d7ec 100644 --- a/discovery/sync_manager.go +++ b/discovery/sync_manager.go @@ -380,7 +380,7 @@ func (m *SyncManager) syncerHandler() { } m.syncersMu.Unlock() - s.Start() + s.Start(context.TODO()) // Once we create the GossipSyncer, we'll signal to the // caller that they can proceed since the SyncManager's diff --git a/discovery/syncer.go b/discovery/syncer.go index da37490505..16a2fa720c 100644 --- a/discovery/syncer.go +++ b/discovery/syncer.go @@ -405,20 +405,22 @@ func newGossipSyncer(cfg gossipSyncerCfg, sema chan struct{}) *GossipSyncer { // Start starts the GossipSyncer and any goroutines that it needs to carry out // its duties. -func (g *GossipSyncer) Start() { +func (g *GossipSyncer) Start(ctx context.Context) { g.started.Do(func() { log.Debugf("Starting GossipSyncer(%x)", g.cfg.peerPub[:]) + ctx, _ := g.cg.Create(ctx) + // TODO(conner): only spawn channelGraphSyncer if remote // supports gossip queries, and only spawn replyHandler if we // advertise support if !g.cfg.noSyncChannels { g.cg.WgAdd(1) - go g.channelGraphSyncer() + go g.channelGraphSyncer(ctx) } if !g.cfg.noReplyQueries { g.cg.WgAdd(1) - go g.replyHandler() + go g.replyHandler(ctx) } }) } @@ -437,9 +439,11 @@ func (g *GossipSyncer) Stop() { // handleSyncingChans handles the state syncingChans for the GossipSyncer. When // in this state, we will send a QueryChannelRange msg to our peer and advance // the syncer's state to waitingQueryRangeReply. -func (g *GossipSyncer) handleSyncingChans() { +func (g *GossipSyncer) handleSyncingChans(ctx context.Context) { // Prepare the query msg. - queryRangeMsg, err := g.genChanRangeQuery(g.genHistoricalChanRangeQuery) + queryRangeMsg, err := g.genChanRangeQuery( + ctx, g.genHistoricalChanRangeQuery, + ) if err != nil { log.Errorf("Unable to gen chan range query: %v", err) return @@ -456,7 +460,6 @@ func (g *GossipSyncer) handleSyncingChans() { // Send the msg to the remote peer, which is non-blocking as // `sendToPeer` only queues the msg in Brontide. - ctx, _ := g.cg.Create(context.Background()) err = g.cfg.sendToPeer(ctx, queryRangeMsg) if err != nil { log.Errorf("Unable to send chan range query: %v", err) @@ -471,7 +474,7 @@ func (g *GossipSyncer) handleSyncingChans() { // channelGraphSyncer is the main goroutine responsible for ensuring that we // properly channel graph state with the remote peer, and also that we only // send them messages which actually pass their defined update horizon. -func (g *GossipSyncer) channelGraphSyncer() { +func (g *GossipSyncer) channelGraphSyncer(ctx context.Context) { defer g.cg.WgDone() for { @@ -488,7 +491,7 @@ func (g *GossipSyncer) channelGraphSyncer() { // understand, as we'll as responding to any other queries by // them. case syncingChans: - g.handleSyncingChans() + g.handleSyncingChans(ctx) // In this state, we've sent out our initial channel range // query and are waiting for the final response from the remote @@ -507,7 +510,9 @@ func (g *GossipSyncer) channelGraphSyncer() { // for the new channels. queryReply, ok := msg.(*lnwire.ReplyChannelRange) if ok { - err := g.processChanRangeReply(queryReply) + err := g.processChanRangeReply( + ctx, queryReply, + ) if err != nil { log.Errorf("Unable to "+ "process chan range "+ @@ -630,13 +635,13 @@ func (g *GossipSyncer) channelGraphSyncer() { // from the state machine maintained on the same node. // // NOTE: This method MUST be run as a goroutine. -func (g *GossipSyncer) replyHandler() { +func (g *GossipSyncer) replyHandler(ctx context.Context) { defer g.cg.WgDone() for { select { case msg := <-g.queryMsgs: - err := g.replyPeerQueries(msg) + err := g.replyPeerQueries(ctx, msg) switch { case err == ErrGossipSyncerExiting: return @@ -759,7 +764,9 @@ func isLegacyReplyChannelRange(query *lnwire.QueryChannelRange, // processChanRangeReply is called each time the GossipSyncer receives a new // reply to the initial range query to discover new channels that it didn't // previously know of. -func (g *GossipSyncer) processChanRangeReply(msg *lnwire.ReplyChannelRange) error { +func (g *GossipSyncer) processChanRangeReply(_ context.Context, + msg *lnwire.ReplyChannelRange) error { + // isStale returns whether the timestamp is too far into the past. isStale := func(timestamp time.Time) bool { return time.Since(timestamp) > graph.DefaultChannelPruneExpiry @@ -948,7 +955,7 @@ func (g *GossipSyncer) processChanRangeReply(msg *lnwire.ReplyChannelRange) erro // party when we're kicking off the channel graph synchronization upon // connection. The historicalQuery boolean can be used to generate a query from // the genesis block of the chain. -func (g *GossipSyncer) genChanRangeQuery( +func (g *GossipSyncer) genChanRangeQuery(_ context.Context, historicalQuery bool) (*lnwire.QueryChannelRange, error) { // First, we'll query our channel graph time series for its highest @@ -1005,18 +1012,20 @@ func (g *GossipSyncer) genChanRangeQuery( // replyPeerQueries is called in response to any query by the remote peer. // We'll examine our state and send back our best response. -func (g *GossipSyncer) replyPeerQueries(msg lnwire.Message) error { +func (g *GossipSyncer) replyPeerQueries(ctx context.Context, + msg lnwire.Message) error { + switch msg := msg.(type) { // In this state, we'll also handle any incoming channel range queries // from the remote peer as they're trying to sync their state as well. case *lnwire.QueryChannelRange: - return g.replyChanRangeQuery(msg) + return g.replyChanRangeQuery(ctx, msg) // If the remote peer skips straight to requesting new channels that // they don't know of, then we'll ensure that we also handle this case. case *lnwire.QueryShortChanIDs: - return g.replyShortChanIDs(msg) + return g.replyShortChanIDs(ctx, msg) default: return fmt.Errorf("unknown message: %T", msg) @@ -1028,7 +1037,9 @@ func (g *GossipSyncer) replyPeerQueries(msg lnwire.Message) error { // meet the channel range, then chunk our responses to the remote node. We also // ensure that our final fragment carries the "complete" bit to indicate the // end of our streaming response. -func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) error { +func (g *GossipSyncer) replyChanRangeQuery(_ context.Context, + query *lnwire.QueryChannelRange) error { + // Before responding, we'll check to ensure that the remote peer is // querying for the same chain that we're on. If not, we'll send back a // response with a complete value of zero to indicate we're on a @@ -1209,7 +1220,9 @@ func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) erro // node for information concerning a set of short channel ID's. Our response // will be sent in a streaming chunked manner to ensure that we remain below // the current transport level message size. -func (g *GossipSyncer) replyShortChanIDs(query *lnwire.QueryShortChanIDs) error { +func (g *GossipSyncer) replyShortChanIDs(ctx context.Context, + query *lnwire.QueryShortChanIDs) error { + // Before responding, we'll check to ensure that the remote peer is // querying for the same chain that we're on. If not, we'll send back a // response with a complete value of zero to indicate we're on a @@ -1219,8 +1232,6 @@ func (g *GossipSyncer) replyShortChanIDs(query *lnwire.QueryShortChanIDs) error "chain=%v, we're on chain=%v", query.ChainHash, g.cfg.chainHash) - ctx, _ := g.cg.Create(context.Background()) - return g.cfg.sendToPeerSync(ctx, &lnwire.ReplyShortChanIDsEnd{ ChainHash: query.ChainHash, Complete: 0, @@ -1261,8 +1272,6 @@ func (g *GossipSyncer) replyShortChanIDs(query *lnwire.QueryShortChanIDs) error // Regardless of whether we had any messages to reply with, send over // the sentinel message to signal that the stream has terminated. - ctx, _ := g.cg.Create(context.Background()) - return g.cfg.sendToPeerSync(ctx, &lnwire.ReplyShortChanIDsEnd{ ChainHash: query.ChainHash, Complete: 1, @@ -1272,7 +1281,9 @@ func (g *GossipSyncer) replyShortChanIDs(query *lnwire.QueryShortChanIDs) error // ApplyGossipFilter applies a gossiper filter sent by the remote node to the // state machine. Once applied, we'll ensure that we don't forward any messages // to the peer that aren't within the time range of the filter. -func (g *GossipSyncer) ApplyGossipFilter(filter *lnwire.GossipTimestampRange) error { +func (g *GossipSyncer) ApplyGossipFilter(_ context.Context, + filter *lnwire.GossipTimestampRange) error { + g.Lock() g.remoteUpdateHorizon = filter @@ -1351,7 +1362,9 @@ func (g *GossipSyncer) ApplyGossipFilter(filter *lnwire.GossipTimestampRange) er // FilterGossipMsgs takes a set of gossip messages, and only send it to a peer // iff the message is within the bounds of their set gossip filter. If the peer // doesn't have a gossip filter set, then no messages will be forwarded. -func (g *GossipSyncer) FilterGossipMsgs(msgs ...msgWithSenders) { +func (g *GossipSyncer) FilterGossipMsgs(_ context.Context, + msgs ...msgWithSenders) { + // If the peer doesn't have an update horizon set, then we won't send // it any new update messages. if g.remoteUpdateHorizon == nil { diff --git a/discovery/syncer_test.go b/discovery/syncer_test.go index a63997909f..32a90ae5fe 100644 --- a/discovery/syncer_test.go +++ b/discovery/syncer_test.go @@ -228,6 +228,7 @@ func newTestSyncer(hID lnwire.ShortChannelID, // doesn't have a horizon set, then we won't send any incoming messages to it. func TestGossipSyncerFilterGossipMsgsNoHorizon(t *testing.T) { t.Parallel() + ctx := context.Background() // First, we'll create a GossipSyncer instance with a canned sendToPeer // message to allow us to intercept their potential sends. @@ -249,7 +250,7 @@ func TestGossipSyncerFilterGossipMsgsNoHorizon(t *testing.T) { // We'll then attempt to filter the set of messages through the target // peer. - syncer.FilterGossipMsgs(msgs...) + syncer.FilterGossipMsgs(ctx, msgs...) // As the remote peer doesn't yet have a gossip timestamp set, we // shouldn't receive any outbound messages. @@ -273,6 +274,7 @@ func unixStamp(a int64) uint32 { // channel ann that already has a channel update on disk. func TestGossipSyncerFilterGossipMsgsAllInMemory(t *testing.T) { t.Parallel() + ctx := context.Background() // First, we'll create a GossipSyncer instance with a canned sendToPeer // message to allow us to intercept their potential sends. @@ -384,7 +386,7 @@ func TestGossipSyncerFilterGossipMsgsAllInMemory(t *testing.T) { }() // We'll then instruct the gossiper to filter this set of messages. - syncer.FilterGossipMsgs(msgs...) + syncer.FilterGossipMsgs(ctx, msgs...) // Out of all the messages we sent in, we should only get 2 of them // back. @@ -415,6 +417,7 @@ func TestGossipSyncerFilterGossipMsgsAllInMemory(t *testing.T) { // messages which are within their desired time horizon. func TestGossipSyncerApplyNoHistoricalGossipFilter(t *testing.T) { t.Parallel() + ctx := context.Background() // First, we'll create a GossipSyncer instance with a canned sendToPeer // message to allow us to intercept their potential sends. @@ -451,7 +454,7 @@ func TestGossipSyncerApplyNoHistoricalGossipFilter(t *testing.T) { }() // We'll now attempt to apply the gossip filter for the remote peer. - syncer.ApplyGossipFilter(remoteHorizon) + syncer.ApplyGossipFilter(ctx, remoteHorizon) // Ensure that the syncer's remote horizon was properly updated. if !reflect.DeepEqual(syncer.remoteUpdateHorizon, remoteHorizon) { @@ -475,6 +478,7 @@ func TestGossipSyncerApplyNoHistoricalGossipFilter(t *testing.T) { // within their desired time horizon. func TestGossipSyncerApplyGossipFilter(t *testing.T) { t.Parallel() + ctx := context.Background() // First, we'll create a GossipSyncer instance with a canned sendToPeer // message to allow us to intercept their potential sends. @@ -515,7 +519,7 @@ func TestGossipSyncerApplyGossipFilter(t *testing.T) { }() // We'll now attempt to apply the gossip filter for the remote peer. - err := syncer.ApplyGossipFilter(remoteHorizon) + err := syncer.ApplyGossipFilter(ctx, remoteHorizon) require.NoError(t, err, "unable to apply filter") // There should be no messages in the message queue as we didn't send @@ -563,7 +567,7 @@ func TestGossipSyncerApplyGossipFilter(t *testing.T) { errCh <- nil } }() - err = syncer.ApplyGossipFilter(remoteHorizon) + err = syncer.ApplyGossipFilter(ctx, remoteHorizon) require.NoError(t, err, "unable to apply filter") // We should get back the exact same message. @@ -594,6 +598,7 @@ func TestGossipSyncerApplyGossipFilter(t *testing.T) { // channels and complete=0. func TestGossipSyncerQueryChannelRangeWrongChainHash(t *testing.T) { t.Parallel() + ctx := context.Background() // First, we'll create a GossipSyncer instance with a canned sendToPeer // message to allow us to intercept their potential sends. @@ -609,7 +614,7 @@ func TestGossipSyncerQueryChannelRangeWrongChainHash(t *testing.T) { FirstBlockHeight: 0, NumBlocks: math.MaxUint32, } - err := syncer.replyChanRangeQuery(query) + err := syncer.replyChanRangeQuery(ctx, query) require.NoError(t, err, "unable to process short chan ID's") select { @@ -646,6 +651,7 @@ func TestGossipSyncerQueryChannelRangeWrongChainHash(t *testing.T) { // complete=0. func TestGossipSyncerReplyShortChanIDsWrongChainHash(t *testing.T) { t.Parallel() + ctx := context.Background() // First, we'll create a GossipSyncer instance with a canned sendToPeer // message to allow us to intercept their potential sends. @@ -656,7 +662,7 @@ func TestGossipSyncerReplyShortChanIDsWrongChainHash(t *testing.T) { // We'll now ask the syncer to reply to a chan ID query, but for a // chain that it isn't aware of. - err := syncer.replyShortChanIDs(&lnwire.QueryShortChanIDs{ + err := syncer.replyShortChanIDs(ctx, &lnwire.QueryShortChanIDs{ ChainHash: *chaincfg.SimNetParams.GenesisHash, }) require.NoError(t, err, "unable to process short chan ID's") @@ -695,6 +701,7 @@ func TestGossipSyncerReplyShortChanIDsWrongChainHash(t *testing.T) { // announcements, as well as an ending ReplyShortChanIDsEnd message. func TestGossipSyncerReplyShortChanIDs(t *testing.T) { t.Parallel() + ctx := context.Background() // First, we'll create a GossipSyncer instance with a canned sendToPeer // message to allow us to intercept their potential sends. @@ -745,7 +752,7 @@ func TestGossipSyncerReplyShortChanIDs(t *testing.T) { // With our set up above complete, we'll now attempt to obtain a reply // from the channel syncer for our target chan ID query. - err := syncer.replyShortChanIDs(&lnwire.QueryShortChanIDs{ + err := syncer.replyShortChanIDs(ctx, &lnwire.QueryShortChanIDs{ ShortChanIDs: queryChanIDs, }) require.NoError(t, err, "unable to query for chan IDs") @@ -800,6 +807,7 @@ func TestGossipSyncerReplyShortChanIDs(t *testing.T) { // the remote peer. func TestGossipSyncerReplyChanRangeQuery(t *testing.T) { t.Parallel() + ctx := context.Background() // We'll use a smaller chunk size so we can easily test all the edge // cases. @@ -866,7 +874,7 @@ func TestGossipSyncerReplyChanRangeQuery(t *testing.T) { }() // With our goroutine active, we'll now issue the query. - if err := syncer.replyChanRangeQuery(query); err != nil { + if err := syncer.replyChanRangeQuery(ctx, query); err != nil { t.Fatalf("unable to issue query: %v", err) } @@ -971,6 +979,7 @@ func TestGossipSyncerReplyChanRangeQuery(t *testing.T) { // executed with the correct block range. func TestGossipSyncerReplyChanRangeQueryBlockRange(t *testing.T) { t.Parallel() + ctx := context.Background() // First create our test gossip syncer that will handle and // respond to the test queries @@ -1052,7 +1061,8 @@ func TestGossipSyncerReplyChanRangeQueryBlockRange(t *testing.T) { // will be reached go func() { for _, query := range queryReqs { - if err := syncer.replyChanRangeQuery(query); err != nil { + err := syncer.replyChanRangeQuery(ctx, query) + if err != nil { errCh <- fmt.Errorf("unable to issue query: %w", err) return @@ -1083,6 +1093,7 @@ func TestGossipSyncerReplyChanRangeQueryBlockRange(t *testing.T) { // back a single response that signals completion. func TestGossipSyncerReplyChanRangeQueryNoNewChans(t *testing.T) { t.Parallel() + ctx := context.Background() // We'll now create our test gossip syncer that will shortly respond to // our canned query. @@ -1121,7 +1132,7 @@ func TestGossipSyncerReplyChanRangeQueryNoNewChans(t *testing.T) { }() // With our goroutine active, we'll now issue the query. - if err := syncer.replyChanRangeQuery(query); err != nil { + if err := syncer.replyChanRangeQuery(ctx, query); err != nil { t.Fatalf("unable to issue query: %v", err) } @@ -1162,6 +1173,7 @@ func TestGossipSyncerReplyChanRangeQueryNoNewChans(t *testing.T) { // channel ID, we properly generate an correct initial channel range response. func TestGossipSyncerGenChanRangeQuery(t *testing.T) { t.Parallel() + ctx := context.Background() // First, we'll create a GossipSyncer instance with a canned sendToPeer // message to allow us to intercept their potential sends. @@ -1174,7 +1186,7 @@ func TestGossipSyncerGenChanRangeQuery(t *testing.T) { // If we now ask the syncer to generate an initial range query, it // should return a start height that's back chanRangeQueryBuffer // blocks. - rangeQuery, err := syncer.genChanRangeQuery(false) + rangeQuery, err := syncer.genChanRangeQuery(ctx, false) require.NoError(t, err, "unable to resp") firstHeight := uint32(startingHeight - chanRangeQueryBuffer) @@ -1190,7 +1202,7 @@ func TestGossipSyncerGenChanRangeQuery(t *testing.T) { // Generating a historical range query should result in a start height // of 0. - rangeQuery, err = syncer.genChanRangeQuery(true) + rangeQuery, err = syncer.genChanRangeQuery(ctx, true) require.NoError(t, err, "unable to resp") if rangeQuery.FirstBlockHeight != 0 { t.Fatalf("incorrect chan range query: expected %v, %v", 0, @@ -1222,6 +1234,7 @@ func TestGossipSyncerProcessChanRangeReply(t *testing.T) { // each reply instead. func testGossipSyncerProcessChanRangeReply(t *testing.T, legacy bool) { t.Parallel() + ctx := context.Background() // First, we'll create a GossipSyncer instance with a canned sendToPeer // message to allow us to intercept their potential sends. @@ -1234,7 +1247,7 @@ func testGossipSyncerProcessChanRangeReply(t *testing.T, legacy bool) { startingState := syncer.state - query, err := syncer.genChanRangeQuery(true) + query, err := syncer.genChanRangeQuery(ctx, true) require.NoError(t, err, "unable to generate channel range query") currentTimestamp := time.Now().Unix() @@ -1359,13 +1372,13 @@ func testGossipSyncerProcessChanRangeReply(t *testing.T, legacy bool) { // We'll begin by sending the syncer a set of non-complete channel // range replies. - if err := syncer.processChanRangeReply(replies[0]); err != nil { + if err := syncer.processChanRangeReply(ctx, replies[0]); err != nil { t.Fatalf("unable to process reply: %v", err) } - if err := syncer.processChanRangeReply(replies[1]); err != nil { + if err := syncer.processChanRangeReply(ctx, replies[1]); err != nil { t.Fatalf("unable to process reply: %v", err) } - if err := syncer.processChanRangeReply(replies[2]); err != nil { + if err := syncer.processChanRangeReply(ctx, replies[2]); err != nil { t.Fatalf("unable to process reply: %v", err) } @@ -1427,7 +1440,7 @@ func testGossipSyncerProcessChanRangeReply(t *testing.T, legacy bool) { // If we send the final message, then we should transition to // queryNewChannels as we've sent a non-empty set of new channels. - if err := syncer.processChanRangeReply(replies[3]); err != nil { + if err := syncer.processChanRangeReply(ctx, replies[3]); err != nil { t.Fatalf("unable to process reply: %v", err) } @@ -1690,6 +1703,7 @@ func queryBatch(t *testing.T, // them. func TestGossipSyncerRoutineSync(t *testing.T) { t.Parallel() + ctx := context.Background() // We'll modify the chunk size to be a smaller value, so we can ensure // our chunk parsing works properly. With this value we should get 3 @@ -1704,13 +1718,13 @@ func TestGossipSyncerRoutineSync(t *testing.T) { msgChan1, syncer1, chanSeries1 := newTestSyncer( highestID, defaultEncoding, chunkSize, true, false, ) - syncer1.Start() + syncer1.Start(ctx) defer syncer1.Stop() msgChan2, syncer2, chanSeries2 := newTestSyncer( highestID, defaultEncoding, chunkSize, false, true, ) - syncer2.Start() + syncer2.Start(ctx) defer syncer2.Stop() // Although both nodes are at the same height, syncer will have 3 chan @@ -1837,6 +1851,7 @@ func TestGossipSyncerRoutineSync(t *testing.T) { // final state and not perform any channel queries. func TestGossipSyncerAlreadySynced(t *testing.T) { t.Parallel() + ctx := context.Background() // We'll modify the chunk size to be a smaller value, so we can ensure // our chunk parsing works properly. With this value we should get 3 @@ -1852,13 +1867,13 @@ func TestGossipSyncerAlreadySynced(t *testing.T) { msgChan1, syncer1, chanSeries1 := newTestSyncer( highestID, defaultEncoding, chunkSize, ) - syncer1.Start() + syncer1.Start(ctx) defer syncer1.Stop() msgChan2, syncer2, chanSeries2 := newTestSyncer( highestID, defaultEncoding, chunkSize, ) - syncer2.Start() + syncer2.Start(ctx) defer syncer2.Stop() // The channel state of both syncers will be identical. They should @@ -2058,6 +2073,7 @@ func TestGossipSyncerAlreadySynced(t *testing.T) { // carries out its duties when accepting a new sync transition request. func TestGossipSyncerSyncTransitions(t *testing.T) { t.Parallel() + ctx := context.Background() assertMsgSent := func(t *testing.T, msgChan chan []lnwire.Message, msg lnwire.Message) { @@ -2178,7 +2194,7 @@ func TestGossipSyncerSyncTransitions(t *testing.T) { // We'll then start the syncer in order to process the // request. - syncer.Start() + syncer.Start(ctx) defer syncer.Stop() syncer.ProcessSyncTransition(test.finalSyncType) @@ -2203,6 +2219,7 @@ func TestGossipSyncerSyncTransitions(t *testing.T) { // historical sync with the remote peer. func TestGossipSyncerHistoricalSync(t *testing.T) { t.Parallel() + ctx := context.Background() // We'll create a new gossip syncer and manually override its state to // chansSynced. This is necessary as the syncer can only process @@ -2214,7 +2231,7 @@ func TestGossipSyncerHistoricalSync(t *testing.T) { syncer.setSyncType(PassiveSync) syncer.setSyncState(chansSynced) - syncer.Start() + syncer.Start(ctx) defer syncer.Stop() syncer.historicalSync() @@ -2247,6 +2264,7 @@ func TestGossipSyncerHistoricalSync(t *testing.T) { // syncer reaches its terminal chansSynced state. func TestGossipSyncerSyncedSignal(t *testing.T) { t.Parallel() + ctx := context.Background() // We'll create a new gossip syncer and manually override its state to // chansSynced. @@ -2261,7 +2279,7 @@ func TestGossipSyncerSyncedSignal(t *testing.T) { signalChan := syncer.ResetSyncedSignal() // Starting the gossip syncer should cause the signal to be delivered. - syncer.Start() + syncer.Start(ctx) select { case <-signalChan: @@ -2280,7 +2298,7 @@ func TestGossipSyncerSyncedSignal(t *testing.T) { syncer.setSyncState(chansSynced) - syncer.Start() + syncer.Start(ctx) defer syncer.Stop() signalChan = syncer.ResetSyncedSignal() @@ -2299,6 +2317,7 @@ func TestGossipSyncerSyncedSignal(t *testing.T) { // said limit are not processed. func TestGossipSyncerMaxChannelRangeReplies(t *testing.T) { t.Parallel() + ctx := context.Background() msgChan, syncer, chanSeries := newTestSyncer( lnwire.ShortChannelID{BlockHeight: latestKnownHeight}, @@ -2309,7 +2328,7 @@ func TestGossipSyncerMaxChannelRangeReplies(t *testing.T) { // the sake of testing. syncer.cfg.maxQueryChanRangeReplies = 100 - syncer.Start() + syncer.Start(ctx) defer syncer.Stop() // Upon initialization, the syncer should submit a QueryChannelRange From 3101f2a66efd526a0656f256cbd6f13848130e0a Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 7 Apr 2025 10:22:24 +0200 Subject: [PATCH 07/41] discovery: thread contexts through sync manager Here, we remove one context.TODO() by threading a context through to the SyncManager. --- discovery/gossiper.go | 2 +- discovery/sync_manager.go | 18 ++++++++++++------ discovery/sync_manager_test.go | 19 ++++++++++--------- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/discovery/gossiper.go b/discovery/gossiper.go index c1f81df534..6232a92ce8 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -678,7 +678,7 @@ func (d *AuthenticatedGossiper) start(ctx context.Context) error { return err } - d.syncMgr.Start() + d.syncMgr.Start(ctx) d.banman.start() diff --git a/discovery/sync_manager.go b/discovery/sync_manager.go index a38065d7ec..c825d27fd2 100644 --- a/discovery/sync_manager.go +++ b/discovery/sync_manager.go @@ -8,6 +8,7 @@ import ( "time" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" @@ -200,8 +201,9 @@ type SyncManager struct { // number of queries. rateLimiter *rate.Limiter - wg sync.WaitGroup - quit chan struct{} + wg sync.WaitGroup + quit chan struct{} + cancel fn.Option[context.CancelFunc] } // newSyncManager constructs a new SyncManager backed by the given config. @@ -246,10 +248,13 @@ func newSyncManager(cfg *SyncManagerCfg) *SyncManager { } // Start starts the SyncManager in order to properly carry out its duties. -func (m *SyncManager) Start() { +func (m *SyncManager) Start(ctx context.Context) { m.start.Do(func() { + ctx, cancel := context.WithCancel(ctx) + m.cancel = fn.Some(cancel) + m.wg.Add(1) - go m.syncerHandler() + go m.syncerHandler(ctx) }) } @@ -259,6 +264,7 @@ func (m *SyncManager) Stop() { log.Debugf("SyncManager is stopping") defer log.Debugf("SyncManager stopped") + m.cancel.WhenSome(func(fn context.CancelFunc) { fn() }) close(m.quit) m.wg.Wait() @@ -282,7 +288,7 @@ func (m *SyncManager) Stop() { // much of the public network as possible. // // NOTE: This must be run as a goroutine. -func (m *SyncManager) syncerHandler() { +func (m *SyncManager) syncerHandler(ctx context.Context) { defer m.wg.Done() m.cfg.RotateTicker.Resume() @@ -380,7 +386,7 @@ func (m *SyncManager) syncerHandler() { } m.syncersMu.Unlock() - s.Start(context.TODO()) + s.Start(ctx) // Once we create the GossipSyncer, we'll signal to the // caller that they can proceed since the SyncManager's diff --git a/discovery/sync_manager_test.go b/discovery/sync_manager_test.go index 4aff5b6315..b8ef931977 100644 --- a/discovery/sync_manager_test.go +++ b/discovery/sync_manager_test.go @@ -2,6 +2,7 @@ package discovery import ( "bytes" + "context" "fmt" "io" "reflect" @@ -82,7 +83,7 @@ func TestSyncManagerNumActiveSyncers(t *testing.T) { } syncMgr := newPinnedTestSyncManager(numActiveSyncers, pinnedSyncers) - syncMgr.Start() + syncMgr.Start(context.Background()) defer syncMgr.Stop() // First we'll start by adding the pinned syncers. These should @@ -134,7 +135,7 @@ func TestSyncManagerNewActiveSyncerAfterDisconnect(t *testing.T) { // We'll create our test sync manager to have two active syncers. syncMgr := newTestSyncManager(2) - syncMgr.Start() + syncMgr.Start(context.Background()) defer syncMgr.Stop() // The first will be an active syncer that performs a historical sync @@ -187,7 +188,7 @@ func TestSyncManagerRotateActiveSyncerCandidate(t *testing.T) { // We'll create our sync manager with three active syncers. syncMgr := newTestSyncManager(1) - syncMgr.Start() + syncMgr.Start(context.Background()) defer syncMgr.Stop() // The first syncer registered always performs a historical sync. @@ -235,7 +236,7 @@ func TestSyncManagerNoInitialHistoricalSync(t *testing.T) { t.Parallel() syncMgr := newTestSyncManager(0) - syncMgr.Start() + syncMgr.Start(context.Background()) defer syncMgr.Stop() // We should not expect any messages from the peer. @@ -269,7 +270,7 @@ func TestSyncManagerInitialHistoricalSync(t *testing.T) { t.Fatal("expected graph to not be considered as synced") } - syncMgr.Start() + syncMgr.Start(context.Background()) defer syncMgr.Stop() // We should expect to see a QueryChannelRange message with a @@ -338,7 +339,7 @@ func TestSyncManagerHistoricalSyncOnReconnect(t *testing.T) { t.Parallel() syncMgr := newTestSyncManager(2) - syncMgr.Start() + syncMgr.Start(context.Background()) defer syncMgr.Stop() // We should expect to see a QueryChannelRange message with a @@ -372,7 +373,7 @@ func TestSyncManagerForceHistoricalSync(t *testing.T) { t.Parallel() syncMgr := newTestSyncManager(1) - syncMgr.Start() + syncMgr.Start(context.Background()) defer syncMgr.Stop() // We should expect to see a QueryChannelRange message with a @@ -410,7 +411,7 @@ func TestSyncManagerGraphSyncedAfterHistoricalSyncReplacement(t *testing.T) { t.Parallel() syncMgr := newTestSyncManager(1) - syncMgr.Start() + syncMgr.Start(context.Background()) defer syncMgr.Stop() // We should expect to see a QueryChannelRange message with a @@ -468,7 +469,7 @@ func TestSyncManagerWaitUntilInitialHistoricalSync(t *testing.T) { // We'll start by creating our test sync manager which will hold up to // 2 active syncers. syncMgr := newTestSyncManager(numActiveSyncers) - syncMgr.Start() + syncMgr.Start(context.Background()) defer syncMgr.Stop() // We'll go ahead and create our syncers. From 1bc66db2288e6b204b1cad4d3d713ae1ad17776c Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 7 Apr 2025 11:13:01 +0200 Subject: [PATCH 08/41] discovery: pass context to ProcessRemoteAnnouncement With this, we move a context.TODO() out of the gossiper and into the brontide package - this will be removed in a future PR which focuses on threading contexts through that code. --- discovery/gossiper.go | 10 +- discovery/gossiper_test.go | 766 ++++++++++++++++++++----------------- discovery/syncer_test.go | 2 +- peer/brontide.go | 6 +- 4 files changed, 431 insertions(+), 353 deletions(-) diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 6232a92ce8..ab12524fb6 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -857,10 +857,8 @@ func (d *AuthenticatedGossiper) stop() { // then added to a queue for batched trickled announcement to all connected // peers. Remote channel announcements should contain the announcement proof // and be fully validated. -func (d *AuthenticatedGossiper) ProcessRemoteAnnouncement(msg lnwire.Message, - peer lnpeer.Peer) chan error { - - ctx := context.TODO() +func (d *AuthenticatedGossiper) ProcessRemoteAnnouncement(ctx context.Context, + msg lnwire.Message, peer lnpeer.Peer) chan error { log.Debugf("Processing remote msg %T from peer=%x", msg, peer.PubKey()) @@ -950,8 +948,12 @@ func (d *AuthenticatedGossiper) ProcessRemoteAnnouncement(msg lnwire.Message, // If the peer that sent us this error is quitting, then we don't need // to send back an error and can return immediately. + // TODO(elle): the peer should now just rely on canceling the passed + // context. case <-peer.QuitSignal(): return nil + case <-ctx.Done(): + return nil case <-d.quit: nMsg.err <- ErrGossiperShuttingDown } diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index b3ca738ebf..ef7f2f21f6 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -1020,9 +1020,10 @@ func createTestCtx(t *testing.T, startHeight uint32, isChanPeer bool) ( // the router subsystem. func TestProcessAnnouncement(t *testing.T) { t.Parallel() + ctx := context.Background() timestamp := testTimestamp - ctx, err := createTestCtx(t, 0, false) + tCtx, err := createTestCtx(t, 0, false) require.NoError(t, err, "can't create context") assertSenderExistence := func(sender *btcec.PublicKey, msg msgWithSenders) { @@ -1038,11 +1039,11 @@ func TestProcessAnnouncement(t *testing.T) { // First, we'll craft a valid remote channel announcement and send it to // the gossiper so that it can be processed. - ca, err := ctx.createRemoteChannelAnnouncement(0) + ca, err := tCtx.createRemoteChannelAnnouncement(0) require.NoError(t, err, "can't create channel announcement") select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement(ca, nodePeer): + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement(ctx, ca, nodePeer): case <-time.After(2 * time.Second): t.Fatal("remote announcement not processed") } @@ -1051,13 +1052,13 @@ func TestProcessAnnouncement(t *testing.T) { // The announcement should be broadcast and included in our local view // of the graph. select { - case msg := <-ctx.broadcastedMessage: + case msg := <-tCtx.broadcastedMessage: assertSenderExistence(nodePeer.IdentityKey(), msg) case <-time.After(2 * trickleDelay): t.Fatal("announcement wasn't proceeded") } - if len(ctx.router.infos) != 1 { + if len(tCtx.router.infos) != 1 { t.Fatalf("edge wasn't added to router: %v", err) } @@ -1068,7 +1069,7 @@ func TestProcessAnnouncement(t *testing.T) { // We send an invalid channel update and expect it to fail. select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement(ua, nodePeer): + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement(ctx, ua, nodePeer): case <-time.After(2 * time.Second): t.Fatal("remote announcement not processed") } @@ -1077,7 +1078,7 @@ func TestProcessAnnouncement(t *testing.T) { // We should not broadcast the channel update. select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("gossiper should not have broadcast channel update") case <-time.After(2 * trickleDelay): } @@ -1088,7 +1089,7 @@ func TestProcessAnnouncement(t *testing.T) { require.NoError(t, err, "can't create update announcement") select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement(ua, nodePeer): + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement(ctx, ua, nodePeer): case <-time.After(2 * time.Second): t.Fatal("remote announcement not processed") } @@ -1096,13 +1097,13 @@ func TestProcessAnnouncement(t *testing.T) { // The channel policy should be broadcast to the rest of the network. select { - case msg := <-ctx.broadcastedMessage: + case msg := <-tCtx.broadcastedMessage: assertSenderExistence(nodePeer.IdentityKey(), msg) case <-time.After(2 * trickleDelay): t.Fatal("announcement wasn't proceeded") } - if len(ctx.router.edges) != 1 { + if len(tCtx.router.edges) != 1 { t.Fatalf("edge update wasn't added to router: %v", err) } @@ -1111,7 +1112,7 @@ func TestProcessAnnouncement(t *testing.T) { require.NoError(t, err, "can't create node announcement") select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement(na, nodePeer): + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement(ctx, na, nodePeer): case <-time.After(2 * time.Second): t.Fatal("remote announcement not processed") } @@ -1120,13 +1121,13 @@ func TestProcessAnnouncement(t *testing.T) { // It should also be broadcast to the network and included in our local // view of the graph. select { - case msg := <-ctx.broadcastedMessage: + case msg := <-tCtx.broadcastedMessage: assertSenderExistence(nodePeer.IdentityKey(), msg) case <-time.After(2 * trickleDelay): t.Fatal("announcement wasn't proceeded") } - if len(ctx.router.nodes) != 1 { + if len(tCtx.router.nodes) != 1 { t.Fatalf("node wasn't added to router: %v", err) } } @@ -1135,10 +1136,11 @@ func TestProcessAnnouncement(t *testing.T) { // propagated to the router subsystem. func TestPrematureAnnouncement(t *testing.T) { t.Parallel() + ctx := context.Background() timestamp := testTimestamp - ctx, err := createTestCtx(t, 0, false) + tCtx, err := createTestCtx(t, 0, false) require.NoError(t, err, "can't create context") _, err = createNodeAnnouncement(remoteKeyPriv1, timestamp) @@ -1150,18 +1152,18 @@ func TestPrematureAnnouncement(t *testing.T) { // remote side, but block height of this announcement is greater than // highest know to us, for that reason it should be ignored and not // added to the router. - ca, err := ctx.createRemoteChannelAnnouncement( + ca, err := tCtx.createRemoteChannelAnnouncement( 1, withFundingTxPrep(fundingTxPrepTypeNone), ) require.NoError(t, err, "can't create channel announcement") select { - case <-ctx.gossiper.ProcessRemoteAnnouncement(ca, nodePeer): + case <-tCtx.gossiper.ProcessRemoteAnnouncement(ctx, ca, nodePeer): case <-time.After(time.Second): t.Fatal("announcement was not processed") } - if len(ctx.router.infos) != 0 { + if len(tCtx.router.infos) != 0 { t.Fatal("edge was added to router") } } @@ -1170,69 +1172,70 @@ func TestPrematureAnnouncement(t *testing.T) { // properly processes partial and fully announcement signatures message. func TestSignatureAnnouncementLocalFirst(t *testing.T) { t.Parallel() + ctx := context.Background() - ctx, err := createTestCtx(t, proofMatureDelta, false) + tCtx, err := createTestCtx(t, proofMatureDelta, false) require.NoError(t, err, "can't create context") // Set up a channel that we can use to inspect the messages sent // directly from the gossiper. sentMsgs := make(chan lnwire.Message, 10) - ctx.gossiper.reliableSender.cfg.NotifyWhenOnline = func(target [33]byte, - peerChan chan<- lnpeer.Peer) { + tCtx.gossiper.reliableSender.cfg.NotifyWhenOnline = func( + target [33]byte, peerChan chan<- lnpeer.Peer) { pk, _ := btcec.ParsePubKey(target[:]) select { case peerChan <- &mockPeer{ - pk, sentMsgs, ctx.gossiper.quit, atomic.Bool{}, + pk, sentMsgs, tCtx.gossiper.quit, atomic.Bool{}, }: - case <-ctx.gossiper.quit: + case <-tCtx.gossiper.quit: } } - batch, err := ctx.createLocalAnnouncements(0) + batch, err := tCtx.createLocalAnnouncements(0) require.NoError(t, err, "can't generate announcements") remoteKey, err := btcec.ParsePubKey(batch.nodeAnn2.NodeID[:]) require.NoError(t, err, "unable to parse pubkey") remotePeer := &mockPeer{ - remoteKey, sentMsgs, ctx.gossiper.quit, atomic.Bool{}, + remoteKey, sentMsgs, tCtx.gossiper.quit, atomic.Bool{}, } // Recreate lightning network topology. Initialize router with channel // between two nodes. select { - case err = <-ctx.gossiper.ProcessLocalAnnouncement(batch.chanAnn): + case err = <-tCtx.gossiper.ProcessLocalAnnouncement(batch.chanAnn): case <-time.After(2 * time.Second): t.Fatal("did not process local announcement") } require.NoError(t, err, "unable to process channel ann") select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("channel announcement was broadcast") case <-time.After(2 * trickleDelay): } select { - case err = <-ctx.gossiper.ProcessLocalAnnouncement(batch.chanUpdAnn1): + case err = <-tCtx.gossiper.ProcessLocalAnnouncement(batch.chanUpdAnn1): case <-time.After(2 * time.Second): t.Fatal("did not process local announcement") } require.NoError(t, err, "unable to process channel update") select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("channel update announcement was broadcast") case <-time.After(2 * trickleDelay): } select { - case err = <-ctx.gossiper.ProcessLocalAnnouncement(batch.nodeAnn1): + case err = <-tCtx.gossiper.ProcessLocalAnnouncement(batch.nodeAnn1): case <-time.After(2 * time.Second): t.Fatal("did not process local announcement") } require.NoError(t, err, "unable to process node ann") select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("node announcement was broadcast") case <-time.After(2 * trickleDelay): } @@ -1248,29 +1251,29 @@ func TestSignatureAnnouncementLocalFirst(t *testing.T) { } select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement( - batch.chanUpdAnn2, remotePeer, + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.chanUpdAnn2, remotePeer, ): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") } require.NoError(t, err, "unable to process channel update") select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("channel update announcement was broadcast") case <-time.After(2 * trickleDelay): } select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement( - batch.nodeAnn2, remotePeer, + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.nodeAnn2, remotePeer, ): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") } require.NoError(t, err, "unable to process node ann") select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("node announcement was broadcast") case <-time.After(2 * trickleDelay): } @@ -1278,20 +1281,22 @@ func TestSignatureAnnouncementLocalFirst(t *testing.T) { // Pretending that we receive local channel announcement from funding // manager, thereby kick off the announcement exchange process. select { - case err = <-ctx.gossiper.ProcessLocalAnnouncement(batch.localProofAnn): + case err = <-tCtx.gossiper.ProcessLocalAnnouncement( + batch.localProofAnn, + ): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") } require.NoError(t, err, "unable to process local proof") select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("announcements were broadcast") case <-time.After(2 * trickleDelay): } number := 0 - if err := ctx.gossiper.cfg.WaitingProofStore.ForAll( + if err := tCtx.gossiper.cfg.WaitingProofStore.ForAll( func(*channeldb.WaitingProof) error { number++ return nil @@ -1308,8 +1313,8 @@ func TestSignatureAnnouncementLocalFirst(t *testing.T) { } select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement( - batch.remoteProofAnn, remotePeer, + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.remoteProofAnn, remotePeer, ): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") @@ -1318,14 +1323,14 @@ func TestSignatureAnnouncementLocalFirst(t *testing.T) { for i := 0; i < 5; i++ { select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: case <-time.After(time.Second): t.Fatal("announcement wasn't broadcast") } } number = 0 - if err := ctx.gossiper.cfg.WaitingProofStore.ForAll( + if err := tCtx.gossiper.cfg.WaitingProofStore.ForAll( func(*channeldb.WaitingProof) error { number++ return nil @@ -1346,33 +1351,34 @@ func TestSignatureAnnouncementLocalFirst(t *testing.T) { // processes announcement with unknown channel ids. func TestOrphanSignatureAnnouncement(t *testing.T) { t.Parallel() + ctx := context.Background() - ctx, err := createTestCtx(t, proofMatureDelta, false) + tCtx, err := createTestCtx(t, proofMatureDelta, false) require.NoError(t, err, "can't create context") // Set up a channel that we can use to inspect the messages sent // directly from the gossiper. sentMsgs := make(chan lnwire.Message, 10) - ctx.gossiper.reliableSender.cfg.NotifyWhenOnline = func(target [33]byte, - peerChan chan<- lnpeer.Peer) { + tCtx.gossiper.reliableSender.cfg.NotifyWhenOnline = func( + target [33]byte, peerChan chan<- lnpeer.Peer) { pk, _ := btcec.ParsePubKey(target[:]) select { case peerChan <- &mockPeer{ - pk, sentMsgs, ctx.gossiper.quit, atomic.Bool{}, + pk, sentMsgs, tCtx.gossiper.quit, atomic.Bool{}, }: - case <-ctx.gossiper.quit: + case <-tCtx.gossiper.quit: } } - batch, err := ctx.createLocalAnnouncements(0) + batch, err := tCtx.createLocalAnnouncements(0) require.NoError(t, err, "can't generate announcements") remoteKey, err := btcec.ParsePubKey(batch.nodeAnn2.NodeID[:]) require.NoError(t, err, "unable to parse pubkey") remotePeer := &mockPeer{ - remoteKey, sentMsgs, ctx.gossiper.quit, atomic.Bool{}, + remoteKey, sentMsgs, tCtx.gossiper.quit, atomic.Bool{}, } // Pretending that we receive local channel announcement from funding @@ -1380,15 +1386,16 @@ func TestOrphanSignatureAnnouncement(t *testing.T) { // this case the announcement should be added in the orphan batch // because we haven't announce the channel yet. select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement(batch.remoteProofAnn, - remotePeer): + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.remoteProofAnn, remotePeer, + ): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") } require.NoError(t, err, "unable to proceed announcement") number := 0 - if err := ctx.gossiper.cfg.WaitingProofStore.ForAll( + if err := tCtx.gossiper.cfg.WaitingProofStore.ForAll( func(*channeldb.WaitingProof) error { number++ return nil @@ -1407,7 +1414,7 @@ func TestOrphanSignatureAnnouncement(t *testing.T) { // Recreate lightning network topology. Initialize router with channel // between two nodes. select { - case err = <-ctx.gossiper.ProcessLocalAnnouncement(batch.chanAnn): + case err = <-tCtx.gossiper.ProcessLocalAnnouncement(batch.chanAnn): case <-time.After(2 * time.Second): t.Fatal("did not process local announcement") } @@ -1415,32 +1422,32 @@ func TestOrphanSignatureAnnouncement(t *testing.T) { require.NoError(t, err, "unable to process") select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("channel announcement was broadcast") case <-time.After(2 * trickleDelay): } select { - case err = <-ctx.gossiper.ProcessLocalAnnouncement(batch.chanUpdAnn1): + case err = <-tCtx.gossiper.ProcessLocalAnnouncement(batch.chanUpdAnn1): case <-time.After(2 * time.Second): t.Fatal("did not process local announcement") } require.NoError(t, err, "unable to process") select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("channel update announcement was broadcast") case <-time.After(2 * trickleDelay): } select { - case err = <-ctx.gossiper.ProcessLocalAnnouncement(batch.nodeAnn1): + case err = <-tCtx.gossiper.ProcessLocalAnnouncement(batch.nodeAnn1): case <-time.After(2 * time.Second): t.Fatal("did not process local announcement") } require.NoError(t, err, "unable to process node ann") select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("node announcement was broadcast") case <-time.After(2 * trickleDelay): } @@ -1456,28 +1463,29 @@ func TestOrphanSignatureAnnouncement(t *testing.T) { } select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement(batch.chanUpdAnn2, - remotePeer): + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.chanUpdAnn2, remotePeer, + ): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") } require.NoError(t, err, "unable to process node ann") select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("channel update announcement was broadcast") case <-time.After(2 * trickleDelay): } select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement( - batch.nodeAnn2, remotePeer, + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.nodeAnn2, remotePeer, ): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") } require.NoError(t, err, "unable to process") select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("node announcement announcement was broadcast") case <-time.After(2 * trickleDelay): } @@ -1485,7 +1493,9 @@ func TestOrphanSignatureAnnouncement(t *testing.T) { // After that we process local announcement, and waiting to receive // the channel announcement. select { - case err = <-ctx.gossiper.ProcessLocalAnnouncement(batch.localProofAnn): + case err = <-tCtx.gossiper.ProcessLocalAnnouncement( + batch.localProofAnn, + ): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") } @@ -1503,14 +1513,14 @@ func TestOrphanSignatureAnnouncement(t *testing.T) { // should be broadcasting the final channel announcements. for i := 0; i < 5; i++ { select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: case <-time.After(time.Second): t.Fatal("announcement wasn't broadcast") } } number = 0 - if err := ctx.gossiper.cfg.WaitingProofStore.ForAll( + if err := tCtx.gossiper.cfg.WaitingProofStore.ForAll( func(p *channeldb.WaitingProof) error { number++ return nil @@ -1533,11 +1543,12 @@ func TestOrphanSignatureAnnouncement(t *testing.T) { // assembled. func TestSignatureAnnouncementRetryAtStartup(t *testing.T) { t.Parallel() + ctx := context.Background() - ctx, err := createTestCtx(t, proofMatureDelta, false) + tCtx, err := createTestCtx(t, proofMatureDelta, false) require.NoError(t, err, "can't create context") - batch, err := ctx.createLocalAnnouncements(0) + batch, err := tCtx.createLocalAnnouncements(0) require.NoError(t, err, "can't generate announcements") remoteKey, err := btcec.ParsePubKey(batch.nodeAnn2.NodeID[:]) @@ -1546,7 +1557,7 @@ func TestSignatureAnnouncementRetryAtStartup(t *testing.T) { // Set up a channel to intercept the messages sent to the remote peer. sentToPeer := make(chan lnwire.Message, 1) remotePeer := &mockPeer{ - remoteKey, sentToPeer, ctx.gossiper.quit, atomic.Bool{}, + remoteKey, sentToPeer, tCtx.gossiper.quit, atomic.Bool{}, } // Since the reliable send to the remote peer of the local channel proof @@ -1554,7 +1565,7 @@ func TestSignatureAnnouncementRetryAtStartup(t *testing.T) { // channel through which it gets sent to control exactly when to // dispatch it. notifyPeers := make(chan chan<- lnpeer.Peer, 1) - ctx.gossiper.reliableSender.cfg.NotifyWhenOnline = func(peer [33]byte, + tCtx.gossiper.reliableSender.cfg.NotifyWhenOnline = func(peer [33]byte, connectedChan chan<- lnpeer.Peer) { notifyPeers <- connectedChan } @@ -1562,13 +1573,13 @@ func TestSignatureAnnouncementRetryAtStartup(t *testing.T) { // Recreate lightning network topology. Initialize router with channel // between two nodes. select { - case err = <-ctx.gossiper.ProcessLocalAnnouncement(batch.chanAnn): + case err = <-tCtx.gossiper.ProcessLocalAnnouncement(batch.chanAnn): case <-time.After(2 * time.Second): t.Fatal("did not process local announcement") } require.NoError(t, err, "unable to process channel ann") select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("channel announcement was broadcast") case <-time.After(2 * trickleDelay): } @@ -1576,7 +1587,7 @@ func TestSignatureAnnouncementRetryAtStartup(t *testing.T) { // Pretending that we receive local channel announcement from funding // manager, thereby kick off the announcement exchange process. select { - case err = <-ctx.gossiper.ProcessLocalAnnouncement( + case err = <-tCtx.gossiper.ProcessLocalAnnouncement( batch.localProofAnn, ): case <-time.After(2 * time.Second): @@ -1598,7 +1609,7 @@ func TestSignatureAnnouncementRetryAtStartup(t *testing.T) { // The proof should not be broadcast yet since we're still missing the // remote party's. select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("announcements were broadcast") case <-time.After(2 * trickleDelay): } @@ -1611,7 +1622,7 @@ func TestSignatureAnnouncementRetryAtStartup(t *testing.T) { } number := 0 - if err := ctx.gossiper.cfg.WaitingProofStore.ForAll( + if err := tCtx.gossiper.cfg.WaitingProofStore.ForAll( func(*channeldb.WaitingProof) error { number++ return nil @@ -1630,7 +1641,7 @@ func TestSignatureAnnouncementRetryAtStartup(t *testing.T) { // Restart the gossiper and restore its original NotifyWhenOnline and // NotifyWhenOffline methods. This should trigger a new attempt to send // the message to the peer. - ctx.gossiper.Stop() + require.NoError(t, tCtx.gossiper.Stop()) isAlias := func(lnwire.ShortChannelID) bool { return false @@ -1654,19 +1665,19 @@ func TestSignatureAnnouncementRetryAtStartup(t *testing.T) { //nolint:ll gossiper := New(Config{ - Notifier: ctx.gossiper.cfg.Notifier, - Broadcast: ctx.gossiper.cfg.Broadcast, - NotifyWhenOnline: ctx.gossiper.reliableSender.cfg.NotifyWhenOnline, - NotifyWhenOffline: ctx.gossiper.reliableSender.cfg.NotifyWhenOffline, - FetchSelfAnnouncement: ctx.gossiper.cfg.FetchSelfAnnouncement, - UpdateSelfAnnouncement: ctx.gossiper.cfg.UpdateSelfAnnouncement, - Graph: ctx.gossiper.cfg.Graph, + Notifier: tCtx.gossiper.cfg.Notifier, + Broadcast: tCtx.gossiper.cfg.Broadcast, + NotifyWhenOnline: tCtx.gossiper.reliableSender.cfg.NotifyWhenOnline, + NotifyWhenOffline: tCtx.gossiper.reliableSender.cfg.NotifyWhenOffline, + FetchSelfAnnouncement: tCtx.gossiper.cfg.FetchSelfAnnouncement, + UpdateSelfAnnouncement: tCtx.gossiper.cfg.UpdateSelfAnnouncement, + Graph: tCtx.gossiper.cfg.Graph, TrickleDelay: trickleDelay, RetransmitTicker: ticker.NewForce(retransmitDelay), RebroadcastInterval: rebroadcastInterval, ProofMatureDelta: proofMatureDelta, - WaitingProofStore: ctx.gossiper.cfg.WaitingProofStore, - MessageStore: ctx.gossiper.cfg.MessageStore, + WaitingProofStore: tCtx.gossiper.cfg.WaitingProofStore, + MessageStore: tCtx.gossiper.cfg.MessageStore, RotateTicker: ticker.NewForce(DefaultSyncerRotationInterval), HistoricalSyncTicker: ticker.NewForce(DefaultHistoricalSyncInterval), NumActiveSyncers: 3, @@ -1677,8 +1688,8 @@ func TestSignatureAnnouncementRetryAtStartup(t *testing.T) { FindBaseByAlias: findBaseByAlias, GetAlias: getAlias, }, &keychain.KeyDescriptor{ - PubKey: ctx.gossiper.selfKey, - KeyLocator: ctx.gossiper.selfKeyLoc, + PubKey: tCtx.gossiper.selfKey, + KeyLocator: tCtx.gossiper.selfKeyLoc, }) require.NoError(t, err, "unable to recreate gossiper") if err := gossiper.Start(context.Background()); err != nil { @@ -1690,8 +1701,8 @@ func TestSignatureAnnouncementRetryAtStartup(t *testing.T) { // broadcast. gossiper.syncMgr.markGraphSynced() - ctx.gossiper = gossiper - remotePeer.quit = ctx.gossiper.quit + tCtx.gossiper = gossiper + remotePeer.quit = tCtx.gossiper.quit // After starting up, the gossiper will see that it has a proof in the // WaitingProofStore, and will retry sending its part to the remote. @@ -1729,8 +1740,8 @@ out: // Now exchanging the remote channel proof, the channel announcement // broadcast should continue as normal. select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement( - batch.remoteProofAnn, remotePeer, + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.remoteProofAnn, remotePeer, ): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") @@ -1740,13 +1751,13 @@ out: } select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: case <-time.After(time.Second): t.Fatal("announcement wasn't broadcast") } number = 0 - if err := ctx.gossiper.cfg.WaitingProofStore.ForAll( + if err := tCtx.gossiper.cfg.WaitingProofStore.ForAll( func(*channeldb.WaitingProof) error { number++ return nil @@ -1768,11 +1779,12 @@ out: // the full proof (ChannelAnnouncement) to the remote peer. func TestSignatureAnnouncementFullProofWhenRemoteProof(t *testing.T) { t.Parallel() + ctx := context.Background() - ctx, err := createTestCtx(t, proofMatureDelta, false) + tCtx, err := createTestCtx(t, proofMatureDelta, false) require.NoError(t, err, "can't create context") - batch, err := ctx.createLocalAnnouncements(0) + batch, err := tCtx.createLocalAnnouncements(0) require.NoError(t, err, "can't generate announcements") remoteKey, err := btcec.ParsePubKey(batch.nodeAnn2.NodeID[:]) @@ -1782,12 +1794,12 @@ func TestSignatureAnnouncementFullProofWhenRemoteProof(t *testing.T) { // gossiper to the remote peer. sentToPeer := make(chan lnwire.Message, 1) remotePeer := &mockPeer{ - remoteKey, sentToPeer, ctx.gossiper.quit, atomic.Bool{}, + remoteKey, sentToPeer, tCtx.gossiper.quit, atomic.Bool{}, } // Override NotifyWhenOnline to return the remote peer which we expect // meesages to be sent to. - ctx.gossiper.reliableSender.cfg.NotifyWhenOnline = func(peer [33]byte, + tCtx.gossiper.reliableSender.cfg.NotifyWhenOnline = func(peer [33]byte, peerChan chan<- lnpeer.Peer) { peerChan <- remotePeer @@ -1796,7 +1808,7 @@ func TestSignatureAnnouncementFullProofWhenRemoteProof(t *testing.T) { // Recreate lightning network topology. Initialize router with channel // between two nodes. select { - case err = <-ctx.gossiper.ProcessLocalAnnouncement( + case err = <-tCtx.gossiper.ProcessLocalAnnouncement( batch.chanAnn, ): case <-time.After(2 * time.Second): @@ -1804,13 +1816,13 @@ func TestSignatureAnnouncementFullProofWhenRemoteProof(t *testing.T) { } require.NoError(t, err, "unable to process channel ann") select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("channel announcement was broadcast") case <-time.After(2 * trickleDelay): } select { - case err = <-ctx.gossiper.ProcessLocalAnnouncement( + case err = <-tCtx.gossiper.ProcessLocalAnnouncement( batch.chanUpdAnn1, ): case <-time.After(2 * time.Second): @@ -1818,7 +1830,7 @@ func TestSignatureAnnouncementFullProofWhenRemoteProof(t *testing.T) { } require.NoError(t, err, "unable to process channel update") select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("channel update announcement was broadcast") case <-time.After(2 * trickleDelay): } @@ -1831,7 +1843,7 @@ func TestSignatureAnnouncementFullProofWhenRemoteProof(t *testing.T) { } select { - case err = <-ctx.gossiper.ProcessLocalAnnouncement( + case err = <-tCtx.gossiper.ProcessLocalAnnouncement( batch.nodeAnn1, ): case <-time.After(2 * time.Second): @@ -1841,35 +1853,34 @@ func TestSignatureAnnouncementFullProofWhenRemoteProof(t *testing.T) { t.Fatalf("unable to process node ann:%v", err) } select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("node announcement was broadcast") case <-time.After(2 * trickleDelay): } select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement( - batch.chanUpdAnn2, remotePeer, + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.chanUpdAnn2, remotePeer, ): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") } require.NoError(t, err, "unable to process channel update") select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("channel update announcement was broadcast") case <-time.After(2 * trickleDelay): } - select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement( - batch.nodeAnn2, remotePeer, + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.nodeAnn2, remotePeer, ): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") } require.NoError(t, err, "unable to process node ann") select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("node announcement was broadcast") case <-time.After(2 * trickleDelay): } @@ -1877,7 +1888,7 @@ func TestSignatureAnnouncementFullProofWhenRemoteProof(t *testing.T) { // Pretending that we receive local channel announcement from funding // manager, thereby kick off the announcement exchange process. select { - case err = <-ctx.gossiper.ProcessLocalAnnouncement( + case err = <-tCtx.gossiper.ProcessLocalAnnouncement( batch.localProofAnn, ): case <-time.After(2 * time.Second): @@ -1886,8 +1897,8 @@ func TestSignatureAnnouncementFullProofWhenRemoteProof(t *testing.T) { require.NoError(t, err, "unable to process local proof") select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement( - batch.remoteProofAnn, remotePeer, + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.remoteProofAnn, remotePeer, ): case <-time.After(2 * time.Second): t.Fatal("did not process local announcement") @@ -1905,14 +1916,14 @@ func TestSignatureAnnouncementFullProofWhenRemoteProof(t *testing.T) { // All channel and node announcements should be broadcast. for i := 0; i < 5; i++ { select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: case <-time.After(time.Second): t.Fatal("announcement wasn't broadcast") } } number := 0 - if err := ctx.gossiper.cfg.WaitingProofStore.ForAll( + if err := tCtx.gossiper.cfg.WaitingProofStore.ForAll( func(*channeldb.WaitingProof) error { number++ return nil @@ -1931,8 +1942,8 @@ func TestSignatureAnnouncementFullProofWhenRemoteProof(t *testing.T) { // Now give the gossiper the remote proof yet again. This should // trigger a send of the full ChannelAnnouncement. select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement( - batch.remoteProofAnn, remotePeer, + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.remoteProofAnn, remotePeer, ): case <-time.After(2 * time.Second): t.Fatal("did not process local announcement") @@ -2210,25 +2221,26 @@ func TestDeDuplicatedAnnouncements(t *testing.T) { // announcements for nodes who do not intend to publicly advertise themselves. func TestForwardPrivateNodeAnnouncement(t *testing.T) { t.Parallel() + ctx := context.Background() const ( startingHeight = 100 timestamp = 123456 ) - ctx, err := createTestCtx(t, startingHeight, false) + tCtx, err := createTestCtx(t, startingHeight, false) require.NoError(t, err, "can't create context") // We'll start off by processing a channel announcement without a proof // (i.e., an unadvertised channel), followed by a node announcement for // this same channel announcement. - chanAnn := ctx.createAnnouncementWithoutProof( + chanAnn := tCtx.createAnnouncementWithoutProof( startingHeight-2, selfKeyDesc.PubKey, remoteKeyPub1, ) pubKey := remoteKeyPriv1.PubKey() select { - case err := <-ctx.gossiper.ProcessLocalAnnouncement(chanAnn): + case err := <-tCtx.gossiper.ProcessLocalAnnouncement(chanAnn): if err != nil { t.Fatalf("unable to process local announcement: %v", err) } @@ -2239,7 +2251,7 @@ func TestForwardPrivateNodeAnnouncement(t *testing.T) { // The gossiper should not broadcast the announcement due to it not // having its announcement signatures. select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("gossiper should not have broadcast channel announcement") case <-time.After(2 * trickleDelay): } @@ -2248,7 +2260,7 @@ func TestForwardPrivateNodeAnnouncement(t *testing.T) { require.NoError(t, err, "unable to create node announcement") select { - case err := <-ctx.gossiper.ProcessLocalAnnouncement(nodeAnn): + case err := <-tCtx.gossiper.ProcessLocalAnnouncement(nodeAnn): if err != nil { t.Fatalf("unable to process remote announcement: %v", err) } @@ -2259,7 +2271,7 @@ func TestForwardPrivateNodeAnnouncement(t *testing.T) { // The gossiper should also not broadcast the node announcement due to // it not being part of any advertised channels. select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("gossiper should not have broadcast node announcement") case <-time.After(2 * trickleDelay): } @@ -2268,14 +2280,16 @@ func TestForwardPrivateNodeAnnouncement(t *testing.T) { // by opening a public channel on the network. We'll create a // ChannelAnnouncement and hand it off to the gossiper in order to // process it. - remoteChanAnn, err := ctx.createRemoteChannelAnnouncement( + remoteChanAnn, err := tCtx.createRemoteChannelAnnouncement( startingHeight - 1, ) require.NoError(t, err, "unable to create remote channel announcement") peer := &mockPeer{pubKey, nil, nil, atomic.Bool{}} select { - case err := <-ctx.gossiper.ProcessRemoteAnnouncement(remoteChanAnn, peer): + case err := <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, remoteChanAnn, peer, + ): if err != nil { t.Fatalf("unable to process remote announcement: %v", err) } @@ -2284,7 +2298,7 @@ func TestForwardPrivateNodeAnnouncement(t *testing.T) { } select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: case <-time.After(2 * trickleDelay): t.Fatal("gossiper should have broadcast the channel announcement") } @@ -2295,7 +2309,9 @@ func TestForwardPrivateNodeAnnouncement(t *testing.T) { require.NoError(t, err, "unable to create node announcement") select { - case err := <-ctx.gossiper.ProcessRemoteAnnouncement(nodeAnn, peer): + case err := <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, nodeAnn, peer, + ): if err != nil { t.Fatalf("unable to process remote announcement: %v", err) } @@ -2304,7 +2320,7 @@ func TestForwardPrivateNodeAnnouncement(t *testing.T) { } select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: case <-time.After(2 * trickleDelay): t.Fatal("gossiper should have broadcast the node announcement") } @@ -2314,13 +2330,14 @@ func TestForwardPrivateNodeAnnouncement(t *testing.T) { // zombie edges. func TestRejectZombieEdge(t *testing.T) { t.Parallel() + ctx := context.Background() // We'll start by creating our test context with a batch of // announcements. - ctx, err := createTestCtx(t, 0, false) + tCtx, err := createTestCtx(t, 0, false) require.NoError(t, err, "unable to create test context") - batch, err := ctx.createRemoteAnnouncements(0) + batch, err := tCtx.createRemoteAnnouncements(0) require.NoError(t, err, "unable to create announcements") remotePeer := &mockPeer{pk: remoteKeyPriv2.PubKey()} @@ -2330,8 +2347,8 @@ func TestRejectZombieEdge(t *testing.T) { processAnnouncements := func(isZombie bool) { t.Helper() - errChan := ctx.gossiper.ProcessRemoteAnnouncement( - batch.chanAnn, remotePeer, + errChan := tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.chanAnn, remotePeer, ) select { case err := <-errChan: @@ -2347,7 +2364,7 @@ func TestRejectZombieEdge(t *testing.T) { t.Fatal("expected to process channel announcement") } select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: if isZombie { t.Fatal("expected to not broadcast zombie " + "channel announcement") @@ -2359,8 +2376,8 @@ func TestRejectZombieEdge(t *testing.T) { } } - errChan = ctx.gossiper.ProcessRemoteAnnouncement( - batch.chanUpdAnn2, remotePeer, + errChan = tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.chanUpdAnn2, remotePeer, ) select { case err := <-errChan: @@ -2376,7 +2393,7 @@ func TestRejectZombieEdge(t *testing.T) { t.Fatal("expected to process channel update") } select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: if isZombie { t.Fatal("expected to not broadcast zombie " + "channel update") @@ -2393,7 +2410,7 @@ func TestRejectZombieEdge(t *testing.T) { // zombie within the router. This should reject any announcements for // this edge while it remains as a zombie. chanID := batch.chanAnn.ShortChannelID - err = ctx.router.MarkEdgeZombie( + err = tCtx.router.MarkEdgeZombie( chanID, batch.chanAnn.NodeID1, batch.chanAnn.NodeID2, ) if err != nil { @@ -2404,7 +2421,7 @@ func TestRejectZombieEdge(t *testing.T) { // If we then mark the edge as live, the edge's zombie status should be // overridden and the announcements should be processed. - if err := ctx.router.MarkEdgeLive(chanID); err != nil { + if err := tCtx.router.MarkEdgeLive(chanID); err != nil { t.Fatalf("unable mark channel %v as zombie: %v", chanID, err) } @@ -2415,13 +2432,14 @@ func TestRejectZombieEdge(t *testing.T) { // becomes live by receiving a fresh update. func TestProcessZombieEdgeNowLive(t *testing.T) { t.Parallel() + ctx := context.Background() // We'll start by creating our test context with a batch of // announcements. - ctx, err := createTestCtx(t, 0, false) + tCtx, err := createTestCtx(t, 0, false) require.NoError(t, err, "unable to create test context") - batch, err := ctx.createRemoteAnnouncements(0) + batch, err := tCtx.createRemoteAnnouncements(0) require.NoError(t, err, "unable to create announcements") remotePeer := &mockPeer{pk: remoteKeyPriv1.PubKey()} @@ -2435,8 +2453,8 @@ func TestProcessZombieEdgeNowLive(t *testing.T) { processAnnouncement := func(ann lnwire.Message, isZombie, expectsErr bool) { t.Helper() - errChan := ctx.gossiper.ProcessRemoteAnnouncement( - ann, remotePeer, + errChan := tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, ann, remotePeer, ) var err error @@ -2454,7 +2472,7 @@ func TestProcessZombieEdgeNowLive(t *testing.T) { } select { - case msgWithSenders := <-ctx.broadcastedMessage: + case msgWithSenders := <-tCtx.broadcastedMessage: if isZombie { t.Fatal("expected to not broadcast zombie " + "channel message") @@ -2483,7 +2501,7 @@ func TestProcessZombieEdgeNowLive(t *testing.T) { // want to allow a new update from the second node to allow the entire // edge to be resurrected. chanID := batch.chanAnn.ShortChannelID - err = ctx.router.MarkEdgeZombie( + err = tCtx.router.MarkEdgeZombie( chanID, [33]byte{}, batch.chanAnn.NodeID2, ) if err != nil { @@ -2500,7 +2518,7 @@ func TestProcessZombieEdgeNowLive(t *testing.T) { processAnnouncement(batch.chanUpdAnn1, true, true) // At this point, the channel should still be considered a zombie. - _, _, _, err = ctx.router.GetChannelByID(chanID) + _, _, _, err = tCtx.router.GetChannelByID(chanID) require.ErrorIs(t, err, graphdb.ErrZombieEdge) // Attempting to process the current channel update should fail due to @@ -2532,12 +2550,12 @@ func TestProcessZombieEdgeNowLive(t *testing.T) { // until the channel announcement is. Since the channel update indicates // a fresh new update, the gossiper should stash it until it sees the // corresponding channel announcement. - updateErrChan := ctx.gossiper.ProcessRemoteAnnouncement( - batch.chanUpdAnn2, remotePeer, + updateErrChan := tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.chanUpdAnn2, remotePeer, ) select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("expected to not broadcast live channel update " + "without announcement") case <-time.After(2 * trickleDelay): @@ -2560,7 +2578,7 @@ func TestProcessZombieEdgeNowLive(t *testing.T) { } select { - case msgWithSenders := <-ctx.broadcastedMessage: + case msgWithSenders := <-tCtx.broadcastedMessage: assertMessage(t, batch.chanUpdAnn2, msgWithSenders.msg) case <-time.After(2 * trickleDelay): t.Fatal("expected to broadcast live channel update") @@ -2572,11 +2590,12 @@ func TestProcessZombieEdgeNowLive(t *testing.T) { // be reprocessed later, after our ChannelAnnouncement. func TestReceiveRemoteChannelUpdateFirst(t *testing.T) { t.Parallel() + ctx := context.Background() - ctx, err := createTestCtx(t, proofMatureDelta, false) + tCtx, err := createTestCtx(t, proofMatureDelta, false) require.NoError(t, err, "can't create context") - batch, err := ctx.createLocalAnnouncements(0) + batch, err := tCtx.createLocalAnnouncements(0) require.NoError(t, err, "can't generate announcements") remoteKey, err := btcec.ParsePubKey(batch.nodeAnn2.NodeID[:]) @@ -2586,12 +2605,12 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) { // directly from the gossiper. sentMsgs := make(chan lnwire.Message, 10) remotePeer := &mockPeer{ - remoteKey, sentMsgs, ctx.gossiper.quit, atomic.Bool{}, + remoteKey, sentMsgs, tCtx.gossiper.quit, atomic.Bool{}, } // Override NotifyWhenOnline to return the remote peer which we expect // messages to be sent to. - ctx.gossiper.reliableSender.cfg.NotifyWhenOnline = func(peer [33]byte, + tCtx.gossiper.reliableSender.cfg.NotifyWhenOnline = func(peer [33]byte, peerChan chan<- lnpeer.Peer) { peerChan <- remotePeer @@ -2600,19 +2619,21 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) { // Recreate the case where the remote node is sending us its ChannelUpdate // before we have been able to process our own ChannelAnnouncement and // ChannelUpdate. - errRemoteAnn := ctx.gossiper.ProcessRemoteAnnouncement( - batch.chanUpdAnn2, remotePeer, + errRemoteAnn := tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.chanUpdAnn2, remotePeer, ) select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("channel update announcement was broadcast") case <-time.After(2 * trickleDelay): } - err = <-ctx.gossiper.ProcessRemoteAnnouncement(batch.nodeAnn2, remotePeer) + err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.nodeAnn2, remotePeer, + ) require.NoError(t, err, "unable to process node ann") select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("node announcement was broadcast") case <-time.After(2 * trickleDelay): } @@ -2621,7 +2642,9 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) { // we did not already know about, it should have been added // to the map of premature ChannelUpdates. Check that nothing // was added to the graph. - chanInfo, e1, e2, err := ctx.router.GetChannelByID(batch.chanUpdAnn1.ShortChannelID) + chanInfo, e1, e2, err := tCtx.router.GetChannelByID( + batch.chanUpdAnn1.ShortChannelID, + ) if !errors.Is(err, graphdb.ErrEdgeNotFound) { t.Fatalf("Expected ErrEdgeNotFound, got: %v", err) } @@ -2637,32 +2660,32 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) { // Recreate lightning network topology. Initialize router with channel // between two nodes. - err = <-ctx.gossiper.ProcessLocalAnnouncement(batch.chanAnn) + err = <-tCtx.gossiper.ProcessLocalAnnouncement(batch.chanAnn) if err != nil { t.Fatalf("unable to process :%v", err) } select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("channel announcement was broadcast") case <-time.After(2 * trickleDelay): } - err = <-ctx.gossiper.ProcessLocalAnnouncement(batch.chanUpdAnn1) + err = <-tCtx.gossiper.ProcessLocalAnnouncement(batch.chanUpdAnn1) if err != nil { t.Fatalf("unable to process :%v", err) } select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("channel update announcement was broadcast") case <-time.After(2 * trickleDelay): } - err = <-ctx.gossiper.ProcessLocalAnnouncement(batch.nodeAnn1) + err = <-tCtx.gossiper.ProcessLocalAnnouncement(batch.nodeAnn1) if err != nil { t.Fatalf("unable to process :%v", err) } select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("node announcement was broadcast") case <-time.After(2 * trickleDelay): } @@ -2689,7 +2712,7 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) { } // Check that the ChannelEdgePolicy was added to the graph. - chanInfo, e1, e2, err = ctx.router.GetChannelByID( + chanInfo, e1, e2, err = tCtx.router.GetChannelByID( batch.chanUpdAnn1.ShortChannelID, ) require.NoError(t, err, "unable to get channel from router") @@ -2705,19 +2728,19 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) { // Pretending that we receive local channel announcement from funding // manager, thereby kick off the announcement exchange process. - err = <-ctx.gossiper.ProcessLocalAnnouncement(batch.localProofAnn) + err = <-tCtx.gossiper.ProcessLocalAnnouncement(batch.localProofAnn) if err != nil { t.Fatalf("unable to process :%v", err) } select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("announcements were broadcast") case <-time.After(2 * trickleDelay): } number := 0 - if err := ctx.gossiper.cfg.WaitingProofStore.ForAll( + if err := tCtx.gossiper.cfg.WaitingProofStore.ForAll( func(*channeldb.WaitingProof) error { number++ return nil @@ -2733,8 +2756,8 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) { t.Fatal("wrong number of objects in storage") } - err = <-ctx.gossiper.ProcessRemoteAnnouncement( - batch.remoteProofAnn, remotePeer, + err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.remoteProofAnn, remotePeer, ) if err != nil { t.Fatalf("unable to process :%v", err) @@ -2742,14 +2765,14 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) { for i := 0; i < 4; i++ { select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: case <-time.After(time.Second): t.Fatal("announcement wasn't broadcast") } } number = 0 - if err := ctx.gossiper.cfg.WaitingProofStore.ForAll( + if err := tCtx.gossiper.cfg.WaitingProofStore.ForAll( func(*channeldb.WaitingProof) error { number++ return nil @@ -2771,8 +2794,9 @@ func TestReceiveRemoteChannelUpdateFirst(t *testing.T) { // currently know of. func TestExtraDataChannelAnnouncementValidation(t *testing.T) { t.Parallel() + ctx := context.Background() - ctx, err := createTestCtx(t, 0, false) + tCtx, err := createTestCtx(t, 0, false) require.NoError(t, err, "can't create context") remotePeer := &mockPeer{ @@ -2783,7 +2807,7 @@ func TestExtraDataChannelAnnouncementValidation(t *testing.T) { // that we don't know of ourselves, but should still include in the // final signature check. extraBytes := []byte("gotta validate this still!") - ca, err := ctx.createRemoteChannelAnnouncement( + ca, err := tCtx.createRemoteChannelAnnouncement( 0, withExtraBytes(extraBytes), ) require.NoError(t, err, "can't create channel announcement") @@ -2791,7 +2815,9 @@ func TestExtraDataChannelAnnouncementValidation(t *testing.T) { // We'll now send the announcement to the main gossiper. We should be // able to validate this announcement to problem. select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement(ca, remotePeer): + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, ca, remotePeer, + ): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") } @@ -2805,9 +2831,10 @@ func TestExtraDataChannelAnnouncementValidation(t *testing.T) { // know of. func TestExtraDataChannelUpdateValidation(t *testing.T) { t.Parallel() + ctx := context.Background() timestamp := testTimestamp - ctx, err := createTestCtx(t, 0, false) + tCtx, err := createTestCtx(t, 0, false) require.NoError(t, err, "can't create context") remotePeer := &mockPeer{ @@ -2817,7 +2844,7 @@ func TestExtraDataChannelUpdateValidation(t *testing.T) { // In this scenario, we'll create two announcements, one regular // channel announcement, and another channel update announcement, that // has additional data that we won't be interpreting. - chanAnn, err := ctx.createRemoteChannelAnnouncement(0) + chanAnn, err := tCtx.createRemoteChannelAnnouncement(0) require.NoError(t, err, "unable to create chan ann") chanUpdAnn1, err := createUpdateAnnouncement( 0, 0, remoteKeyPriv1, timestamp, @@ -2833,21 +2860,27 @@ func TestExtraDataChannelUpdateValidation(t *testing.T) { // We should be able to properly validate all three messages without // any issue. select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement(chanAnn, remotePeer): + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, chanAnn, remotePeer, + ): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") } require.NoError(t, err, "unable to process announcement") select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement(chanUpdAnn1, remotePeer): + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, chanUpdAnn1, remotePeer, + ): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") } require.NoError(t, err, "unable to process announcement") select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement(chanUpdAnn2, remotePeer): + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, chanUpdAnn2, remotePeer, + ): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") } @@ -2859,8 +2892,9 @@ func TestExtraDataChannelUpdateValidation(t *testing.T) { // currently know of. func TestExtraDataNodeAnnouncementValidation(t *testing.T) { t.Parallel() + ctx := context.Background() - ctx, err := createTestCtx(t, 0, false) + tCtx, err := createTestCtx(t, 0, false) require.NoError(t, err, "can't create context") remotePeer := &mockPeer{ @@ -2877,7 +2911,9 @@ func TestExtraDataNodeAnnouncementValidation(t *testing.T) { require.NoError(t, err, "can't create node announcement") select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement(nodeAnn, remotePeer): + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, nodeAnn, remotePeer, + ): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") } @@ -2929,11 +2965,12 @@ func assertProcessAnnouncement(t *testing.T, result chan error) { // the retransmit ticker ticks. func TestRetransmit(t *testing.T) { t.Parallel() + ctx := context.Background() - ctx, err := createTestCtx(t, proofMatureDelta, false) + tCtx, err := createTestCtx(t, proofMatureDelta, false) require.NoError(t, err, "can't create context") - batch, err := ctx.createLocalAnnouncements(0) + batch, err := tCtx.createLocalAnnouncements(0) require.NoError(t, err, "can't generate announcements") remoteKey, err := btcec.ParsePubKey(batch.nodeAnn2.NodeID[:]) @@ -2944,39 +2981,39 @@ func TestRetransmit(t *testing.T) { // announcement. No messages should be broadcasted yet, since no proof // has been exchanged. assertProcessAnnouncement( - t, ctx.gossiper.ProcessLocalAnnouncement(batch.chanAnn), + t, tCtx.gossiper.ProcessLocalAnnouncement(batch.chanAnn), ) - assertBroadcast(t, ctx, 0) + assertBroadcast(t, tCtx, 0) assertProcessAnnouncement( - t, ctx.gossiper.ProcessLocalAnnouncement(batch.chanUpdAnn1), + t, tCtx.gossiper.ProcessLocalAnnouncement(batch.chanUpdAnn1), ) - assertBroadcast(t, ctx, 0) + assertBroadcast(t, tCtx, 0) assertProcessAnnouncement( - t, ctx.gossiper.ProcessLocalAnnouncement(batch.nodeAnn1), + t, tCtx.gossiper.ProcessLocalAnnouncement(batch.nodeAnn1), ) - assertBroadcast(t, ctx, 0) + assertBroadcast(t, tCtx, 0) // Add the remote channel update to the gossiper. Similarly, nothing // should be broadcasted. assertProcessAnnouncement( - t, ctx.gossiper.ProcessRemoteAnnouncement( - batch.chanUpdAnn2, remotePeer, + t, tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.chanUpdAnn2, remotePeer, ), ) - assertBroadcast(t, ctx, 0) + assertBroadcast(t, tCtx, 0) // Now add the local and remote proof to the gossiper, which should // trigger a broadcast of the announcements. assertProcessAnnouncement( - t, ctx.gossiper.ProcessLocalAnnouncement(batch.localProofAnn), + t, tCtx.gossiper.ProcessLocalAnnouncement(batch.localProofAnn), ) - assertBroadcast(t, ctx, 0) + assertBroadcast(t, tCtx, 0) assertProcessAnnouncement( - t, ctx.gossiper.ProcessRemoteAnnouncement( - batch.remoteProofAnn, remotePeer, + t, tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.remoteProofAnn, remotePeer, ), ) @@ -2988,7 +3025,7 @@ func TestRetransmit(t *testing.T) { t.Helper() num := chanAnns + chanUpds + nodeAnns - anns := assertBroadcast(t, ctx, num) + anns := assertBroadcast(t, tCtx, num) // Count the received announcements. var chanAnn, chanUpd, nodeAnn int @@ -3016,12 +3053,15 @@ func TestRetransmit(t *testing.T) { // update. checkAnnouncements(t, 1, 2, 1) + retransmit, ok := tCtx.gossiper.cfg.RetransmitTicker.(*ticker.Force) + require.True(t, ok) + // Now let the retransmit ticker tick, which should trigger updates to // be rebroadcast. now := time.Unix(int64(testTimestamp), 0) future := now.Add(rebroadcastInterval + 10*time.Second) select { - case ctx.gossiper.cfg.RetransmitTicker.(*ticker.Force).Force <- future: + case retransmit.Force <- future: case <-time.After(2 * time.Second): t.Fatalf("unable to force tick") } @@ -3035,11 +3075,12 @@ func TestRetransmit(t *testing.T) { // no existing channels in the graph do not get forwarded. func TestNodeAnnouncementNoChannels(t *testing.T) { t.Parallel() + ctx := context.Background() - ctx, err := createTestCtx(t, 0, false) + tCtx, err := createTestCtx(t, 0, false) require.NoError(t, err, "can't create context") - batch, err := ctx.createRemoteAnnouncements(0) + batch, err := tCtx.createRemoteAnnouncements(0) require.NoError(t, err, "can't generate announcements") remoteKey, err := btcec.ParsePubKey(batch.nodeAnn2.NodeID[:]) @@ -3048,8 +3089,9 @@ func TestNodeAnnouncementNoChannels(t *testing.T) { // Process the remote node announcement. select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement(batch.nodeAnn2, - remotePeer): + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.nodeAnn2, remotePeer, + ): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") } @@ -3058,7 +3100,7 @@ func TestNodeAnnouncementNoChannels(t *testing.T) { // Since no channels or node announcements were already in the graph, // the node announcement should be ignored, and not forwarded. select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("node announcement was broadcast") case <-time.After(2 * trickleDelay): } @@ -3066,16 +3108,18 @@ func TestNodeAnnouncementNoChannels(t *testing.T) { // Now add the node's channel to the graph by processing the channel // announcement and channel update. select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement(batch.chanAnn, - remotePeer): + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.chanAnn, remotePeer, + ): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") } require.NoError(t, err, "unable to process announcement") select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement(batch.chanUpdAnn2, - remotePeer): + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.chanUpdAnn2, remotePeer, + ): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") } @@ -3083,7 +3127,9 @@ func TestNodeAnnouncementNoChannels(t *testing.T) { // Now process the node announcement again. select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement(batch.nodeAnn2, remotePeer): + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.nodeAnn2, remotePeer, + ): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") } @@ -3093,7 +3139,7 @@ func TestNodeAnnouncementNoChannels(t *testing.T) { // the channel announcement and update be. for i := 0; i < 3; i++ { select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: case <-time.After(time.Second): t.Fatal("announcement wasn't broadcast") } @@ -3102,15 +3148,16 @@ func TestNodeAnnouncementNoChannels(t *testing.T) { // Processing the same node announcement again should be ignored, as it // is stale. select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement(batch.nodeAnn2, - remotePeer): + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.nodeAnn2, remotePeer, + ): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") } require.NoError(t, err, "unable to process announcement") select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("node announcement was broadcast") case <-time.After(2 * trickleDelay): } @@ -3120,11 +3167,12 @@ func TestNodeAnnouncementNoChannels(t *testing.T) { // validate the msg flags and max HTLC field of a ChannelUpdate. func TestOptionalFieldsChannelUpdateValidation(t *testing.T) { t.Parallel() + ctx := context.Background() - ctx, err := createTestCtx(t, 0, false) + tCtx, err := createTestCtx(t, 0, false) require.NoError(t, err, "can't create context") - processRemoteAnnouncement := ctx.gossiper.ProcessRemoteAnnouncement + processRemoteAnnouncement := tCtx.gossiper.ProcessRemoteAnnouncement chanUpdateHeight := uint32(0) timestamp := uint32(123456) @@ -3132,11 +3180,11 @@ func TestOptionalFieldsChannelUpdateValidation(t *testing.T) { // In this scenario, we'll test whether the message flags field in a // channel update is properly handled. - chanAnn, err := ctx.createRemoteChannelAnnouncement(chanUpdateHeight) + chanAnn, err := tCtx.createRemoteChannelAnnouncement(chanUpdateHeight) require.NoError(t, err, "can't create channel announcement") select { - case err = <-processRemoteAnnouncement(chanAnn, nodePeer): + case err = <-processRemoteAnnouncement(ctx, chanAnn, nodePeer): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") } @@ -3156,7 +3204,7 @@ func TestOptionalFieldsChannelUpdateValidation(t *testing.T) { } select { - case err = <-processRemoteAnnouncement(chanUpdAnn, nodePeer): + case err = <-processRemoteAnnouncement(ctx, chanUpdAnn, nodePeer): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") } @@ -3173,7 +3221,7 @@ func TestOptionalFieldsChannelUpdateValidation(t *testing.T) { } select { - case err = <-processRemoteAnnouncement(chanUpdAnn, nodePeer): + case err = <-processRemoteAnnouncement(ctx, chanUpdAnn, nodePeer): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") } @@ -3189,7 +3237,7 @@ func TestOptionalFieldsChannelUpdateValidation(t *testing.T) { } select { - case err = <-processRemoteAnnouncement(chanUpdAnn, nodePeer): + case err = <-processRemoteAnnouncement(ctx, chanUpdAnn, nodePeer): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") } @@ -3206,7 +3254,7 @@ func TestOptionalFieldsChannelUpdateValidation(t *testing.T) { } select { - case err = <-processRemoteAnnouncement(chanUpdAnn, nodePeer): + case err = <-processRemoteAnnouncement(ctx, chanUpdAnn, nodePeer): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") } @@ -3217,13 +3265,14 @@ func TestOptionalFieldsChannelUpdateValidation(t *testing.T) { // channel is always sent upon the remote party reconnecting. func TestSendChannelUpdateReliably(t *testing.T) { t.Parallel() + ctx := context.Background() // We'll start by creating our test context and a batch of // announcements. - ctx, err := createTestCtx(t, proofMatureDelta, false) + tCtx, err := createTestCtx(t, proofMatureDelta, false) require.NoError(t, err, "unable to create test context") - batch, err := ctx.createLocalAnnouncements(0) + batch, err := tCtx.createLocalAnnouncements(0) require.NoError(t, err, "can't generate announcements") // We'll also create two keys, one for ourselves and another for the @@ -3236,7 +3285,7 @@ func TestSendChannelUpdateReliably(t *testing.T) { // gossiper to the remote peer. sentToPeer := make(chan lnwire.Message, 1) remotePeer := &mockPeer{ - remoteKey, sentToPeer, ctx.gossiper.quit, atomic.Bool{}, + remoteKey, sentToPeer, tCtx.gossiper.quit, atomic.Bool{}, } // Since we first wait to be notified of the peer before attempting to @@ -3244,13 +3293,13 @@ func TestSendChannelUpdateReliably(t *testing.T) { // NotifyWhenOffline to instead give us access to the channel that will // receive the notification. notifyOnline := make(chan chan<- lnpeer.Peer, 1) - ctx.gossiper.reliableSender.cfg.NotifyWhenOnline = func(_ [33]byte, + tCtx.gossiper.reliableSender.cfg.NotifyWhenOnline = func(_ [33]byte, peerChan chan<- lnpeer.Peer) { notifyOnline <- peerChan } notifyOffline := make(chan chan struct{}, 1) - ctx.gossiper.reliableSender.cfg.NotifyWhenOffline = func( + tCtx.gossiper.reliableSender.cfg.NotifyWhenOffline = func( _ [33]byte) <-chan struct{} { c := make(chan struct{}, 1) @@ -3275,7 +3324,7 @@ func TestSendChannelUpdateReliably(t *testing.T) { // Process the channel announcement for which we'll send a channel // update for. select { - case err = <-ctx.gossiper.ProcessLocalAnnouncement(batch.chanAnn): + case err = <-tCtx.gossiper.ProcessLocalAnnouncement(batch.chanAnn): case <-time.After(2 * time.Second): t.Fatal("did not process local channel announcement") } @@ -3283,14 +3332,14 @@ func TestSendChannelUpdateReliably(t *testing.T) { // It should not be broadcast due to not having an announcement proof. select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("channel announcement was broadcast") case <-time.After(2 * trickleDelay): } // Now, we'll process the channel update. select { - case err = <-ctx.gossiper.ProcessLocalAnnouncement(batch.chanUpdAnn1): + case err = <-tCtx.gossiper.ProcessLocalAnnouncement(batch.chanUpdAnn1): case <-time.After(2 * time.Second): t.Fatal("did not process local channel update") } @@ -3299,7 +3348,7 @@ func TestSendChannelUpdateReliably(t *testing.T) { // It should also not be broadcast due to the announcement not having an // announcement proof. select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("channel announcement was broadcast") case <-time.After(2 * trickleDelay): } @@ -3348,7 +3397,7 @@ func TestSendChannelUpdateReliably(t *testing.T) { // With the new update created, we'll go ahead and process it. select { - case err = <-ctx.gossiper.ProcessLocalAnnouncement( + case err = <-tCtx.gossiper.ProcessLocalAnnouncement( batch.chanUpdAnn1, ): case <-time.After(2 * time.Second): @@ -3359,7 +3408,7 @@ func TestSendChannelUpdateReliably(t *testing.T) { // It should also not be broadcast due to the announcement not having an // announcement proof. select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("channel announcement was broadcast") case <-time.After(2 * trickleDelay): } @@ -3387,7 +3436,7 @@ func TestSendChannelUpdateReliably(t *testing.T) { // We'll then exchange proofs with the remote peer in order to announce // the channel. select { - case err = <-ctx.gossiper.ProcessLocalAnnouncement( + case err = <-tCtx.gossiper.ProcessLocalAnnouncement( batch.localProofAnn, ): case <-time.After(2 * time.Second): @@ -3397,7 +3446,7 @@ func TestSendChannelUpdateReliably(t *testing.T) { // No messages should be broadcast as we don't have the full proof yet. select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("channel announcement was broadcast") case <-time.After(2 * trickleDelay): } @@ -3406,8 +3455,8 @@ func TestSendChannelUpdateReliably(t *testing.T) { assertMsgSent(batch.localProofAnn) select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement( - batch.remoteProofAnn, remotePeer, + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.remoteProofAnn, remotePeer, ): case <-time.After(2 * time.Second): t.Fatal("did not process remote channel proof") @@ -3418,7 +3467,7 @@ func TestSendChannelUpdateReliably(t *testing.T) { // channel has been announced. for i := 0; i < 2; i++ { select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: case <-time.After(2 * trickleDelay): t.Fatal("expected channel to be announced") } @@ -3440,7 +3489,7 @@ func TestSendChannelUpdateReliably(t *testing.T) { // directly since the reliable sender only applies when the channel is // not announced. select { - case err = <-ctx.gossiper.ProcessLocalAnnouncement( + case err = <-tCtx.gossiper.ProcessLocalAnnouncement( newChannelUpdate, ): case <-time.After(2 * time.Second): @@ -3448,7 +3497,7 @@ func TestSendChannelUpdateReliably(t *testing.T) { } require.NoError(t, err, "unable to process local channel update") select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: case <-time.After(2 * trickleDelay): t.Fatal("channel update was not broadcast") } @@ -3495,7 +3544,7 @@ func TestSendChannelUpdateReliably(t *testing.T) { // Since the messages above are now deemed as stale, they should be // removed from the message store. err = wait.NoError(func() error { - msgs, err := ctx.gossiper.cfg.MessageStore.Messages() + msgs, err := tCtx.gossiper.cfg.MessageStore.Messages() if err != nil { return fmt.Errorf("unable to retrieve pending "+ "messages: %v", err) @@ -3533,7 +3582,9 @@ func sendRemoteMsg(t *testing.T, ctx *testCtx, msg lnwire.Message, t.Helper() select { - case err := <-ctx.gossiper.ProcessRemoteAnnouncement(msg, remotePeer): + case err := <-ctx.gossiper.ProcessRemoteAnnouncement( + context.Background(), msg, remotePeer, + ): if err != nil { t.Fatalf("unable to process channel msg: %v", err) } @@ -3933,14 +3984,15 @@ func (m *SyncManager) markGraphSyncing() { // initial historical sync has completed. func TestBroadcastAnnsAfterGraphSynced(t *testing.T) { t.Parallel() + ctx := context.Background() - ctx, err := createTestCtx(t, 10, false) + tCtx, err := createTestCtx(t, 10, false) require.NoError(t, err, "can't create context") // We'll mark the graph as not synced. This should prevent us from // broadcasting any messages we've received as part of our initial // historical sync. - ctx.gossiper.syncMgr.markGraphSyncing() + tCtx.gossiper.syncMgr.markGraphSyncing() assertBroadcast := func(msg lnwire.Message, isRemote bool, shouldBroadcast bool) { @@ -3952,11 +4004,11 @@ func TestBroadcastAnnsAfterGraphSynced(t *testing.T) { } var errChan chan error if isRemote { - errChan = ctx.gossiper.ProcessRemoteAnnouncement( - msg, nodePeer, + errChan = tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, msg, nodePeer, ) } else { - errChan = ctx.gossiper.ProcessLocalAnnouncement(msg) + errChan = tCtx.gossiper.ProcessLocalAnnouncement(msg) } select { @@ -3970,7 +4022,7 @@ func TestBroadcastAnnsAfterGraphSynced(t *testing.T) { } select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: if !shouldBroadcast { t.Fatal("gossip message was broadcast") } @@ -3983,7 +4035,7 @@ func TestBroadcastAnnsAfterGraphSynced(t *testing.T) { // A remote channel announcement should not be broadcast since the graph // has not yet been synced. - chanAnn1, err := ctx.createRemoteChannelAnnouncement(0) + chanAnn1, err := tCtx.createRemoteChannelAnnouncement(0) require.NoError(t, err, "unable to create channel announcement") assertBroadcast(chanAnn1, true, false) @@ -3995,9 +4047,9 @@ func TestBroadcastAnnsAfterGraphSynced(t *testing.T) { // Mark the graph as synced, which should allow the channel announcement // should to be broadcast. - ctx.gossiper.syncMgr.markGraphSynced() + tCtx.gossiper.syncMgr.markGraphSynced() - chanAnn2, err := ctx.createRemoteChannelAnnouncement(1) + chanAnn2, err := tCtx.createRemoteChannelAnnouncement(1) require.NoError(t, err, "unable to create channel announcement") assertBroadcast(chanAnn2, true, true) } @@ -4010,15 +4062,16 @@ func TestBroadcastAnnsAfterGraphSynced(t *testing.T) { // is tested by TestRateLimitChannelUpdates. func TestRateLimitDeDup(t *testing.T) { t.Parallel() + ctx := context.Background() // Create our test harness. const blockHeight = 100 - ctx, err := createTestCtx(t, blockHeight, false) + tCtx, err := createTestCtx(t, blockHeight, false) require.NoError(t, err, "can't create context") - ctx.gossiper.cfg.RebroadcastInterval = time.Hour + tCtx.gossiper.cfg.RebroadcastInterval = time.Hour var findBaseByAliasCount atomic.Int32 - ctx.gossiper.cfg.FindBaseByAlias = func(alias lnwire.ShortChannelID) ( + tCtx.gossiper.cfg.FindBaseByAlias = func(alias lnwire.ShortChannelID) ( lnwire.ShortChannelID, error) { findBaseByAliasCount.Add(1) @@ -4027,33 +4080,33 @@ func TestRateLimitDeDup(t *testing.T) { } getUpdateEdgeCount := func() int { - ctx.router.mu.Lock() - defer ctx.router.mu.Unlock() + tCtx.router.mu.Lock() + defer tCtx.router.mu.Unlock() - return ctx.router.updateEdgeCount + return tCtx.router.updateEdgeCount } // We set the burst to 2 here. The very first update should not count // towards this _and_ any duplicates should also not count towards it. - ctx.gossiper.cfg.MaxChannelUpdateBurst = 2 - ctx.gossiper.cfg.ChannelUpdateInterval = time.Minute + tCtx.gossiper.cfg.MaxChannelUpdateBurst = 2 + tCtx.gossiper.cfg.ChannelUpdateInterval = time.Minute // The graph should start empty. - require.Empty(t, ctx.router.infos) - require.Empty(t, ctx.router.edges) + require.Empty(t, tCtx.router.infos) + require.Empty(t, tCtx.router.edges) // We'll create a batch of signed announcements, including updates for // both sides, for a channel and process them. They should all be // forwarded as this is our first time learning about the channel. - batch, err := ctx.createRemoteAnnouncements(blockHeight) + batch, err := tCtx.createRemoteAnnouncements(blockHeight) require.NoError(t, err) nodePeer1 := &mockPeer{ remoteKeyPriv1.PubKey(), nil, nil, atomic.Bool{}, } select { - case err := <-ctx.gossiper.ProcessRemoteAnnouncement( - batch.chanAnn, nodePeer1, + case err := <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.chanAnn, nodePeer1, ): require.NoError(t, err) case <-time.After(time.Second): @@ -4061,8 +4114,8 @@ func TestRateLimitDeDup(t *testing.T) { } select { - case err := <-ctx.gossiper.ProcessRemoteAnnouncement( - batch.chanUpdAnn1, nodePeer1, + case err := <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.chanUpdAnn1, nodePeer1, ): require.NoError(t, err) case <-time.After(time.Second): @@ -4073,8 +4126,8 @@ func TestRateLimitDeDup(t *testing.T) { remoteKeyPriv2.PubKey(), nil, nil, atomic.Bool{}, } select { - case err := <-ctx.gossiper.ProcessRemoteAnnouncement( - batch.chanUpdAnn2, nodePeer2, + case err := <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.chanUpdAnn2, nodePeer2, ): require.NoError(t, err) case <-time.After(time.Second): @@ -4084,21 +4137,21 @@ func TestRateLimitDeDup(t *testing.T) { timeout := time.After(2 * trickleDelay) for i := 0; i < 3; i++ { select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: case <-timeout: t.Fatal("expected announcement to be broadcast") } } shortChanID := batch.chanAnn.ShortChannelID.ToUint64() - require.Contains(t, ctx.router.infos, shortChanID) - require.Contains(t, ctx.router.edges, shortChanID) + require.Contains(t, tCtx.router.infos, shortChanID) + require.Contains(t, tCtx.router.edges, shortChanID) // Before we send anymore updates, we want to let our test harness // hang during GetChannelByID so that we can ensure that two threads are // waiting for the chan. pause := make(chan struct{}) - ctx.router.pauseGetChannelByID <- pause + tCtx.router.pauseGetChannelByID <- pause // Take note of how many times FindBaseByAlias has been called. // It should be 2 since we have processed two channel updates. @@ -4123,10 +4176,14 @@ func TestRateLimitDeDup(t *testing.T) { // succession. We wait for both to have hit the FindBaseByAlias check // before we un-pause the GetChannelByID call. go func() { - ctx.gossiper.ProcessRemoteAnnouncement(&update, nodePeer1) + tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, &update, nodePeer1, + ) }() go func() { - ctx.gossiper.ProcessRemoteAnnouncement(&update, nodePeer1) + tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, &update, nodePeer1, + ) }() // We know that both are being processed once the count for @@ -4164,7 +4221,7 @@ func TestRateLimitDeDup(t *testing.T) { t.Helper() select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: require.True(t, shouldBroadcast) case <-time.After(2 * trickleDelay): require.False(t, shouldBroadcast) @@ -4173,8 +4230,8 @@ func TestRateLimitDeDup(t *testing.T) { processUpdate := func(msg lnwire.Message, peer lnpeer.Peer) { select { - case err := <-ctx.gossiper.ProcessRemoteAnnouncement( - msg, peer, + case err := <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, msg, peer, ): require.NoError(t, err) case <-time.After(time.Second): @@ -4202,31 +4259,32 @@ func TestRateLimitDeDup(t *testing.T) { // channel updates. func TestRateLimitChannelUpdates(t *testing.T) { t.Parallel() + ctx := context.Background() // Create our test harness. const blockHeight = 100 - ctx, err := createTestCtx(t, blockHeight, false) + tCtx, err := createTestCtx(t, blockHeight, false) require.NoError(t, err, "can't create context") - ctx.gossiper.cfg.RebroadcastInterval = time.Hour - ctx.gossiper.cfg.MaxChannelUpdateBurst = 5 - ctx.gossiper.cfg.ChannelUpdateInterval = 5 * time.Second + tCtx.gossiper.cfg.RebroadcastInterval = time.Hour + tCtx.gossiper.cfg.MaxChannelUpdateBurst = 5 + tCtx.gossiper.cfg.ChannelUpdateInterval = 5 * time.Second // The graph should start empty. - require.Empty(t, ctx.router.infos) - require.Empty(t, ctx.router.edges) + require.Empty(t, tCtx.router.infos) + require.Empty(t, tCtx.router.edges) // We'll create a batch of signed announcements, including updates for // both sides, for a channel and process them. They should all be // forwarded as this is our first time learning about the channel. - batch, err := ctx.createRemoteAnnouncements(blockHeight) + batch, err := tCtx.createRemoteAnnouncements(blockHeight) require.NoError(t, err) nodePeer1 := &mockPeer{ remoteKeyPriv1.PubKey(), nil, nil, atomic.Bool{}, } select { - case err := <-ctx.gossiper.ProcessRemoteAnnouncement( - batch.chanAnn, nodePeer1, + case err := <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.chanAnn, nodePeer1, ): require.NoError(t, err) case <-time.After(time.Second): @@ -4234,8 +4292,8 @@ func TestRateLimitChannelUpdates(t *testing.T) { } select { - case err := <-ctx.gossiper.ProcessRemoteAnnouncement( - batch.chanUpdAnn1, nodePeer1, + case err := <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.chanUpdAnn1, nodePeer1, ): require.NoError(t, err) case <-time.After(time.Second): @@ -4246,8 +4304,8 @@ func TestRateLimitChannelUpdates(t *testing.T) { remoteKeyPriv2.PubKey(), nil, nil, atomic.Bool{}, } select { - case err := <-ctx.gossiper.ProcessRemoteAnnouncement( - batch.chanUpdAnn2, nodePeer2, + case err := <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.chanUpdAnn2, nodePeer2, ): require.NoError(t, err) case <-time.After(time.Second): @@ -4257,15 +4315,15 @@ func TestRateLimitChannelUpdates(t *testing.T) { timeout := time.After(2 * trickleDelay) for i := 0; i < 3; i++ { select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: case <-timeout: t.Fatal("expected announcement to be broadcast") } } shortChanID := batch.chanAnn.ShortChannelID.ToUint64() - require.Contains(t, ctx.router.infos, shortChanID) - require.Contains(t, ctx.router.edges, shortChanID) + require.Contains(t, tCtx.router.infos, shortChanID) + require.Contains(t, tCtx.router.edges, shortChanID) // We'll define a helper to assert whether updates should be rate // limited or not depending on their contents. @@ -4275,14 +4333,16 @@ func TestRateLimitChannelUpdates(t *testing.T) { t.Helper() select { - case err := <-ctx.gossiper.ProcessRemoteAnnouncement(update, peer): + case err := <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, update, peer, + ): require.NoError(t, err) case <-time.After(time.Second): t.Fatal("remote announcement not processed") } select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: if shouldRateLimit { t.Fatal("unexpected channel update broadcast") } @@ -4305,7 +4365,7 @@ func TestRateLimitChannelUpdates(t *testing.T) { keepAliveUpdate := *batch.chanUpdAnn1 keepAliveUpdate.Timestamp = uint32( time.Unix(int64(batch.chanUpdAnn1.Timestamp), 0). - Add(ctx.gossiper.cfg.RebroadcastInterval).Unix(), + Add(tCtx.gossiper.cfg.RebroadcastInterval).Unix(), ) require.NoError(t, signUpdate(remoteKeyPriv1, &keepAliveUpdate)) assertRateLimit(&keepAliveUpdate, nodePeer1, false) @@ -4316,7 +4376,7 @@ func TestRateLimitChannelUpdates(t *testing.T) { // seconds with a max burst of 5 per direction. We'll process the max // burst of one direction first. None of these should be rate limited. updateSameDirection := keepAliveUpdate - for i := uint32(0); i < uint32(ctx.gossiper.cfg.MaxChannelUpdateBurst); i++ { + for i := uint32(0); i < uint32(tCtx.gossiper.cfg.MaxChannelUpdateBurst); i++ { //nolint:ll updateSameDirection.Timestamp++ updateSameDirection.BaseFee++ require.NoError(t, signUpdate(remoteKeyPriv1, &updateSameDirection)) @@ -4339,8 +4399,8 @@ func TestRateLimitChannelUpdates(t *testing.T) { // Wait for the next interval to tick. Since we've only waited for one, // only one more update is allowed. - <-time.After(ctx.gossiper.cfg.ChannelUpdateInterval) - for i := 0; i < ctx.gossiper.cfg.MaxChannelUpdateBurst; i++ { + <-time.After(tCtx.gossiper.cfg.ChannelUpdateInterval) + for i := 0; i < tCtx.gossiper.cfg.MaxChannelUpdateBurst; i++ { updateSameDirection.Timestamp++ updateSameDirection.BaseFee++ require.NoError(t, signUpdate(remoteKeyPriv1, &updateSameDirection)) @@ -4354,11 +4414,12 @@ func TestRateLimitChannelUpdates(t *testing.T) { // about our own channels when coming from a remote peer. func TestIgnoreOwnAnnouncement(t *testing.T) { t.Parallel() + ctx := context.Background() - ctx, err := createTestCtx(t, proofMatureDelta, false) + tCtx, err := createTestCtx(t, proofMatureDelta, false) require.NoError(t, err, "can't create context") - batch, err := ctx.createLocalAnnouncements(0) + batch, err := tCtx.createLocalAnnouncements(0) require.NoError(t, err, "can't generate announcements") remoteKey, err := btcec.ParsePubKey(batch.nodeAnn2.NodeID[:]) @@ -4367,8 +4428,8 @@ func TestIgnoreOwnAnnouncement(t *testing.T) { // Try to let the remote peer tell us about the channel we are part of. select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement( - batch.chanAnn, remotePeer, + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.chanAnn, remotePeer, ): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") @@ -4383,66 +4444,66 @@ func TestIgnoreOwnAnnouncement(t *testing.T) { // update. No messages should be broadcast yet, since we don't have // the announcement signatures. select { - case err = <-ctx.gossiper.ProcessLocalAnnouncement(batch.chanAnn): + case err = <-tCtx.gossiper.ProcessLocalAnnouncement(batch.chanAnn): case <-time.After(2 * time.Second): t.Fatal("did not process local announcement") } require.NoError(t, err, "unable to process channel ann") select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("channel announcement was broadcast") case <-time.After(2 * trickleDelay): } select { - case err = <-ctx.gossiper.ProcessLocalAnnouncement(batch.chanUpdAnn1): + case err = <-tCtx.gossiper.ProcessLocalAnnouncement(batch.chanUpdAnn1): case <-time.After(2 * time.Second): t.Fatal("did not process local announcement") } require.NoError(t, err, "unable to process channel update") select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("channel update announcement was broadcast") case <-time.After(2 * trickleDelay): } select { - case err = <-ctx.gossiper.ProcessLocalAnnouncement(batch.nodeAnn1): + case err = <-tCtx.gossiper.ProcessLocalAnnouncement(batch.nodeAnn1): case <-time.After(2 * time.Second): t.Fatal("did not process local announcement") } require.NoError(t, err, "unable to process node ann") select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("node announcement was broadcast") case <-time.After(2 * trickleDelay): } // We should accept the remote's channel update and node announcement. select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement( - batch.chanUpdAnn2, remotePeer, + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.chanUpdAnn2, remotePeer, ): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") } require.NoError(t, err, "unable to process channel update") select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("channel update announcement was broadcast") case <-time.After(2 * trickleDelay): } select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement( - batch.nodeAnn2, remotePeer, + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.nodeAnn2, remotePeer, ): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") } require.NoError(t, err, "unable to process node ann") select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("node announcement was broadcast") case <-time.After(2 * trickleDelay): } @@ -4450,21 +4511,23 @@ func TestIgnoreOwnAnnouncement(t *testing.T) { // Now we exchange the proofs, the messages will be broadcasted to the // network. select { - case err = <-ctx.gossiper.ProcessLocalAnnouncement(batch.localProofAnn): + case err = <-tCtx.gossiper.ProcessLocalAnnouncement( + batch.localProofAnn, + ): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") } require.NoError(t, err, "unable to process local proof") select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: t.Fatal("announcements were broadcast") case <-time.After(2 * trickleDelay): } select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement( - batch.remoteProofAnn, remotePeer, + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.remoteProofAnn, remotePeer, ): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") @@ -4473,7 +4536,7 @@ func TestIgnoreOwnAnnouncement(t *testing.T) { for i := 0; i < 5; i++ { select { - case <-ctx.broadcastedMessage: + case <-tCtx.broadcastedMessage: case <-time.After(time.Second): t.Fatal("announcement wasn't broadcast") } @@ -4482,8 +4545,8 @@ func TestIgnoreOwnAnnouncement(t *testing.T) { // Finally, we again check that we'll ignore the remote giving us // announcements about our own channel. select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement( - batch.chanAnn, remotePeer, + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.chanAnn, remotePeer, ): case <-time.After(2 * time.Second): t.Fatal("did not process remote announcement") @@ -4498,13 +4561,14 @@ func TestIgnoreOwnAnnouncement(t *testing.T) { // error. func TestRejectCacheChannelAnn(t *testing.T) { t.Parallel() + ctx := context.Background() - ctx, err := createTestCtx(t, proofMatureDelta, false) + tCtx, err := createTestCtx(t, proofMatureDelta, false) require.NoError(t, err, "can't create context") // First, we create a channel announcement to send over to our test // peer. - batch, err := ctx.createRemoteAnnouncements(0) + batch, err := tCtx.createRemoteAnnouncements(0) require.NoError(t, err, "can't generate announcements") remoteKey, err := btcec.ParsePubKey(batch.nodeAnn2.NodeID[:]) @@ -4514,12 +4578,12 @@ func TestRejectCacheChannelAnn(t *testing.T) { // Before sending over the announcement, we'll modify it such that we // know it will always fail. chanID := batch.chanAnn.ShortChannelID.ToUint64() - ctx.router.queueValidationFail(chanID) + tCtx.router.queueValidationFail(chanID) // If we process the batch the first time we should get an error. select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement( - batch.chanAnn, remotePeer, + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.chanAnn, remotePeer, ): require.NotNil(t, err) case <-time.After(2 * time.Second): @@ -4529,8 +4593,8 @@ func TestRejectCacheChannelAnn(t *testing.T) { // If we process it a *second* time, then we should get an error saying // we rejected it already. select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement( - batch.chanAnn, remotePeer, + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, batch.chanAnn, remotePeer, ): errStr := err.Error() require.Contains(t, errStr, "recently rejected") @@ -4578,8 +4642,9 @@ func TestFutureMsgCacheEviction(t *testing.T) { // channel announcements are banned properly. func TestChanAnnBanningNonChanPeer(t *testing.T) { t.Parallel() + ctx := context.Background() - ctx, err := createTestCtx(t, 1000, false) + tCtx, err := createTestCtx(t, 1000, false) require.NoError(t, err, "can't create context") nodePeer1 := &mockPeer{ @@ -4594,15 +4659,15 @@ func TestChanAnnBanningNonChanPeer(t *testing.T) { // Craft a valid channel announcement for a channel we don't // have. We will ensure that it fails validation by modifying // the tx script. - ca, err := ctx.createRemoteChannelAnnouncement( + ca, err := tCtx.createRemoteChannelAnnouncement( uint32(i), withFundingTxPrep(fundingTxPrepTypeInvalidOutput), ) require.NoError(t, err, "can't create channel announcement") select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement( - ca, nodePeer1, + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, ca, nodePeer1, ): require.ErrorIs(t, err, ErrInvalidFundingOutput) @@ -4612,20 +4677,23 @@ func TestChanAnnBanningNonChanPeer(t *testing.T) { } // The peer should be banned now. - require.True(t, ctx.gossiper.isBanned(nodePeer1.PubKey())) + require.True(t, tCtx.gossiper.isBanned(nodePeer1.PubKey())) // Assert that nodePeer has been disconnected. require.True(t, nodePeer1.disconnected.Load()) // Mark the UTXO as spent so that we get the ErrChannelSpent error and // can thus tests that the gossiper ignores closed channels. - ca, err := ctx.createRemoteChannelAnnouncement( + ca, err := tCtx.createRemoteChannelAnnouncement( 101, withFundingTxPrep(fundingTxPrepTypeSpent), ) require.NoError(t, err, "can't create channel announcement") select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement(ca, nodePeer2): + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, ca, nodePeer2, + ): + require.ErrorIs(t, err, ErrChannelSpent) case <-time.After(2 * time.Second): @@ -4633,7 +4701,7 @@ func TestChanAnnBanningNonChanPeer(t *testing.T) { } // Check that the announcement's scid is marked as closed. - isClosed, err := ctx.gossiper.cfg.ScidCloser.IsClosedScid( + isClosed, err := tCtx.gossiper.cfg.ScidCloser.IsClosedScid( ca.ShortChannelID, ) require.Nil(t, err) @@ -4645,16 +4713,19 @@ func TestChanAnnBanningNonChanPeer(t *testing.T) { sourceToPub(nodePeer2.IdentityKey()), ) - ctx.gossiper.recentRejects.Delete(key) + tCtx.gossiper.recentRejects.Delete(key) // The validateFundingTransaction method will mark this channel // as a zombie if any error occurs in the chanvalidate.Validate call. // For the sake of the rest of the test, however, we mark it as live // here. - _ = ctx.router.MarkEdgeLive(ca.ShortChannelID) + _ = tCtx.router.MarkEdgeLive(ca.ShortChannelID) select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement(ca, nodePeer2): + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, ca, nodePeer2, + ): + require.ErrorContains(t, err, "ignoring closed channel") case <-time.After(2 * time.Second): @@ -4666,8 +4737,9 @@ func TestChanAnnBanningNonChanPeer(t *testing.T) { // get disconnected. func TestChanAnnBanningChanPeer(t *testing.T) { t.Parallel() + ctx := context.Background() - ctx, err := createTestCtx(t, 1000, true) + tCtx, err := createTestCtx(t, 1000, true) require.NoError(t, err, "can't create context") nodePeer := &mockPeer{remoteKeyPriv1.PubKey(), nil, nil, atomic.Bool{}} @@ -4677,15 +4749,15 @@ func TestChanAnnBanningChanPeer(t *testing.T) { // Craft a valid channel announcement for a channel we don't // have. We will ensure that it fails validation by modifying // the router. - ca, err := ctx.createRemoteChannelAnnouncement( + ca, err := tCtx.createRemoteChannelAnnouncement( uint32(i), withFundingTxPrep(fundingTxPrepTypeInvalidOutput), ) require.NoError(t, err, "can't create channel announcement") select { - case err = <-ctx.gossiper.ProcessRemoteAnnouncement( - ca, nodePeer, + case err = <-tCtx.gossiper.ProcessRemoteAnnouncement( + ctx, ca, nodePeer, ): require.ErrorIs(t, err, ErrInvalidFundingOutput) @@ -4695,7 +4767,7 @@ func TestChanAnnBanningChanPeer(t *testing.T) { } // The peer should be banned now. - require.True(t, ctx.gossiper.isBanned(nodePeer.PubKey())) + require.True(t, tCtx.gossiper.isBanned(nodePeer.PubKey())) // Assert that the peer wasn't disconnected. require.False(t, nodePeer.disconnected.Load()) diff --git a/discovery/syncer_test.go b/discovery/syncer_test.go index 32a90ae5fe..366dc26d27 100644 --- a/discovery/syncer_test.go +++ b/discovery/syncer_test.go @@ -454,7 +454,7 @@ func TestGossipSyncerApplyNoHistoricalGossipFilter(t *testing.T) { }() // We'll now attempt to apply the gossip filter for the remote peer. - syncer.ApplyGossipFilter(ctx, remoteHorizon) + require.NoError(t, syncer.ApplyGossipFilter(ctx, remoteHorizon)) // Ensure that the syncer's remote horizon was properly updated. if !reflect.DeepEqual(syncer.remoteUpdateHorizon, remoteHorizon) { diff --git a/peer/brontide.go b/peer/brontide.go index bfc603ae8a..4ba70323f2 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -1969,9 +1969,13 @@ func newChanMsgStream(p *Brontide, cid lnwire.ChannelID) *msgStream { // channel announcements. func newDiscMsgStream(p *Brontide) *msgStream { apply := func(msg lnwire.Message) { + // TODO(elle): thread contexts through the peer system properly + // so that a parent context can be passed in here. + ctx := context.TODO() + // TODO(yy): `ProcessRemoteAnnouncement` returns an error chan // and we need to process it. - p.cfg.AuthGossiper.ProcessRemoteAnnouncement(msg, p) + p.cfg.AuthGossiper.ProcessRemoteAnnouncement(ctx, msg, p) } return newMsgStream( From 3f9c554c4dbd615a817722e404c3ce95c7be3d0c Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 8 Apr 2025 07:05:00 +0200 Subject: [PATCH 09/41] discovery: pass context through to bootstrapper SampleNodeAddrs Since the ChannelGraphBootstrapper implementation makes a call to the graph DB. --- discovery/bootstrapper.go | 19 +++++++++++++------ server.go | 16 +++++++++------- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/discovery/bootstrapper.go b/discovery/bootstrapper.go index 0d370663d6..d07a5f852b 100644 --- a/discovery/bootstrapper.go +++ b/discovery/bootstrapper.go @@ -2,6 +2,7 @@ package discovery import ( "bytes" + "context" "crypto/rand" "crypto/sha256" "errors" @@ -36,8 +37,9 @@ type NetworkPeerBootstrapper interface { // denotes how many valid peer addresses to return. The passed set of // node nodes allows the caller to ignore a set of nodes perhaps // because they already have connections established. - SampleNodeAddrs(numAddrs uint32, - ignore map[autopilot.NodeID]struct{}) ([]*lnwire.NetAddress, error) + SampleNodeAddrs(ctx context.Context, numAddrs uint32, + ignore map[autopilot.NodeID]struct{}) ([]*lnwire.NetAddress, + error) // Name returns a human readable string which names the concrete // implementation of the NetworkPeerBootstrapper. @@ -50,7 +52,8 @@ type NetworkPeerBootstrapper interface { // bootstrapper will be queried successively until the target amount is met. If // the ignore map is populated, then the bootstrappers will be instructed to // skip those nodes. -func MultiSourceBootstrap(ignore map[autopilot.NodeID]struct{}, numAddrs uint32, +func MultiSourceBootstrap(ctx context.Context, + ignore map[autopilot.NodeID]struct{}, numAddrs uint32, bootstrappers ...NetworkPeerBootstrapper) ([]*lnwire.NetAddress, error) { // We'll randomly shuffle our bootstrappers before querying them in @@ -73,7 +76,9 @@ func MultiSourceBootstrap(ignore map[autopilot.NodeID]struct{}, numAddrs uint32, // the number of address remaining that we need to fetch. numAddrsLeft := numAddrs - uint32(len(addrs)) log.Tracef("Querying for %v addresses", numAddrsLeft) - netAddrs, err := bootstrapper.SampleNodeAddrs(numAddrsLeft, ignore) + netAddrs, err := bootstrapper.SampleNodeAddrs( + ctx, numAddrsLeft, ignore, + ) if err != nil { // If we encounter an error with a bootstrapper, then // we'll continue on to the next available @@ -152,7 +157,8 @@ func NewGraphBootstrapper(cg autopilot.ChannelGraph) (NetworkPeerBootstrapper, e // many valid peer addresses to return. // // NOTE: Part of the NetworkPeerBootstrapper interface. -func (c *ChannelGraphBootstrapper) SampleNodeAddrs(numAddrs uint32, +func (c *ChannelGraphBootstrapper) SampleNodeAddrs(_ context.Context, + numAddrs uint32, ignore map[autopilot.NodeID]struct{}) ([]*lnwire.NetAddress, error) { // We'll merge the ignore map with our currently selected map in order @@ -382,7 +388,8 @@ func (d *DNSSeedBootstrapper) fallBackSRVLookup(soaShim string, // network peer bootstrapper source. The num addrs field passed in denotes how // many valid peer addresses to return. The set of DNS seeds are used // successively to retrieve eligible target nodes. -func (d *DNSSeedBootstrapper) SampleNodeAddrs(numAddrs uint32, +func (d *DNSSeedBootstrapper) SampleNodeAddrs(_ context.Context, + numAddrs uint32, ignore map[autopilot.NodeID]struct{}) ([]*lnwire.NetAddress, error) { var netAddrs []*lnwire.NetAddress diff --git a/server.go b/server.go index f8beba1dd0..2f48e4a0da 100644 --- a/server.go +++ b/server.go @@ -2628,7 +2628,9 @@ func (s *server) Start(ctx context.Context) error { } s.wg.Add(1) - go s.peerBootstrapper(defaultMinPeers, bootstrappers) + go s.peerBootstrapper( + ctx, defaultMinPeers, bootstrappers, + ) } else { srvrLog.Infof("Auto peer bootstrapping is disabled") } @@ -3075,7 +3077,7 @@ func (s *server) createBootstrapIgnorePeers() map[autopilot.NodeID]struct{} { // invariant, we ensure that our node is connected to a diverse set of peers // and that nodes newly joining the network receive an up to date network view // as soon as possible. -func (s *server) peerBootstrapper(numTargetPeers uint32, +func (s *server) peerBootstrapper(ctx context.Context, numTargetPeers uint32, bootstrappers []discovery.NetworkPeerBootstrapper) { defer s.wg.Done() @@ -3085,7 +3087,7 @@ func (s *server) peerBootstrapper(numTargetPeers uint32, // We'll start off by aggressively attempting connections to peers in // order to be a part of the network as soon as possible. - s.initialPeerBootstrap(ignoreList, numTargetPeers, bootstrappers) + s.initialPeerBootstrap(ctx, ignoreList, numTargetPeers, bootstrappers) // Once done, we'll attempt to maintain our target minimum number of // peers. @@ -3163,7 +3165,7 @@ func (s *server) peerBootstrapper(numTargetPeers uint32, ignoreList = s.createBootstrapIgnorePeers() peerAddrs, err := discovery.MultiSourceBootstrap( - ignoreList, numNeeded*2, bootstrappers..., + ctx, ignoreList, numNeeded*2, bootstrappers..., ) if err != nil { srvrLog.Errorf("Unable to retrieve bootstrap "+ @@ -3212,8 +3214,8 @@ const bootstrapBackOffCeiling = time.Minute * 5 // initialPeerBootstrap attempts to continuously connect to peers on startup // until the target number of peers has been reached. This ensures that nodes // receive an up to date network view as soon as possible. -func (s *server) initialPeerBootstrap(ignore map[autopilot.NodeID]struct{}, - numTargetPeers uint32, +func (s *server) initialPeerBootstrap(ctx context.Context, + ignore map[autopilot.NodeID]struct{}, numTargetPeers uint32, bootstrappers []discovery.NetworkPeerBootstrapper) { srvrLog.Debugf("Init bootstrap with targetPeers=%v, bootstrappers=%v, "+ @@ -3272,7 +3274,7 @@ func (s *server) initialPeerBootstrap(ignore map[autopilot.NodeID]struct{}, // in order to reach our target. peersNeeded := numTargetPeers - numActivePeers bootstrapAddrs, err := discovery.MultiSourceBootstrap( - ignore, peersNeeded, bootstrappers..., + ctx, ignore, peersNeeded, bootstrappers..., ) if err != nil { srvrLog.Errorf("Unable to retrieve initial bootstrap "+ From 67a81a2b1c341ba28f0f177e7dd2728f5e8a49bb Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 9 Apr 2025 12:18:45 +0200 Subject: [PATCH 10/41] discovery: remove unnecessary context.Background() calls --- discovery/syncer.go | 33 +++++++++++++-------------------- discovery/syncer_test.go | 4 ++-- 2 files changed, 15 insertions(+), 22 deletions(-) diff --git a/discovery/syncer.go b/discovery/syncer.go index 16a2fa720c..abbf61923c 100644 --- a/discovery/syncer.go +++ b/discovery/syncer.go @@ -536,7 +536,7 @@ func (g *GossipSyncer) channelGraphSyncer(ctx context.Context) { // First, we'll attempt to continue our channel // synchronization by continuing to send off another // query chunk. - done := g.synchronizeChanIDs() + done := g.synchronizeChanIDs(ctx) // If this wasn't our last query, then we'll need to // transition to our waiting state. @@ -596,7 +596,7 @@ func (g *GossipSyncer) channelGraphSyncer(ctx context.Context) { syncType.IsActiveSync() { err := g.sendGossipTimestampRange( - time.Now(), math.MaxUint32, + ctx, time.Now(), math.MaxUint32, ) if err != nil { log.Errorf("Unable to send update "+ @@ -616,7 +616,7 @@ func (g *GossipSyncer) channelGraphSyncer(ctx context.Context) { case syncerIdle: select { case req := <-g.syncTransitionReqs: - req.errChan <- g.handleSyncTransition(req) + req.errChan <- g.handleSyncTransition(ctx, req) case req := <-g.historicalSyncReqs: g.handleHistoricalSync(req) @@ -662,8 +662,8 @@ func (g *GossipSyncer) replyHandler(ctx context.Context) { // sendGossipTimestampRange constructs and sets a GossipTimestampRange for the // syncer and sends it to the remote peer. -func (g *GossipSyncer) sendGossipTimestampRange(firstTimestamp time.Time, - timestampRange uint32) error { +func (g *GossipSyncer) sendGossipTimestampRange(ctx context.Context, + firstTimestamp time.Time, timestampRange uint32) error { endTimestamp := firstTimestamp.Add( time.Duration(timestampRange) * time.Second, @@ -678,7 +678,6 @@ func (g *GossipSyncer) sendGossipTimestampRange(firstTimestamp time.Time, TimestampRange: timestampRange, } - ctx, _ := g.cg.Create(context.Background()) if err := g.cfg.sendToPeer(ctx, localUpdateHorizon); err != nil { return err } @@ -698,7 +697,7 @@ func (g *GossipSyncer) sendGossipTimestampRange(firstTimestamp time.Time, // been queried for with a response received. We'll chunk our requests as // required to ensure they fit into a single message. We may re-renter this // state in the case that chunking is required. -func (g *GossipSyncer) synchronizeChanIDs() bool { +func (g *GossipSyncer) synchronizeChanIDs(ctx context.Context) bool { // If we're in this state yet there are no more new channels to query // for, then we'll transition to our final synced state and return true // to signal that we're fully synchronized. @@ -735,7 +734,6 @@ func (g *GossipSyncer) synchronizeChanIDs() bool { // With our chunk obtained, we'll send over our next query, then return // false indicating that we're net yet fully synced. - ctx, _ := g.cg.Create(context.Background()) err := g.cfg.sendToPeer(ctx, &lnwire.QueryShortChanIDs{ ChainHash: g.cfg.chainHash, EncodingType: lnwire.EncodingSortedPlain, @@ -1037,7 +1035,7 @@ func (g *GossipSyncer) replyPeerQueries(ctx context.Context, // meet the channel range, then chunk our responses to the remote node. We also // ensure that our final fragment carries the "complete" bit to indicate the // end of our streaming response. -func (g *GossipSyncer) replyChanRangeQuery(_ context.Context, +func (g *GossipSyncer) replyChanRangeQuery(ctx context.Context, query *lnwire.QueryChannelRange) error { // Before responding, we'll check to ensure that the remote peer is @@ -1049,8 +1047,6 @@ func (g *GossipSyncer) replyChanRangeQuery(_ context.Context, "chain=%v, we're on chain=%v", query.ChainHash, g.cfg.chainHash) - ctx, _ := g.cg.Create(context.Background()) - return g.cfg.sendToPeerSync(ctx, &lnwire.ReplyChannelRange{ ChainHash: query.ChainHash, FirstBlockHeight: query.FirstBlockHeight, @@ -1124,8 +1120,6 @@ func (g *GossipSyncer) replyChanRangeQuery(_ context.Context, ) } - ctx, _ := g.cg.Create(context.Background()) - return g.cfg.sendToPeerSync(ctx, &lnwire.ReplyChannelRange{ ChainHash: query.ChainHash, NumBlocks: numBlocks, @@ -1263,7 +1257,6 @@ func (g *GossipSyncer) replyShortChanIDs(ctx context.Context, // each one individually and synchronously to throttle the sends and // perform buffering of responses in the syncer as opposed to the peer. for _, msg := range replyMsgs { - ctx, _ := g.cg.Create(context.Background()) err := g.cfg.sendToPeerSync(ctx, msg) if err != nil { return err @@ -1281,7 +1274,7 @@ func (g *GossipSyncer) replyShortChanIDs(ctx context.Context, // ApplyGossipFilter applies a gossiper filter sent by the remote node to the // state machine. Once applied, we'll ensure that we don't forward any messages // to the peer that aren't within the time range of the filter. -func (g *GossipSyncer) ApplyGossipFilter(_ context.Context, +func (g *GossipSyncer) ApplyGossipFilter(ctx context.Context, filter *lnwire.GossipTimestampRange) error { g.Lock() @@ -1340,7 +1333,6 @@ func (g *GossipSyncer) ApplyGossipFilter(_ context.Context, defer returnSema() for _, msg := range newUpdatestoSend { - ctx, _ := g.cg.Create(context.Background()) err := g.cfg.sendToPeerSync(ctx, msg) switch { case err == ErrGossipSyncerExiting: @@ -1362,7 +1354,7 @@ func (g *GossipSyncer) ApplyGossipFilter(_ context.Context, // FilterGossipMsgs takes a set of gossip messages, and only send it to a peer // iff the message is within the bounds of their set gossip filter. If the peer // doesn't have a gossip filter set, then no messages will be forwarded. -func (g *GossipSyncer) FilterGossipMsgs(_ context.Context, +func (g *GossipSyncer) FilterGossipMsgs(ctx context.Context, msgs ...msgWithSenders) { // If the peer doesn't have an update horizon set, then we won't send @@ -1485,7 +1477,6 @@ func (g *GossipSyncer) FilterGossipMsgs(_ context.Context, return } - ctx, _ := g.cg.Create(context.Background()) if err = g.cfg.sendToPeer(ctx, msgsToSend...); err != nil { log.Errorf("unable to send gossip msgs: %v", err) } @@ -1586,7 +1577,9 @@ func (g *GossipSyncer) ProcessSyncTransition(newSyncType SyncerType) error { // // NOTE: The gossip syncer might have another sync state as a result of this // transition. -func (g *GossipSyncer) handleSyncTransition(req *syncTransitionReq) error { +func (g *GossipSyncer) handleSyncTransition(ctx context.Context, + req *syncTransitionReq) error { + // Return early from any NOP sync transitions. syncType := g.SyncType() if syncType == req.newSyncType { @@ -1621,7 +1614,7 @@ func (g *GossipSyncer) handleSyncTransition(req *syncTransitionReq) error { req.newSyncType) } - err := g.sendGossipTimestampRange(firstTimestamp, timestampRange) + err := g.sendGossipTimestampRange(ctx, firstTimestamp, timestampRange) if err != nil { return fmt.Errorf("unable to send local update horizon: %w", err) diff --git a/discovery/syncer_test.go b/discovery/syncer_test.go index 366dc26d27..13071d4b01 100644 --- a/discovery/syncer_test.go +++ b/discovery/syncer_test.go @@ -1495,7 +1495,7 @@ func TestGossipSyncerSynchronizeChanIDs(t *testing.T) { for i := 0; i < chunkSize*2; i += 2 { // With our set up complete, we'll request a sync of chan ID's. - done := syncer.synchronizeChanIDs() + done := syncer.synchronizeChanIDs(context.Background()) // At this point, we shouldn't yet be done as only 2 items // should have been queried for. @@ -1542,7 +1542,7 @@ func TestGossipSyncerSynchronizeChanIDs(t *testing.T) { } // If we issue another query, the syncer should tell us that it's done. - done := syncer.synchronizeChanIDs() + done := syncer.synchronizeChanIDs(context.Background()) if done { t.Fatalf("syncer should be finished!") } From 350de0711fbc4c043df8e3913a7f49f235ec3250 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 9 Apr 2025 15:08:46 +0200 Subject: [PATCH 11/41] discovery: listen on ctx in any select For any method that takes a context that has a select that listens on the systems quit channel, we should also listen on the ctx since we should not need to worry about if this context is derived internally or externally. --- discovery/sync_manager.go | 3 +++ discovery/syncer.go | 16 ++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/discovery/sync_manager.go b/discovery/sync_manager.go index c825d27fd2..6ed96ac015 100644 --- a/discovery/sync_manager.go +++ b/discovery/sync_manager.go @@ -538,6 +538,9 @@ func (m *SyncManager) syncerHandler(ctx context.Context) { case <-m.quit: return + + case <-ctx.Done(): + return } } } diff --git a/discovery/syncer.go b/discovery/syncer.go index abbf61923c..0b4e7030b3 100644 --- a/discovery/syncer.go +++ b/discovery/syncer.go @@ -527,6 +527,9 @@ func (g *GossipSyncer) channelGraphSyncer(ctx context.Context) { case <-g.cg.Done(): return + + case <-ctx.Done(): + return } // We'll enter this state once we've discovered which channels @@ -577,6 +580,9 @@ func (g *GossipSyncer) channelGraphSyncer(ctx context.Context) { case <-g.cg.Done(): return + + case <-ctx.Done(): + return } // This is our final terminal state where we'll only reply to @@ -623,6 +629,9 @@ func (g *GossipSyncer) channelGraphSyncer(ctx context.Context) { case <-g.cg.Done(): return + + case <-ctx.Done(): + return } } } @@ -656,6 +665,9 @@ func (g *GossipSyncer) replyHandler(ctx context.Context) { case <-g.cg.Done(): return + + case <-ctx.Done(): + return } } } @@ -1298,6 +1310,8 @@ func (g *GossipSyncer) ApplyGossipFilter(ctx context.Context, case <-g.syncerSema: case <-g.cg.Done(): return ErrGossipSyncerExiting + case <-ctx.Done(): + return ctx.Err() } // We don't put this in a defer because if the goroutine is launched, @@ -1370,6 +1384,8 @@ func (g *GossipSyncer) FilterGossipMsgs(ctx context.Context, select { case <-g.cg.Done(): return + case <-ctx.Done(): + return default: } From 6157399a56266a687a4f79fc07f2bc5701d1f7f0 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sat, 5 Apr 2025 16:37:42 +0200 Subject: [PATCH 12/41] graph/db: test clean-up This commit cleans up the graph test code by removing unused kvdb type parameters from the `createTextVertex` and `createLightningNode` helper methods. We also pass in the testing parameter now so that we dont need to check the error each time we call `createTestVertex`. --- graph/db/graph_test.go | 179 +++++++++++++++-------------------------- 1 file changed, 63 insertions(+), 116 deletions(-) diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index 754da5eef3..e2543f443e 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -66,9 +66,7 @@ var ( } ) -func createLightningNode(_ kvdb.Backend, priv *btcec.PrivateKey) ( - *models.LightningNode, error) { - +func createLightningNode(priv *btcec.PrivateKey) *models.LightningNode { updateTime := prand.Int63() pub := priv.PubKey().SerializeCompressed() @@ -83,16 +81,16 @@ func createLightningNode(_ kvdb.Backend, priv *btcec.PrivateKey) ( } copy(n.PubKeyBytes[:], priv.PubKey().SerializeCompressed()) - return n, nil + return n } -func createTestVertex(db kvdb.Backend) (*models.LightningNode, error) { +func createTestVertex(t testing.TB) *models.LightningNode { + t.Helper() + priv, err := btcec.NewPrivateKey() - if err != nil { - return nil, err - } + require.NoError(t, err) - return createLightningNode(db, priv) + return createLightningNode(priv) } func TestNodeInsertionAndDeletion(t *testing.T) { @@ -229,8 +227,7 @@ func TestAliasLookup(t *testing.T) { // We'd like to test the alias index within the database, so first // create a new test node. - testNode, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") + testNode := createTestVertex(t) // Add the node to the graph's database, this should also insert an // entry into the alias index for this node. @@ -250,8 +247,7 @@ func TestAliasLookup(t *testing.T) { } // Ensure that looking up a non-existent alias results in an error. - node, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") + node := createTestVertex(t) nodePub, err = node.PubKey() require.NoError(t, err, "unable to generate pubkey") _, err = graph.LookupAlias(nodePub) @@ -268,8 +264,7 @@ func TestSourceNode(t *testing.T) { // We'd like to test the setting/getting of the source node, so we // first create a fake node to use within the test. - testNode, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") + testNode := createTestVertex(t) // Attempt to fetch the source node, this should return an error as the // source node hasn't yet been set. @@ -300,10 +295,8 @@ func TestEdgeInsertionDeletion(t *testing.T) { // We'd like to test the insertion/deletion of edges, so we create two // vertexes to connect. - node1, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") - node2, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") + node1 := createTestVertex(t) + node2 := createTestVertex(t) // In addition to the fake vertexes we create some fake channel // identifiers. @@ -424,18 +417,15 @@ func TestDisconnectBlockAtHeight(t *testing.T) { graph, err := MakeTestGraph(t) require.NoError(t, err, "unable to make test database") - sourceNode, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create source node") + sourceNode := createTestVertex(t) if err := graph.SetSourceNode(sourceNode); err != nil { t.Fatalf("unable to set source node: %v", err) } // We'd like to test the insertion/deletion of edges, so we create two // vertexes to connect. - node1, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") - node2, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") + node1 := createTestVertex(t) + node2 := createTestVertex(t) // In addition to the fake vertexes we create some fake channel // identifiers. @@ -691,14 +681,12 @@ func TestEdgeInfoUpdates(t *testing.T) { // We'd like to test the update of edges inserted into the database, so // we create two vertexes to connect. - node1, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") + node1 := createTestVertex(t) if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } assertNodeInCache(t, graph, node1, testFeatures) - node2, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") + node2 := createTestVertex(t) if err := graph.AddLightningNode(node2); err != nil { t.Fatalf("unable to add node: %v", err) } @@ -1219,14 +1207,13 @@ func TestGraphCacheTraversal(t *testing.T) { require.Equal(t, numChannels*2*(numNodes-1), numNodeChans) } -func fillTestGraph(t require.TestingT, graph *ChannelGraph, numNodes, +func fillTestGraph(t testing.TB, graph *ChannelGraph, numNodes, numChannels int) (map[uint64]struct{}, []*models.LightningNode) { nodes := make([]*models.LightningNode, numNodes) nodeIndex := map[string]struct{}{} for i := 0; i < numNodes; i++ { - node, err := createTestVertex(graph.db) - require.NoError(t, err) + node := createTestVertex(t) nodes[i] = node nodeIndex[node.Alias] = struct{}{} @@ -1418,8 +1405,7 @@ func TestGraphPruning(t *testing.T) { graph, err := MakeTestGraph(t) require.NoError(t, err, "unable to make test database") - sourceNode, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create source node") + sourceNode := createTestVertex(t) if err := graph.SetSourceNode(sourceNode); err != nil { t.Fatalf("unable to set source node: %v", err) } @@ -1430,10 +1416,7 @@ func TestGraphPruning(t *testing.T) { const numNodes = 5 graphNodes := make([]*models.LightningNode, numNodes) for i := 0; i < numNodes; i++ { - node, err := createTestVertex(graph.db) - if err != nil { - t.Fatalf("unable to create node: %v", err) - } + node := createTestVertex(t) if err := graph.AddLightningNode(node); err != nil { t.Fatalf("unable to add node: %v", err) @@ -1623,10 +1606,8 @@ func TestHighestChanID(t *testing.T) { // Next, we'll insert two channels into the database, with each channel // connecting the same two nodes. - node1, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") - node2, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") + node1 := createTestVertex(t) + node2 := createTestVertex(t) // The first channel with be at height 10, while the other will be at // height 100. @@ -1686,13 +1667,11 @@ func TestChanUpdatesInHorizon(t *testing.T) { } // We'll start by creating two nodes which will seed our test graph. - node1, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") + node1 := createTestVertex(t) if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } - node2, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") + node2 := createTestVertex(t) if err := graph.AddLightningNode(node2); err != nil { t.Fatalf("unable to add node: %v", err) } @@ -1854,10 +1833,7 @@ func TestNodeUpdatesInHorizon(t *testing.T) { const numNodes = 10 nodeAnns := make([]models.LightningNode, 0, numNodes) for i := 0; i < numNodes; i++ { - nodeAnn, err := createTestVertex(graph.db) - if err != nil { - t.Fatalf("unable to create test vertex: %v", err) - } + nodeAnn := createTestVertex(t) // The node ann will use the current end time as its last // update them, then we'll add 10 seconds in order to create @@ -2051,13 +2027,11 @@ func TestFilterKnownChanIDs(t *testing.T) { }, filteredIDs) // We'll start by creating two nodes which will seed our test graph. - node1, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") + node1 := createTestVertex(t) if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } - node2, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") + node2 := createTestVertex(t) if err := graph.AddLightningNode(node2); err != nil { t.Fatalf("unable to add node: %v", err) } @@ -2203,11 +2177,11 @@ func TestStressTestChannelGraphAPI(t *testing.T) { graph, err := MakeTestGraph(t) require.NoError(t, err) - node1, err := createTestVertex(graph.db) + node1 := createTestVertex(t) require.NoError(t, err, "unable to create test node") require.NoError(t, graph.AddLightningNode(node1)) - node2, err := createTestVertex(graph.db) + node2 := createTestVertex(t) require.NoError(t, err, "unable to create test node") require.NoError(t, graph.AddLightningNode(node2)) @@ -2493,12 +2467,10 @@ func TestFilterChannelRange(t *testing.T) { // We'll first populate our graph with two nodes. All channels created // below will be made between these two nodes. - node1, err := createTestVertex(graph.db) - require.NoError(t, err) + node1 := createTestVertex(t) require.NoError(t, graph.AddLightningNode(node1)) - node2, err := createTestVertex(graph.db) - require.NoError(t, err) + node2 := createTestVertex(t) require.NoError(t, graph.AddLightningNode(node2)) // If we try to filter a channel range before we have any channels @@ -2712,13 +2684,11 @@ func TestFetchChanInfos(t *testing.T) { // We'll first populate our graph with two nodes. All channels created // below will be made between these two nodes. - node1, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") + node1 := createTestVertex(t) if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } - node2, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") + node2 := createTestVertex(t) if err := graph.AddLightningNode(node2); err != nil { t.Fatalf("unable to add node: %v", err) } @@ -2820,13 +2790,11 @@ func TestIncompleteChannelPolicies(t *testing.T) { require.NoError(t, err, "unable to make test database") // Create two nodes. - node1, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") + node1 := createTestVertex(t) if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } - node2, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") + node2 := createTestVertex(t) if err := graph.AddLightningNode(node2); err != nil { t.Fatalf("unable to add node: %v", err) } @@ -2921,21 +2889,18 @@ func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { graph, err := MakeTestGraph(t) require.NoError(t, err, "unable to make test database") - sourceNode, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create source node") + sourceNode := createTestVertex(t) if err := graph.SetSourceNode(sourceNode); err != nil { t.Fatalf("unable to set source node: %v", err) } // We'll first populate our graph with two nodes. All channels created // below will be made between these two nodes. - node1, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") + node1 := createTestVertex(t) if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } - node2, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") + node2 := createTestVertex(t) if err := graph.AddLightningNode(node2); err != nil { t.Fatalf("unable to add node: %v", err) } @@ -3068,8 +3033,7 @@ func TestPruneGraphNodes(t *testing.T) { // We'll start off by inserting our source node, to ensure that it's // the only node left after we prune the graph. - sourceNode, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create source node") + sourceNode := createTestVertex(t) if err := graph.SetSourceNode(sourceNode); err != nil { t.Fatalf("unable to set source node: %v", err) } @@ -3077,18 +3041,15 @@ func TestPruneGraphNodes(t *testing.T) { // With the source node inserted, we'll now add three nodes to the // channel graph, at the end of the scenario, only two of these nodes // should still be in the graph. - node1, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") + node1 := createTestVertex(t) if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } - node2, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") + node2 := createTestVertex(t) if err := graph.AddLightningNode(node2); err != nil { t.Fatalf("unable to add node: %v", err) } - node3, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") + node3 := createTestVertex(t) if err := graph.AddLightningNode(node3); err != nil { t.Fatalf("unable to add node: %v", err) } @@ -3140,13 +3101,11 @@ func TestAddChannelEdgeShellNodes(t *testing.T) { // To start, we'll create two nodes, and only add one of them to the // channel graph. - node1, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") + node1 := createTestVertex(t) if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } - node2, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") + node2 := createTestVertex(t) // We'll now create an edge between the two nodes, as a result, node2 // should be inserted into the database as a shell node. @@ -3179,8 +3138,7 @@ func TestNodePruningUpdateIndexDeletion(t *testing.T) { // We'll first populate our graph with a single node that will be // removed shortly. - node1, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") + node1 := createTestVertex(t) if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } @@ -3235,24 +3193,21 @@ func TestNodeIsPublic(t *testing.T) { // some graphs but not others, etc.). aliceGraph, err := MakeTestGraph(t) require.NoError(t, err, "unable to make test database") - aliceNode, err := createTestVertex(aliceGraph.db) - require.NoError(t, err, "unable to create test node") + aliceNode := createTestVertex(t) if err := aliceGraph.SetSourceNode(aliceNode); err != nil { t.Fatalf("unable to set source node: %v", err) } bobGraph, err := MakeTestGraph(t) require.NoError(t, err, "unable to make test database") - bobNode, err := createTestVertex(bobGraph.db) - require.NoError(t, err, "unable to create test node") + bobNode := createTestVertex(t) if err := bobGraph.SetSourceNode(bobNode); err != nil { t.Fatalf("unable to set source node: %v", err) } carolGraph, err := MakeTestGraph(t) require.NoError(t, err, "unable to make test database") - carolNode, err := createTestVertex(carolGraph.db) - require.NoError(t, err, "unable to create test node") + carolNode := createTestVertex(t) if err := carolGraph.SetSourceNode(carolNode); err != nil { t.Fatalf("unable to set source node: %v", err) } @@ -3370,15 +3325,13 @@ func TestDisabledChannelIDs(t *testing.T) { require.NoError(t, err, "unable to make test database") // Create first node and add it to the graph. - node1, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") + node1 := createTestVertex(t) if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } // Create second node and add it to the graph. - node2, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") + node2 := createTestVertex(t) if err := graph.AddLightningNode(node2); err != nil { t.Fatalf("unable to add node: %v", err) } @@ -3457,13 +3410,11 @@ func TestEdgePolicyMissingMaxHtcl(t *testing.T) { // We'd like to test the update of edges inserted into the database, so // we create two vertexes to connect. - node1, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") + node1 := createTestVertex(t) if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } - node2, err := createTestVertex(graph.db) - require.NoError(t, err, "unable to create test node") + node2 := createTestVertex(t) edgeInfo, edge1, edge2 := createChannelEdge(graph.db, node1, node2) if err := graph.AddLightningNode(node2); err != nil { @@ -3618,9 +3569,9 @@ func TestGraphZombieIndex(t *testing.T) { graph, err := MakeTestGraph(t) require.NoError(t, err, "unable to create test database") - node1, err := createTestVertex(graph.db) + node1 := createTestVertex(t) require.NoError(t, err, "unable to create test vertex") - node2, err := createTestVertex(graph.db) + node2 := createTestVertex(t) require.NoError(t, err, "unable to create test vertex") // Swap the nodes if the second's pubkey is smaller than the first. @@ -3787,11 +3738,7 @@ func TestLightningNodeSigVerification(t *testing.T) { } // Create a LightningNode from the same private key. - graph, err := MakeTestGraph(t) - require.NoError(t, err, "unable to make test database") - - node, err := createLightningNode(graph.db, priv) - require.NoError(t, err, "unable to create node") + node := createLightningNode(priv) // And finally check that we can verify the same signature from the // pubkey returned from the lightning node. @@ -3828,16 +3775,16 @@ func TestBatchedAddChannelEdge(t *testing.T) { graph, err := MakeTestGraph(t) require.Nil(t, err) - sourceNode, err := createTestVertex(graph.db) + sourceNode := createTestVertex(t) require.Nil(t, err) err = graph.SetSourceNode(sourceNode) require.Nil(t, err) // We'd like to test the insertion/deletion of edges, so we create two // vertexes to connect. - node1, err := createTestVertex(graph.db) + node1 := createTestVertex(t) require.Nil(t, err) - node2, err := createTestVertex(graph.db) + node2 := createTestVertex(t) require.Nil(t, err) // In addition to the fake vertexes we create some fake channel @@ -3911,11 +3858,11 @@ func TestBatchedUpdateEdgePolicy(t *testing.T) { // We'd like to test the update of edges inserted into the database, so // we create two vertexes to connect. - node1, err := createTestVertex(graph.db) + node1 := createTestVertex(t) require.Nil(t, err) err = graph.AddLightningNode(node1) require.Nil(t, err) - node2, err := createTestVertex(graph.db) + node2 := createTestVertex(t) require.Nil(t, err) err = graph.AddLightningNode(node2) require.Nil(t, err) @@ -4022,11 +3969,11 @@ func TestGraphCacheForEachNodeChannel(t *testing.T) { // option turned off. graph.graphCache = nil - node1, err := createTestVertex(graph.db) + node1 := createTestVertex(t) require.Nil(t, err) err = graph.AddLightningNode(node1) require.Nil(t, err) - node2, err := createTestVertex(graph.db) + node2 := createTestVertex(t) require.Nil(t, err) err = graph.AddLightningNode(node2) require.Nil(t, err) From b624a6a74ed5821cb249d99e8b45d036863a07c1 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sat, 5 Apr 2025 16:40:34 +0200 Subject: [PATCH 13/41] graph/db: remove kvdb param from test helper Remove the kvdb.Backend parameter from the `createChannelEdge` helper. This is all in preparation for having the unit tests run against any DB backend. --- graph/db/graph_test.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index e2543f443e..eb16684cae 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -597,7 +597,7 @@ func assertEdgeInfoEqual(t *testing.T, e1 *models.ChannelEdgeInfo, } } -func createChannelEdge(db kvdb.Backend, node1, node2 *models.LightningNode) ( +func createChannelEdge(node1, node2 *models.LightningNode) ( *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) { @@ -693,7 +693,7 @@ func TestEdgeInfoUpdates(t *testing.T) { assertNodeInCache(t, graph, node2, testFeatures) // Create an edge and add it to the db. - edgeInfo, edge1, edge2 := createChannelEdge(graph.db, node1, node2) + edgeInfo, edge1, edge2 := createChannelEdge(node1, node2) // Make sure inserting the policy at this point, before the edge info // is added, will fail. @@ -3337,7 +3337,7 @@ func TestDisabledChannelIDs(t *testing.T) { } // Adding a new channel edge to the graph. - edgeInfo, edge1, edge2 := createChannelEdge(graph.db, node1, node2) + edgeInfo, edge1, edge2 := createChannelEdge(node1, node2) if err := graph.AddLightningNode(node2); err != nil { t.Fatalf("unable to add node: %v", err) } @@ -3416,7 +3416,7 @@ func TestEdgePolicyMissingMaxHtcl(t *testing.T) { } node2 := createTestVertex(t) - edgeInfo, edge1, edge2 := createChannelEdge(graph.db, node1, node2) + edgeInfo, edge1, edge2 := createChannelEdge(node1, node2) if err := graph.AddLightningNode(node2); err != nil { t.Fatalf("unable to add node: %v", err) } @@ -3580,7 +3580,7 @@ func TestGraphZombieIndex(t *testing.T) { node1, node2 = node2, node1 } - edge, _, _ := createChannelEdge(graph.db, node1, node2) + edge, _, _ := createChannelEdge(node1, node2) require.NoError(t, graph.AddChannelEdge(edge)) // Since the edge is known the graph and it isn't a zombie, IsZombieEdge @@ -3868,7 +3868,7 @@ func TestBatchedUpdateEdgePolicy(t *testing.T) { require.Nil(t, err) // Create an edge and add it to the db. - edgeInfo, edge1, edge2 := createChannelEdge(graph.db, node1, node2) + edgeInfo, edge1, edge2 := createChannelEdge(node1, node2) // Make sure inserting the policy at this point, before the edge info // is added, will fail. @@ -3979,7 +3979,7 @@ func TestGraphCacheForEachNodeChannel(t *testing.T) { require.Nil(t, err) // Create an edge and add it to the db. - edgeInfo, e1, e2 := createChannelEdge(graph.db, node1, node2) + edgeInfo, e1, e2 := createChannelEdge(node1, node2) // Because of lexigraphical sorting and the usage of random node keys in // this test, we need to determine which edge belongs to node 1 at From e46c9f88c014812b29245efd4eb795a44d700243 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sat, 5 Apr 2025 16:45:19 +0200 Subject: [PATCH 14/41] graph/db: remove kvdb.Backend from test helpers Remove unused kvdb.Backend param from `randEdgePolicy` and `newEdgePolicy` test helpers. --- graph/db/graph_test.go | 42 ++++++++++++++++-------------------------- 1 file changed, 16 insertions(+), 26 deletions(-) diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index eb16684cae..155429b62d 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -954,10 +954,10 @@ func assertEdgeWithPolicyInCache(t *testing.T, g *ChannelGraph, } } -func randEdgePolicy(chanID uint64, db kvdb.Backend) *models.ChannelEdgePolicy { +func randEdgePolicy(chanID uint64) *models.ChannelEdgePolicy { update := prand.Int63() - return newEdgePolicy(chanID, db, update) + return newEdgePolicy(chanID, update) } func copyEdgePolicy(p *models.ChannelEdgePolicy) *models.ChannelEdgePolicy { @@ -977,9 +977,7 @@ func copyEdgePolicy(p *models.ChannelEdgePolicy) *models.ChannelEdgePolicy { } } -func newEdgePolicy(chanID uint64, db kvdb.Backend, - updateTime int64) *models.ChannelEdgePolicy { - +func newEdgePolicy(chanID uint64, updateTime int64) *models.ChannelEdgePolicy { return &models.ChannelEdgePolicy{ ChannelID: chanID, LastUpdate: time.Unix(updateTime, 0), @@ -1277,7 +1275,7 @@ func fillTestGraph(t testing.TB, graph *ChannelGraph, numNodes, // Create and add an edge with random data that points // from node1 -> node2. - edge := randEdgePolicy(chanID, graph.db) + edge := randEdgePolicy(chanID) edge.ChannelFlags = 0 edge.ToNode = node2.PubKeyBytes edge.SigBytes = testSig.Serialize() @@ -1285,7 +1283,7 @@ func fillTestGraph(t testing.TB, graph *ChannelGraph, numNodes, // Create another random edge that points from // node2 -> node1 this time. - edge = randEdgePolicy(chanID, graph.db) + edge = randEdgePolicy(chanID) edge.ChannelFlags = 1 edge.ToNode = node1.PubKeyBytes edge.SigBytes = testSig.Serialize() @@ -1476,7 +1474,7 @@ func TestGraphPruning(t *testing.T) { // Create and add an edge with random data that points from // node_i -> node_i+1 - edge := randEdgePolicy(chanID, graph.db) + edge := randEdgePolicy(chanID) edge.ChannelFlags = 0 edge.ToNode = graphNodes[i].PubKeyBytes edge.SigBytes = testSig.Serialize() @@ -1486,7 +1484,7 @@ func TestGraphPruning(t *testing.T) { // Create another random edge that points from node_i+1 -> // node_i this time. - edge = randEdgePolicy(chanID, graph.db) + edge = randEdgePolicy(chanID) edge.ChannelFlags = 1 edge.ToNode = graphNodes[i].PubKeyBytes edge.SigBytes = testSig.Serialize() @@ -1696,7 +1694,7 @@ func TestChanUpdatesInHorizon(t *testing.T) { endTime = endTime.Add(time.Second * 10) edge1 := newEdgePolicy( - chanID.ToUint64(), graph.db, edge1UpdateTime.Unix(), + chanID.ToUint64(), edge1UpdateTime.Unix(), ) edge1.ChannelFlags = 0 edge1.ToNode = node2.PubKeyBytes @@ -1706,7 +1704,7 @@ func TestChanUpdatesInHorizon(t *testing.T) { } edge2 := newEdgePolicy( - chanID.ToUint64(), graph.db, edge2UpdateTime.Unix(), + chanID.ToUint64(), edge2UpdateTime.Unix(), ) edge2.ChannelFlags = 1 edge2.ToNode = node1.PubKeyBytes @@ -2712,9 +2710,7 @@ func TestFetchChanInfos(t *testing.T) { updateTime := endTime endTime = updateTime.Add(time.Second * 10) - edge1 := newEdgePolicy( - chanID.ToUint64(), graph.db, updateTime.Unix(), - ) + edge1 := newEdgePolicy(chanID.ToUint64(), updateTime.Unix()) edge1.ChannelFlags = 0 edge1.ToNode = node2.PubKeyBytes edge1.SigBytes = testSig.Serialize() @@ -2722,9 +2718,7 @@ func TestFetchChanInfos(t *testing.T) { t.Fatalf("unable to update edge: %v", err) } - edge2 := newEdgePolicy( - chanID.ToUint64(), graph.db, updateTime.Unix(), - ) + edge2 := newEdgePolicy(chanID.ToUint64(), updateTime.Unix()) edge2.ChannelFlags = 1 edge2.ToNode = node1.PubKeyBytes edge2.SigBytes = testSig.Serialize() @@ -2851,9 +2845,7 @@ func TestIncompleteChannelPolicies(t *testing.T) { // unknown. updateTime := time.Unix(1234, 0) - edgePolicy := newEdgePolicy( - chanID.ToUint64(), graph.db, updateTime.Unix(), - ) + edgePolicy := newEdgePolicy(chanID.ToUint64(), updateTime.Unix()) edgePolicy.ChannelFlags = 0 edgePolicy.ToNode = node2.PubKeyBytes edgePolicy.SigBytes = testSig.Serialize() @@ -2866,9 +2858,7 @@ func TestIncompleteChannelPolicies(t *testing.T) { // Create second policy and assert that both policies are reported // as present. - edgePolicy = newEdgePolicy( - chanID.ToUint64(), graph.db, updateTime.Unix(), - ) + edgePolicy = newEdgePolicy(chanID.ToUint64(), updateTime.Unix()) edgePolicy.ChannelFlags = 1 edgePolicy.ToNode = node1.PubKeyBytes edgePolicy.SigBytes = testSig.Serialize() @@ -2912,7 +2902,7 @@ func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { t.Fatalf("unable to add edge: %v", err) } - edge1 := randEdgePolicy(chanID.ToUint64(), graph.db) + edge1 := randEdgePolicy(chanID.ToUint64()) edge1.ChannelFlags = 0 edge1.ToNode = node1.PubKeyBytes edge1.SigBytes = testSig.Serialize() @@ -2921,7 +2911,7 @@ func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { } edge1 = copyEdgePolicy(edge1) // Avoid read/write race conditions. - edge2 := randEdgePolicy(chanID.ToUint64(), graph.db) + edge2 := randEdgePolicy(chanID.ToUint64()) edge2.ChannelFlags = 1 edge2.ToNode = node2.PubKeyBytes edge2.SigBytes = testSig.Serialize() @@ -3063,7 +3053,7 @@ func TestPruneGraphNodes(t *testing.T) { // We'll now insert an advertised edge, but it'll only be the edge that // points from the first to the second node. - edge1 := randEdgePolicy(chanID.ToUint64(), graph.db) + edge1 := randEdgePolicy(chanID.ToUint64()) edge1.ChannelFlags = 0 edge1.ToNode = node1.PubKeyBytes edge1.SigBytes = testSig.Serialize() From 033ad381649fff03a1cd5ecd25966bed2952fb78 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 9 Apr 2025 14:05:24 +0200 Subject: [PATCH 15/41] graph/db: use only exported KVStore ForEachNode method in tests Replace all tests calls to the private `forEachNode` method on the `KVStore` with the exported ForEachNode method. This is in preparation for having the tests run against an abstract DB backend. --- graph/db/graph_test.go | 31 +++++++++++++------------------ 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index 155429b62d..5e445f10d2 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -1103,13 +1103,11 @@ func TestGraphTraversalCacheable(t *testing.T) { // Create a map of all nodes with the iteration we know works (because // it is tested in another test). nodeMap := make(map[route.Vertex]struct{}) - err = graph.forEachNode( - func(tx kvdb.RTx, n *models.LightningNode) error { - nodeMap[n.PubKeyBytes] = struct{}{} + err = graph.ForEachNode(func(tx NodeRTx) error { + nodeMap[tx.Node().PubKeyBytes] = struct{}{} - return nil - }, - ) + return nil + }) require.NoError(t, err) require.Len(t, nodeMap, numNodes) @@ -1225,12 +1223,10 @@ func fillTestGraph(t testing.TB, graph *ChannelGraph, numNodes, // Iterate over each node as returned by the graph, if all nodes are // reached, then the map created above should be empty. - err := graph.forEachNode( - func(_ kvdb.RTx, node *models.LightningNode) error { - delete(nodeIndex, node.Alias) - return nil - }, - ) + err := graph.ForEachNode(func(tx NodeRTx) error { + delete(nodeIndex, tx.Node().Alias) + return nil + }) require.NoError(t, err) require.Len(t, nodeIndex, 0) @@ -1337,12 +1333,11 @@ func assertNumChans(t *testing.T, graph *ChannelGraph, n int) { func assertNumNodes(t *testing.T, graph *ChannelGraph, n int) { numNodes := 0 - err := graph.forEachNode( - func(_ kvdb.RTx, _ *models.LightningNode) error { - numNodes++ - return nil - }, - ) + err := graph.ForEachNode(func(tx NodeRTx) error { + numNodes++ + + return nil + }) if err != nil { _, _, line, _ := runtime.Caller(1) t.Fatalf("line %v: unable to scan nodes: %v", line, err) From 643f696e2447a6e803e92585655480d5759a3e7b Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 9 Apr 2025 06:25:29 +0200 Subject: [PATCH 16/41] autopilot: start threading contexts through The `GraphSource` interface in the `autopilot` package is directly implemented by the `graphdb.KVStore` and so we will eventually thread contexts through to this interface. So in this commit, we start updating the autopilot system to thread contexts through in preparation for passing the context through to any calls made to the GraphSource. Two context.TODOs are added here which will be addressed in follow up commits. --- autopilot/agent.go | 34 +++++++++----- autopilot/agent_test.go | 3 +- autopilot/betweenness_centrality.go | 5 ++- autopilot/graph.go | 33 +++++++++----- autopilot/interface.go | 6 ++- autopilot/manager.go | 18 ++++++-- autopilot/prefattach.go | 19 +++++--- autopilot/prefattach_test.go | 62 ++++++++++++++++---------- autopilot/simple_graph.go | 25 +++++++---- discovery/bootstrapper.go | 6 ++- lnd.go | 2 +- lnrpc/autopilotrpc/autopilot_server.go | 2 +- rpcserver.go | 2 +- 13 files changed, 146 insertions(+), 71 deletions(-) diff --git a/autopilot/agent.go b/autopilot/agent.go index d9c35a685f..7b951a66d4 100644 --- a/autopilot/agent.go +++ b/autopilot/agent.go @@ -2,6 +2,7 @@ package autopilot import ( "bytes" + "context" "fmt" "math/rand" "net" @@ -11,6 +12,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwire" ) @@ -166,8 +168,9 @@ type Agent struct { pendingOpens map[NodeID]LocalChannel pendingMtx sync.Mutex - quit chan struct{} - wg sync.WaitGroup + quit chan struct{} + wg sync.WaitGroup + cancel fn.Option[context.CancelFunc] } // New creates a new instance of the Agent instantiated using the passed @@ -199,20 +202,23 @@ func New(cfg Config, initialState []LocalChannel) (*Agent, error) { // Start starts the agent along with any goroutines it needs to perform its // normal duties. -func (a *Agent) Start() error { +func (a *Agent) Start(ctx context.Context) error { var err error a.started.Do(func() { - err = a.start() + ctx, cancel := context.WithCancel(ctx) + a.cancel = fn.Some(cancel) + + err = a.start(ctx) }) return err } -func (a *Agent) start() error { +func (a *Agent) start(ctx context.Context) error { rand.Seed(time.Now().Unix()) log.Infof("Autopilot Agent starting") a.wg.Add(1) - go a.controller() + go a.controller(ctx) return nil } @@ -230,6 +236,7 @@ func (a *Agent) Stop() error { func (a *Agent) stop() error { log.Infof("Autopilot Agent stopping") + a.cancel.WhenSome(func(fn context.CancelFunc) { fn() }) close(a.quit) a.wg.Wait() @@ -401,7 +408,7 @@ func mergeChanState(pendingChans map[NodeID]LocalChannel, // and external state changes as a result of decisions it makes w.r.t channel // allocation, or attributes affecting its control loop being updated by the // backing Lightning Node. -func (a *Agent) controller() { +func (a *Agent) controller(ctx context.Context) { defer a.wg.Done() // We'll start off by assigning our starting balance, and injecting @@ -502,6 +509,9 @@ func (a *Agent) controller() { // immediately. case <-a.quit: return + + case <-ctx.Done(): + return } a.pendingMtx.Lock() @@ -539,7 +549,7 @@ func (a *Agent) controller() { log.Infof("Triggering attachment directive dispatch, "+ "total_funds=%v", a.totalBalance) - err := a.openChans(availableFunds, numChans, totalChans) + err := a.openChans(ctx, availableFunds, numChans, totalChans) if err != nil { log.Errorf("Unable to open channels: %v", err) } @@ -548,8 +558,8 @@ func (a *Agent) controller() { // openChans queries the agent's heuristic for a set of channel candidates, and // attempts to open channels to them. -func (a *Agent) openChans(availableFunds btcutil.Amount, numChans uint32, - totalChans []LocalChannel) error { +func (a *Agent) openChans(ctx context.Context, availableFunds btcutil.Amount, + numChans uint32, totalChans []LocalChannel) error { // As channel size we'll use the maximum channel size available. chanSize := a.cfg.Constraints.MaxChanSize() @@ -598,7 +608,9 @@ func (a *Agent) openChans(availableFunds btcutil.Amount, numChans uint32, selfPubBytes := a.cfg.Self.SerializeCompressed() nodes := make(map[NodeID]struct{}) addresses := make(map[NodeID][]net.Addr) - if err := a.cfg.Graph.ForEachNode(func(node Node) error { + if err := a.cfg.Graph.ForEachNode(ctx, func(_ context.Context, + node Node) error { + nID := NodeID(node.PubKey()) // If we come across ourselves, them we'll continue in diff --git a/autopilot/agent_test.go b/autopilot/agent_test.go index 39e86906e2..9b3c30bb53 100644 --- a/autopilot/agent_test.go +++ b/autopilot/agent_test.go @@ -1,6 +1,7 @@ package autopilot import ( + "context" "errors" "fmt" "net" @@ -220,7 +221,7 @@ func setup(t *testing.T, initialChans []LocalChannel) *testContext { // With the autopilot agent and all its dependencies we'll start the // primary controller goroutine. - if err := agent.Start(); err != nil { + if err := agent.Start(context.Background()); err != nil { t.Fatalf("unable to start agent: %v", err) } diff --git a/autopilot/betweenness_centrality.go b/autopilot/betweenness_centrality.go index db45bcf665..2a45fe1f13 100644 --- a/autopilot/betweenness_centrality.go +++ b/autopilot/betweenness_centrality.go @@ -1,6 +1,7 @@ package autopilot import ( + "context" "fmt" "sync" ) @@ -169,7 +170,9 @@ func betweennessCentrality(g *SimpleGraph, s int, centrality []float64) { // Refresh recalculates and stores centrality values. func (bc *BetweennessCentrality) Refresh(graph ChannelGraph) error { - cache, err := NewSimpleGraph(graph) + ctx := context.TODO() + + cache, err := NewSimpleGraph(ctx, graph) if err != nil { return err } diff --git a/autopilot/graph.go b/autopilot/graph.go index c8b54082ad..d20c6316ec 100644 --- a/autopilot/graph.go +++ b/autopilot/graph.go @@ -1,6 +1,7 @@ package autopilot import ( + "context" "encoding/hex" "net" "sort" @@ -80,7 +81,9 @@ func (d *dbNode) Addrs() []net.Addr { // describes the active channel. // // NOTE: Part of the autopilot.Node interface. -func (d *dbNode) ForEachChannel(cb func(ChannelEdge) error) error { +func (d *dbNode) ForEachChannel(ctx context.Context, + cb func(context.Context, ChannelEdge) error) error { + return d.tx.ForEachChannel(func(ei *models.ChannelEdgeInfo, ep, _ *models.ChannelEdgePolicy) error { @@ -108,7 +111,7 @@ func (d *dbNode) ForEachChannel(cb func(ChannelEdge) error) error { }, } - return cb(edge) + return cb(ctx, edge) }) } @@ -117,7 +120,9 @@ func (d *dbNode) ForEachChannel(cb func(ChannelEdge) error) error { // error, then execution should be terminated. // // NOTE: Part of the autopilot.ChannelGraph interface. -func (d *databaseChannelGraph) ForEachNode(cb func(Node) error) error { +func (d *databaseChannelGraph) ForEachNode(ctx context.Context, + cb func(context.Context, Node) error) error { + return d.db.ForEachNode(func(nodeTx graphdb.NodeRTx) error { // We'll skip over any node that doesn't have any advertised // addresses. As we won't be able to reach them to actually @@ -129,7 +134,8 @@ func (d *databaseChannelGraph) ForEachNode(cb func(Node) error) error { node := &dbNode{ tx: nodeTx, } - return cb(node) + + return cb(ctx, node) }) } @@ -185,7 +191,9 @@ func (nc dbNodeCached) Addrs() []net.Addr { // describes the active channel. // // NOTE: Part of the autopilot.Node interface. -func (nc dbNodeCached) ForEachChannel(cb func(ChannelEdge) error) error { +func (nc dbNodeCached) ForEachChannel(ctx context.Context, + cb func(context.Context, ChannelEdge) error) error { + for cid, channel := range nc.channels { edge := ChannelEdge{ ChanID: lnwire.NewShortChanIDFromInt(cid), @@ -195,7 +203,7 @@ func (nc dbNodeCached) ForEachChannel(cb func(ChannelEdge) error) error { }, } - if err := cb(edge); err != nil { + if err := cb(ctx, edge); err != nil { return err } } @@ -208,7 +216,9 @@ func (nc dbNodeCached) ForEachChannel(cb func(ChannelEdge) error) error { // error, then execution should be terminated. // // NOTE: Part of the autopilot.ChannelGraph interface. -func (dc *databaseChannelGraphCached) ForEachNode(cb func(Node) error) error { +func (dc *databaseChannelGraphCached) ForEachNode(ctx context.Context, + cb func(context.Context, Node) error) error { + return dc.db.ForEachNodeCached(func(n route.Vertex, channels map[uint64]*graphdb.DirectedChannel) error { @@ -217,7 +227,8 @@ func (dc *databaseChannelGraphCached) ForEachNode(cb func(Node) error) error { node: n, channels: channels, } - return cb(node) + + return cb(ctx, node) } return nil }) @@ -262,9 +273,11 @@ func (m memNode) Addrs() []net.Addr { // describes the active channel. // // NOTE: Part of the autopilot.Node interface. -func (m memNode) ForEachChannel(cb func(ChannelEdge) error) error { +func (m memNode) ForEachChannel(ctx context.Context, + cb func(context.Context, ChannelEdge) error) error { + for _, channel := range m.chans { - if err := cb(channel); err != nil { + if err := cb(ctx, channel); err != nil { return err } } diff --git a/autopilot/interface.go b/autopilot/interface.go index 35182a7600..0991d9864e 100644 --- a/autopilot/interface.go +++ b/autopilot/interface.go @@ -1,6 +1,7 @@ package autopilot import ( + "context" "net" "github.com/btcsuite/btcd/btcec/v2" @@ -35,7 +36,8 @@ type Node interface { // iterate through all edges emanating from/to the target node. For // each active channel, this function should be called with the // populated ChannelEdge that describes the active channel. - ForEachChannel(func(ChannelEdge) error) error + ForEachChannel(context.Context, func(context.Context, + ChannelEdge) error) error } // LocalChannel is a simple struct which contains relevant details of a @@ -83,7 +85,7 @@ type ChannelGraph interface { // ForEachNode is a higher-order function that should be called once // for each connected node within the channel graph. If the passed // callback returns an error, then execution should be terminated. - ForEachNode(func(Node) error) error + ForEachNode(context.Context, func(context.Context, Node) error) error } // NodeScore is a tuple mapping a NodeID to a score indicating the preference diff --git a/autopilot/manager.go b/autopilot/manager.go index 0463f98d99..036bf3a31e 100644 --- a/autopilot/manager.go +++ b/autopilot/manager.go @@ -1,11 +1,13 @@ package autopilot import ( + "context" "fmt" "sync" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn/v2" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -53,8 +55,9 @@ type Manager struct { // disabled. pilot *Agent - quit chan struct{} - wg sync.WaitGroup + quit chan struct{} + wg sync.WaitGroup + cancel fn.Option[context.CancelFunc] sync.Mutex } @@ -80,6 +83,7 @@ func (m *Manager) Stop() error { log.Errorf("Unable to stop pilot: %v", err) } + m.cancel.WhenSome(func(fn context.CancelFunc) { fn() }) close(m.quit) m.wg.Wait() }) @@ -96,7 +100,7 @@ func (m *Manager) IsActive() bool { // StartAgent creates and starts an autopilot agent from the Manager's // config. -func (m *Manager) StartAgent() error { +func (m *Manager) StartAgent(ctx context.Context) error { m.Lock() defer m.Unlock() @@ -104,6 +108,8 @@ func (m *Manager) StartAgent() error { if m.pilot != nil { return nil } + ctx, cancel := context.WithCancel(ctx) + m.cancel = fn.Some(cancel) // Next, we'll fetch the current state of open channels from the // database to use as initial state for the auto-pilot agent. @@ -119,7 +125,7 @@ func (m *Manager) StartAgent() error { return err } - if err := pilot.Start(); err != nil { + if err := pilot.Start(ctx); err != nil { return err } @@ -163,6 +169,8 @@ func (m *Manager) StartAgent() error { return case <-m.quit: return + case <-ctx.Done(): + return } } @@ -233,6 +241,8 @@ func (m *Manager) StartAgent() error { return case <-m.quit: return + case <-ctx.Done(): + return } } }() diff --git a/autopilot/prefattach.go b/autopilot/prefattach.go index 4f4ff635fa..4f55e87ea2 100644 --- a/autopilot/prefattach.go +++ b/autopilot/prefattach.go @@ -1,6 +1,7 @@ package autopilot import ( + "context" prand "math/rand" "time" @@ -82,14 +83,18 @@ func (p *PrefAttachment) NodeScores(g ChannelGraph, chans []LocalChannel, chanSize btcutil.Amount, nodes map[NodeID]struct{}) ( map[NodeID]*NodeScore, error) { + ctx := context.TODO() + // We first run though the graph once in order to find the median // channel size. var ( allChans []btcutil.Amount seenChans = make(map[uint64]struct{}) ) - if err := g.ForEachNode(func(n Node) error { - err := n.ForEachChannel(func(e ChannelEdge) error { + if err := g.ForEachNode(ctx, func(ctx context.Context, n Node) error { + err := n.ForEachChannel(ctx, func(_ context.Context, + e ChannelEdge) error { + if _, ok := seenChans[e.ChanID.ToUint64()]; ok { return nil } @@ -114,15 +119,19 @@ func (p *PrefAttachment) NodeScores(g ChannelGraph, chans []LocalChannel, // the graph. var maxChans int nodeChanNum := make(map[NodeID]int) - if err := g.ForEachNode(func(n Node) error { + if err := g.ForEachNode(ctx, func(ctx context.Context, n Node) error { var nodeChans int - err := n.ForEachChannel(func(e ChannelEdge) error { + err := n.ForEachChannel(ctx, func(_ context.Context, + e ChannelEdge) error { + // Since connecting to nodes with a lot of small // channels actually worsens our connectivity in the // graph (we will potentially waste time trying to use // these useless channels in path finding), we decrease // the counter for such channels. - if e.Capacity < medianChanSize/minMedianChanSizeFraction { + if e.Capacity < + medianChanSize/minMedianChanSizeFraction { + nodeChans-- return nil } diff --git a/autopilot/prefattach_test.go b/autopilot/prefattach_test.go index f20c3a480b..7dec5f49f4 100644 --- a/autopilot/prefattach_test.go +++ b/autopilot/prefattach_test.go @@ -2,6 +2,7 @@ package autopilot import ( "bytes" + "context" "errors" prand "math/rand" "net" @@ -126,6 +127,7 @@ func TestPrefAttachmentSelectEmptyGraph(t *testing.T) { // and the funds are appropriately allocated across each peer. func TestPrefAttachmentSelectTwoVertexes(t *testing.T) { t.Parallel() + ctx := context.Background() prand.Seed(time.Now().Unix()) @@ -156,10 +158,12 @@ func TestPrefAttachmentSelectTwoVertexes(t *testing.T) { // Get the score for all nodes found in the graph at // this point. nodes := make(map[NodeID]struct{}) - err = graph.ForEachNode(func(n Node) error { - nodes[n.PubKey()] = struct{}{} - return nil - }) + err = graph.ForEachNode(ctx, + func(_ context.Context, n Node) error { + nodes[n.PubKey()] = struct{}{} + return nil + }, + ) require.NoError(t1, err) require.Len(t1, nodes, 3) @@ -207,6 +211,7 @@ func TestPrefAttachmentSelectTwoVertexes(t *testing.T) { // allocate all funds to each vertex (up to the max channel size). func TestPrefAttachmentSelectGreedyAllocation(t *testing.T) { t.Parallel() + ctx := context.Background() prand.Seed(time.Now().Unix()) @@ -245,22 +250,25 @@ func TestPrefAttachmentSelectGreedyAllocation(t *testing.T) { numNodes := 0 twoChans := false nodes := make(map[NodeID]struct{}) - err = graph.ForEachNode(func(n Node) error { - numNodes++ - nodes[n.PubKey()] = struct{}{} - numChans := 0 - err := n.ForEachChannel(func(c ChannelEdge) error { - numChans++ + err = graph.ForEachNode( + ctx, func(ctx context.Context, n Node) error { + numNodes++ + nodes[n.PubKey()] = struct{}{} + numChans := 0 + err := n.ForEachChannel(ctx, + func(_ context.Context, c ChannelEdge) error { //nolint:ll + numChans++ + return nil + }, + ) + if err != nil { + return err + } + + twoChans = twoChans || (numChans == 2) + return nil }) - if err != nil { - return err - } - - twoChans = twoChans || (numChans == 2) - - return nil - }) require.NoError(t1, err) require.EqualValues(t1, 3, numNodes) @@ -313,6 +321,7 @@ func TestPrefAttachmentSelectGreedyAllocation(t *testing.T) { // of zero during scoring. func TestPrefAttachmentSelectSkipNodes(t *testing.T) { t.Parallel() + ctx := context.Background() prand.Seed(time.Now().Unix()) @@ -335,10 +344,13 @@ func TestPrefAttachmentSelectSkipNodes(t *testing.T) { require.NoError(t1, err) nodes := make(map[NodeID]struct{}) - err = graph.ForEachNode(func(n Node) error { - nodes[n.PubKey()] = struct{}{} - return nil - }) + err = graph.ForEachNode( + ctx, func(_ context.Context, n Node) error { + nodes[n.PubKey()] = struct{}{} + + return nil + }, + ) require.NoError(t1, err) require.Len(t1, nodes, 2) @@ -583,9 +595,11 @@ func newMemChannelGraph() *memChannelGraph { // error, then execution should be terminated. // // NOTE: Part of the autopilot.ChannelGraph interface. -func (m *memChannelGraph) ForEachNode(cb func(Node) error) error { +func (m *memChannelGraph) ForEachNode(ctx context.Context, + cb func(context.Context, Node) error) error { + for _, node := range m.graph { - if err := cb(node); err != nil { + if err := cb(ctx, node); err != nil { return err } } diff --git a/autopilot/simple_graph.go b/autopilot/simple_graph.go index 4d294b3f22..f028db3c75 100644 --- a/autopilot/simple_graph.go +++ b/autopilot/simple_graph.go @@ -1,5 +1,7 @@ package autopilot +import "context" + // diameterCutoff is used to discard nodes in the diameter calculation. // It is the multiplier for the eccentricity of the highest-degree node, // serving as a cutoff to discard all nodes with a smaller hop distance. This @@ -20,7 +22,7 @@ type SimpleGraph struct { // NewSimpleGraph creates a simplified graph from the current channel graph. // Returns an error if the channel graph iteration fails due to underlying // failure. -func NewSimpleGraph(g ChannelGraph) (*SimpleGraph, error) { +func NewSimpleGraph(ctx context.Context, g ChannelGraph) (*SimpleGraph, error) { nodes := make(map[NodeID]int) adj := make(map[int][]int) nextIndex := 0 @@ -42,17 +44,22 @@ func NewSimpleGraph(g ChannelGraph) (*SimpleGraph, error) { return nodeIndex } - // Iterate over each node and each channel and update the adj and the node - // index. - err := g.ForEachNode(func(node Node) error { + // Iterate over each node and each channel and update the adj and the + // node index. + err := g.ForEachNode(ctx, func(ctx context.Context, node Node) error { u := getNodeIndex(node) - return node.ForEachChannel(func(edge ChannelEdge) error { - v := getNodeIndex(edge.Peer) + return node.ForEachChannel( + ctx, func(_ context.Context, + edge ChannelEdge) error { + + v := getNodeIndex(edge.Peer) + + adj[u] = append(adj[u], v) - adj[u] = append(adj[u], v) - return nil - }) + return nil + }, + ) }) if err != nil { return nil, err diff --git a/discovery/bootstrapper.go b/discovery/bootstrapper.go index d07a5f852b..1a0f997351 100644 --- a/discovery/bootstrapper.go +++ b/discovery/bootstrapper.go @@ -161,6 +161,8 @@ func (c *ChannelGraphBootstrapper) SampleNodeAddrs(_ context.Context, numAddrs uint32, ignore map[autopilot.NodeID]struct{}) ([]*lnwire.NetAddress, error) { + ctx := context.TODO() + // We'll merge the ignore map with our currently selected map in order // to ensure we don't return any duplicate nodes. for n := range ignore { @@ -183,7 +185,9 @@ func (c *ChannelGraphBootstrapper) SampleNodeAddrs(_ context.Context, errFound = fmt.Errorf("found node") ) - err := c.chanGraph.ForEachNode(func(node autopilot.Node) error { + err := c.chanGraph.ForEachNode(ctx, func(_ context.Context, + node autopilot.Node) error { + nID := autopilot.NodeID(node.PubKey()) if _, ok := c.tried[nID]; ok { return nil diff --git a/lnd.go b/lnd.go index 3afa8c2fba..41bd3ca4ba 100644 --- a/lnd.go +++ b/lnd.go @@ -788,7 +788,7 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, // active, then we'll start the autopilot agent immediately. It will be // stopped together with the autopilot service. if cfg.Autopilot.Active { - if err := atplManager.StartAgent(); err != nil { + if err := atplManager.StartAgent(ctx); err != nil { return mkErr("unable to start autopilot agent", err) } } diff --git a/lnrpc/autopilotrpc/autopilot_server.go b/lnrpc/autopilotrpc/autopilot_server.go index 23d0ff5f36..3e3c6f8f7c 100644 --- a/lnrpc/autopilotrpc/autopilot_server.go +++ b/lnrpc/autopilotrpc/autopilot_server.go @@ -205,7 +205,7 @@ func (s *Server) ModifyStatus(ctx context.Context, var err error if in.Enable { - err = s.manager.StartAgent() + err = s.manager.StartAgent(ctx) } else { err = s.manager.StopAgent() } diff --git a/rpcserver.go b/rpcserver.go index 6e288604f2..d51cf1cab9 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -7140,7 +7140,7 @@ func (r *rpcServer) GetNetworkInfo(ctx context.Context, // Graph diameter. channelGraph := autopilot.ChannelGraphFromCachedDatabase(graph) - simpleGraph, err := autopilot.NewSimpleGraph(channelGraph) + simpleGraph, err := autopilot.NewSimpleGraph(ctx, channelGraph) if err != nil { return nil, err } From dea22650ad7a572166a842efaf1bdff781acbf30 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 9 Apr 2025 06:31:05 +0200 Subject: [PATCH 17/41] autopilot: continue threading context Remove one context.TODO and add one more. --- autopilot/betweenness_centrality.go | 4 ++-- autopilot/betweenness_centrality_test.go | 11 +++++++++-- autopilot/interface.go | 2 +- autopilot/top_centrality.go | 5 ++++- rpcserver.go | 2 +- 5 files changed, 17 insertions(+), 7 deletions(-) diff --git a/autopilot/betweenness_centrality.go b/autopilot/betweenness_centrality.go index 2a45fe1f13..c3f0a3d022 100644 --- a/autopilot/betweenness_centrality.go +++ b/autopilot/betweenness_centrality.go @@ -169,8 +169,8 @@ func betweennessCentrality(g *SimpleGraph, s int, centrality []float64) { } // Refresh recalculates and stores centrality values. -func (bc *BetweennessCentrality) Refresh(graph ChannelGraph) error { - ctx := context.TODO() +func (bc *BetweennessCentrality) Refresh(ctx context.Context, + graph ChannelGraph) error { cache, err := NewSimpleGraph(ctx, graph) if err != nil { diff --git a/autopilot/betweenness_centrality_test.go b/autopilot/betweenness_centrality_test.go index 8e3e07ce2d..b1de27834b 100644 --- a/autopilot/betweenness_centrality_test.go +++ b/autopilot/betweenness_centrality_test.go @@ -1,6 +1,7 @@ package autopilot import ( + "context" "fmt" "testing" @@ -30,6 +31,9 @@ func TestBetweennessCentralityMetricConstruction(t *testing.T) { // Tests that empty graph results in empty centrality result. func TestBetweennessCentralityEmptyGraph(t *testing.T) { + t.Parallel() + ctx := context.Background() + centralityMetric, err := NewBetweennessCentralityMetric(1) require.NoError( t, err, @@ -42,7 +46,7 @@ func TestBetweennessCentralityEmptyGraph(t *testing.T) { require.NoError(t, err, "unable to create graph") success := t.Run(chanGraph.name, func(t1 *testing.T) { - err = centralityMetric.Refresh(graph) + err = centralityMetric.Refresh(ctx, graph) require.NoError(t1, err) centrality := centralityMetric.GetMetric(false) @@ -59,6 +63,9 @@ func TestBetweennessCentralityEmptyGraph(t *testing.T) { // Test betweenness centrality calculating using an example graph. func TestBetweennessCentralityWithNonEmptyGraph(t *testing.T) { + t.Parallel() + ctx := context.Background() + workers := []int{1, 3, 9, 100} tests := []struct { @@ -100,7 +107,7 @@ func TestBetweennessCentralityWithNonEmptyGraph(t *testing.T) { t1, graph, centralityTestGraph, ) - err = metric.Refresh(graph) + err = metric.Refresh(ctx, graph) require.NoError(t1, err) for _, expected := range tests { diff --git a/autopilot/interface.go b/autopilot/interface.go index 0991d9864e..ae803632f3 100644 --- a/autopilot/interface.go +++ b/autopilot/interface.go @@ -157,7 +157,7 @@ type NodeMetric interface { Name() string // Refresh refreshes the metric values based on the current graph. - Refresh(graph ChannelGraph) error + Refresh(ctx context.Context, graph ChannelGraph) error // GetMetric returns the latest value of this metric. Values in the // map are per node and can be in arbitrary domain. If normalize is diff --git a/autopilot/top_centrality.go b/autopilot/top_centrality.go index 65157c6212..90bf66ae8f 100644 --- a/autopilot/top_centrality.go +++ b/autopilot/top_centrality.go @@ -1,6 +1,7 @@ package autopilot import ( + "context" "runtime" "github.com/btcsuite/btcd/btcutil" @@ -54,8 +55,10 @@ func (g *TopCentrality) NodeScores(graph ChannelGraph, chans []LocalChannel, chanSize btcutil.Amount, nodes map[NodeID]struct{}) ( map[NodeID]*NodeScore, error) { + ctx := context.TODO() + // Calculate betweenness centrality for the whole graph. - if err := g.centralityMetric.Refresh(graph); err != nil { + if err := g.centralityMetric.Refresh(ctx, graph); err != nil { return nil, err } diff --git a/rpcserver.go b/rpcserver.go index d51cf1cab9..72cf953bf4 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -6857,7 +6857,7 @@ func (r *rpcServer) GetNodeMetrics(ctx context.Context, if err != nil { return nil, err } - if err := centralityMetric.Refresh(channelGraph); err != nil { + if err := centralityMetric.Refresh(ctx, channelGraph); err != nil { return nil, err } From 5c1d21a9d219686cfe49b9c007db8f262ca8b7d8 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 9 Apr 2025 06:36:02 +0200 Subject: [PATCH 18/41] autopilot: update AttachmentHeuristics with context Continue threading context through the autopilot system and remove the remaining context.TODOs. --- autopilot/agent.go | 2 +- autopilot/agent_test.go | 6 +++--- autopilot/combinedattach.go | 9 +++++---- autopilot/externalscoreattach.go | 7 ++++--- autopilot/externalscoreattach_test.go | 4 +++- autopilot/interface.go | 2 +- autopilot/manager.go | 13 +++++++------ autopilot/prefattach.go | 8 +++----- autopilot/prefattach_test.go | 14 ++++++++------ autopilot/top_centrality.go | 8 +++----- autopilot/top_centrality_test.go | 3 ++- lnrpc/autopilotrpc/autopilot_server.go | 2 +- 12 files changed, 41 insertions(+), 37 deletions(-) diff --git a/autopilot/agent.go b/autopilot/agent.go index 7b951a66d4..95e95c55f0 100644 --- a/autopilot/agent.go +++ b/autopilot/agent.go @@ -648,7 +648,7 @@ func (a *Agent) openChans(ctx context.Context, availableFunds btcutil.Amount, // graph. log.Debugf("Scoring %d nodes for chan_size=%v", len(nodes), chanSize) scores, err := a.cfg.Heuristic.NodeScores( - a.cfg.Graph, totalChans, chanSize, nodes, + ctx, a.cfg.Graph, totalChans, chanSize, nodes, ) if err != nil { return fmt.Errorf("unable to calculate node scores : %w", err) diff --git a/autopilot/agent_test.go b/autopilot/agent_test.go index 9b3c30bb53..47305a2ebc 100644 --- a/autopilot/agent_test.go +++ b/autopilot/agent_test.go @@ -86,9 +86,9 @@ func (m *mockHeuristic) Name() string { return "mock" } -func (m *mockHeuristic) NodeScores(g ChannelGraph, chans []LocalChannel, - chanSize btcutil.Amount, nodes map[NodeID]struct{}) ( - map[NodeID]*NodeScore, error) { +func (m *mockHeuristic) NodeScores(_ context.Context, g ChannelGraph, + chans []LocalChannel, chanSize btcutil.Amount, + nodes map[NodeID]struct{}) (map[NodeID]*NodeScore, error) { if m.nodeScoresArgs != nil { directive := directiveArg{ diff --git a/autopilot/combinedattach.go b/autopilot/combinedattach.go index b43856d242..9064a82297 100644 --- a/autopilot/combinedattach.go +++ b/autopilot/combinedattach.go @@ -1,6 +1,7 @@ package autopilot import ( + "context" "fmt" "github.com/btcsuite/btcd/btcutil" @@ -70,9 +71,9 @@ func (c *WeightedCombAttachment) Name() string { // is the maximum possible improvement in connectivity. // // NOTE: This is a part of the AttachmentHeuristic interface. -func (c *WeightedCombAttachment) NodeScores(g ChannelGraph, chans []LocalChannel, - chanSize btcutil.Amount, nodes map[NodeID]struct{}) ( - map[NodeID]*NodeScore, error) { +func (c *WeightedCombAttachment) NodeScores(ctx context.Context, g ChannelGraph, + chans []LocalChannel, chanSize btcutil.Amount, + nodes map[NodeID]struct{}) (map[NodeID]*NodeScore, error) { // We now query each heuristic to determine the score they give to the // nodes for the given channel size. @@ -81,7 +82,7 @@ func (c *WeightedCombAttachment) NodeScores(g ChannelGraph, chans []LocalChannel log.Tracef("Getting scores from sub heuristic %v", h.Name()) s, err := h.NodeScores( - g, chans, chanSize, nodes, + ctx, g, chans, chanSize, nodes, ) if err != nil { return nil, fmt.Errorf("unable to get sub score: %w", diff --git a/autopilot/externalscoreattach.go b/autopilot/externalscoreattach.go index 1144d7d4a5..a979e15ee8 100644 --- a/autopilot/externalscoreattach.go +++ b/autopilot/externalscoreattach.go @@ -1,6 +1,7 @@ package autopilot import ( + "context" "fmt" "sync" @@ -80,9 +81,9 @@ func (s *ExternalScoreAttachment) SetNodeScores(targetHeuristic string, // not known will get a score of 0. // // NOTE: This is a part of the AttachmentHeuristic interface. -func (s *ExternalScoreAttachment) NodeScores(g ChannelGraph, chans []LocalChannel, - chanSize btcutil.Amount, nodes map[NodeID]struct{}) ( - map[NodeID]*NodeScore, error) { +func (s *ExternalScoreAttachment) NodeScores(_ context.Context, g ChannelGraph, + chans []LocalChannel, chanSize btcutil.Amount, + nodes map[NodeID]struct{}) (map[NodeID]*NodeScore, error) { existingPeers := make(map[NodeID]struct{}) for _, c := range chans { diff --git a/autopilot/externalscoreattach_test.go b/autopilot/externalscoreattach_test.go index aa0f87fb26..bef50a6746 100644 --- a/autopilot/externalscoreattach_test.go +++ b/autopilot/externalscoreattach_test.go @@ -1,6 +1,7 @@ package autopilot_test import ( + "context" "testing" "github.com/btcsuite/btcd/btcec/v2" @@ -22,6 +23,7 @@ func randKey() (*btcec.PublicKey, error) { // ExternalScoreAttachment correctly reflects the scores we set last. func TestSetNodeScores(t *testing.T) { t.Parallel() + ctx := context.Background() const name = "externalscore" @@ -62,7 +64,7 @@ func TestSetNodeScores(t *testing.T) { q[nID] = struct{}{} } resp, err := h.NodeScores( - nil, nil, btcutil.Amount(btcutil.SatoshiPerBitcoin), q, + ctx, nil, nil, btcutil.Amount(btcutil.SatoshiPerBitcoin), q, ) if err != nil { t.Fatal(err) diff --git a/autopilot/interface.go b/autopilot/interface.go index ae803632f3..db25d9bb21 100644 --- a/autopilot/interface.go +++ b/autopilot/interface.go @@ -144,7 +144,7 @@ type AttachmentHeuristic interface { // // NOTE: A NodeID not found in the returned map is implicitly given a // score of 0. - NodeScores(g ChannelGraph, chans []LocalChannel, + NodeScores(ctx context.Context, g ChannelGraph, chans []LocalChannel, chanSize btcutil.Amount, nodes map[NodeID]struct{}) ( map[NodeID]*NodeScore, error) } diff --git a/autopilot/manager.go b/autopilot/manager.go index 036bf3a31e..c0ca40b559 100644 --- a/autopilot/manager.go +++ b/autopilot/manager.go @@ -276,8 +276,8 @@ func (m *Manager) StopAgent() error { } // QueryHeuristics queries the available autopilot heuristics for node scores. -func (m *Manager) QueryHeuristics(nodes []NodeID, localState bool) ( - HeuristicScores, error) { +func (m *Manager) QueryHeuristics(ctx context.Context, nodes []NodeID, + localState bool) (HeuristicScores, error) { m.Lock() defer m.Unlock() @@ -288,7 +288,8 @@ func (m *Manager) QueryHeuristics(nodes []NodeID, localState bool) ( } log.Debugf("Querying heuristics for %d nodes", len(n)) - return m.queryHeuristics(n, localState) + + return m.queryHeuristics(ctx, n, localState) } // HeuristicScores is an alias for a map that maps heuristic names to a map of @@ -299,8 +300,8 @@ type HeuristicScores map[string]map[NodeID]float64 // the agent's current active heuristic. // // NOTE: Must be called with the manager's lock. -func (m *Manager) queryHeuristics(nodes map[NodeID]struct{}, localState bool) ( - HeuristicScores, error) { +func (m *Manager) queryHeuristics(ctx context.Context, + nodes map[NodeID]struct{}, localState bool) (HeuristicScores, error) { // If we want to take the local state into action when querying the // heuristics, we fetch it. If not we'll just pass an empty slice to @@ -348,7 +349,7 @@ func (m *Manager) queryHeuristics(nodes map[NodeID]struct{}, localState bool) ( } s, err := h.NodeScores( - m.cfg.PilotCfg.Graph, totalChans, chanSize, nodes, + ctx, m.cfg.PilotCfg.Graph, totalChans, chanSize, nodes, ) if err != nil { return nil, fmt.Errorf("unable to get sub score: %w", diff --git a/autopilot/prefattach.go b/autopilot/prefattach.go index 4f55e87ea2..76d814a46a 100644 --- a/autopilot/prefattach.go +++ b/autopilot/prefattach.go @@ -79,11 +79,9 @@ func (p *PrefAttachment) Name() string { // given to nodes already having high connectivity in the graph. // // NOTE: This is a part of the AttachmentHeuristic interface. -func (p *PrefAttachment) NodeScores(g ChannelGraph, chans []LocalChannel, - chanSize btcutil.Amount, nodes map[NodeID]struct{}) ( - map[NodeID]*NodeScore, error) { - - ctx := context.TODO() +func (p *PrefAttachment) NodeScores(ctx context.Context, g ChannelGraph, + chans []LocalChannel, chanSize btcutil.Amount, + nodes map[NodeID]struct{}) (map[NodeID]*NodeScore, error) { // We first run though the graph once in order to find the median // channel size. diff --git a/autopilot/prefattach_test.go b/autopilot/prefattach_test.go index 7dec5f49f4..b79a396b1c 100644 --- a/autopilot/prefattach_test.go +++ b/autopilot/prefattach_test.go @@ -88,6 +88,8 @@ var chanGraphs = []struct { // TestPrefAttachmentSelectEmptyGraph ensures that when passed an // empty graph, the NodeSores function always returns a score of 0. func TestPrefAttachmentSelectEmptyGraph(t *testing.T) { + t.Parallel() + ctx := context.Background() prefAttach := NewPrefAttachment() // Create a random public key, which we will query to get a score for. @@ -108,7 +110,7 @@ func TestPrefAttachmentSelectEmptyGraph(t *testing.T) { // attempt to get the score for this one node. const walletFunds = btcutil.SatoshiPerBitcoin scores, err := prefAttach.NodeScores( - graph, nil, walletFunds, nodes, + ctx, graph, nil, walletFunds, nodes, ) require.NoError(t1, err) @@ -172,7 +174,7 @@ func TestPrefAttachmentSelectTwoVertexes(t *testing.T) { // attempt to get our candidates channel score given // the current state of the graph. candidates, err := prefAttach.NodeScores( - graph, nil, maxChanSize, nodes, + ctx, graph, nil, maxChanSize, nodes, ) require.NoError(t1, err) @@ -280,7 +282,7 @@ func TestPrefAttachmentSelectGreedyAllocation(t *testing.T) { // result, the heuristic should try to greedily // allocate funds to channels. scores, err := prefAttach.NodeScores( - graph, nil, maxChanSize, nodes, + ctx, graph, nil, maxChanSize, nodes, ) require.NoError(t1, err) @@ -298,7 +300,7 @@ func TestPrefAttachmentSelectGreedyAllocation(t *testing.T) { // candidates of that size. const remBalance = btcutil.SatoshiPerBitcoin * 0.5 scores, err = prefAttach.NodeScores( - graph, nil, remBalance, nodes, + ctx, graph, nil, remBalance, nodes, ) require.NoError(t1, err) @@ -358,7 +360,7 @@ func TestPrefAttachmentSelectSkipNodes(t *testing.T) { // With our graph created, we'll now get the scores for // all nodes in the graph. scores, err := prefAttach.NodeScores( - graph, nil, maxChanSize, nodes, + ctx, graph, nil, maxChanSize, nodes, ) require.NoError(t1, err) @@ -386,7 +388,7 @@ func TestPrefAttachmentSelectSkipNodes(t *testing.T) { // then all nodes should have a score of zero, since we // already got channels to them. scores, err = prefAttach.NodeScores( - graph, chans, maxChanSize, nodes, + ctx, graph, chans, maxChanSize, nodes, ) require.NoError(t1, err) diff --git a/autopilot/top_centrality.go b/autopilot/top_centrality.go index 90bf66ae8f..e96b340978 100644 --- a/autopilot/top_centrality.go +++ b/autopilot/top_centrality.go @@ -51,11 +51,9 @@ func (g *TopCentrality) Name() string { // As our current implementation of betweenness centrality is non-incremental, // NodeScores will recalculate the centrality values on every call, which is // slow for large graphs. -func (g *TopCentrality) NodeScores(graph ChannelGraph, chans []LocalChannel, - chanSize btcutil.Amount, nodes map[NodeID]struct{}) ( - map[NodeID]*NodeScore, error) { - - ctx := context.TODO() +func (g *TopCentrality) NodeScores(ctx context.Context, graph ChannelGraph, + chans []LocalChannel, chanSize btcutil.Amount, + nodes map[NodeID]struct{}) (map[NodeID]*NodeScore, error) { // Calculate betweenness centrality for the whole graph. if err := g.centralityMetric.Refresh(ctx, graph); err != nil { diff --git a/autopilot/top_centrality_test.go b/autopilot/top_centrality_test.go index 9e82427614..e8ca3616d1 100644 --- a/autopilot/top_centrality_test.go +++ b/autopilot/top_centrality_test.go @@ -1,6 +1,7 @@ package autopilot import ( + "context" "testing" "github.com/btcsuite/btcd/btcec/v2" @@ -58,7 +59,7 @@ func testTopCentrality(t *testing.T, graph testGraph, // Attempt to get centrality scores and expect // that the result equals with the expected set. scores, err := topCentrality.NodeScores( - graph, channels, chanSize, nodes, + context.Background(), graph, channels, chanSize, nodes, ) require.NoError(t, err) diff --git a/lnrpc/autopilotrpc/autopilot_server.go b/lnrpc/autopilotrpc/autopilot_server.go index 3e3c6f8f7c..67e6c283a4 100644 --- a/lnrpc/autopilotrpc/autopilot_server.go +++ b/lnrpc/autopilotrpc/autopilot_server.go @@ -236,7 +236,7 @@ func (s *Server) QueryScores(ctx context.Context, in *QueryScoresRequest) ( // Query the heuristics. heuristicScores, err := s.manager.QueryHeuristics( - nodes, !in.IgnoreLocalState, + ctx, nodes, !in.IgnoreLocalState, ) if err != nil { return nil, err From 75bf82b1bce4659935139b97ebd4988e0d74cb0a Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sat, 29 Mar 2025 14:07:08 +0200 Subject: [PATCH 19/41] graph/db: remove unused Wipe method Later on we will create an interface for the persisted graph data. We want this interface to be as small and as neat as possible. In preparation for this, we remove this unused `Wipe` method. --- graph/db/kv_store.go | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/graph/db/kv_store.go b/graph/db/kv_store.go index aa120a39a7..50d507185a 100644 --- a/graph/db/kv_store.go +++ b/graph/db/kv_store.go @@ -299,29 +299,6 @@ var graphTopLevelBuckets = [][]byte{ closedScidBucket, } -// Wipe completely deletes all saved state within all used buckets within the -// database. The deletion is done in a single transaction, therefore this -// operation is fully atomic. -func (c *KVStore) Wipe() error { - err := kvdb.Update(c.db, func(tx kvdb.RwTx) error { - for _, tlb := range graphTopLevelBuckets { - err := tx.DeleteTopLevelBucket(tlb) - if err != nil && - !errors.Is(err, kvdb.ErrBucketNotFound) { - - return err - } - } - - return nil - }, func() {}) - if err != nil { - return err - } - - return initKVStore(c.db) -} - // createChannelDB creates and initializes a fresh version of In // the case that the target path has not yet been created or doesn't yet exist, // then the path is created. Additionally, all required top-level buckets used From d381f03637bb2c97332ff1ceaa55ae004bb424a1 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 26 Mar 2025 17:23:51 +0200 Subject: [PATCH 20/41] graph/db: introduce ForEachSourceNodeChannel In preparation for creating a clean interface for the graph store, we want to hide anything that is DB specific from the exposed methods on the interface. Currently the `ForEachNodeChannel` and the `FetchOtherNode` methods of the `KVStore` expose a `kvdb.RTx` parameter which is bbolt specific. There is only one call-site of `ForEachNodeChannel` actually makes use of the passed `kvdb.RTx` parameter, and that is in the `establishPersistentConnections` method of the `server` which then passes the tx parameter to `FetchOtherNode`. So to clean-up the interface such that the `kvdb.RTx` is no longer exposed: we instead create one new method called `ForEachSourceNodeChannel` which can be used to replace the above mentioned call-site. So as of this commit, all the remaining call-site of `ForEachNodeChannel` pass in a nil param for `kvdb.RTx` - meaning we can remove the parameter in a future commit. --- graph/db/graph_test.go | 93 +++++++++++++++++++++++++++++++++++++++++- graph/db/kv_store.go | 38 +++++++++++++++++ server.go | 31 +++----------- 3 files changed, 136 insertions(+), 26 deletions(-) diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index 5e445f10d2..825096b062 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -618,7 +618,7 @@ func createChannelEdge(node1, node2 *models.LightningNode) ( chanID := uint64(prand.Int63()) outpoint := wire.OutPoint{ Hash: rev, - Index: 9, + Index: prand.Uint32(), } // Add the new edge to the database, this should proceed without any @@ -991,6 +991,97 @@ func newEdgePolicy(chanID uint64, updateTime int64) *models.ChannelEdgePolicy { } } +// TestForEachSourceNodeChannel tests that the ForEachSourceNodeChannel +// correctly iterates through the channels of the set source node. +func TestForEachSourceNodeChannel(t *testing.T) { + t.Parallel() + + graph, err := MakeTestGraph(t) + require.NoError(t, err, "unable to make test database") + + // Create a source node (A) and set it as such in the DB. + nodeA := createTestVertex(t) + require.NoError(t, graph.SetSourceNode(nodeA)) + + // Now, create a few more nodes (B, C, D) along with some channels + // between them. We'll create the following graph: + // + // A -- B -- D + // | + // C + // + // The graph includes a channel (B-D) that does not belong to the source + // node along with 2 channels (A-B and A-C) that do belong to the source + // node. For the A-B channel, we will let the source node set an + // outgoing policy but for the A-C channel, we will set only an incoming + // policy. + + nodeB := createTestVertex(t) + nodeC := createTestVertex(t) + nodeD := createTestVertex(t) + + abEdge, abPolicy1, abPolicy2 := createChannelEdge(nodeA, nodeB) + require.NoError(t, graph.AddChannelEdge(abEdge)) + acEdge, acPolicy1, acPolicy2 := createChannelEdge(nodeA, nodeC) + require.NoError(t, graph.AddChannelEdge(acEdge)) + bdEdge, _, _ := createChannelEdge(nodeB, nodeD) + require.NoError(t, graph.AddChannelEdge(bdEdge)) + + // Figure out which of the policies returned above are node A's so that + // we know which to persist. + // + // First, set the outgoing policy for the A-B channel. + abPolicyAOutgoing := abPolicy1 + if !bytes.Equal(abPolicy1.ToNode[:], nodeB.PubKeyBytes[:]) { + abPolicyAOutgoing = abPolicy2 + } + require.NoError(t, graph.UpdateEdgePolicy(abPolicyAOutgoing)) + + // Now, set the incoming policy for the A-C channel. + acPolicyAIncoming := acPolicy1 + if !bytes.Equal(acPolicy1.ToNode[:], nodeA.PubKeyBytes[:]) { + acPolicyAIncoming = acPolicy2 + } + require.NoError(t, graph.UpdateEdgePolicy(acPolicyAIncoming)) + + type sourceNodeChan struct { + otherNode route.Vertex + havePolicy bool + } + + // Put together our expected source node channels. + expectedSrcChans := map[wire.OutPoint]*sourceNodeChan{ + abEdge.ChannelPoint: { + otherNode: nodeB.PubKeyBytes, + havePolicy: true, + }, + acEdge.ChannelPoint: { + otherNode: nodeC.PubKeyBytes, + havePolicy: false, + }, + } + + // Now, we'll use the ForEachSourceNodeChannel and assert that it + // returns the expected data in the call-back. + err = graph.ForEachSourceNodeChannel(func(chanPoint wire.OutPoint, + havePolicy bool, otherNode *models.LightningNode) error { + + require.Contains(t, expectedSrcChans, chanPoint) + expected := expectedSrcChans[chanPoint] + + require.Equal( + t, expected.otherNode[:], otherNode.PubKeyBytes[:], + ) + require.Equal(t, expected.havePolicy, havePolicy) + + delete(expectedSrcChans, chanPoint) + + return nil + }) + require.NoError(t, err) + require.Empty(t, expectedSrcChans) +} + func TestGraphTraversal(t *testing.T) { t.Parallel() diff --git a/graph/db/kv_store.go b/graph/db/kv_store.go index 50d507185a..68beb7a5d7 100644 --- a/graph/db/kv_store.go +++ b/graph/db/kv_store.go @@ -3109,6 +3109,44 @@ func (c *KVStore) ForEachNodeChannel(nodePub route.Vertex, return nodeTraversal(nil, nodePub[:], c.db, cb) } +// ForEachSourceNodeChannel iterates through all channels of the source node, +// executing the passed callback on each. The callback is provided with the +// channel's outpoint, whether we have a policy for the channel and the channel +// peer's node information. +func (c *KVStore) ForEachSourceNodeChannel(cb func(chanPoint wire.OutPoint, + havePolicy bool, otherNode *models.LightningNode) error) error { + + return kvdb.View(c.db, func(tx kvdb.RTx) error { + nodes := tx.ReadBucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + + node, err := c.sourceNode(nodes) + if err != nil { + return err + } + + return nodeTraversal( + tx, node.PubKeyBytes[:], c.db, func(tx kvdb.RTx, + info *models.ChannelEdgeInfo, + policy, _ *models.ChannelEdgePolicy) error { + + peer, err := c.FetchOtherNode( + tx, info, node.PubKeyBytes[:], + ) + if err != nil { + return err + } + + return cb( + info.ChannelPoint, policy != nil, peer, + ) + }, + ) + }, func() {}) +} + // ForEachNodeChannelTx iterates through all channels of the given node, // executing the passed callback with an edge info structure and the policies // of each end of the channel. The first edge policy is the outgoing edge *to* diff --git a/server.go b/server.go index 2f48e4a0da..395100962d 100644 --- a/server.go +++ b/server.go @@ -3553,36 +3553,17 @@ func (s *server) establishPersistentConnections() error { // After checking our previous connections for addresses to connect to, // iterate through the nodes in our channel graph to find addresses // that have been added via NodeAnnouncement messages. - sourceNode, err := s.graphDB.SourceNode() - if err != nil { - return fmt.Errorf("failed to fetch source node: %w", err) - } - // TODO(roasbeef): instead iterate over link nodes and query graph for // each of the nodes. - selfPub := s.identityECDH.PubKey().SerializeCompressed() - err = s.graphDB.ForEachNodeChannel(sourceNode.PubKeyBytes, func( - tx kvdb.RTx, - chanInfo *models.ChannelEdgeInfo, - policy, _ *models.ChannelEdgePolicy) error { + err = s.graphDB.ForEachSourceNodeChannel(func(chanPoint wire.OutPoint, + havePolicy bool, channelPeer *models.LightningNode) error { // If the remote party has announced the channel to us, but we // haven't yet, then we won't have a policy. However, we don't // need this to connect to the peer, so we'll log it and move on. - if policy == nil { + if !havePolicy { srvrLog.Warnf("No channel policy found for "+ - "ChannelPoint(%v): ", chanInfo.ChannelPoint) - } - - // We'll now fetch the peer opposite from us within this - // channel so we can queue up a direct connection to them. - channelPeer, err := s.graphDB.FetchOtherNode( - tx, chanInfo, selfPub, - ) - if err != nil { - return fmt.Errorf("unable to fetch channel peer for "+ - "ChannelPoint(%v): %v", chanInfo.ChannelPoint, - err) + "ChannelPoint(%v): ", chanPoint) } pubStr := string(channelPeer.PubKeyBytes[:]) @@ -3642,8 +3623,8 @@ func (s *server) establishPersistentConnections() error { return nil }) if err != nil { - srvrLog.Errorf("Failed to iterate channels for node %x", - sourceNode.PubKeyBytes) + srvrLog.Errorf("Failed to iterate over source node channels: "+ + "%v", err) if !errors.Is(err, graphdb.ErrGraphNoEdgesFound) && !errors.Is(err, graphdb.ErrEdgeNotFound) { From d079f864d275377db610d21ec11a47ede23bb197 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Thu, 27 Mar 2025 11:59:48 +0200 Subject: [PATCH 21/41] graph/db: unexport various methods that expose `kvdb.RTx` Unexport the KVStore `FetchOtherNode` and `ForEachNodeChannelTx` methods so that fewer exposed methods are leaking implementation details. --- graph/db/graph_test.go | 2 +- graph/db/kv_store.go | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index 825096b062..f7d87a1ced 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -1219,7 +1219,7 @@ func TestGraphTraversalCacheable(t *testing.T) { err = graph.db.View(func(tx kvdb.RTx) error { for _, node := range nodes { - err := graph.ForEachNodeChannelTx(tx, node, + err := graph.forEachNodeChannelTx(tx, node, func(tx kvdb.RTx, info *models.ChannelEdgeInfo, policy *models.ChannelEdgePolicy, policy2 *models.ChannelEdgePolicy) error { //nolint:ll diff --git a/graph/db/kv_store.go b/graph/db/kv_store.go index 68beb7a5d7..a0399ff908 100644 --- a/graph/db/kv_store.go +++ b/graph/db/kv_store.go @@ -561,7 +561,7 @@ func (c *KVStore) ForEachNodeCached(cb func(node route.Vertex, channels := make(map[uint64]*DirectedChannel) - err := c.ForEachNodeChannelTx(tx, node.PubKeyBytes, + err := c.forEachNodeChannelTx(tx, node.PubKeyBytes, func(tx kvdb.RTx, e *models.ChannelEdgeInfo, p1 *models.ChannelEdgePolicy, p2 *models.ChannelEdgePolicy) error { @@ -2850,7 +2850,7 @@ func (c *KVStore) isPublic(tx kvdb.RTx, nodePub route.Vertex, // used to terminate the check early. nodeIsPublic := false errDone := errors.New("done") - err := c.ForEachNodeChannelTx(tx, nodePub, func(tx kvdb.RTx, + err := c.forEachNodeChannelTx(tx, nodePub, func(tx kvdb.RTx, info *models.ChannelEdgeInfo, _ *models.ChannelEdgePolicy, _ *models.ChannelEdgePolicy) error { @@ -3132,7 +3132,7 @@ func (c *KVStore) ForEachSourceNodeChannel(cb func(chanPoint wire.OutPoint, info *models.ChannelEdgeInfo, policy, _ *models.ChannelEdgePolicy) error { - peer, err := c.FetchOtherNode( + peer, err := c.fetchOtherNode( tx, info, node.PubKeyBytes[:], ) if err != nil { @@ -3147,7 +3147,7 @@ func (c *KVStore) ForEachSourceNodeChannel(cb func(chanPoint wire.OutPoint, }, func() {}) } -// ForEachNodeChannelTx iterates through all channels of the given node, +// forEachNodeChannelTx iterates through all channels of the given node, // executing the passed callback with an edge info structure and the policies // of each end of the channel. The first edge policy is the outgoing edge *to* // the connecting node, while the second is the incoming edge *from* the @@ -3160,7 +3160,7 @@ func (c *KVStore) ForEachSourceNodeChannel(cb func(chanPoint wire.OutPoint, // should be passed as the first argument. Otherwise, the first argument should // be nil and a fresh transaction will be created to execute the graph // traversal. -func (c *KVStore) ForEachNodeChannelTx(tx kvdb.RTx, +func (c *KVStore) forEachNodeChannelTx(tx kvdb.RTx, nodePub route.Vertex, cb func(kvdb.RTx, *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error { @@ -3168,11 +3168,11 @@ func (c *KVStore) ForEachNodeChannelTx(tx kvdb.RTx, return nodeTraversal(tx, nodePub[:], c.db, cb) } -// FetchOtherNode attempts to fetch the full LightningNode that's opposite of +// fetchOtherNode attempts to fetch the full LightningNode that's opposite of // the target node in the channel. This is useful when one knows the pubkey of // one of the nodes, and wishes to obtain the full LightningNode for the other // end of the channel. -func (c *KVStore) FetchOtherNode(tx kvdb.RTx, +func (c *KVStore) fetchOtherNode(tx kvdb.RTx, channel *models.ChannelEdgeInfo, thisNodeKey []byte) ( *models.LightningNode, error) { @@ -4702,7 +4702,7 @@ func (c *chanGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) { func (c *chanGraphNodeTx) ForEachChannel(f func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error { - return c.db.ForEachNodeChannelTx(c.tx, c.node.PubKeyBytes, + return c.db.forEachNodeChannelTx(c.tx, c.node.PubKeyBytes, func(_ kvdb.RTx, info *models.ChannelEdgeInfo, policy1, policy2 *models.ChannelEdgePolicy) error { From 14ce1ffe0eaa5aa6f96542d74d82fd819d416bf8 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sat, 5 Apr 2025 16:53:02 +0200 Subject: [PATCH 22/41] graph/db: use only exported KVStore methods in tests Replace all calls to bbolt specific methods on the KVStore to instead use exported methods on the KVStore that are more db-agnostic. --- graph/db/graph_test.go | 29 +++++++++-------------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index f7d87a1ced..26b7bede19 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -1217,26 +1217,15 @@ func TestGraphTraversalCacheable(t *testing.T) { require.NoError(t, err) require.Len(t, nodeMap, 0) - err = graph.db.View(func(tx kvdb.RTx) error { - for _, node := range nodes { - err := graph.forEachNodeChannelTx(tx, node, - func(tx kvdb.RTx, info *models.ChannelEdgeInfo, - policy *models.ChannelEdgePolicy, - policy2 *models.ChannelEdgePolicy) error { //nolint:ll - - delete(chanIndex, info.ChannelID) - return nil - }, - ) - if err != nil { - return err - } - } - - return nil - }, func() {}) - - require.NoError(t, err) + for _, node := range nodes { + err = graph.ForEachNodeDirectedChannel( + node, func(d *DirectedChannel) error { + delete(chanIndex, d.ChannelID) + return nil + }, + ) + require.NoError(t, err) + } require.Len(t, chanIndex, 0) } From 302a1342ddf0fe954e0f062c12e726ca61ba45f3 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 26 Mar 2025 17:28:10 +0200 Subject: [PATCH 23/41] multi: remove kvdb.RTx from ForEachNodeChannel Since we have not removed all call-sites that make use of this parameter, we can remove it. This helps hide DB-specific details from the interface we will introduce for the graph store. --- autopilot/prefattach_test.go | 2 +- graph/builder.go | 4 +--- graph/db/graph_test.go | 7 +++---- graph/db/kv_store.go | 9 +++++++-- graph/interfaces.go | 8 +++----- rpcserver.go | 5 ++--- server.go | 3 +-- 7 files changed, 18 insertions(+), 20 deletions(-) diff --git a/autopilot/prefattach_test.go b/autopilot/prefattach_test.go index b79a396b1c..2e3b22ff3f 100644 --- a/autopilot/prefattach_test.go +++ b/autopilot/prefattach_test.go @@ -736,7 +736,7 @@ func (t *testNodeTx) Node() *models.LightningNode { func (t *testNodeTx) ForEachChannel(f func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error { - return t.db.db.ForEachNodeChannel(t.node.PubKeyBytes, func(_ kvdb.RTx, + return t.db.db.ForEachNodeChannel(t.node.PubKeyBytes, func( edge *models.ChannelEdgeInfo, policy1, policy2 *models.ChannelEdgePolicy) error { diff --git a/graph/builder.go b/graph/builder.go index f92b523b00..3350eeb331 100644 --- a/graph/builder.go +++ b/graph/builder.go @@ -13,7 +13,6 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/graph/db/models" - "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -1276,8 +1275,7 @@ func (b *Builder) ForAllOutgoingChannels(cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy) error) error { return b.cfg.Graph.ForEachNodeChannel(b.cfg.SelfNode, - func(_ kvdb.RTx, c *models.ChannelEdgeInfo, - e *models.ChannelEdgePolicy, + func(c *models.ChannelEdgeInfo, e *models.ChannelEdgePolicy, _ *models.ChannelEdgePolicy) error { if e == nil { diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index 26b7bede19..ce8b68ade0 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -1141,7 +1141,7 @@ func TestGraphTraversal(t *testing.T) { numNodeChans := 0 firstNode, secondNode := nodeList[0], nodeList[1] err = graph.ForEachNodeChannel(firstNode.PubKeyBytes, - func(_ kvdb.RTx, _ *models.ChannelEdgeInfo, outEdge, + func(_ *models.ChannelEdgeInfo, outEdge, inEdge *models.ChannelEdgePolicy) error { // All channels between first and second node should @@ -2882,7 +2882,7 @@ func TestIncompleteChannelPolicies(t *testing.T) { calls := 0 err := graph.ForEachNodeChannel(node.PubKeyBytes, - func(_ kvdb.RTx, _ *models.ChannelEdgeInfo, outEdge, + func(_ *models.ChannelEdgeInfo, outEdge, inEdge *models.ChannelEdgePolicy) error { if !expectedOut && outEdge != nil { @@ -4001,8 +4001,7 @@ func BenchmarkForEachChannel(b *testing.B) { require.NoError(b, err) for _, n := range nodes { - cb := func(tx kvdb.RTx, - info *models.ChannelEdgeInfo, + cb := func(info *models.ChannelEdgeInfo, policy *models.ChannelEdgePolicy, policy2 *models.ChannelEdgePolicy) error { //nolint:ll diff --git a/graph/db/kv_store.go b/graph/db/kv_store.go index a0399ff908..104c65b366 100644 --- a/graph/db/kv_store.go +++ b/graph/db/kv_store.go @@ -3103,10 +3103,15 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, // // Unknown policies are passed into the callback as nil values. func (c *KVStore) ForEachNodeChannel(nodePub route.Vertex, - cb func(kvdb.RTx, *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error { - return nodeTraversal(nil, nodePub[:], c.db, cb) + return nodeTraversal(nil, nodePub[:], c.db, func(_ kvdb.RTx, + info *models.ChannelEdgeInfo, policy, + policy2 *models.ChannelEdgePolicy) error { + + return cb(info, policy, policy2) + }) } // ForEachSourceNodeChannel iterates through all channels of the source node, diff --git a/graph/interfaces.go b/graph/interfaces.go index e351514bba..52897d8f5b 100644 --- a/graph/interfaces.go +++ b/graph/interfaces.go @@ -8,7 +8,6 @@ import ( "github.com/lightningnetwork/lnd/batch" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/graph/db/models" - "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" ) @@ -254,10 +253,9 @@ type DB interface { // to the caller. // // Unknown policies are passed into the callback as nil values. - ForEachNodeChannel(nodePub route.Vertex, cb func(kvdb.RTx, - *models.ChannelEdgeInfo, - *models.ChannelEdgePolicy, - *models.ChannelEdgePolicy) error) error + ForEachNodeChannel(nodePub route.Vertex, + cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy) error) error // AddEdgeProof sets the proof of an existing edge in the graph // database. diff --git a/rpcserver.go b/rpcserver.go index 72cf953bf4..b009b2d6b4 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -56,7 +56,6 @@ import ( "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/keychain" - "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/labels" "github.com/lightningnetwork/lnd/lncfg" "github.com/lightningnetwork/lnd/lnrpc" @@ -6959,7 +6958,7 @@ func (r *rpcServer) GetNodeInfo(ctx context.Context, ) err = graph.ForEachNodeChannel(node.PubKeyBytes, - func(_ kvdb.RTx, edge *models.ChannelEdgeInfo, + func(edge *models.ChannelEdgeInfo, c1, c2 *models.ChannelEdgePolicy) error { numChannels++ @@ -7641,7 +7640,7 @@ func (r *rpcServer) FeeReport(ctx context.Context, var feeReports []*lnrpc.ChannelFeeReport err = channelGraph.ForEachNodeChannel(selfNode.PubKeyBytes, - func(_ kvdb.RTx, chanInfo *models.ChannelEdgeInfo, + func(chanInfo *models.ChannelEdgeInfo, edgePolicy, _ *models.ChannelEdgePolicy) error { // Self node should always have policies for its diff --git a/server.go b/server.go index 395100962d..7f5219f998 100644 --- a/server.go +++ b/server.go @@ -51,7 +51,6 @@ import ( "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/keychain" - "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lncfg" "github.com/lightningnetwork/lnd/lnencrypt" "github.com/lightningnetwork/lnd/lnpeer" @@ -1212,7 +1211,7 @@ func newServer(_ context.Context, cfg *Config, listenAddrs []net.Addr, *models.ChannelEdgePolicy) error) error { return s.graphDB.ForEachNodeChannel(selfVertex, - func(_ kvdb.RTx, c *models.ChannelEdgeInfo, + func(c *models.ChannelEdgeInfo, e *models.ChannelEdgePolicy, _ *models.ChannelEdgePolicy) error { From 8b87367026e91b0fa6bcca67997c546c01d9ebaf Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 11 Apr 2025 09:54:28 +0200 Subject: [PATCH 24/41] discovery: revert passing ctx through to Start methods --- discovery/gossiper.go | 8 ++++---- discovery/gossiper_test.go | 4 ++-- discovery/reliable_sender.go | 4 ++-- discovery/sync_manager.go | 21 ++++++--------------- discovery/sync_manager_test.go | 19 +++++++++---------- discovery/syncer.go | 4 ++-- discovery/syncer_test.go | 24 +++++++++--------------- server.go | 2 +- 8 files changed, 35 insertions(+), 51 deletions(-) diff --git a/discovery/gossiper.go b/discovery/gossiper.go index ab12524fb6..4b2f4f220e 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -643,10 +643,10 @@ func (d *AuthenticatedGossiper) PropagateChanPolicyUpdate( // Start spawns network messages handler goroutine and registers on new block // notifications in order to properly handle the premature announcements. -func (d *AuthenticatedGossiper) Start(ctx context.Context) error { +func (d *AuthenticatedGossiper) Start() error { var err error d.started.Do(func() { - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := context.WithCancel(context.Background()) d.cancel = fn.Some(cancel) log.Info("Authenticated Gossiper starting") @@ -674,11 +674,11 @@ func (d *AuthenticatedGossiper) start(ctx context.Context) error { // Start the reliable sender. In case we had any pending messages ready // to be sent when the gossiper was last shut down, we must continue on // our quest to deliver them to their respective peers. - if err := d.reliableSender.Start(ctx); err != nil { + if err := d.reliableSender.Start(); err != nil { return err } - d.syncMgr.Start(ctx) + d.syncMgr.Start() d.banman.start() diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index ef7f2f21f6..3990109916 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -994,7 +994,7 @@ func createTestCtx(t *testing.T, startHeight uint32, isChanPeer bool) ( ScidCloser: newMockScidCloser(isChanPeer), }, selfKeyDesc) - if err := gossiper.Start(context.Background()); err != nil { + if err := gossiper.Start(); err != nil { return nil, fmt.Errorf("unable to start router: %w", err) } @@ -1692,7 +1692,7 @@ func TestSignatureAnnouncementRetryAtStartup(t *testing.T) { KeyLocator: tCtx.gossiper.selfKeyLoc, }) require.NoError(t, err, "unable to recreate gossiper") - if err := gossiper.Start(context.Background()); err != nil { + if err := gossiper.Start(); err != nil { t.Fatalf("unable to start recreated gossiper: %v", err) } defer gossiper.Stop() diff --git a/discovery/reliable_sender.go b/discovery/reliable_sender.go index 57f2f28ff3..357654a65e 100644 --- a/discovery/reliable_sender.go +++ b/discovery/reliable_sender.go @@ -76,10 +76,10 @@ func newReliableSender(cfg *reliableSenderCfg) *reliableSender { } // Start spawns message handlers for any peers with pending messages. -func (s *reliableSender) Start(ctx context.Context) error { +func (s *reliableSender) Start() error { var err error s.start.Do(func() { - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := context.WithCancel(context.Background()) s.cancel = fn.Some(cancel) err = s.resendPendingMsgs(ctx) diff --git a/discovery/sync_manager.go b/discovery/sync_manager.go index 6ed96ac015..b1cb208cd9 100644 --- a/discovery/sync_manager.go +++ b/discovery/sync_manager.go @@ -8,7 +8,6 @@ import ( "time" "github.com/btcsuite/btcd/chaincfg/chainhash" - "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" @@ -201,9 +200,8 @@ type SyncManager struct { // number of queries. rateLimiter *rate.Limiter - wg sync.WaitGroup - quit chan struct{} - cancel fn.Option[context.CancelFunc] + wg sync.WaitGroup + quit chan struct{} } // newSyncManager constructs a new SyncManager backed by the given config. @@ -248,13 +246,10 @@ func newSyncManager(cfg *SyncManagerCfg) *SyncManager { } // Start starts the SyncManager in order to properly carry out its duties. -func (m *SyncManager) Start(ctx context.Context) { +func (m *SyncManager) Start() { m.start.Do(func() { - ctx, cancel := context.WithCancel(ctx) - m.cancel = fn.Some(cancel) - m.wg.Add(1) - go m.syncerHandler(ctx) + go m.syncerHandler() }) } @@ -264,7 +259,6 @@ func (m *SyncManager) Stop() { log.Debugf("SyncManager is stopping") defer log.Debugf("SyncManager stopped") - m.cancel.WhenSome(func(fn context.CancelFunc) { fn() }) close(m.quit) m.wg.Wait() @@ -288,7 +282,7 @@ func (m *SyncManager) Stop() { // much of the public network as possible. // // NOTE: This must be run as a goroutine. -func (m *SyncManager) syncerHandler(ctx context.Context) { +func (m *SyncManager) syncerHandler() { defer m.wg.Done() m.cfg.RotateTicker.Resume() @@ -386,7 +380,7 @@ func (m *SyncManager) syncerHandler(ctx context.Context) { } m.syncersMu.Unlock() - s.Start(ctx) + s.Start() // Once we create the GossipSyncer, we'll signal to the // caller that they can proceed since the SyncManager's @@ -538,9 +532,6 @@ func (m *SyncManager) syncerHandler(ctx context.Context) { case <-m.quit: return - - case <-ctx.Done(): - return } } } diff --git a/discovery/sync_manager_test.go b/discovery/sync_manager_test.go index b8ef931977..4aff5b6315 100644 --- a/discovery/sync_manager_test.go +++ b/discovery/sync_manager_test.go @@ -2,7 +2,6 @@ package discovery import ( "bytes" - "context" "fmt" "io" "reflect" @@ -83,7 +82,7 @@ func TestSyncManagerNumActiveSyncers(t *testing.T) { } syncMgr := newPinnedTestSyncManager(numActiveSyncers, pinnedSyncers) - syncMgr.Start(context.Background()) + syncMgr.Start() defer syncMgr.Stop() // First we'll start by adding the pinned syncers. These should @@ -135,7 +134,7 @@ func TestSyncManagerNewActiveSyncerAfterDisconnect(t *testing.T) { // We'll create our test sync manager to have two active syncers. syncMgr := newTestSyncManager(2) - syncMgr.Start(context.Background()) + syncMgr.Start() defer syncMgr.Stop() // The first will be an active syncer that performs a historical sync @@ -188,7 +187,7 @@ func TestSyncManagerRotateActiveSyncerCandidate(t *testing.T) { // We'll create our sync manager with three active syncers. syncMgr := newTestSyncManager(1) - syncMgr.Start(context.Background()) + syncMgr.Start() defer syncMgr.Stop() // The first syncer registered always performs a historical sync. @@ -236,7 +235,7 @@ func TestSyncManagerNoInitialHistoricalSync(t *testing.T) { t.Parallel() syncMgr := newTestSyncManager(0) - syncMgr.Start(context.Background()) + syncMgr.Start() defer syncMgr.Stop() // We should not expect any messages from the peer. @@ -270,7 +269,7 @@ func TestSyncManagerInitialHistoricalSync(t *testing.T) { t.Fatal("expected graph to not be considered as synced") } - syncMgr.Start(context.Background()) + syncMgr.Start() defer syncMgr.Stop() // We should expect to see a QueryChannelRange message with a @@ -339,7 +338,7 @@ func TestSyncManagerHistoricalSyncOnReconnect(t *testing.T) { t.Parallel() syncMgr := newTestSyncManager(2) - syncMgr.Start(context.Background()) + syncMgr.Start() defer syncMgr.Stop() // We should expect to see a QueryChannelRange message with a @@ -373,7 +372,7 @@ func TestSyncManagerForceHistoricalSync(t *testing.T) { t.Parallel() syncMgr := newTestSyncManager(1) - syncMgr.Start(context.Background()) + syncMgr.Start() defer syncMgr.Stop() // We should expect to see a QueryChannelRange message with a @@ -411,7 +410,7 @@ func TestSyncManagerGraphSyncedAfterHistoricalSyncReplacement(t *testing.T) { t.Parallel() syncMgr := newTestSyncManager(1) - syncMgr.Start(context.Background()) + syncMgr.Start() defer syncMgr.Stop() // We should expect to see a QueryChannelRange message with a @@ -469,7 +468,7 @@ func TestSyncManagerWaitUntilInitialHistoricalSync(t *testing.T) { // We'll start by creating our test sync manager which will hold up to // 2 active syncers. syncMgr := newTestSyncManager(numActiveSyncers) - syncMgr.Start(context.Background()) + syncMgr.Start() defer syncMgr.Stop() // We'll go ahead and create our syncers. diff --git a/discovery/syncer.go b/discovery/syncer.go index 0b4e7030b3..b27cc8c718 100644 --- a/discovery/syncer.go +++ b/discovery/syncer.go @@ -405,11 +405,11 @@ func newGossipSyncer(cfg gossipSyncerCfg, sema chan struct{}) *GossipSyncer { // Start starts the GossipSyncer and any goroutines that it needs to carry out // its duties. -func (g *GossipSyncer) Start(ctx context.Context) { +func (g *GossipSyncer) Start() { g.started.Do(func() { log.Debugf("Starting GossipSyncer(%x)", g.cfg.peerPub[:]) - ctx, _ := g.cg.Create(ctx) + ctx, _ := g.cg.Create(context.Background()) // TODO(conner): only spawn channelGraphSyncer if remote // supports gossip queries, and only spawn replyHandler if we diff --git a/discovery/syncer_test.go b/discovery/syncer_test.go index 13071d4b01..44e8d6d701 100644 --- a/discovery/syncer_test.go +++ b/discovery/syncer_test.go @@ -1703,7 +1703,6 @@ func queryBatch(t *testing.T, // them. func TestGossipSyncerRoutineSync(t *testing.T) { t.Parallel() - ctx := context.Background() // We'll modify the chunk size to be a smaller value, so we can ensure // our chunk parsing works properly. With this value we should get 3 @@ -1718,13 +1717,13 @@ func TestGossipSyncerRoutineSync(t *testing.T) { msgChan1, syncer1, chanSeries1 := newTestSyncer( highestID, defaultEncoding, chunkSize, true, false, ) - syncer1.Start(ctx) + syncer1.Start() defer syncer1.Stop() msgChan2, syncer2, chanSeries2 := newTestSyncer( highestID, defaultEncoding, chunkSize, false, true, ) - syncer2.Start(ctx) + syncer2.Start() defer syncer2.Stop() // Although both nodes are at the same height, syncer will have 3 chan @@ -1851,7 +1850,6 @@ func TestGossipSyncerRoutineSync(t *testing.T) { // final state and not perform any channel queries. func TestGossipSyncerAlreadySynced(t *testing.T) { t.Parallel() - ctx := context.Background() // We'll modify the chunk size to be a smaller value, so we can ensure // our chunk parsing works properly. With this value we should get 3 @@ -1867,13 +1865,13 @@ func TestGossipSyncerAlreadySynced(t *testing.T) { msgChan1, syncer1, chanSeries1 := newTestSyncer( highestID, defaultEncoding, chunkSize, ) - syncer1.Start(ctx) + syncer1.Start() defer syncer1.Stop() msgChan2, syncer2, chanSeries2 := newTestSyncer( highestID, defaultEncoding, chunkSize, ) - syncer2.Start(ctx) + syncer2.Start() defer syncer2.Stop() // The channel state of both syncers will be identical. They should @@ -2073,7 +2071,6 @@ func TestGossipSyncerAlreadySynced(t *testing.T) { // carries out its duties when accepting a new sync transition request. func TestGossipSyncerSyncTransitions(t *testing.T) { t.Parallel() - ctx := context.Background() assertMsgSent := func(t *testing.T, msgChan chan []lnwire.Message, msg lnwire.Message) { @@ -2194,7 +2191,7 @@ func TestGossipSyncerSyncTransitions(t *testing.T) { // We'll then start the syncer in order to process the // request. - syncer.Start(ctx) + syncer.Start() defer syncer.Stop() syncer.ProcessSyncTransition(test.finalSyncType) @@ -2219,7 +2216,6 @@ func TestGossipSyncerSyncTransitions(t *testing.T) { // historical sync with the remote peer. func TestGossipSyncerHistoricalSync(t *testing.T) { t.Parallel() - ctx := context.Background() // We'll create a new gossip syncer and manually override its state to // chansSynced. This is necessary as the syncer can only process @@ -2231,7 +2227,7 @@ func TestGossipSyncerHistoricalSync(t *testing.T) { syncer.setSyncType(PassiveSync) syncer.setSyncState(chansSynced) - syncer.Start(ctx) + syncer.Start() defer syncer.Stop() syncer.historicalSync() @@ -2264,7 +2260,6 @@ func TestGossipSyncerHistoricalSync(t *testing.T) { // syncer reaches its terminal chansSynced state. func TestGossipSyncerSyncedSignal(t *testing.T) { t.Parallel() - ctx := context.Background() // We'll create a new gossip syncer and manually override its state to // chansSynced. @@ -2279,7 +2274,7 @@ func TestGossipSyncerSyncedSignal(t *testing.T) { signalChan := syncer.ResetSyncedSignal() // Starting the gossip syncer should cause the signal to be delivered. - syncer.Start(ctx) + syncer.Start() select { case <-signalChan: @@ -2298,7 +2293,7 @@ func TestGossipSyncerSyncedSignal(t *testing.T) { syncer.setSyncState(chansSynced) - syncer.Start(ctx) + syncer.Start() defer syncer.Stop() signalChan = syncer.ResetSyncedSignal() @@ -2317,7 +2312,6 @@ func TestGossipSyncerSyncedSignal(t *testing.T) { // said limit are not processed. func TestGossipSyncerMaxChannelRangeReplies(t *testing.T) { t.Parallel() - ctx := context.Background() msgChan, syncer, chanSeries := newTestSyncer( lnwire.ShortChannelID{BlockHeight: latestKnownHeight}, @@ -2328,7 +2322,7 @@ func TestGossipSyncerMaxChannelRangeReplies(t *testing.T) { // the sake of testing. syncer.cfg.maxQueryChanRangeReplies = 100 - syncer.Start(ctx) + syncer.Start() defer syncer.Stop() // Upon initialization, the syncer should submit a QueryChannelRange diff --git a/server.go b/server.go index 7f5219f998..8608e50642 100644 --- a/server.go +++ b/server.go @@ -2389,7 +2389,7 @@ func (s *server) Start(ctx context.Context) error { // The authGossiper depends on the chanRouter and therefore // should be started after it. cleanup = cleanup.add(s.authGossiper.Stop) - if err := s.authGossiper.Start(ctx); err != nil { + if err := s.authGossiper.Start(); err != nil { startErr = err return } From d01b39235ad294ccaff7912478b37ce1b9590b99 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 11 Apr 2025 09:57:24 +0200 Subject: [PATCH 25/41] autopilot: revert passing ctx to Start methods --- autopilot/agent.go | 4 ++-- autopilot/agent_test.go | 2 +- autopilot/manager.go | 17 ++++------------- lnd.go | 2 +- lnrpc/autopilotrpc/autopilot_server.go | 4 ++-- 5 files changed, 10 insertions(+), 19 deletions(-) diff --git a/autopilot/agent.go b/autopilot/agent.go index 95e95c55f0..1b70580a3e 100644 --- a/autopilot/agent.go +++ b/autopilot/agent.go @@ -202,10 +202,10 @@ func New(cfg Config, initialState []LocalChannel) (*Agent, error) { // Start starts the agent along with any goroutines it needs to perform its // normal duties. -func (a *Agent) Start(ctx context.Context) error { +func (a *Agent) Start() error { var err error a.started.Do(func() { - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := context.WithCancel(context.Background()) a.cancel = fn.Some(cancel) err = a.start(ctx) diff --git a/autopilot/agent_test.go b/autopilot/agent_test.go index 47305a2ebc..82c21a2ffc 100644 --- a/autopilot/agent_test.go +++ b/autopilot/agent_test.go @@ -221,7 +221,7 @@ func setup(t *testing.T, initialChans []LocalChannel) *testContext { // With the autopilot agent and all its dependencies we'll start the // primary controller goroutine. - if err := agent.Start(context.Background()); err != nil { + if err := agent.Start(); err != nil { t.Fatalf("unable to start agent: %v", err) } diff --git a/autopilot/manager.go b/autopilot/manager.go index c0ca40b559..600a1055b2 100644 --- a/autopilot/manager.go +++ b/autopilot/manager.go @@ -7,7 +7,6 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/fn/v2" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -55,9 +54,8 @@ type Manager struct { // disabled. pilot *Agent - quit chan struct{} - wg sync.WaitGroup - cancel fn.Option[context.CancelFunc] + quit chan struct{} + wg sync.WaitGroup sync.Mutex } @@ -83,7 +81,6 @@ func (m *Manager) Stop() error { log.Errorf("Unable to stop pilot: %v", err) } - m.cancel.WhenSome(func(fn context.CancelFunc) { fn() }) close(m.quit) m.wg.Wait() }) @@ -100,7 +97,7 @@ func (m *Manager) IsActive() bool { // StartAgent creates and starts an autopilot agent from the Manager's // config. -func (m *Manager) StartAgent(ctx context.Context) error { +func (m *Manager) StartAgent() error { m.Lock() defer m.Unlock() @@ -108,8 +105,6 @@ func (m *Manager) StartAgent(ctx context.Context) error { if m.pilot != nil { return nil } - ctx, cancel := context.WithCancel(ctx) - m.cancel = fn.Some(cancel) // Next, we'll fetch the current state of open channels from the // database to use as initial state for the auto-pilot agent. @@ -125,7 +120,7 @@ func (m *Manager) StartAgent(ctx context.Context) error { return err } - if err := pilot.Start(ctx); err != nil { + if err := pilot.Start(); err != nil { return err } @@ -169,8 +164,6 @@ func (m *Manager) StartAgent(ctx context.Context) error { return case <-m.quit: return - case <-ctx.Done(): - return } } @@ -241,8 +234,6 @@ func (m *Manager) StartAgent(ctx context.Context) error { return case <-m.quit: return - case <-ctx.Done(): - return } } }() diff --git a/lnd.go b/lnd.go index 41bd3ca4ba..3afa8c2fba 100644 --- a/lnd.go +++ b/lnd.go @@ -788,7 +788,7 @@ func Main(cfg *Config, lisCfg ListenerCfg, implCfg *ImplementationCfg, // active, then we'll start the autopilot agent immediately. It will be // stopped together with the autopilot service. if cfg.Autopilot.Active { - if err := atplManager.StartAgent(ctx); err != nil { + if err := atplManager.StartAgent(); err != nil { return mkErr("unable to start autopilot agent", err) } } diff --git a/lnrpc/autopilotrpc/autopilot_server.go b/lnrpc/autopilotrpc/autopilot_server.go index 67e6c283a4..e5952949ec 100644 --- a/lnrpc/autopilotrpc/autopilot_server.go +++ b/lnrpc/autopilotrpc/autopilot_server.go @@ -198,14 +198,14 @@ func (s *Server) Status(ctx context.Context, // ModifyStatus activates the current autopilot agent, if active. // // NOTE: Part of the AutopilotServer interface. -func (s *Server) ModifyStatus(ctx context.Context, +func (s *Server) ModifyStatus(_ context.Context, in *ModifyStatusRequest) (*ModifyStatusResponse, error) { log.Debugf("Setting agent enabled=%v", in.Enable) var err error if in.Enable { - err = s.manager.StartAgent(ctx) + err = s.manager.StartAgent() } else { err = s.manager.StopAgent() } From dedc51fb3d13598c14fc477e4546c5a7a15dcac4 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sat, 5 Apr 2025 16:58:51 +0200 Subject: [PATCH 26/41] graph/db: let test alias be UTF-8 compatible Later on we will store the Alias as a Text field in our sql impls of the graph db. For Postgres, this field then MUST be a valid UTF-8 string. This is also the case in general for the alias according to [bolt 7](https://github.com/lightning/bolts/blob/e1fa25cf00446f3a6a6abbbc9a617cae5b75e39f/07-routing-gossip.md?plain=1#L313) --- graph/db/graph_test.go | 2 +- graph/notifications_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index ce8b68ade0..26dd411791 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -75,7 +75,7 @@ func createLightningNode(priv *btcec.PrivateKey) *models.LightningNode { AuthSigBytes: testSig.Serialize(), LastUpdate: time.Unix(updateTime, 0), Color: color.RGBA{1, 2, 3, 0}, - Alias: "kek" + string(pub[:]), + Alias: "kek" + hex.EncodeToString(pub), Features: testFeatures, Addresses: testAddrs, } diff --git a/graph/notifications_test.go b/graph/notifications_test.go index 0e2ec7afba..ff4649cfe0 100644 --- a/graph/notifications_test.go +++ b/graph/notifications_test.go @@ -89,7 +89,7 @@ func createTestNode(t *testing.T) *models.LightningNode { LastUpdate: time.Unix(updateTime, 0), Addresses: testAddrs, Color: color.RGBA{1, 2, 3, 0}, - Alias: "kek" + string(pub[:]), + Alias: "kek" + hex.EncodeToString(pub), AuthSigBytes: testSig.Serialize(), Features: testFeatures, } From 424f63b3103599ea37b61d5c3dbaa78676f82736 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sat, 5 Apr 2025 17:16:05 +0200 Subject: [PATCH 27/41] graph/db: update the `compareNodes` helper - Let it do a proper comparison of the full structs passed in. - Pass in a testing parameter so we can remove the returned error. - Make sure the callers are passing in the expected and result parameters in the correct order. - Fix a bug: the compareNodes was not comparing the Features field of the LightningNode structs. Now that it does, one test needed to be updated to properly set the expected Features fields. --- graph/db/graph_test.go | 70 +++++++++++++----------------------------- 1 file changed, 21 insertions(+), 49 deletions(-) diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index 26dd411791..de6f3a009f 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -78,6 +78,7 @@ func createLightningNode(priv *btcec.PrivateKey) *models.LightningNode { Alias: "kek" + hex.EncodeToString(pub), Features: testFeatures, Addresses: testAddrs, + ExtraOpaqueData: make([]byte, 0), } copy(n.PubKeyBytes[:], priv.PubKey().SerializeCompressed()) @@ -133,9 +134,7 @@ func TestNodeInsertionAndDeletion(t *testing.T) { } // The two nodes should match exactly! - if err := compareNodes(node, dbNode); err != nil { - t.Fatalf("nodes don't match: %v", err) - } + compareNodes(t, node, dbNode) // Next, delete the node from the graph, this should purge all data // related to the node. @@ -192,8 +191,9 @@ func TestPartialNode(t *testing.T) { HaveNodeAnnouncement: false, LastUpdate: time.Unix(0, 0), PubKeyBytes: pubKey1, + Features: lnwire.EmptyFeatureVector(), } - require.NoError(t, compareNodes(dbNode1, expectedNode1)) + compareNodes(t, expectedNode1, dbNode1) _, exists, err = graph.HasLightningNode(dbNode2.PubKeyBytes) require.NoError(t, err) @@ -205,8 +205,9 @@ func TestPartialNode(t *testing.T) { HaveNodeAnnouncement: false, LastUpdate: time.Unix(0, 0), PubKeyBytes: pubKey2, + Features: lnwire.EmptyFeatureVector(), } - require.NoError(t, compareNodes(dbNode2, expectedNode2)) + compareNodes(t, expectedNode2, dbNode2) // Next, delete the node from the graph, this should purge all data // related to the node. @@ -282,9 +283,7 @@ func TestSourceNode(t *testing.T) { // the one we set above. sourceNode, err := graph.SourceNode() require.NoError(t, err, "unable to fetch source node") - if err := compareNodes(testNode, sourceNode); err != nil { - t.Fatalf("nodes don't match: %v", err) - } + compareNodes(t, testNode, sourceNode) } func TestEdgeInsertionDeletion(t *testing.T) { @@ -1985,10 +1984,7 @@ func TestNodeUpdatesInHorizon(t *testing.T) { } for i := 0; i < len(resp); i++ { - err := compareNodes(&queryCase.resp[i], &resp[i]) - if err != nil { - t.Fatal(err) - } + compareNodes(t, &queryCase.resp[i], &resp[i]) } } } @@ -3222,9 +3218,7 @@ func TestNodePruningUpdateIndexDeletion(t *testing.T) { t.Fatalf("should have 1 nodes instead have: %v", len(nodesInHorizon)) } - if err := compareNodes(node1, &nodesInHorizon[0]); err != nil { - t.Fatalf("nodes don't match: %v", err) - } + compareNodes(t, node1, &nodesInHorizon[0]) // We'll now delete the node from the graph, this should result in it // being removed from the update index as well. @@ -3691,41 +3685,19 @@ func TestGraphZombieIndex(t *testing.T) { assertNumZombies(t, graph, 1) } -// compareNodes is used to compare two LightningNodes while excluding the -// Features struct, which cannot be compared as the semantics for reserializing -// the featuresMap have not been defined. -func compareNodes(a, b *models.LightningNode) error { - if a.LastUpdate != b.LastUpdate { - return fmt.Errorf("node LastUpdate doesn't match: expected "+ - "%v, got %v", a.LastUpdate, b.LastUpdate) - } - if !reflect.DeepEqual(a.Addresses, b.Addresses) { - return fmt.Errorf("Addresses doesn't match: expected %#v, \n "+ - "got %#v", a.Addresses, b.Addresses) - } - if !reflect.DeepEqual(a.PubKeyBytes, b.PubKeyBytes) { - return fmt.Errorf("PubKey doesn't match: expected %#v, \n "+ - "got %#v", a.PubKeyBytes, b.PubKeyBytes) - } - if !reflect.DeepEqual(a.Color, b.Color) { - return fmt.Errorf("Color doesn't match: expected %#v, \n "+ - "got %#v", a.Color, b.Color) - } - if !reflect.DeepEqual(a.Alias, b.Alias) { - return fmt.Errorf("Alias doesn't match: expected %#v, \n "+ - "got %#v", a.Alias, b.Alias) - } - if !reflect.DeepEqual(a.HaveNodeAnnouncement, b.HaveNodeAnnouncement) { - return fmt.Errorf("HaveNodeAnnouncement doesn't match: "+ - "expected %#v, got %#v", a.HaveNodeAnnouncement, - b.HaveNodeAnnouncement) - } - if !bytes.Equal(a.ExtraOpaqueData, b.ExtraOpaqueData) { - return fmt.Errorf("extra data doesn't match: %v vs %v", - a.ExtraOpaqueData, b.ExtraOpaqueData) - } +// compareNodes is used to compare two LightningNodes. +func compareNodes(t *testing.T, a, b *models.LightningNode) { + t.Helper() - return nil + // Call the PubKey method for each node to ensure that the internal + // `pubKey` field is set for both objects and so require.Equals can + // then be used to compare the structs. + _, err := a.PubKey() + require.NoError(t, err) + _, err = b.PubKey() + require.NoError(t, err) + + require.Equal(t, a, b) } // compareEdgePolicies is used to compare two ChannelEdgePolices using From 4ab0699e103eb514550ff183809e06f802795b7a Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sun, 30 Mar 2025 11:30:45 +0200 Subject: [PATCH 28/41] graph: test cleanup Remove the kvdb.Backend return type of the `makeTestGraph` helper. This is in preparation for the helper being used to create a test graph backed by a DB other than bbolt. --- graph/builder_test.go | 30 +++++++++++------------------- graph/notifications_test.go | 23 ++++++++--------------- 2 files changed, 19 insertions(+), 34 deletions(-) diff --git a/graph/builder_test.go b/graph/builder_test.go index b813c17e22..aed30e0973 100644 --- a/graph/builder_test.go +++ b/graph/builder_test.go @@ -1352,10 +1352,7 @@ func parseTestGraph(t *testing.T, useCache bool, path string) ( testAddrs = append(testAddrs, testAddr) // Next, create a temporary graph database for usage within the test. - graph, graphBackend, err := makeTestGraph(t, useCache) - if err != nil { - return nil, err - } + graph := makeTestGraph(t, useCache) aliasMap := make(map[string]route.Vertex) privKeyMap := make(map[string]*btcec.PrivateKey) @@ -1562,12 +1559,11 @@ func parseTestGraph(t *testing.T, useCache bool, path string) ( } return &testGraphInstance{ - graph: graph, - graphBackend: graphBackend, - aliasMap: aliasMap, - privKeyMap: privKeyMap, - channelIDs: channelIDs, - links: links, + graph: graph, + aliasMap: aliasMap, + privKeyMap: privKeyMap, + channelIDs: channelIDs, + links: links, }, nil } @@ -1730,10 +1726,7 @@ func createTestGraphFromChannels(t *testing.T, useCache bool, testAddrs = append(testAddrs, testAddr) // Next, create a temporary graph database for usage within the test. - graph, graphBackend, err := makeTestGraph(t, useCache) - if err != nil { - return nil, err - } + graph := makeTestGraph(t, useCache) aliasMap := make(map[string]route.Vertex) privKeyMap := make(map[string]*btcec.PrivateKey) @@ -1947,11 +1940,10 @@ func createTestGraphFromChannels(t *testing.T, useCache bool, } return &testGraphInstance{ - graph: graph, - graphBackend: graphBackend, - aliasMap: aliasMap, - privKeyMap: privKeyMap, - links: links, + graph: graph, + aliasMap: aliasMap, + privKeyMap: privKeyMap, + links: links, }, nil } diff --git a/graph/notifications_test.go b/graph/notifications_test.go index ff4649cfe0..ace578376a 100644 --- a/graph/notifications_test.go +++ b/graph/notifications_test.go @@ -1034,8 +1034,7 @@ type testCtx struct { func createTestCtxSingleNode(t *testing.T, startingHeight uint32) *testCtx { - graph, graphBackend, err := makeTestGraph(t, true) - require.NoError(t, err, "failed to make test graph") + graph := makeTestGraph(t, true) sourceNode := createTestNode(t) @@ -1044,8 +1043,7 @@ func createTestCtxSingleNode(t *testing.T, ) graphInstance := &testGraphInstance{ - graph: graph, - graphBackend: graphBackend, + graph: graph, } return createTestCtxFromGraphInstance( @@ -1086,14 +1084,12 @@ func (c *testCtx) RestartBuilder(t *testing.T) { // makeTestGraph creates a new instance of a channeldb.ChannelGraph for testing // purposes. -func makeTestGraph(t *testing.T, useCache bool) (*graphdb.ChannelGraph, - kvdb.Backend, error) { +func makeTestGraph(t *testing.T, useCache bool) *graphdb.ChannelGraph { + t.Helper() // Create channelgraph for the first time. backend, backendCleanup, err := kvdb.GetTestBackend(t.TempDir(), "cgr") - if err != nil { - return nil, nil, err - } + require.NoError(t, err) t.Cleanup(backendCleanup) @@ -1101,20 +1097,17 @@ func makeTestGraph(t *testing.T, useCache bool) (*graphdb.ChannelGraph, &graphdb.Config{KVDB: backend}, graphdb.WithUseGraphCache(useCache), ) - if err != nil { - return nil, nil, err - } + require.NoError(t, err) require.NoError(t, graph.Start()) t.Cleanup(func() { require.NoError(t, graph.Stop()) }) - return graph, backend, nil + return graph } type testGraphInstance struct { - graph *graphdb.ChannelGraph - graphBackend kvdb.Backend + graph *graphdb.ChannelGraph // aliasMap is a map from a node's alias to its public key. This type is // provided in order to allow easily look up from the human memorable From 149abb96f658b4bbfcbb0a1085b7fb1653ad6dd9 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sat, 5 Apr 2025 19:38:20 +0200 Subject: [PATCH 29/41] channeldb: remove graph calls from tests The channeldb no longer depends on the graph. So remove the use of MakeTestGraph from tests. --- channeldb/db_test.go | 22 +++++++++++----------- channeldb/meta_test.go | 8 -------- 2 files changed, 11 insertions(+), 19 deletions(-) diff --git a/channeldb/db_test.go b/channeldb/db_test.go index e175ea1fb1..9c8c34d586 100644 --- a/channeldb/db_test.go +++ b/channeldb/db_test.go @@ -14,7 +14,6 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" - graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" @@ -185,21 +184,22 @@ func TestMultiSourceAddrsForNode(t *testing.T) { fullDB, err := MakeTestDB(t) require.NoError(t, err, "unable to make test database") - graph, err := graphdb.MakeTestGraph(t) - require.NoError(t, err) + graph := newMockAddrSource(t) + t.Cleanup(func() { + graph.AssertExpectations(t) + }) - // We'll make a test vertex to insert into the database, as the source - // node, but this node will only have half the number of addresses it - // usually does. + // We'll make a test vertex, but this node will only have half the + // number of addresses it usually does. testNode := createTestVertex(t) - require.NoError(t, err, "unable to create test node") - testNode.Addresses = []net.Addr{testAddr} - require.NoError(t, graph.SetSourceNode(testNode)) + nodePub, err := testNode.PubKey() + require.NoError(t, err) + graph.On("AddrsForNode", nodePub).Return( + true, []net.Addr{testAddr}, nil, + ).Once() // Next, we'll make a link node with the same pubkey, but with an // additional address. - nodePub, err := testNode.PubKey() - require.NoError(t, err, "unable to recv node pub") linkNode := NewLinkNode( fullDB.channelStateDB.linkNodeDB, wire.MainNet, nodePub, anotherAddr, diff --git a/channeldb/meta_test.go b/channeldb/meta_test.go index 5b6bd29a94..97f6f0489f 100644 --- a/channeldb/meta_test.go +++ b/channeldb/meta_test.go @@ -6,7 +6,6 @@ import ( "github.com/btcsuite/btcwallet/walletdb" "github.com/go-errors/errors" - graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/kvdb" "github.com/stretchr/testify/require" ) @@ -22,13 +21,6 @@ func applyMigration(t *testing.T, beforeMigration, afterMigration func(d *DB), } cdb.dryRun = dryRun - // Create a test node that will be our source node. - testNode := createTestVertex(t) - - graph, err := graphdb.MakeTestGraph(t) - require.NoError(t, err) - require.NoError(t, graph.SetSourceNode(testNode)) - // beforeMigration usually used for populating the database // with test data. beforeMigration(cdb) From f1db30a89a5e592be94d187ba8cf488ca86ffaef Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sun, 30 Mar 2025 11:26:02 +0200 Subject: [PATCH 30/41] batch: dont expose kvdb.RwTx in batch.SchedulerOptions Currently, a few of the graph KVStore methods take the `batch.SchedulerOptions` param. This is only used to set the LazyAdd option. A SchedulerOption is a functional option that takes a `batch.Request` which has bolt-specific fields in it. This commit restructures things a bit such that the `batch.Request` type is no longer part of the `batch.SchedulerOptions` - this will make it easier to implement the graph store with a different DB backend. --- batch/interface.go | 32 +++++++++++++++++++++++++++++--- batch/scheduler.go | 6 +++++- graph/db/kv_store.go | 25 ++++++------------------- 3 files changed, 40 insertions(+), 23 deletions(-) diff --git a/batch/interface.go b/batch/interface.go index cd58a148c8..2a92fbed2b 100644 --- a/batch/interface.go +++ b/batch/interface.go @@ -5,6 +5,9 @@ import "github.com/lightningnetwork/lnd/kvdb" // Request defines an operation that can be batched into a single bbolt // transaction. type Request struct { + // Opts holds various configuration options for a scheduled request. + Opts *SchedulerOptions + // Reset is called before each invocation of Update and is used to clear // any possible modifications to local state as a result of previous // calls to Update that were not committed due to a concurrent batch @@ -25,22 +28,45 @@ type Request struct { // // NOTE: This field is optional. OnCommit func(commitErr error) error +} +// SchedulerOptions holds various configuration options for a scheduled request. +type SchedulerOptions struct { // lazy should be true if we don't have to immediately execute this // request when it comes in. This means that it can be scheduled later, // allowing larger batches. lazy bool } +// NewDefaultSchedulerOpts returns a new SchedulerOptions with default values. +func NewDefaultSchedulerOpts() *SchedulerOptions { + return &SchedulerOptions{ + lazy: false, + } +} + +// NewSchedulerOptions returns a new SchedulerOptions with the given options +// applied on top of the default options. +func NewSchedulerOptions(options ...SchedulerOption) *SchedulerOptions { + opts := NewDefaultSchedulerOpts() + for _, o := range options { + o(opts) + } + + return opts +} + // SchedulerOption is a type that can be used to supply options to a scheduled // request. -type SchedulerOption func(r *Request) +type SchedulerOption func(*SchedulerOptions) // LazyAdd will make the request be executed lazily, added to the next batch to // reduce db contention. +// +// NOTE: This is currently a no-op for any DB backend other than bbolt. func LazyAdd() SchedulerOption { - return func(r *Request) { - r.lazy = true + return func(opts *SchedulerOptions) { + opts.lazy = true } } diff --git a/batch/scheduler.go b/batch/scheduler.go index 4a01adda71..b91ee615c1 100644 --- a/batch/scheduler.go +++ b/batch/scheduler.go @@ -43,6 +43,10 @@ func NewTimeScheduler(db kvdb.Backend, locker sync.Locker, // // NOTE: Part of the Scheduler interface. func (s *TimeScheduler) Execute(r *Request) error { + if r.Opts == nil { + r.Opts = NewDefaultSchedulerOpts() + } + req := request{ Request: r, errChan: make(chan error, 1), @@ -62,7 +66,7 @@ func (s *TimeScheduler) Execute(r *Request) error { s.b.reqs = append(s.b.reqs, &req) // If this is a non-lazy request, we'll execute the batch immediately. - if !r.lazy { + if !r.Opts.lazy { go s.b.trigger() } diff --git a/graph/db/kv_store.go b/graph/db/kv_store.go index 104c65b366..34be8cfc03 100644 --- a/graph/db/kv_store.go +++ b/graph/db/kv_store.go @@ -848,18 +848,15 @@ func (c *KVStore) SetSourceNode(node *models.LightningNode) error { // // TODO(roasbeef): also need sig of announcement. func (c *KVStore) AddLightningNode(node *models.LightningNode, - op ...batch.SchedulerOption) error { + opts ...batch.SchedulerOption) error { r := &batch.Request{ + Opts: batch.NewSchedulerOptions(opts...), Update: func(tx kvdb.RwTx) error { return addLightningNode(tx, node) }, } - for _, f := range op { - f(r) - } - return c.nodeScheduler.Execute(r) } @@ -986,10 +983,11 @@ func (c *KVStore) deleteLightningNode(nodes kvdb.RwBucket, // supports. The chanPoint and chanID are used to uniquely identify the edge // globally within the database. func (c *KVStore) AddChannelEdge(edge *models.ChannelEdgeInfo, - op ...batch.SchedulerOption) error { + opts ...batch.SchedulerOption) error { var alreadyExists bool r := &batch.Request{ + Opts: batch.NewSchedulerOptions(opts...), Reset: func() { alreadyExists = false }, @@ -1019,14 +1017,6 @@ func (c *KVStore) AddChannelEdge(edge *models.ChannelEdgeInfo, }, } - for _, f := range op { - if f == nil { - return fmt.Errorf("nil scheduler option was used") - } - - f(r) - } - return c.chanScheduler.Execute(r) } @@ -2696,7 +2686,7 @@ func makeZombiePubkeys(info *models.ChannelEdgeInfo, // determined by the lexicographical ordering of the identity public keys of the // nodes on either side of the channel. func (c *KVStore) UpdateEdgePolicy(edge *models.ChannelEdgePolicy, - op ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) { + opts ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) { var ( isUpdate1 bool @@ -2705,6 +2695,7 @@ func (c *KVStore) UpdateEdgePolicy(edge *models.ChannelEdgePolicy, ) r := &batch.Request{ + Opts: batch.NewSchedulerOptions(opts...), Reset: func() { isUpdate1 = false edgeNotFound = false @@ -2738,10 +2729,6 @@ func (c *KVStore) UpdateEdgePolicy(edge *models.ChannelEdgePolicy, }, } - for _, f := range op { - f(r) - } - err := c.chanScheduler.Execute(r) return from, to, err From 83eb7e3fed2312b465b7478023f403650e937011 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sat, 5 Apr 2025 17:53:14 +0200 Subject: [PATCH 31/41] graph/db: introduce the V1Store interface In this commit, we introduce the `V1Store` interface which the existing `graphdb.KVStore` implements today. The idea is to eventually create a SQL DB backed implementation of this interface. --- docs/release-notes/release-notes-0.20.0.md | 5 + graph/db/interfaces.go | 327 +++++++++++++++++++++ graph/db/kv_store.go | 4 + 3 files changed, 336 insertions(+) diff --git a/docs/release-notes/release-notes-0.20.0.md b/docs/release-notes/release-notes-0.20.0.md index 1ee2e05fab..94bcd18a67 100644 --- a/docs/release-notes/release-notes-0.20.0.md +++ b/docs/release-notes/release-notes-0.20.0.md @@ -31,6 +31,11 @@ # Improvements ## Functional Updates +* Graph Store SQL implementation and migration project: + * Introduce an [abstract graph + store](https://github.com/lightningnetwork/lnd/pull/9791) interface. + + ## RPC Updates ## lncli Updates diff --git a/graph/db/interfaces.go b/graph/db/interfaces.go index f5dce71ca0..25d7257bea 100644 --- a/graph/db/interfaces.go +++ b/graph/db/interfaces.go @@ -1,6 +1,13 @@ package graphdb import ( + "net" + "time" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/batch" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" @@ -37,3 +44,323 @@ type NodeTraverser interface { // FetchNodeFeatures returns the features of the given node. FetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error) } + +// V1Store represents the main interface for the channel graph database for all +// channels and nodes gossiped via the V1 gossip protocol as defined in BOLT 7. +type V1Store interface { //nolint:interfacebloat + NodeTraverser + + // AddLightningNode adds a vertex/node to the graph database. If the + // node is not in the database from before, this will add a new, + // unconnected one to the graph. If it is present from before, this will + // update that node's information. Note that this method is expected to + // only be called to update an already present node from a node + // announcement, or to insert a node found in a channel update. + AddLightningNode(node *models.LightningNode, + op ...batch.SchedulerOption) error + + // AddrsForNode returns all known addresses for the target node public + // key that the graph DB is aware of. The returned boolean indicates if + // the given node is unknown to the graph DB or not. + AddrsForNode(nodePub *btcec.PublicKey) (bool, []net.Addr, error) + + // ForEachSourceNodeChannel iterates through all channels of the source + // node, executing the passed callback on each. The call-back is + // provided with the channel's outpoint, whether we have a policy for + // the channel and the channel peer's node information. + ForEachSourceNodeChannel(cb func(chanPoint wire.OutPoint, + havePolicy bool, otherNode *models.LightningNode) error) error + + // ForEachNodeChannel iterates through all channels of the given node, + // executing the passed callback with an edge info structure and the + // policies of each end of the channel. The first edge policy is the + // outgoing edge *to* the connecting node, while the second is the + // incoming edge *from* the connecting node. If the callback returns an + // error, then the iteration is halted with the error propagated back up + // to the caller. + // + // Unknown policies are passed into the callback as nil values. + ForEachNodeChannel(nodePub route.Vertex, + cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy) error) error + + // ForEachNodeCached is similar to forEachNode, but it returns + // DirectedChannel data to the call-back. + // + // NOTE: The callback contents MUST not be modified. + ForEachNodeCached(cb func(node route.Vertex, + chans map[uint64]*DirectedChannel) error) error + + // ForEachNode iterates through all the stored vertices/nodes in the + // graph, executing the passed callback with each node encountered. If + // the callback returns an error, then the transaction is aborted and + // the iteration stops early. Any operations performed on the NodeTx + // passed to the call-back are executed under the same read transaction + // and so, methods on the NodeTx object _MUST_ only be called from + // within the call-back. + ForEachNode(cb func(tx NodeRTx) error) error + + // ForEachNodeCacheable iterates through all the stored vertices/nodes + // in the graph, executing the passed callback with each node + // encountered. If the callback returns an error, then the transaction + // is aborted and the iteration stops early. + ForEachNodeCacheable(cb func(route.Vertex, + *lnwire.FeatureVector) error) error + + // LookupAlias attempts to return the alias as advertised by the target + // node. + LookupAlias(pub *btcec.PublicKey) (string, error) + + // DeleteLightningNode starts a new database transaction to remove a + // vertex/node from the database according to the node's public key. + DeleteLightningNode(nodePub route.Vertex) error + + // NodeUpdatesInHorizon returns all the known lightning node which have + // an update timestamp within the passed range. This method can be used + // by two nodes to quickly determine if they have the same set of up to + // date node announcements. + NodeUpdatesInHorizon(startTime, + endTime time.Time) ([]models.LightningNode, error) + + // FetchLightningNode attempts to look up a target node by its identity + // public key. If the node isn't found in the database, then + // ErrGraphNodeNotFound is returned. + FetchLightningNode(nodePub route.Vertex) ( + *models.LightningNode, error) + + // HasLightningNode determines if the graph has a vertex identified by + // the target node identity public key. If the node exists in the + // database, a timestamp of when the data for the node was lasted + // updated is returned along with a true boolean. Otherwise, an empty + // time.Time is returned with a false boolean. + HasLightningNode(nodePub [33]byte) (time.Time, bool, + error) + + // IsPublicNode is a helper method that determines whether the node with + // the given public key is seen as a public node in the graph from the + // graph's source node's point of view. + IsPublicNode(pubKey [33]byte) (bool, error) + + // GraphSession will provide the call-back with access to a + // NodeTraverser instance which can be used to perform queries against + // the channel graph. + GraphSession(cb func(graph NodeTraverser) error) error + + // ForEachChannel iterates through all the channel edges stored within + // the graph and invokes the passed callback for each edge. The callback + // takes two edges as since this is a directed graph, both the in/out + // edges are visited. If the callback returns an error, then the + // transaction is aborted and the iteration stops early. + // + // NOTE: If an edge can't be found, or wasn't advertised, then a nil + // pointer for that particular channel edge routing policy will be + // passed into the callback. + ForEachChannel(cb func(*models.ChannelEdgeInfo, + *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy) error) error + + // DisabledChannelIDs returns the channel ids of disabled channels. + // A channel is disabled when two of the associated ChanelEdgePolicies + // have their disabled bit on. + DisabledChannelIDs() ([]uint64, error) + + // AddChannelEdge adds a new (undirected, blank) edge to the graph + // database. An undirected edge from the two target nodes are created. + // The information stored denotes the static attributes of the channel, + // such as the channelID, the keys involved in creation of the channel, + // and the set of features that the channel supports. The chanPoint and + // chanID are used to uniquely identify the edge globally within the + // database. + AddChannelEdge(edge *models.ChannelEdgeInfo, + op ...batch.SchedulerOption) error + + // HasChannelEdge returns true if the database knows of a channel edge + // with the passed channel ID, and false otherwise. If an edge with that + // ID is found within the graph, then two time stamps representing the + // last time the edge was updated for both directed edges are returned + // along with the boolean. If it is not found, then the zombie index is + // checked and its result is returned as the second boolean. + HasChannelEdge(chanID uint64) (time.Time, time.Time, bool, bool, + error) + + // DeleteChannelEdges removes edges with the given channel IDs from the + // database and marks them as zombies. This ensures that we're unable to + // re-add it to our database once again. If an edge does not exist + // within the database, then ErrEdgeNotFound will be returned. If + // strictZombiePruning is true, then when we mark these edges as + // zombies, we'll set up the keys such that we require the node that + // failed to send the fresh update to be the one that resurrects the + // channel from its zombie state. The markZombie bool denotes whether + // to mark the channel as a zombie. + DeleteChannelEdges(strictZombiePruning, markZombie bool, + chanIDs ...uint64) ([]*models.ChannelEdgeInfo, error) + + // AddEdgeProof sets the proof of an existing edge in the graph + // database. + AddEdgeProof(chanID lnwire.ShortChannelID, + proof *models.ChannelAuthProof) error + + // ChannelID attempt to lookup the 8-byte compact channel ID which maps + // to the passed channel point (outpoint). If the passed channel doesn't + // exist within the database, then ErrEdgeNotFound is returned. + ChannelID(chanPoint *wire.OutPoint) (uint64, error) + + // HighestChanID returns the "highest" known channel ID in the channel + // graph. This represents the "newest" channel from the PoV of the + // chain. This method can be used by peers to quickly determine if + // they're graphs are in sync. + HighestChanID() (uint64, error) + + // ChanUpdatesInHorizon returns all the known channel edges which have + // at least one edge that has an update timestamp within the specified + // horizon. + ChanUpdatesInHorizon(startTime, endTime time.Time) ([]ChannelEdge, + error) + + // FilterKnownChanIDs takes a set of channel IDs and return the subset + // of chan ID's that we don't know and are not known zombies of the + // passed set. In other words, we perform a set difference of our set + // of chan ID's and the ones passed in. This method can be used by + // callers to determine the set of channels another peer knows of that + // we don't. The ChannelUpdateInfos for the known zombies is also + // returned. + FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64, + []ChannelUpdateInfo, error) + + // FilterChannelRange returns the channel ID's of all known channels + // which were mined in a block height within the passed range. The + // channel IDs are grouped by their common block height. This method can + // be used to quickly share with a peer the set of channels we know of + // within a particular range to catch them up after a period of time + // offline. If withTimestamps is true then the timestamp info of the + // latest received channel update messages of the channel will be + // included in the response. + FilterChannelRange(startHeight, endHeight uint32, withTimestamps bool) ( + []BlockChannelRange, error) + + // FetchChanInfos returns the set of channel edges that correspond to + // the passed channel ID's. If an edge is the query is unknown to the + // database, it will skipped and the result will contain only those + // edges that exist at the time of the query. This can be used to + // respond to peer queries that are seeking to fill in gaps in their + // view of the channel graph. + FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) + + // FetchChannelEdgesByOutpoint attempts to lookup the two directed edges + // for the channel identified by the funding outpoint. If the channel + // can't be found, then ErrEdgeNotFound is returned. A struct which + // houses the general information for the channel itself is returned as + // well as two structs that contain the routing policies for the channel + // in either direction. + FetchChannelEdgesByOutpoint(op *wire.OutPoint) ( + *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy, error) + + // FetchChannelEdgesByID attempts to lookup the two directed edges for + // the channel identified by the channel ID. If the channel can't be + // found, then ErrEdgeNotFound is returned. A struct which houses the + // general information for the channel itself is returned as well as + // two structs that contain the routing policies for the channel in + // either direction. + // + // ErrZombieEdge can be returned if the edge is currently marked as a + // zombie within the database. In this case, the ChannelEdgePolicy's + // will be nil, and the ChannelEdgeInfo will only include the public + // keys of each node. + FetchChannelEdgesByID(chanID uint64) ( + *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy, error) + + // ChannelView returns the verifiable edge information for each active + // channel within the known channel graph. The set of UTXO's (along with + // their scripts) returned are the ones that need to be watched on chain + // to detect channel closes on the resident blockchain. + ChannelView() ([]EdgePoint, error) + + // MarkEdgeZombie attempts to mark a channel identified by its channel + // ID as a zombie. This method is used on an ad-hoc basis, when channels + // need to be marked as zombies outside the normal pruning cycle. + MarkEdgeZombie(chanID uint64, + pubKey1, pubKey2 [33]byte) error + + // MarkEdgeLive clears an edge from our zombie index, deeming it as + // live. + MarkEdgeLive(chanID uint64) error + + // IsZombieEdge returns whether the edge is considered zombie. If it is + // a zombie, then the two node public keys corresponding to this edge + // are also returned. + IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte) + + // NumZombies returns the current number of zombie channels in the + // graph. + NumZombies() (uint64, error) + + // PutClosedScid stores a SCID for a closed channel in the database. + // This is so that we can ignore channel announcements that we know to + // be closed without having to validate them and fetch a block. + PutClosedScid(scid lnwire.ShortChannelID) error + + // IsClosedScid checks whether a channel identified by the passed in + // scid is closed. This helps avoid having to perform expensive + // validation checks. + IsClosedScid(scid lnwire.ShortChannelID) (bool, error) + + // UpdateEdgePolicy updates the edge routing policy for a single + // directed edge within the database for the referenced channel. The + // `flags` attribute within the ChannelEdgePolicy determines which of + // the directed edges are being updated. If the flag is 1, then the + // first node's information is being updated, otherwise it's the second + // node's information. The node ordering is determined by the + // lexicographical ordering of the identity public keys of the nodes on + // either side of the channel. + UpdateEdgePolicy(edge *models.ChannelEdgePolicy, + op ...batch.SchedulerOption) (route.Vertex, route.Vertex, error) + + // SourceNode returns the source node of the graph. The source node is + // treated as the center node within a star-graph. This method may be + // used to kick off a path finding algorithm in order to explore the + // reachability of another node based off the source node. + SourceNode() (*models.LightningNode, error) + + // SetSourceNode sets the source node within the graph database. The + // source node is to be used as the center of a star-graph within path + // finding algorithms. + SetSourceNode(node *models.LightningNode) error + + // PruneTip returns the block height and hash of the latest block that + // has been used to prune channels in the graph. Knowing the "prune tip" + // allows callers to tell if the graph is currently in sync with the + // current best known UTXO state. + PruneTip() (*chainhash.Hash, uint32, error) + + // PruneGraphNodes is a garbage collection method which attempts to + // prune out any nodes from the channel graph that are currently + // unconnected. This ensures that we only maintain a graph of reachable + // nodes. In the event that a pruned node gains more channels, it will + // be re-added back to the graph. + PruneGraphNodes() ([]route.Vertex, error) + + // PruneGraph prunes newly closed channels from the channel graph in + // response to a new block being solved on the network. Any transactions + // which spend the funding output of any known channels within he graph + // will be deleted. Additionally, the "prune tip", or the last block + // which has been used to prune the graph is stored so callers can + // ensure the graph is fully in sync with the current UTXO state. A + // slice of channels that have been closed by the target block along + // with any pruned nodes are returned if the function succeeds without + // error. + PruneGraph(spentOutputs []*wire.OutPoint, + blockHash *chainhash.Hash, blockHeight uint32) ( + []*models.ChannelEdgeInfo, []route.Vertex, error) + + // DisconnectBlockAtHeight is used to indicate that the block specified + // by the passed height has been disconnected from the main chain. This + // will "rewind" the graph back to the height below, deleting channels + // that are no longer confirmed from the graph. The prune log will be + // set to the last prune height valid for the remaining chain. + // Channels that were removed from the graph resulting from the + // disconnected block are returned. + DisconnectBlockAtHeight(height uint32) ([]*models.ChannelEdgeInfo, + error) +} diff --git a/graph/db/kv_store.go b/graph/db/kv_store.go index 34be8cfc03..6c283cb81c 100644 --- a/graph/db/kv_store.go +++ b/graph/db/kv_store.go @@ -196,6 +196,10 @@ type KVStore struct { nodeScheduler batch.Scheduler } +// A compile-time assertion to ensure that the KVStore struct implements the +// V1Store interface. +var _ V1Store = (*KVStore)(nil) + // NewKVStore allocates a new KVStore backed by a DB instance. The // returned instance has its own unique reject cache and channel cache. func NewKVStore(db kvdb.Backend, options ...KVStoreOptionModifier) (*KVStore, From 9f79322f812916d6179cb982784d172a62dc21a7 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sat, 5 Apr 2025 17:54:53 +0200 Subject: [PATCH 32/41] graph/db: use V1Store interface in ChannelGraph Use the new `V1Store` interface as a field in the `ChannelGraph` rather than a pointer to the `KVStore` implementation in preparation for allowing a SQL implementation of `V1Store` to be used by the `ChannelGraph`. Note that two tests are adjusted in this commit to be skipped if the V1Store is not the `KVStore` implementation since the tests are very bbolt specific. --- graph/db/graph.go | 42 +++++++++++++++++++++--------------------- graph/db/graph_test.go | 34 ++++++++++++++++------------------ 2 files changed, 37 insertions(+), 39 deletions(-) diff --git a/graph/db/graph.go b/graph/db/graph.go index 9e35e58dd5..ab63b8b64b 100644 --- a/graph/db/graph.go +++ b/graph/db/graph.go @@ -48,7 +48,7 @@ type ChannelGraph struct { graphCache *GraphCache - *KVStore + V1Store *topologyManager quit chan struct{} @@ -70,7 +70,7 @@ func NewChannelGraph(cfg *Config, options ...ChanGraphOption) (*ChannelGraph, } g := &ChannelGraph{ - KVStore: store, + V1Store: store, topologyManager: newTopologyManager(), quit: make(chan struct{}), } @@ -184,7 +184,7 @@ func (c *ChannelGraph) populateCache() error { log.Info("Populating in-memory channel graph, this might take a " + "while...") - err := c.KVStore.ForEachNodeCacheable(func(node route.Vertex, + err := c.V1Store.ForEachNodeCacheable(func(node route.Vertex, features *lnwire.FeatureVector) error { c.graphCache.AddNodeFeatures(node, features) @@ -195,7 +195,7 @@ func (c *ChannelGraph) populateCache() error { return err } - err = c.KVStore.ForEachChannel(func(info *models.ChannelEdgeInfo, + err = c.V1Store.ForEachChannel(func(info *models.ChannelEdgeInfo, policy1, policy2 *models.ChannelEdgePolicy) error { c.graphCache.AddChannel(info, policy1, policy2) @@ -229,7 +229,7 @@ func (c *ChannelGraph) ForEachNodeDirectedChannel(node route.Vertex, return c.graphCache.ForEachChannel(node, cb) } - return c.KVStore.ForEachNodeDirectedChannel(node, cb) + return c.V1Store.ForEachNodeDirectedChannel(node, cb) } // FetchNodeFeatures returns the features of the given node. If no features are @@ -245,7 +245,7 @@ func (c *ChannelGraph) FetchNodeFeatures(node route.Vertex) ( return c.graphCache.GetFeatures(node), nil } - return c.KVStore.FetchNodeFeatures(node) + return c.V1Store.FetchNodeFeatures(node) } // GraphSession will provide the call-back with access to a NodeTraverser @@ -257,7 +257,7 @@ func (c *ChannelGraph) GraphSession(cb func(graph NodeTraverser) error) error { return cb(c) } - return c.KVStore.GraphSession(cb) + return c.V1Store.GraphSession(cb) } // ForEachNodeCached iterates through all the stored vertices/nodes in the @@ -271,7 +271,7 @@ func (c *ChannelGraph) ForEachNodeCached(cb func(node route.Vertex, return c.graphCache.ForEachNode(cb) } - return c.KVStore.ForEachNodeCached(cb) + return c.V1Store.ForEachNodeCached(cb) } // AddLightningNode adds a vertex/node to the graph database. If the node is not @@ -286,7 +286,7 @@ func (c *ChannelGraph) AddLightningNode(node *models.LightningNode, c.cacheMu.Lock() defer c.cacheMu.Unlock() - err := c.KVStore.AddLightningNode(node, op...) + err := c.V1Store.AddLightningNode(node, op...) if err != nil { return err } @@ -312,7 +312,7 @@ func (c *ChannelGraph) DeleteLightningNode(nodePub route.Vertex) error { c.cacheMu.Lock() defer c.cacheMu.Unlock() - err := c.KVStore.DeleteLightningNode(nodePub) + err := c.V1Store.DeleteLightningNode(nodePub) if err != nil { return err } @@ -336,7 +336,7 @@ func (c *ChannelGraph) AddChannelEdge(edge *models.ChannelEdgeInfo, c.cacheMu.Lock() defer c.cacheMu.Unlock() - err := c.KVStore.AddChannelEdge(edge, op...) + err := c.V1Store.AddChannelEdge(edge, op...) if err != nil { return err } @@ -361,7 +361,7 @@ func (c *ChannelGraph) MarkEdgeLive(chanID uint64) error { c.cacheMu.Lock() defer c.cacheMu.Unlock() - err := c.KVStore.MarkEdgeLive(chanID) + err := c.V1Store.MarkEdgeLive(chanID) if err != nil { return err } @@ -369,7 +369,7 @@ func (c *ChannelGraph) MarkEdgeLive(chanID uint64) error { if c.graphCache != nil { // We need to add the channel back into our graph cache, // otherwise we won't use it for path finding. - infos, err := c.KVStore.FetchChanInfos([]uint64{chanID}) + infos, err := c.V1Store.FetchChanInfos([]uint64{chanID}) if err != nil { return err } @@ -400,7 +400,7 @@ func (c *ChannelGraph) DeleteChannelEdges(strictZombiePruning, markZombie bool, c.cacheMu.Lock() defer c.cacheMu.Unlock() - infos, err := c.KVStore.DeleteChannelEdges( + infos, err := c.V1Store.DeleteChannelEdges( strictZombiePruning, markZombie, chanIDs..., ) if err != nil { @@ -432,7 +432,7 @@ func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ( c.cacheMu.Lock() defer c.cacheMu.Unlock() - edges, err := c.KVStore.DisconnectBlockAtHeight(height) + edges, err := c.V1Store.DisconnectBlockAtHeight(height) if err != nil { return nil, err } @@ -463,7 +463,7 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, c.cacheMu.Lock() defer c.cacheMu.Unlock() - edges, nodes, err := c.KVStore.PruneGraph( + edges, nodes, err := c.V1Store.PruneGraph( spentOutputs, blockHash, blockHeight, ) if err != nil { @@ -508,7 +508,7 @@ func (c *ChannelGraph) PruneGraphNodes() error { c.cacheMu.Lock() defer c.cacheMu.Unlock() - nodes, err := c.KVStore.PruneGraphNodes() + nodes, err := c.V1Store.PruneGraphNodes() if err != nil { return err } @@ -530,7 +530,7 @@ func (c *ChannelGraph) PruneGraphNodes() error { func (c *ChannelGraph) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo, isZombieChan func(time.Time, time.Time) bool) ([]uint64, error) { - unknown, knownZombies, err := c.KVStore.FilterKnownChanIDs(chansInfo) + unknown, knownZombies, err := c.V1Store.FilterKnownChanIDs(chansInfo) if err != nil { return nil, err } @@ -559,7 +559,7 @@ func (c *ChannelGraph) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo, // timestamps could bring it back from the dead, then we mark it // alive, and we let it be added to the set of IDs to query our // peer for. - err := c.KVStore.MarkEdgeLive( + err := c.V1Store.MarkEdgeLive( info.ShortChannelID.ToUint64(), ) // Since there is a chance that the edge could have been marked @@ -583,7 +583,7 @@ func (c *ChannelGraph) MarkEdgeZombie(chanID uint64, c.cacheMu.Lock() defer c.cacheMu.Unlock() - err := c.KVStore.MarkEdgeZombie(chanID, pubKey1, pubKey2) + err := c.V1Store.MarkEdgeZombie(chanID, pubKey1, pubKey2) if err != nil { return err } @@ -608,7 +608,7 @@ func (c *ChannelGraph) UpdateEdgePolicy(edge *models.ChannelEdgePolicy, c.cacheMu.Lock() defer c.cacheMu.Unlock() - from, to, err := c.KVStore.UpdateEdgePolicy(edge, op...) + from, to, err := c.V1Store.UpdateEdgePolicy(edge, op...) if err != nil { return err } diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index de6f3a009f..1818ef7b77 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -2950,6 +2950,12 @@ func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { graph, err := MakeTestGraph(t) require.NoError(t, err, "unable to make test database") + // The update index only applies to the bbolt graph. + boltStore, ok := graph.V1Store.(*KVStore) + if !ok { + t.Skipf("skipping test that is aimed at a bbolt graph DB") + } + sourceNode := createTestVertex(t) if err := graph.SetSourceNode(sourceNode); err != nil { t.Fatalf("unable to set source node: %v", err) @@ -2999,7 +3005,7 @@ func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { timestampSet[t] = struct{}{} } - err := kvdb.View(graph.db, func(tx kvdb.RTx) error { + err := kvdb.View(boltStore.db, func(tx kvdb.RTx) error { edges := tx.ReadBucket(edgeBucket) if edges == nil { return ErrGraphNoEdgesFound @@ -3467,6 +3473,12 @@ func TestEdgePolicyMissingMaxHtcl(t *testing.T) { graph, err := MakeTestGraph(t) require.NoError(t, err, "unable to make test database") + // This test currently directly edits the bytes stored in the bbolt DB. + boltStore, ok := graph.V1Store.(*KVStore) + if !ok { + t.Skipf("skipping test that is aimed at a bbolt graph DB") + } + // We'd like to test the update of edges inserted into the database, so // we create two vertexes to connect. node1 := createTestVertex(t) @@ -3515,25 +3527,11 @@ func TestEdgePolicyMissingMaxHtcl(t *testing.T) { // Attempting to deserialize these bytes should return an error. r := bytes.NewReader(stripped) - err = kvdb.View(graph.db, func(tx kvdb.RTx) error { - nodes := tx.ReadBucket(nodeBucket) - if nodes == nil { - return ErrGraphNotFound - } - - _, err = deserializeChanEdgePolicy(r) - if err != ErrEdgePolicyOptionalFieldNotFound { - t.Fatalf("expected "+ - "ErrEdgePolicyOptionalFieldNotFound, got %v", - err) - } - - return nil - }, func() {}) - require.NoError(t, err, "error reading db") + _, err = deserializeChanEdgePolicy(r) + require.ErrorIs(t, err, ErrEdgePolicyOptionalFieldNotFound) // Put the stripped bytes in the DB. - err = kvdb.Update(graph.db, func(tx kvdb.RwTx) error { + err = kvdb.Update(boltStore.db, func(tx kvdb.RwTx) error { edges := tx.ReadWriteBucket(edgeBucket) if edges == nil { return ErrEdgeNotFound From 36d8bb2a4e2787d0b8dd4c5838ea09b8314a127a Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sun, 30 Mar 2025 11:42:22 +0200 Subject: [PATCH 33/41] graph/db: init KVStore outside of ChannelGraph So that we can pass in the abstract V1Store in preparation for adding a SQL implementation of the KVStore. --- autopilot/prefattach_test.go | 5 ++++- config_builder.go | 12 ++++++++---- graph/db/graph.go | 24 +++--------------------- graph/db/graph_test.go | 7 +++++-- graph/db/kv_store.go | 8 ++++---- graph/notifications_test.go | 6 ++++-- peer/test_utils.go | 7 ++++--- routing/pathfind_test.go | 6 ++++-- 8 files changed, 36 insertions(+), 39 deletions(-) diff --git a/autopilot/prefattach_test.go b/autopilot/prefattach_test.go index 2e3b22ff3f..de7b6a7d67 100644 --- a/autopilot/prefattach_test.go +++ b/autopilot/prefattach_test.go @@ -47,7 +47,10 @@ func newDiskChanGraph(t *testing.T) (testGraph, error) { }) require.NoError(t, err) - graphDB, err := graphdb.NewChannelGraph(&graphdb.Config{KVDB: backend}) + graphStore, err := graphdb.NewKVStore(backend) + require.NoError(t, err) + + graphDB, err := graphdb.NewChannelGraph(graphStore) require.NoError(t, err) require.NoError(t, graphDB.Start()) diff --git a/config_builder.go b/config_builder.go index 2214250e7c..708e79d8f2 100644 --- a/config_builder.go +++ b/config_builder.go @@ -1046,10 +1046,14 @@ func (d *DefaultDatabaseBuilder) BuildDatabase( ) } - dbs.GraphDB, err = graphdb.NewChannelGraph(&graphdb.Config{ - KVDB: databaseBackends.GraphDB, - KVStoreOpts: graphDBOptions, - }, chanGraphOpts...) + graphStore, err := graphdb.NewKVStore( + databaseBackends.GraphDB, graphDBOptions..., + ) + if err != nil { + return nil, nil, err + } + + dbs.GraphDB, err = graphdb.NewChannelGraph(graphStore, chanGraphOpts...) if err != nil { cleanUp() diff --git a/graph/db/graph.go b/graph/db/graph.go index ab63b8b64b..f6cf6f302a 100644 --- a/graph/db/graph.go +++ b/graph/db/graph.go @@ -11,7 +11,6 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/batch" "github.com/lightningnetwork/lnd/graph/db/models" - "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" ) @@ -20,18 +19,6 @@ import ( // busy shutting down. var ErrChanGraphShuttingDown = fmt.Errorf("ChannelGraph shutting down") -// Config is a struct that holds all the necessary dependencies for a -// ChannelGraph. -type Config struct { - // KVDB is the kvdb.Backend that will be used for initializing the - // KVStore CRUD layer. - KVDB kvdb.Backend - - // KVStoreOpts is a list of functional options that will be used when - // initializing the KVStore. - KVStoreOpts []KVStoreOptionModifier -} - // ChannelGraph is a layer above the graph's CRUD layer. // // NOTE: currently, this is purely a pass-through layer directly to the backing @@ -56,21 +43,16 @@ type ChannelGraph struct { } // NewChannelGraph creates a new ChannelGraph instance with the given backend. -func NewChannelGraph(cfg *Config, options ...ChanGraphOption) (*ChannelGraph, - error) { +func NewChannelGraph(v1Store V1Store, + options ...ChanGraphOption) (*ChannelGraph, error) { opts := defaultChanGraphOptions() for _, o := range options { o(opts) } - store, err := NewKVStore(cfg.KVDB, cfg.KVStoreOpts...) - if err != nil { - return nil, err - } - g := &ChannelGraph{ - V1Store: store, + V1Store: v1Store, topologyManager: newTopologyManager(), quit: make(chan struct{}), } diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index 1818ef7b77..4551a29c5c 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -4084,7 +4084,10 @@ func TestGraphLoading(t *testing.T) { defer backend.Close() defer backendCleanup() - graph, err := NewChannelGraph(&Config{KVDB: backend}) + graphStore, err := NewKVStore(backend) + require.NoError(t, err) + + graph, err := NewChannelGraph(graphStore) require.NoError(t, err) require.NoError(t, graph.Start()) t.Cleanup(func() { @@ -4098,7 +4101,7 @@ func TestGraphLoading(t *testing.T) { // Recreate the graph. This should cause the graph cache to be // populated. - graphReloaded, err := NewChannelGraph(&Config{KVDB: backend}) + graphReloaded, err := NewChannelGraph(graphStore) require.NoError(t, err) require.NoError(t, graphReloaded.Start()) t.Cleanup(func() { diff --git a/graph/db/kv_store.go b/graph/db/kv_store.go index 6c283cb81c..36da985cff 100644 --- a/graph/db/kv_store.go +++ b/graph/db/kv_store.go @@ -4725,10 +4725,10 @@ func MakeTestGraph(t testing.TB, modifiers ...KVStoreOptionModifier) ( return nil, err } - graph, err := NewChannelGraph(&Config{ - KVDB: backend, - KVStoreOpts: modifiers, - }) + graphStore, err := NewKVStore(backend, modifiers...) + require.NoError(t, err) + + graph, err := NewChannelGraph(graphStore) if err != nil { backendCleanup() diff --git a/graph/notifications_test.go b/graph/notifications_test.go index ace578376a..33bac64ddf 100644 --- a/graph/notifications_test.go +++ b/graph/notifications_test.go @@ -1093,9 +1093,11 @@ func makeTestGraph(t *testing.T, useCache bool) *graphdb.ChannelGraph { t.Cleanup(backendCleanup) + graphStore, err := graphdb.NewKVStore(backend) + require.NoError(t, err) + graph, err := graphdb.NewChannelGraph( - &graphdb.Config{KVDB: backend}, - graphdb.WithUseGraphCache(useCache), + graphStore, graphdb.WithUseGraphCache(useCache), ) require.NoError(t, err) require.NoError(t, graph.Start()) diff --git a/peer/test_utils.go b/peer/test_utils.go index 87a80712fa..4d2bae8c7e 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -615,9 +615,10 @@ func createTestPeer(t *testing.T) *peerTestCtx { }) require.NoError(t, err) - dbAliceGraph, err := graphdb.NewChannelGraph(&graphdb.Config{ - KVDB: graphBackend, - }) + graphStore, err := graphdb.NewKVStore(graphBackend) + require.NoError(t, err) + + dbAliceGraph, err := graphdb.NewChannelGraph(graphStore) require.NoError(t, err) require.NoError(t, dbAliceGraph.Start()) t.Cleanup(func() { diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 8a0280686b..363858a456 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -166,9 +166,11 @@ func makeTestGraph(t *testing.T, useCache bool) (*graphdb.ChannelGraph, t.Cleanup(backendCleanup) + graphStore, err := graphdb.NewKVStore(backend) + require.NoError(t, err) + graph, err := graphdb.NewChannelGraph( - &graphdb.Config{KVDB: backend}, - graphdb.WithUseGraphCache(useCache), + graphStore, graphdb.WithUseGraphCache(useCache), ) if err != nil { return nil, nil, err From 8bf10b183d20c8028ff6ca8124c79b0d94589be1 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sat, 5 Apr 2025 17:03:42 +0200 Subject: [PATCH 34/41] graph/db: make all ExtraOpaqueData valid TLV streams Later when we introduce our SQL version of the graph store, we will normalise the persistence of the ExtraOpaqueData using the fact that it is always made up of TLV entries. So we update our tests here to ensure that they use valid TLV streams as examples. --- graph/db/graph_test.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index 4551a29c5c..a543772a40 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -110,7 +110,7 @@ func TestNodeInsertionAndDeletion(t *testing.T) { Alias: "kek", Features: testFeatures, Addresses: testAddrs, - ExtraOpaqueData: []byte("extra new data"), + ExtraOpaqueData: []byte{1, 1, 1, 2, 2, 2, 2}, PubKeyBytes: testPub, } @@ -631,9 +631,13 @@ func createChannelEdge(node1, node2 *models.LightningNode) ( BitcoinSig1Bytes: testSig.Serialize(), BitcoinSig2Bytes: testSig.Serialize(), }, - ChannelPoint: outpoint, - Capacity: 1000, - ExtraOpaqueData: []byte("new unknown feature"), + ChannelPoint: outpoint, + Capacity: 1000, + ExtraOpaqueData: []byte{ + 1, 1, 1, + 2, 2, 2, 2, + 3, 3, 3, 3, 3, + }, } copy(edgeInfo.NodeKey1Bytes[:], firstNode[:]) copy(edgeInfo.NodeKey2Bytes[:], secondNode[:]) From e6e6d7217a680c54e13722a571c20957ab4023c4 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 7 May 2025 10:36:05 +0200 Subject: [PATCH 35/41] lnwire: validate that gossip messages contain valid TLV In this commit, we check that the extra bytes appended to gossip messages contain valid TLV streams. We do this here for: - channel_announcement - channel_announcement_2 - channel_update - channel_update_2 - node_announcement This is in preparation for the SQL version of the graph store which will normalise TLV streams at persistence time. --- docs/release-notes/release-notes-0.20.0.md | 3 ++- lnwire/channel_announcement.go | 9 +++++++-- lnwire/channel_announcement_2.go | 2 +- lnwire/channel_update.go | 9 +++++++-- lnwire/channel_update_2.go | 2 +- lnwire/extra_bytes.go | 22 ++++++++++++++++++++++ lnwire/node_announcement.go | 9 +++++++-- 7 files changed, 47 insertions(+), 9 deletions(-) diff --git a/docs/release-notes/release-notes-0.20.0.md b/docs/release-notes/release-notes-0.20.0.md index 94bcd18a67..85b5f7b7ea 100644 --- a/docs/release-notes/release-notes-0.20.0.md +++ b/docs/release-notes/release-notes-0.20.0.md @@ -34,7 +34,8 @@ * Graph Store SQL implementation and migration project: * Introduce an [abstract graph store](https://github.com/lightningnetwork/lnd/pull/9791) interface. - + * Start [validating](https://github.com/lightningnetwork/lnd/pull/9787) that + byte blobs at the end of gossip messages are valid TLV streams. ## RPC Updates diff --git a/lnwire/channel_announcement.go b/lnwire/channel_announcement.go index 0a3989abbe..05161cca83 100644 --- a/lnwire/channel_announcement.go +++ b/lnwire/channel_announcement.go @@ -70,8 +70,8 @@ var _ SizeableMessage = (*ChannelAnnouncement1)(nil) // io.Reader observing the specified protocol version. // // This is part of the lnwire.Message interface. -func (a *ChannelAnnouncement1) Decode(r io.Reader, pver uint32) error { - return ReadElements(r, +func (a *ChannelAnnouncement1) Decode(r io.Reader, _ uint32) error { + err := ReadElements(r, &a.NodeSig1, &a.NodeSig2, &a.BitcoinSig1, @@ -85,6 +85,11 @@ func (a *ChannelAnnouncement1) Decode(r io.Reader, pver uint32) error { &a.BitcoinKey2, &a.ExtraOpaqueData, ) + if err != nil { + return err + } + + return a.ExtraOpaqueData.ValidateTLV() } // Encode serializes the target ChannelAnnouncement into the passed io.Writer diff --git a/lnwire/channel_announcement_2.go b/lnwire/channel_announcement_2.go index 57b3a24b8c..95af69eda3 100644 --- a/lnwire/channel_announcement_2.go +++ b/lnwire/channel_announcement_2.go @@ -126,7 +126,7 @@ func (c *ChannelAnnouncement2) DecodeTLVRecords(r io.Reader) error { c.ExtraOpaqueData = tlvRecords } - return nil + return c.ExtraOpaqueData.ValidateTLV() } // Encode serializes the target AnnounceSignatures1 into the passed io.Writer diff --git a/lnwire/channel_update.go b/lnwire/channel_update.go index 88f9816715..55d9d3181a 100644 --- a/lnwire/channel_update.go +++ b/lnwire/channel_update.go @@ -132,7 +132,7 @@ var _ SizeableMessage = (*ChannelUpdate1)(nil) // io.Reader observing the specified protocol version. // // This is part of the lnwire.Message interface. -func (a *ChannelUpdate1) Decode(r io.Reader, pver uint32) error { +func (a *ChannelUpdate1) Decode(r io.Reader, _ uint32) error { err := ReadElements(r, &a.Signature, a.ChainHash[:], @@ -156,7 +156,12 @@ func (a *ChannelUpdate1) Decode(r io.Reader, pver uint32) error { } } - return a.ExtraOpaqueData.Decode(r) + err = a.ExtraOpaqueData.Decode(r) + if err != nil { + return err + } + + return a.ExtraOpaqueData.ValidateTLV() } // Encode serializes the target ChannelUpdate into the passed io.Writer diff --git a/lnwire/channel_update_2.go b/lnwire/channel_update_2.go index 56f7edf6b4..343af6b1e9 100644 --- a/lnwire/channel_update_2.go +++ b/lnwire/channel_update_2.go @@ -154,7 +154,7 @@ func (c *ChannelUpdate2) DecodeTLVRecords(r io.Reader) error { c.ExtraOpaqueData = tlvRecords } - return nil + return c.ExtraOpaqueData.ValidateTLV() } // Encode serializes the target ChannelUpdate2 into the passed io.Writer diff --git a/lnwire/extra_bytes.go b/lnwire/extra_bytes.go index 4681426cbb..9530e06e81 100644 --- a/lnwire/extra_bytes.go +++ b/lnwire/extra_bytes.go @@ -63,6 +63,28 @@ func (e *ExtraOpaqueData) Decode(r io.Reader) error { return nil } +// ValidateTLV checks that the raw bytes that make up the ExtraOpaqueData +// instance are a valid TLV stream. +func (e *ExtraOpaqueData) ValidateTLV() error { + // There is nothing to validate if the ExtraOpaqueData is nil or empty. + if e == nil || len(*e) == 0 { + return nil + } + + tlvStream, err := tlv.NewStream() + if err != nil { + return err + } + + // Ensure that the TLV stream is valid by attempting to decode it. + _, err = tlvStream.DecodeWithParsedTypesP2P(bytes.NewReader(*e)) + if err != nil { + return fmt.Errorf("invalid TLV stream: %w: %v", err, *e) + } + + return nil +} + // PackRecords attempts to encode the set of tlv records into the target // ExtraOpaqueData instance. The records will be encoded as a raw TLV stream // and stored within the backing slice pointer. diff --git a/lnwire/node_announcement.go b/lnwire/node_announcement.go index 5ba2d7a1db..d3502a7d91 100644 --- a/lnwire/node_announcement.go +++ b/lnwire/node_announcement.go @@ -112,8 +112,8 @@ var _ SizeableMessage = (*NodeAnnouncement)(nil) // io.Reader observing the specified protocol version. // // This is part of the lnwire.Message interface. -func (a *NodeAnnouncement) Decode(r io.Reader, pver uint32) error { - return ReadElements(r, +func (a *NodeAnnouncement) Decode(r io.Reader, _ uint32) error { + err := ReadElements(r, &a.Signature, &a.Features, &a.Timestamp, @@ -123,6 +123,11 @@ func (a *NodeAnnouncement) Decode(r io.Reader, pver uint32) error { &a.Addresses, &a.ExtraOpaqueData, ) + if err != nil { + return err + } + + return a.ExtraOpaqueData.ValidateTLV() } // Encode serializes the target NodeAnnouncement into the passed io.Writer From 9d9bf714c64620334b78f50cca82c201cd91cb14 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Thu, 8 May 2025 10:06:52 +0200 Subject: [PATCH 36/41] graph/db: expand TestNodeInsertionAndDeletion Expand this existing test so that it also tests that a node's addresses and feature are fetched correctly after insertion. With this, we ensure that the `FetchNodeFeatures` and `AddrsForNodes` methods of the `V1Store` interface are properly covered by unit tests. --- graph/db/graph_test.go | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index a543772a40..079f46c215 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -136,6 +136,27 @@ func TestNodeInsertionAndDeletion(t *testing.T) { // The two nodes should match exactly! compareNodes(t, node, dbNode) + // Check that the addresses for the node are fetched correctly. + pub, err := node.PubKey() + require.NoError(t, err) + + known, addrs, err := graph.AddrsForNode(pub) + require.NoError(t, err) + require.True(t, known) + require.Equal(t, testAddrs, addrs) + + // Check that the node's features are fetched correctly. This check + // will use the graph cache to fetch the features. + features, err := graph.FetchNodeFeatures(node.PubKeyBytes) + require.NoError(t, err) + require.Equal(t, testFeatures, features) + + // Check that the node's features are fetched correctly. This check + // will check the database directly. + features, err = graph.V1Store.FetchNodeFeatures(node.PubKeyBytes) + require.NoError(t, err) + require.Equal(t, testFeatures, features) + // Next, delete the node from the graph, this should purge all data // related to the node. if err := graph.DeleteLightningNode(testPub); err != nil { From 6cd2770375d931db594da1dc0f66eb3b477b3f83 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Thu, 8 May 2025 10:12:42 +0200 Subject: [PATCH 37/41] graph/db: check for wrapped errors In preparation for a different store impl which may wrap errors, we use `require.ErrorIs` for error assertions rather than direct "=" comparisons in tests. --- graph/db/graph_test.go | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index 079f46c215..3ee7d2e341 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -167,9 +167,7 @@ func TestNodeInsertionAndDeletion(t *testing.T) { // Finally, attempt to fetch the node again. This should fail as the // node should have been deleted from the database. _, err = graph.FetchLightningNode(testPub) - if err != ErrGraphNodeNotFound { - t.Fatalf("fetch after delete should fail!") - } + require.ErrorIs(t, err, ErrGraphNodeNotFound) } // TestPartialNode checks that we can add and retrieve a LightningNode where @@ -273,9 +271,7 @@ func TestAliasLookup(t *testing.T) { nodePub, err = node.PubKey() require.NoError(t, err, "unable to generate pubkey") _, err = graph.LookupAlias(nodePub) - if err != ErrNodeAliasNotFound { - t.Fatalf("alias lookup should fail for non-existent pubkey") - } + require.ErrorIs(t, err, ErrNodeAliasNotFound) } func TestSourceNode(t *testing.T) { @@ -290,9 +286,8 @@ func TestSourceNode(t *testing.T) { // Attempt to fetch the source node, this should return an error as the // source node hasn't yet been set. - if _, err := graph.SourceNode(); err != ErrSourceNodeNotSet { - t.Fatalf("source node shouldn't be set in new graph") - } + _, err = graph.SourceNode() + require.ErrorIs(t, err, ErrSourceNodeNotSet) // Set the source the source node, this should insert the node into the // database in a special way indicating it's the source node. @@ -387,9 +382,7 @@ func TestEdgeInsertionDeletion(t *testing.T) { // Finally, attempt to delete a (now) non-existent edge within the // database, this should result in an error. err = graph.DeleteChannelEdges(false, true, chanID) - if err != ErrEdgeNotFound { - t.Fatalf("deleting a non-existent edge should fail!") - } + require.ErrorIs(t, err, ErrEdgeNotFound) } func createEdge(height, txIndex uint32, txPosition uint16, outPointIndex uint32, @@ -721,9 +714,8 @@ func TestEdgeInfoUpdates(t *testing.T) { // Make sure inserting the policy at this point, before the edge info // is added, will fail. - if err := graph.UpdateEdgePolicy(edge1); err != ErrEdgeNotFound { - t.Fatalf("expected ErrEdgeNotFound, got: %v", err) - } + err = graph.UpdateEdgePolicy(edge1) + require.ErrorIs(t, err, ErrEdgeNotFound) require.Len(t, graph.graphCache.nodeChannels, 0) // Add the edge info. From 5aeab7c8188bd15de9ebe07191bc421ef2d0da66 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Thu, 8 May 2025 10:44:53 +0200 Subject: [PATCH 38/41] graph/db: set empty Features and ExtraOpaqueData in tests So we can use`require.Equal` for the ChannelEdgeInfo type. --- graph/db/graph_test.go | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index 3ee7d2e341..37b27d9f75 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -399,6 +399,12 @@ func createEdge(height, txIndex uint32, txPosition uint16, outPointIndex uint32, Index: outPointIndex, } + var ( + features = lnwire.NewRawFeatureVector() + featureBuf bytes.Buffer + ) + _ = features.Encode(&featureBuf) + node1Pub, _ := node1.PubKey() node2Pub, _ := node2.PubKey() edgeInfo := models.ChannelEdgeInfo{ @@ -410,8 +416,10 @@ func createEdge(height, txIndex uint32, txPosition uint16, outPointIndex uint32, BitcoinSig1Bytes: testSig.Serialize(), BitcoinSig2Bytes: testSig.Serialize(), }, - ChannelPoint: outpoint, - Capacity: 9000, + ChannelPoint: outpoint, + Capacity: 9000, + ExtraOpaqueData: make([]byte, 0), + Features: featureBuf.Bytes(), } copy(edgeInfo.NodeKey1Bytes[:], node1Pub.SerializeCompressed()) @@ -634,6 +642,12 @@ func createChannelEdge(node1, node2 *models.LightningNode) ( Index: prand.Uint32(), } + var ( + features = lnwire.NewRawFeatureVector() + featureBuf bytes.Buffer + ) + _ = features.Encode(&featureBuf) + // Add the new edge to the database, this should proceed without any // errors. edgeInfo := &models.ChannelEdgeInfo{ @@ -652,6 +666,7 @@ func createChannelEdge(node1, node2 *models.LightningNode) ( 2, 2, 2, 2, 3, 3, 3, 3, 3, }, + Features: featureBuf.Bytes(), } copy(edgeInfo.NodeKey1Bytes[:], firstNode[:]) copy(edgeInfo.NodeKey2Bytes[:], secondNode[:]) From 5ebad9d49f47dbbe6b4ef9bf6a86771b1ee3e085 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Thu, 8 May 2025 10:14:27 +0200 Subject: [PATCH 39/41] graph/db: use mainnet genisis hash in tests In preparation for our SQL Graph store which wont explicitly store the chain hash but will instead obtain it from the runtime config, we replace the test chainhash value with that of the mainnet genesis hash. --- graph/db/graph_test.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index 37b27d9f75..cae2042b7b 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -19,6 +19,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/graph/db/models" @@ -329,7 +330,7 @@ func TestEdgeInsertionDeletion(t *testing.T) { require.NoError(t, err, "unable to generate node key") edgeInfo := models.ChannelEdgeInfo{ ChannelID: chanID, - ChainHash: key, + ChainHash: *chaincfg.MainNetParams.GenesisHash, AuthProof: &models.ChannelAuthProof{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), @@ -409,7 +410,7 @@ func createEdge(height, txIndex uint32, txPosition uint16, outPointIndex uint32, node2Pub, _ := node2.PubKey() edgeInfo := models.ChannelEdgeInfo{ ChannelID: shortChanID.ToUint64(), - ChainHash: key, + ChainHash: *chaincfg.MainNetParams.GenesisHash, AuthProof: &models.ChannelAuthProof{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), @@ -652,7 +653,7 @@ func createChannelEdge(node1, node2 *models.LightningNode) ( // errors. edgeInfo := &models.ChannelEdgeInfo{ ChannelID: chanID, - ChainHash: key, + ChainHash: *chaincfg.MainNetParams.GenesisHash, AuthProof: &models.ChannelAuthProof{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), @@ -1363,7 +1364,7 @@ func fillTestGraph(t testing.TB, graph *ChannelGraph, numNodes, edgeInfo := models.ChannelEdgeInfo{ ChannelID: chanID, - ChainHash: key, + ChainHash: *chaincfg.MainNetParams.GenesisHash, AuthProof: &models.ChannelAuthProof{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), @@ -1545,7 +1546,7 @@ func TestGraphPruning(t *testing.T) { edgeInfo := models.ChannelEdgeInfo{ ChannelID: chanID, - ChainHash: key, + ChainHash: *chaincfg.MainNetParams.GenesisHash, AuthProof: &models.ChannelAuthProof{ NodeSig1Bytes: testSig.Serialize(), NodeSig2Bytes: testSig.Serialize(), From acad19edcd91b98ad934525cdf28df9fcb83d2c8 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Thu, 8 May 2025 10:52:59 +0200 Subject: [PATCH 40/41] graph/db: add test coverage for AddEdgeProof --- graph/db/graph_test.go | 104 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 93 insertions(+), 11 deletions(-) diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index cae2042b7b..5a774d53bd 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -619,9 +619,28 @@ func assertEdgeInfoEqual(t *testing.T, e1 *models.ChannelEdgeInfo, } } -func createChannelEdge(node1, node2 *models.LightningNode) ( - *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, - *models.ChannelEdgePolicy) { +type createEdgeConfig struct { + skipProofs bool +} + +type createEdgeOpt func(*createEdgeConfig) + +// withSkipProofs will let createChannelEdge create an edge without auth +// proofs. In this case, createChannelEdge will then also not create policies. +func withSkipProofs() createEdgeOpt { + return func(cfg *createEdgeConfig) { + cfg.skipProofs = true + } +} + +func createChannelEdge(node1, node2 *models.LightningNode, + options ...createEdgeOpt) (*models.ChannelEdgeInfo, + *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) { + + var opts createEdgeConfig + for _, o := range options { + o(&opts) + } var ( firstNode [33]byte @@ -652,14 +671,8 @@ func createChannelEdge(node1, node2 *models.LightningNode) ( // Add the new edge to the database, this should proceed without any // errors. edgeInfo := &models.ChannelEdgeInfo{ - ChannelID: chanID, - ChainHash: *chaincfg.MainNetParams.GenesisHash, - AuthProof: &models.ChannelAuthProof{ - NodeSig1Bytes: testSig.Serialize(), - NodeSig2Bytes: testSig.Serialize(), - BitcoinSig1Bytes: testSig.Serialize(), - BitcoinSig2Bytes: testSig.Serialize(), - }, + ChannelID: chanID, + ChainHash: *chaincfg.MainNetParams.GenesisHash, ChannelPoint: outpoint, Capacity: 1000, ExtraOpaqueData: []byte{ @@ -674,6 +687,17 @@ func createChannelEdge(node1, node2 *models.LightningNode) ( copy(edgeInfo.BitcoinKey1Bytes[:], firstNode[:]) copy(edgeInfo.BitcoinKey2Bytes[:], secondNode[:]) + if opts.skipProofs { + return edgeInfo, nil, nil + } + + edgeInfo.AuthProof = &models.ChannelAuthProof{ + NodeSig1Bytes: testSig.Serialize(), + NodeSig2Bytes: testSig.Serialize(), + BitcoinSig1Bytes: testSig.Serialize(), + BitcoinSig2Bytes: testSig.Serialize(), + } + edge1 := &models.ChannelEdgePolicy{ SigBytes: testSig.Serialize(), ChannelID: chanID, @@ -1023,6 +1047,64 @@ func newEdgePolicy(chanID uint64, updateTime int64) *models.ChannelEdgePolicy { } } +// TestAddEdgeProof tests the ability to add an edge proof to an existing edge. +func TestAddEdgeProof(t *testing.T) { + t.Parallel() + + graph, err := MakeTestGraph(t) + require.NoError(t, err, "unable to make test database") + + // Add an edge with no proof. + node1 := createTestVertex(t) + node2 := createTestVertex(t) + edge1, _, _ := createChannelEdge(node1, node2, withSkipProofs()) + require.NoError(t, graph.AddChannelEdge(edge1)) + + // Fetch the edge and assert that the proof is nil and that the rest + // of the edge info is correct. + dbEdge, _, _, err := graph.FetchChannelEdgesByID(edge1.ChannelID) + require.NoError(t, err) + require.Nil(t, dbEdge.AuthProof) + require.Equal(t, edge1, dbEdge) + + // Now, add the edge proof. + proof := &models.ChannelAuthProof{ + NodeSig1Bytes: testSig.Serialize(), + NodeSig2Bytes: testSig.Serialize(), + BitcoinSig1Bytes: testSig.Serialize(), + BitcoinSig2Bytes: testSig.Serialize(), + } + + // First, add the proof to the rest of the channel edge info and try + // to call AddChannelEdge again - this should fail due to the channel + // already existing. + edge1.AuthProof = proof + err = graph.AddChannelEdge(edge1) + require.Error(t, err, ErrEdgeAlreadyExist) + + // Now add just the proof. + scid1 := lnwire.NewShortChanIDFromInt(edge1.ChannelID) + require.NoError(t, graph.AddEdgeProof(scid1, proof)) + + // Fetch the edge again and assert that the proof is now set. + dbEdge, _, _, err = graph.FetchChannelEdgesByID(edge1.ChannelID) + require.NoError(t, err) + require.NotNil(t, dbEdge.AuthProof) + require.Equal(t, edge1, dbEdge) + + // For completeness, also test the case where we insert a new edge with + // an edge proof. Show that the proof is present from the get go. + edge2, _, _ := createChannelEdge(node1, node2) + require.NoError(t, graph.AddChannelEdge(edge2)) + + // Fetch the edge and assert that the proof is nil and that the rest + // of the edge info is correct. + dbEdge2, _, _, err := graph.FetchChannelEdgesByID(edge2.ChannelID) + require.NoError(t, err) + require.NotNil(t, dbEdge2.AuthProof) + require.Equal(t, edge2, dbEdge2) +} + // TestForEachSourceNodeChannel tests that the ForEachSourceNodeChannel // correctly iterates through the channels of the set source node. func TestForEachSourceNodeChannel(t *testing.T) { From 595077b24c83d3a8683459ae1ea0f5e8dfae8e75 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 9 May 2025 10:59:17 +0200 Subject: [PATCH 41/41] graph/db: let MakeTestGraph require no error internally Instead of returning an error and needing to call `require.NoError` for each call to `MakeTestGraph`, rather just used the available testing variable to require no error within the function itself. --- graph/db/graph_test.go | 199 ++++++++++++++--------------------------- graph/db/kv_store.go | 25 ++---- 2 files changed, 77 insertions(+), 147 deletions(-) diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index 5a774d53bd..9adc2ed096 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -98,8 +98,7 @@ func createTestVertex(t testing.TB) *models.LightningNode { func TestNodeInsertionAndDeletion(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.NoError(t, err, "unable to make test database") + graph := MakeTestGraph(t) // We'd like to test basic insertion/deletion for vertexes from the // graph, so we'll create a test vertex to start with. @@ -176,8 +175,7 @@ func TestNodeInsertionAndDeletion(t *testing.T) { func TestPartialNode(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.NoError(t, err, "unable to make test database") + graph := MakeTestGraph(t) // To insert a partial node, we need to add a channel edge that has // node keys for nodes we are not yet aware @@ -243,8 +241,7 @@ func TestPartialNode(t *testing.T) { func TestAliasLookup(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.NoError(t, err, "unable to make test database") + graph := MakeTestGraph(t) // We'd like to test the alias index within the database, so first // create a new test node. @@ -278,8 +275,7 @@ func TestAliasLookup(t *testing.T) { func TestSourceNode(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.NoError(t, err, "unable to make test database") + graph := MakeTestGraph(t) // We'd like to test the setting/getting of the source node, so we // first create a fake node to use within the test. @@ -287,7 +283,7 @@ func TestSourceNode(t *testing.T) { // Attempt to fetch the source node, this should return an error as the // source node hasn't yet been set. - _, err = graph.SourceNode() + _, err := graph.SourceNode() require.ErrorIs(t, err, ErrSourceNodeNotSet) // Set the source the source node, this should insert the node into the @@ -306,8 +302,7 @@ func TestSourceNode(t *testing.T) { func TestEdgeInsertionDeletion(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.NoError(t, err, "unable to make test database") + graph := MakeTestGraph(t) // We'd like to test the insertion/deletion of edges, so we create two // vertexes to connect. @@ -436,8 +431,7 @@ func createEdge(height, txIndex uint32, txPosition uint16, outPointIndex uint32, func TestDisconnectBlockAtHeight(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.NoError(t, err, "unable to make test database") + graph := MakeTestGraph(t) sourceNode := createTestVertex(t) if err := graph.SetSourceNode(sourceNode); err != nil { @@ -457,7 +451,7 @@ func TestDisconnectBlockAtHeight(t *testing.T) { // Prune the graph a few times to make sure we have entries in the // prune log. - _, err = graph.PruneGraph(spendOutputs, &blockHash, 155) + _, err := graph.PruneGraph(spendOutputs, &blockHash, 155) require.NoError(t, err, "unable to prune graph") var blockHash2 chainhash.Hash copy(blockHash2[:], bytes.Repeat([]byte{2}, 32)) @@ -733,8 +727,7 @@ func createChannelEdge(node1, node2 *models.LightningNode, func TestEdgeInfoUpdates(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.NoError(t, err, "unable to make test database") + graph := MakeTestGraph(t) // We'd like to test the update of edges inserted into the database, so // we create two vertexes to connect. @@ -754,7 +747,7 @@ func TestEdgeInfoUpdates(t *testing.T) { // Make sure inserting the policy at this point, before the edge info // is added, will fail. - err = graph.UpdateEdgePolicy(edge1) + err := graph.UpdateEdgePolicy(edge1) require.ErrorIs(t, err, ErrEdgeNotFound) require.Len(t, graph.graphCache.nodeChannels, 0) @@ -1051,8 +1044,7 @@ func newEdgePolicy(chanID uint64, updateTime int64) *models.ChannelEdgePolicy { func TestAddEdgeProof(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.NoError(t, err, "unable to make test database") + graph := MakeTestGraph(t) // Add an edge with no proof. node1 := createTestVertex(t) @@ -1110,8 +1102,7 @@ func TestAddEdgeProof(t *testing.T) { func TestForEachSourceNodeChannel(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.NoError(t, err, "unable to make test database") + graph := MakeTestGraph(t) // Create a source node (A) and set it as such in the DB. nodeA := createTestVertex(t) @@ -1177,7 +1168,7 @@ func TestForEachSourceNodeChannel(t *testing.T) { // Now, we'll use the ForEachSourceNodeChannel and assert that it // returns the expected data in the call-back. - err = graph.ForEachSourceNodeChannel(func(chanPoint wire.OutPoint, + err := graph.ForEachSourceNodeChannel(func(chanPoint wire.OutPoint, havePolicy bool, otherNode *models.LightningNode) error { require.Contains(t, expectedSrcChans, chanPoint) @@ -1199,8 +1190,7 @@ func TestForEachSourceNodeChannel(t *testing.T) { func TestGraphTraversal(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.NoError(t, err, "unable to make test database") + graph := MakeTestGraph(t) // We'd like to test some of the graph traversal capabilities within // the DB, so we'll create a series of fake nodes to insert into the @@ -1219,7 +1209,7 @@ func TestGraphTraversal(t *testing.T) { // set of channels (to force the fall back), we should find all the // channel as well as the nodes included. graph.graphCache = nil - err = graph.ForEachNodeCached(func(node route.Vertex, + err := graph.ForEachNodeCached(func(node route.Vertex, chans map[uint64]*DirectedChannel) error { if _, ok := nodeIndex[node]; !ok { @@ -1295,8 +1285,7 @@ func TestGraphTraversal(t *testing.T) { func TestGraphTraversalCacheable(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.NoError(t, err, "unable to make test database") + graph := MakeTestGraph(t) // We'd like to test some of the graph traversal capabilities within // the DB, so we'll create a series of fake nodes to insert into the @@ -1308,7 +1297,7 @@ func TestGraphTraversalCacheable(t *testing.T) { // Create a map of all nodes with the iteration we know works (because // it is tested in another test). nodeMap := make(map[route.Vertex]struct{}) - err = graph.ForEachNode(func(tx NodeRTx) error { + err := graph.ForEachNode(func(tx NodeRTx) error { nodeMap[tx.Node().PubKeyBytes] = struct{}{} return nil @@ -1346,8 +1335,7 @@ func TestGraphTraversalCacheable(t *testing.T) { func TestGraphCacheTraversal(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.NoError(t, err) + graph := MakeTestGraph(t) // We'd like to test some of the graph traversal capabilities within // the DB, so we'll create a series of fake nodes to insert into the @@ -1363,7 +1351,7 @@ func TestGraphCacheTraversal(t *testing.T) { for _, node := range nodeList { node := node - err = graph.graphCache.ForEachChannel( + err := graph.graphCache.ForEachChannel( node.PubKeyBytes, func(d *DirectedChannel) error { delete(chanIndex, d.ChannelID) @@ -1589,8 +1577,7 @@ func assertChanViewEqualChanPoints(t *testing.T, a []EdgePoint, func TestGraphPruning(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.NoError(t, err, "unable to make test database") + graph := MakeTestGraph(t) sourceNode := createTestVertex(t) if err := graph.SetSourceNode(sourceNode); err != nil { @@ -1779,8 +1766,7 @@ func TestGraphPruning(t *testing.T) { func TestHighestChanID(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.NoError(t, err, "unable to make test database") + graph := MakeTestGraph(t) // If we don't yet have any channels in the database, then we should // get a channel ID of zero if we ask for the highest channel ID. @@ -1839,8 +1825,7 @@ func TestHighestChanID(t *testing.T) { func TestChanUpdatesInHorizon(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.NoError(t, err, "unable to make test database") + graph := MakeTestGraph(t) // If we issue an arbitrary query before any channel updates are // inserted in the database, we should get zero results. @@ -1998,8 +1983,7 @@ func TestChanUpdatesInHorizon(t *testing.T) { func TestNodeUpdatesInHorizon(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.NoError(t, err, "unable to make test database") + graph := MakeTestGraph(t) startTime := time.Unix(1234, 0) endTime := startTime @@ -2114,8 +2098,7 @@ func TestNodeUpdatesInHorizon(t *testing.T) { func TestFilterKnownChanIDsZombieRevival(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.NoError(t, err) + graph := MakeTestGraph(t) var ( scid1 = lnwire.ShortChannelID{BlockHeight: 1} @@ -2129,7 +2112,7 @@ func TestFilterKnownChanIDsZombieRevival(t *testing.T) { } // Mark channel 1 and 2 as zombies. - err = graph.MarkEdgeZombie(scid1.ToUint64(), [33]byte{}, [33]byte{}) + err := graph.MarkEdgeZombie(scid1.ToUint64(), [33]byte{}, [33]byte{}) require.NoError(t, err) err = graph.MarkEdgeZombie(scid2.ToUint64(), [33]byte{}, [33]byte{}) require.NoError(t, err) @@ -2180,8 +2163,7 @@ func TestFilterKnownChanIDsZombieRevival(t *testing.T) { func TestFilterKnownChanIDs(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.NoError(t, err, "unable to make test database") + graph := MakeTestGraph(t) isZombieUpdate := func(updateTime1 time.Time, updateTime2 time.Time) bool { @@ -2358,19 +2340,15 @@ func TestFilterKnownChanIDs(t *testing.T) { func TestStressTestChannelGraphAPI(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.NoError(t, err) + graph := MakeTestGraph(t) node1 := createTestVertex(t) - require.NoError(t, err, "unable to create test node") require.NoError(t, graph.AddLightningNode(node1)) node2 := createTestVertex(t) - require.NoError(t, err, "unable to create test node") require.NoError(t, graph.AddLightningNode(node2)) - err = graph.SetSourceNode(node1) - require.NoError(t, err) + require.NoError(t, graph.SetSourceNode(node1)) type chanInfo struct { info models.ChannelEdgeInfo @@ -2646,8 +2624,7 @@ func TestStressTestChannelGraphAPI(t *testing.T) { func TestFilterChannelRange(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.NoError(t, err) + graph := MakeTestGraph(t) // We'll first populate our graph with two nodes. All channels created // below will be made between these two nodes. @@ -2863,8 +2840,7 @@ func TestFilterChannelRange(t *testing.T) { func TestFetchChanInfos(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.NoError(t, err, "unable to make test database") + graph := MakeTestGraph(t) // We'll first populate our graph with two nodes. All channels created // below will be made between these two nodes. @@ -2933,7 +2909,7 @@ func TestFetchChanInfos(t *testing.T) { if err := graph.AddChannelEdge(&zombieChan); err != nil { t.Fatalf("unable to create channel edge: %v", err) } - err = graph.DeleteChannelEdges(false, true, zombieChan.ChannelID) + err := graph.DeleteChannelEdges(false, true, zombieChan.ChannelID) require.NoError(t, err, "unable to delete and mark edge zombie") edgeQuery = append(edgeQuery, zombieChanID.ToUint64()) @@ -2966,8 +2942,7 @@ func TestFetchChanInfos(t *testing.T) { func TestIncompleteChannelPolicies(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.NoError(t, err, "unable to make test database") + graph := MakeTestGraph(t) // Create two nodes. node1 := createTestVertex(t) @@ -3062,8 +3037,7 @@ func TestIncompleteChannelPolicies(t *testing.T) { func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.NoError(t, err, "unable to make test database") + graph := MakeTestGraph(t) // The update index only applies to the bbolt graph. boltStore, ok := graph.V1Store.(*KVStore) @@ -3194,7 +3168,7 @@ func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { // index entries from the database all together. var blockHash chainhash.Hash copy(blockHash[:], bytes.Repeat([]byte{2}, 32)) - _, err = graph.PruneGraph( + _, err := graph.PruneGraph( []*wire.OutPoint{&edgeInfo.ChannelPoint}, &blockHash, 101, ) require.NoError(t, err, "unable to prune graph") @@ -3210,8 +3184,7 @@ func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { func TestPruneGraphNodes(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.NoError(t, err, "unable to make test database") + graph := MakeTestGraph(t) // We'll start off by inserting our source node, to ensure that it's // the only node left after we prune the graph. @@ -3266,10 +3239,8 @@ func TestPruneGraphNodes(t *testing.T) { // Finally, we'll ensure that node3, the only fully unconnected node as // properly deleted from the graph and not another node in its place. - _, err = graph.FetchLightningNode(node3.PubKeyBytes) - if err == nil { - t.Fatalf("node 3 should have been deleted!") - } + _, err := graph.FetchLightningNode(node3.PubKeyBytes) + require.NotNil(t, err) } // TestAddChannelEdgeShellNodes tests that when we attempt to add a ChannelEdge @@ -3278,8 +3249,7 @@ func TestPruneGraphNodes(t *testing.T) { func TestAddChannelEdgeShellNodes(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.NoError(t, err, "unable to make test database") + graph := MakeTestGraph(t) // To start, we'll create two nodes, and only add one of them to the // channel graph. @@ -3298,7 +3268,7 @@ func TestAddChannelEdgeShellNodes(t *testing.T) { // Ensure that node1 was inserted as a full node, while node2 only has // a shell node present. - node1, err = graph.FetchLightningNode(node1.PubKeyBytes) + node1, err := graph.FetchLightningNode(node1.PubKeyBytes) require.NoError(t, err, "unable to fetch node1") if !node1.HaveNodeAnnouncement { t.Fatalf("have shell announcement for node1, shouldn't") @@ -3315,8 +3285,7 @@ func TestAddChannelEdgeShellNodes(t *testing.T) { func TestNodePruningUpdateIndexDeletion(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.NoError(t, err, "unable to make test database") + graph := MakeTestGraph(t) // We'll first populate our graph with a single node that will be // removed shortly. @@ -3371,22 +3340,19 @@ func TestNodeIsPublic(t *testing.T) { // We'll need to create a separate database and channel graph for each // participant to replicate real-world scenarios (private edges being in // some graphs but not others, etc.). - aliceGraph, err := MakeTestGraph(t) - require.NoError(t, err, "unable to make test database") + aliceGraph := MakeTestGraph(t) aliceNode := createTestVertex(t) if err := aliceGraph.SetSourceNode(aliceNode); err != nil { t.Fatalf("unable to set source node: %v", err) } - bobGraph, err := MakeTestGraph(t) - require.NoError(t, err, "unable to make test database") + bobGraph := MakeTestGraph(t) bobNode := createTestVertex(t) if err := bobGraph.SetSourceNode(bobNode); err != nil { t.Fatalf("unable to set source node: %v", err) } - carolGraph, err := MakeTestGraph(t) - require.NoError(t, err, "unable to make test database") + carolGraph := MakeTestGraph(t) carolNode := createTestVertex(t) if err := carolGraph.SetSourceNode(carolNode); err != nil { t.Fatalf("unable to set source node: %v", err) @@ -3501,8 +3467,7 @@ func TestNodeIsPublic(t *testing.T) { func TestDisabledChannelIDs(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.NoError(t, err, "unable to make test database") + graph := MakeTestGraph(t) // Create first node and add it to the graph. node1 := createTestVertex(t) @@ -3585,8 +3550,7 @@ func TestDisabledChannelIDs(t *testing.T) { func TestEdgePolicyMissingMaxHtcl(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.NoError(t, err, "unable to make test database") + graph := MakeTestGraph(t) // This test currently directly edits the bytes stored in the bbolt DB. boltStore, ok := graph.V1Store.(*KVStore) @@ -3620,20 +3584,14 @@ func TestEdgePolicyMissingMaxHtcl(t *testing.T) { edge1.ExtraOpaqueData = nil var b bytes.Buffer - err = serializeChanEdgePolicy(&b, edge1, to) - if err != nil { - t.Fatalf("unable to serialize policy") - } + require.NoError(t, serializeChanEdgePolicy(&b, edge1, to)) // Set the max_htlc field. The extra bytes added to the serialization // will be the opaque data containing the serialized field. edge1.MessageFlags = lnwire.ChanUpdateRequiredMaxHtlc edge1.MaxHTLC = 13928598 var b2 bytes.Buffer - err = serializeChanEdgePolicy(&b2, edge1, to) - if err != nil { - t.Fatalf("unable to serialize policy") - } + require.NoError(t, serializeChanEdgePolicy(&b2, edge1, to)) withMaxHtlc := b2.Bytes() @@ -3642,7 +3600,7 @@ func TestEdgePolicyMissingMaxHtcl(t *testing.T) { // Attempting to deserialize these bytes should return an error. r := bytes.NewReader(stripped) - _, err = deserializeChanEdgePolicy(r) + _, err := deserializeChanEdgePolicy(r) require.ErrorIs(t, err, ErrEdgePolicyOptionalFieldNotFound) // Put the stripped bytes in the DB. @@ -3738,13 +3696,10 @@ func TestGraphZombieIndex(t *testing.T) { t.Parallel() // We'll start by creating our test graph along with a test edge. - graph, err := MakeTestGraph(t) - require.NoError(t, err, "unable to create test database") + graph := MakeTestGraph(t) node1 := createTestVertex(t) - require.NoError(t, err, "unable to create test vertex") node2 := createTestVertex(t) - require.NoError(t, err, "unable to create test vertex") // Swap the nodes if the second's pubkey is smaller than the first. // Without this, the comparisons at the end will fail probabilistically. @@ -3763,7 +3718,7 @@ func TestGraphZombieIndex(t *testing.T) { // If we delete the edge and mark it as a zombie, then we should expect // to see it within the index. - err = graph.DeleteChannelEdges(false, true, edge.ChannelID) + err := graph.DeleteChannelEdges(false, true, edge.ChannelID) require.NoError(t, err, "unable to mark edge as zombie") isZombie, pubKey1, pubKey2 := graph.IsZombieEdge(edge.ChannelID) require.True(t, isZombie) @@ -3922,20 +3877,15 @@ func TestComputeFee(t *testing.T) { func TestBatchedAddChannelEdge(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.Nil(t, err) + graph := MakeTestGraph(t) sourceNode := createTestVertex(t) - require.Nil(t, err) - err = graph.SetSourceNode(sourceNode) - require.Nil(t, err) + require.Nil(t, graph.SetSourceNode(sourceNode)) // We'd like to test the insertion/deletion of edges, so we create two // vertexes to connect. node1 := createTestVertex(t) - require.Nil(t, err) node2 := createTestVertex(t) - require.Nil(t, err) // In addition to the fake vertexes we create some fake channel // identifiers. @@ -3945,7 +3895,7 @@ func TestBatchedAddChannelEdge(t *testing.T) { // Prune the graph a few times to make sure we have entries in the // prune log. - _, err = graph.PruneGraph(spendOutputs, &blockHash, 155) + _, err := graph.PruneGraph(spendOutputs, &blockHash, 155) require.Nil(t, err) var blockHash2 chainhash.Hash copy(blockHash2[:], bytes.Repeat([]byte{2}, 32)) @@ -4003,31 +3953,24 @@ func TestBatchedAddChannelEdge(t *testing.T) { func TestBatchedUpdateEdgePolicy(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.Nil(t, err) + graph := MakeTestGraph(t) // We'd like to test the update of edges inserted into the database, so // we create two vertexes to connect. node1 := createTestVertex(t) - require.Nil(t, err) - err = graph.AddLightningNode(node1) - require.Nil(t, err) + require.NoError(t, graph.AddLightningNode(node1)) node2 := createTestVertex(t) - require.Nil(t, err) - err = graph.AddLightningNode(node2) - require.Nil(t, err) + require.NoError(t, graph.AddLightningNode(node2)) // Create an edge and add it to the db. edgeInfo, edge1, edge2 := createChannelEdge(node1, node2) // Make sure inserting the policy at this point, before the edge info // is added, will fail. - err = graph.UpdateEdgePolicy(edge1) - require.Error(t, ErrEdgeNotFound, err) + require.Error(t, ErrEdgeNotFound, graph.UpdateEdgePolicy(edge1)) // Add the edge info. - err = graph.AddChannelEdge(edgeInfo) - require.Nil(t, err) + require.NoError(t, graph.AddChannelEdge(edgeInfo)) errTimeout := errors.New("timeout adding batched channel") @@ -4060,8 +4003,7 @@ func TestBatchedUpdateEdgePolicy(t *testing.T) { // BenchmarkForEachChannel is a benchmark test that measures the number of // allocations and the total memory consumed by the full graph traversal. func BenchmarkForEachChannel(b *testing.B) { - graph, err := MakeTestGraph(b) - require.Nil(b, err) + graph := MakeTestGraph(b) const numNodes = 100 const numChannels = 4 @@ -4076,7 +4018,7 @@ func BenchmarkForEachChannel(b *testing.B) { ) var nodes []route.Vertex - err = graph.ForEachNodeCacheable(func(node route.Vertex, + err := graph.ForEachNodeCacheable(func(node route.Vertex, vector *lnwire.FeatureVector) error { nodes = append(nodes, node) @@ -4111,21 +4053,18 @@ func BenchmarkForEachChannel(b *testing.B) { // TestGraphCacheForEachNodeChannel tests that the forEachNodeDirectedChannel // method works as expected, and is able to handle nil self edges. func TestGraphCacheForEachNodeChannel(t *testing.T) { - graph, err := MakeTestGraph(t) - require.NoError(t, err) + t.Parallel() + + graph := MakeTestGraph(t) // Unset the channel graph cache to simulate the user running with the // option turned off. graph.graphCache = nil node1 := createTestVertex(t) - require.Nil(t, err) - err = graph.AddLightningNode(node1) - require.Nil(t, err) + require.NoError(t, graph.AddLightningNode(node1)) node2 := createTestVertex(t) - require.Nil(t, err) - err = graph.AddLightningNode(node2) - require.Nil(t, err) + require.NoError(t, graph.AddLightningNode(node2)) // Create an edge and add it to the db. edgeInfo, e1, e2 := createChannelEdge(node1, node2) @@ -4145,7 +4084,7 @@ func TestGraphCacheForEachNodeChannel(t *testing.T) { getSingleChannel := func() *DirectedChannel { var ch *DirectedChannel - err = graph.ForEachNodeDirectedChannel(node1.PubKeyBytes, + err := graph.ForEachNodeDirectedChannel(node1.PubKeyBytes, func(c *DirectedChannel) error { require.Nil(t, ch) ch = c @@ -4240,8 +4179,7 @@ func TestGraphLoading(t *testing.T) { func TestClosedScid(t *testing.T) { t.Parallel() - graph, err := MakeTestGraph(t) - require.Nil(t, err) + graph := MakeTestGraph(t) scid := lnwire.ShortChannelID{} @@ -4277,8 +4215,7 @@ func TestLightningNodePersistence(t *testing.T) { t.Parallel() // Create a new test graph instance. - graph, err := MakeTestGraph(t) - require.NoError(t, err) + graph := MakeTestGraph(t) nodeAnnBytes, err := hex.DecodeString(testNodeAnn) require.NoError(t, err) diff --git a/graph/db/kv_store.go b/graph/db/kv_store.go index 36da985cff..f7ebc241ba 100644 --- a/graph/db/kv_store.go +++ b/graph/db/kv_store.go @@ -4709,8 +4709,8 @@ func (c *chanGraphNodeTx) ForEachChannel(f func(*models.ChannelEdgeInfo, // MakeTestGraph creates a new instance of the KVStore for testing // purposes. -func MakeTestGraph(t testing.TB, modifiers ...KVStoreOptionModifier) ( - *ChannelGraph, error) { +func MakeTestGraph(t testing.TB, + modifiers ...KVStoreOptionModifier) *ChannelGraph { opts := DefaultOptions() for _, modifier := range modifiers { @@ -4719,28 +4719,21 @@ func MakeTestGraph(t testing.TB, modifiers ...KVStoreOptionModifier) ( // Next, create KVStore for the first time. backend, backendCleanup, err := kvdb.GetTestBackend(t.TempDir(), "cgr") - if err != nil { - backendCleanup() - - return nil, err - } + t.Cleanup(backendCleanup) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, backend.Close()) + }) graphStore, err := NewKVStore(backend, modifiers...) require.NoError(t, err) graph, err := NewChannelGraph(graphStore) - if err != nil { - backendCleanup() - - return nil, err - } + require.NoError(t, err) require.NoError(t, graph.Start()) - t.Cleanup(func() { - _ = backend.Close() - backendCleanup() require.NoError(t, graph.Stop()) }) - return graph, nil + return graph }