diff --git a/.github/workflows/postgres-integration.yml b/.github/workflows/postgres-integration.yml new file mode 100644 index 000000000..ff0a553f2 --- /dev/null +++ b/.github/workflows/postgres-integration.yml @@ -0,0 +1,65 @@ +name: CI Checks - PostgreSQL Integration Tests + +on: [ push, pull_request ] + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + build-and-test: + runs-on: ubuntu-latest + + services: + postgres: + image: postgres:latest + ports: + - 5432:5432 + env: + POSTGRES_DB: postgres + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + steps: + - name: Checkout code + uses: actions/checkout@v6 + - name: Install Rust stable toolchain + run: | + curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --profile=minimal --default-toolchain stable + - name: Enable caching for bitcoind + id: cache-bitcoind + uses: actions/cache@v4 + with: + path: bin/bitcoind-${{ runner.os }}-${{ runner.arch }} + key: bitcoind-${{ runner.os }}-${{ runner.arch }} + - name: Enable caching for electrs + id: cache-electrs + uses: actions/cache@v4 + with: + path: bin/electrs-${{ runner.os }}-${{ runner.arch }} + key: electrs-${{ runner.os }}-${{ runner.arch }} + - name: Download bitcoind/electrs + if: "steps.cache-bitcoind.outputs.cache-hit != 'true' || steps.cache-electrs.outputs.cache-hit != 'true'" + run: | + source ./scripts/download_bitcoind_electrs.sh + mkdir bin + mv "$BITCOIND_EXE" bin/bitcoind-${{ runner.os }}-${{ runner.arch }} + mv "$ELECTRS_EXE" bin/electrs-${{ runner.os }}-${{ runner.arch }} + - name: Set bitcoind/electrs environment variables + run: | + echo "BITCOIND_EXE=$( pwd )/bin/bitcoind-${{ runner.os }}-${{ runner.arch }}" >> "$GITHUB_ENV" + echo "ELECTRS_EXE=$( pwd )/bin/electrs-${{ runner.os }}-${{ runner.arch }}" >> "$GITHUB_ENV" + - name: Run PostgreSQL store tests + env: + TEST_POSTGRES_URL: "host=localhost user=postgres password=postgres" + run: cargo test --features postgres io::postgres_store + - name: Run PostgreSQL integration tests + env: + TEST_POSTGRES_URL: "host=localhost user=postgres password=postgres" + run: | + RUSTFLAGS="--cfg no_download --cfg cycle_tests" cargo test --features postgres --test integration_tests_postgres diff --git a/Cargo.toml b/Cargo.toml index 539941677..04130f237 100755 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ panic = 'abort' # Abort on panic [features] default = [] +postgres = ["dep:tokio-postgres", "dep:native-tls", "dep:postgres-native-tls"] [dependencies] #lightning = { version = "0.2.0", features = ["std"] } @@ -76,6 +77,9 @@ serde_json = { version = "1.0.128", default-features = false, features = ["std"] log = { version = "0.4.22", default-features = false, features = ["std"]} async-trait = { version = "0.1", default-features = false } +tokio-postgres = { version = "0.7", default-features = false, features = ["runtime"], optional = true } +native-tls = { version = "0.2", default-features = false, optional = true } +postgres-native-tls = { version = "0.5", default-features = false, features = ["runtime"], optional = true } vss-client = { package = "vss-client-ng", version = "0.5" } prost = { version = "0.11.6", default-features = false} #bitcoin-payment-instructions = { version = "0.6" } diff --git a/src/builder.rs b/src/builder.rs index cd8cc184f..7217bb7a4 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -629,6 +629,40 @@ impl NodeBuilder { self.build_with_store(node_entropy, kv_store) } + /// Builds a [`Node`] instance with a [PostgreSQL] backend and according to the options + /// previously configured. + /// + /// Connects to the PostgreSQL database at the given `connection_string`. + /// + /// The given `db_name` will be used or default to + /// [`DEFAULT_DB_NAME`](io::postgres_store::DEFAULT_DB_NAME). The database will be created + /// automatically if it doesn't already exist. + /// + /// The given `kv_table_name` will be used or default to + /// [`DEFAULT_KV_TABLE_NAME`](io::postgres_store::DEFAULT_KV_TABLE_NAME). + /// + /// If `tls_config` is `Some`, TLS will be used for database connections. A custom CA + /// certificate can be provided via + /// [`PostgresTlsConfig::certificate_pem`](io::postgres_store::PostgresTlsConfig::certificate_pem), + /// otherwise the system's default root certificates are used. If `tls_config` is `None`, + /// connections will be unencrypted. + /// + /// [PostgreSQL]: https://www.postgresql.org + #[cfg(feature = "postgres")] + pub fn build_with_postgres_store( + &self, node_entropy: NodeEntropy, connection_string: String, db_name: Option, + kv_table_name: Option, tls_config: Option, + ) -> Result { + let kv_store = io::postgres_store::PostgresStore::new( + connection_string, + db_name, + kv_table_name, + tls_config, + ) + .map_err(|_| BuildError::KVStoreSetupFailed)?; + self.build_with_store(node_entropy, kv_store) + } + /// Builds a [`Node`] instance with a [`FilesystemStore`] backend and according to the options /// previously configured. pub fn build_with_fs_store(&self, node_entropy: NodeEntropy) -> Result { @@ -1087,6 +1121,28 @@ impl ArcedNodeBuilder { self.inner.read().unwrap().build(*node_entropy).map(Arc::new) } + /// Builds a [`Node`] instance with a [PostgreSQL] backend and according to the options + /// previously configured. + /// + /// [PostgreSQL]: https://www.postgresql.org + #[cfg(feature = "postgres")] + pub fn build_with_postgres_store( + &self, node_entropy: Arc, connection_string: String, db_name: Option, + kv_table_name: Option, tls_config: Option, + ) -> Result, BuildError> { + self.inner + .read() + .unwrap() + .build_with_postgres_store( + *node_entropy, + connection_string, + db_name, + kv_table_name, + tls_config, + ) + .map(Arc::new) + } + /// Builds a [`Node`] instance with a [`FilesystemStore`] backend and according to the options /// previously configured. pub fn build_with_fs_store( diff --git a/src/io/mod.rs b/src/io/mod.rs index e080d39f7..e16a99975 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -7,6 +7,8 @@ //! Objects and traits for data persistence. +#[cfg(feature = "postgres")] +pub mod postgres_store; pub mod sqlite_store; #[cfg(test)] pub(crate) mod test_utils; diff --git a/src/io/postgres_store/migrations.rs b/src/io/postgres_store/migrations.rs new file mode 100644 index 000000000..c9add1c57 --- /dev/null +++ b/src/io/postgres_store/migrations.rs @@ -0,0 +1,21 @@ +// This file is Copyright its original authors, visible in version control history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license , at your option. You may not use this file except in +// accordance with one or both of these licenses. + +use lightning::io; +use tokio_postgres::Client; + +pub(super) async fn migrate_schema( + _client: &Client, _kv_table_name: &str, from_version: u16, to_version: u16, +) -> io::Result<()> { + assert!(from_version < to_version); + // Future migrations go here, e.g.: + // if from_version == 1 && to_version >= 2 { + // migrate_v1_to_v2(client, kv_table_name).await?; + // from_version = 2; + // } + Ok(()) +} diff --git a/src/io/postgres_store/mod.rs b/src/io/postgres_store/mod.rs new file mode 100644 index 000000000..42d2bce10 --- /dev/null +++ b/src/io/postgres_store/mod.rs @@ -0,0 +1,1184 @@ +// This file is Copyright its original authors, visible in version control history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license , at your option. You may not use this file except in +// accordance with one or both of these licenses. + +//! Objects related to [`PostgresStore`] live here. +use std::collections::HashMap; +use std::future::Future; +use std::sync::atomic::{AtomicI64, AtomicU64, AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex}; + +use lightning::io; +use lightning::util::persist::{ + KVStore, KVStoreSync, PageToken, PaginatedKVStore, PaginatedKVStoreSync, PaginatedListResponse, +}; +use lightning_types::string::PrintableString; +use native_tls::TlsConnector; +use postgres_native_tls::MakeTlsConnector; +use tokio_postgres::{Client, Config, Error as PgError, NoTls}; + +use crate::io::utils::check_namespace_key_validity; + +mod migrations; + +/// The default database name used when none is specified. +pub const DEFAULT_DB_NAME: &str = "ldk_node"; + +/// The default table in which we store all data. +pub const DEFAULT_KV_TABLE_NAME: &str = "ldk_data"; + +// The current schema version for the PostgreSQL store. +const SCHEMA_VERSION: u16 = 1; + +// The number of entries returned per page in paginated list operations. +const PAGE_SIZE: usize = 50; + +// The number of worker threads for the internal runtime used by sync operations. +const INTERNAL_RUNTIME_WORKERS: usize = 2; + +/// A [`KVStoreSync`] implementation that writes to and reads from a [PostgreSQL] database. +/// +/// [PostgreSQL]: https://www.postgresql.org +pub struct PostgresStore { + inner: Arc, + + // Version counter to ensure that writes are applied in the correct order. It is assumed that read and list + // operations aren't sensitive to the order of execution. + next_write_version: AtomicU64, + + // An internal runtime we use to avoid any deadlocks we could hit when waiting on async + // operations to finish from a sync context. + internal_runtime: Option, +} + +// tokio::sync::Mutex (used for the DB client) contains UnsafeCell which opts out of +// RefUnwindSafe. std::sync::Mutex (used by SqliteStore) doesn't have this issue because +// it poisons on panic. This impl is needed for do_read_write_remove_list_persist which +// requires K: KVStoreSync + RefUnwindSafe. +#[cfg(test)] +impl std::panic::RefUnwindSafe for PostgresStore {} + +impl PostgresStore { + /// Constructs a new [`PostgresStore`]. + /// + /// Connects to the PostgreSQL database at the given `connection_string`. + /// + /// The given `db_name` will be used or default to [`DEFAULT_DB_NAME`]. The database will be + /// created automatically if it doesn't already exist. + /// + /// The given `kv_table_name` will be used or default to [`DEFAULT_KV_TABLE_NAME`]. + /// + /// If `tls_config` is `Some`, TLS will be used for database connections. A custom CA + /// certificate can be provided via [`PostgresTlsConfig::certificate_pem`], otherwise the + /// system's default root certificates are used. If `tls_config` is `None`, connections + /// will be unencrypted. + pub fn new( + connection_string: String, db_name: Option, kv_table_name: Option, + tls_config: Option, + ) -> io::Result { + let tls = Self::build_tls_connector(tls_config)?; + + let internal_runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .thread_name_fn(|| { + static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0); + let id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst); + format!("ldk-node-pg-runtime-{id}") + }) + .worker_threads(INTERNAL_RUNTIME_WORKERS) + .max_blocking_threads(INTERNAL_RUNTIME_WORKERS) + .build() + .unwrap(); + + let inner = tokio::task::block_in_place(|| { + internal_runtime.block_on(async { + PostgresStoreInner::new(connection_string, db_name, kv_table_name, tls).await + }) + })?; + + let inner = Arc::new(inner); + let next_write_version = AtomicU64::new(1); + Ok(Self { inner, next_write_version, internal_runtime: Some(internal_runtime) }) + } + + fn build_tls_connector(tls_config: Option) -> io::Result { + match tls_config { + Some(config) => { + let mut builder = TlsConnector::builder(); + if let Some(pem) = config.certificate_pem { + let crt = native_tls::Certificate::from_pem(pem.as_bytes()).map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidInput, + format!("Failed to parse PEM certificate: {e}"), + ) + })?; + builder.add_root_certificate(crt); + } + let connector = builder.build().map_err(|e| { + io::Error::new( + io::ErrorKind::Other, + format!("Failed to build TLS connector: {e}"), + ) + })?; + Ok(PgTlsConnector::NativeTls(MakeTlsConnector::new(connector))) + }, + None => Ok(PgTlsConnector::Plain), + } + } + + fn build_locking_key( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, + ) -> String { + format!("{primary_namespace}#{secondary_namespace}#{key}") + } + + fn get_new_version_and_lock_ref( + &self, locking_key: String, + ) -> (Arc>, u64) { + let version = self.next_write_version.fetch_add(1, Ordering::Relaxed); + if version == u64::MAX { + panic!("PostgresStore version counter overflowed"); + } + + let inner_lock_ref = self.inner.get_inner_lock_ref(locking_key); + + (inner_lock_ref, version) + } +} + +impl Drop for PostgresStore { + fn drop(&mut self) { + let internal_runtime = self.internal_runtime.take(); + tokio::task::block_in_place(move || drop(internal_runtime)); + } +} + +impl KVStore for PostgresStore { + fn read( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, + ) -> impl Future, io::Error>> + 'static + Send { + let primary_namespace = primary_namespace.to_string(); + let secondary_namespace = secondary_namespace.to_string(); + let key = key.to_string(); + let inner = Arc::clone(&self.inner); + async move { inner.read_internal(&primary_namespace, &secondary_namespace, &key).await } + } + + fn write( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, buf: Vec, + ) -> impl Future> + 'static + Send { + let locking_key = self.build_locking_key(primary_namespace, secondary_namespace, key); + let (inner_lock_ref, version) = self.get_new_version_and_lock_ref(locking_key.clone()); + let primary_namespace = primary_namespace.to_string(); + let secondary_namespace = secondary_namespace.to_string(); + let key = key.to_string(); + let inner = Arc::clone(&self.inner); + async move { + inner + .write_internal( + inner_lock_ref, + locking_key, + version, + &primary_namespace, + &secondary_namespace, + &key, + buf, + ) + .await + } + } + + fn remove( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, _lazy: bool, + ) -> impl Future> + 'static + Send { + let locking_key = self.build_locking_key(primary_namespace, secondary_namespace, key); + let (inner_lock_ref, version) = self.get_new_version_and_lock_ref(locking_key.clone()); + let primary_namespace = primary_namespace.to_string(); + let secondary_namespace = secondary_namespace.to_string(); + let key = key.to_string(); + let inner = Arc::clone(&self.inner); + async move { + inner + .remove_internal( + inner_lock_ref, + locking_key, + version, + &primary_namespace, + &secondary_namespace, + &key, + ) + .await + } + } + + fn list( + &self, primary_namespace: &str, secondary_namespace: &str, + ) -> impl Future, io::Error>> + 'static + Send { + let primary_namespace = primary_namespace.to_string(); + let secondary_namespace = secondary_namespace.to_string(); + let inner = Arc::clone(&self.inner); + async move { inner.list_internal(&primary_namespace, &secondary_namespace).await } + } +} + +impl KVStoreSync for PostgresStore { + fn read( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, + ) -> io::Result> { + let internal_runtime = self.internal_runtime.as_ref().ok_or_else(|| { + debug_assert!(false, "Failed to access internal runtime"); + io::Error::new(io::ErrorKind::Other, "Failed to access internal runtime") + })?; + let primary_namespace = primary_namespace.to_string(); + let secondary_namespace = secondary_namespace.to_string(); + let key = key.to_string(); + let inner = Arc::clone(&self.inner); + let fut = async move { + inner.read_internal(&primary_namespace, &secondary_namespace, &key).await + }; + tokio::task::block_in_place(move || internal_runtime.block_on(fut)) + } + + fn write( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, buf: Vec, + ) -> io::Result<()> { + let internal_runtime = self.internal_runtime.as_ref().ok_or_else(|| { + debug_assert!(false, "Failed to access internal runtime"); + io::Error::new(io::ErrorKind::Other, "Failed to access internal runtime") + })?; + let locking_key = self.build_locking_key(primary_namespace, secondary_namespace, key); + let (inner_lock_ref, version) = self.get_new_version_and_lock_ref(locking_key.clone()); + let primary_namespace = primary_namespace.to_string(); + let secondary_namespace = secondary_namespace.to_string(); + let key = key.to_string(); + let inner = Arc::clone(&self.inner); + let fut = async move { + inner + .write_internal( + inner_lock_ref, + locking_key, + version, + &primary_namespace, + &secondary_namespace, + &key, + buf, + ) + .await + }; + tokio::task::block_in_place(move || internal_runtime.block_on(fut)) + } + + fn remove( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, _lazy: bool, + ) -> io::Result<()> { + let internal_runtime = self.internal_runtime.as_ref().ok_or_else(|| { + debug_assert!(false, "Failed to access internal runtime"); + io::Error::new(io::ErrorKind::Other, "Failed to access internal runtime") + })?; + let locking_key = self.build_locking_key(primary_namespace, secondary_namespace, key); + let (inner_lock_ref, version) = self.get_new_version_and_lock_ref(locking_key.clone()); + let primary_namespace = primary_namespace.to_string(); + let secondary_namespace = secondary_namespace.to_string(); + let key = key.to_string(); + let inner = Arc::clone(&self.inner); + let fut = async move { + inner + .remove_internal( + inner_lock_ref, + locking_key, + version, + &primary_namespace, + &secondary_namespace, + &key, + ) + .await + }; + tokio::task::block_in_place(move || internal_runtime.block_on(fut)) + } + + fn list(&self, primary_namespace: &str, secondary_namespace: &str) -> io::Result> { + let internal_runtime = self.internal_runtime.as_ref().ok_or_else(|| { + debug_assert!(false, "Failed to access internal runtime"); + io::Error::new(io::ErrorKind::Other, "Failed to access internal runtime") + })?; + let primary_namespace = primary_namespace.to_string(); + let secondary_namespace = secondary_namespace.to_string(); + let inner = Arc::clone(&self.inner); + let fut = + async move { inner.list_internal(&primary_namespace, &secondary_namespace).await }; + tokio::task::block_in_place(move || internal_runtime.block_on(fut)) + } +} + +impl PaginatedKVStoreSync for PostgresStore { + fn list_paginated( + &self, primary_namespace: &str, secondary_namespace: &str, page_token: Option, + ) -> io::Result { + let internal_runtime = self.internal_runtime.as_ref().ok_or_else(|| { + debug_assert!(false, "Failed to access internal runtime"); + io::Error::new(io::ErrorKind::Other, "Failed to access internal runtime") + })?; + let primary_namespace = primary_namespace.to_string(); + let secondary_namespace = secondary_namespace.to_string(); + let inner = Arc::clone(&self.inner); + let fut = async move { + inner + .list_paginated_internal(&primary_namespace, &secondary_namespace, page_token) + .await + }; + tokio::task::block_in_place(move || internal_runtime.block_on(fut)) + } +} + +impl PaginatedKVStore for PostgresStore { + fn list_paginated( + &self, primary_namespace: &str, secondary_namespace: &str, page_token: Option, + ) -> impl Future> + 'static + Send { + let primary_namespace = primary_namespace.to_string(); + let secondary_namespace = secondary_namespace.to_string(); + let inner = Arc::clone(&self.inner); + async move { + inner + .list_paginated_internal(&primary_namespace, &secondary_namespace, page_token) + .await + } + } +} + +struct PostgresStoreInner { + client: tokio::sync::Mutex, + config: Config, + kv_table_name: String, + tls: PgTlsConnector, + write_version_locks: Mutex>>>, + next_sort_order: AtomicI64, +} + +impl PostgresStoreInner { + async fn new( + connection_string: String, db_name: Option, kv_table_name: Option, + tls: PgTlsConnector, + ) -> io::Result { + let kv_table_name = kv_table_name.unwrap_or(DEFAULT_KV_TABLE_NAME.to_string()); + + let mut config: Config = connection_string.parse().map_err(|e: PgError| { + let msg = format!("Failed to parse PostgreSQL connection string: {e}"); + io::Error::new(io::ErrorKind::InvalidInput, msg) + })?; + + if db_name.is_some() && config.get_dbname().is_some() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "db_name must not be set when the connection string already contains a dbname", + )); + } + + let db_name = db_name + .or_else(|| config.get_dbname().map(|s| s.to_string())) + .unwrap_or(DEFAULT_DB_NAME.to_string()); + config.dbname(&db_name); + Self::create_database_if_not_exists(&config, &db_name, &tls).await?; + + let client = Self::make_config_connection(&config, &tls).await?; + + // Create the KV data table if it doesn't exist. + let sql = format!( + "CREATE TABLE IF NOT EXISTS {kv_table_name} ( + primary_namespace TEXT NOT NULL, + secondary_namespace TEXT NOT NULL DEFAULT '', + key TEXT NOT NULL CHECK (key <> ''), + value BYTEA, + sort_order BIGINT NOT NULL DEFAULT 0, + PRIMARY KEY (primary_namespace, secondary_namespace, key) + )" + ); + client.execute(sql.as_str(), &[]).await.map_err(|e| { + let msg = format!("Failed to create table {kv_table_name}: {e}"); + io::Error::new(io::ErrorKind::Other, msg) + })?; + + // Read the schema version from the table comment (analogous to SQLite's PRAGMA user_version). + let sql = format!("SELECT obj_description('{kv_table_name}'::regclass, 'pg_class')"); + let row = client.query_one(sql.as_str(), &[]).await.map_err(|e| { + let msg = format!("Failed to read schema version for {kv_table_name}: {e}"); + io::Error::new(io::ErrorKind::Other, msg) + })?; + let version_res: u16 = match row.get::<_, Option<&str>>(0) { + Some(version_str) => version_str.parse().map_err(|_| { + let msg = format!("Invalid schema version: {version_str}"); + io::Error::new(io::ErrorKind::Other, msg) + })?, + None => 0, + }; + + if version_res == 0 { + // New table, set our SCHEMA_VERSION. + let sql = format!("COMMENT ON TABLE {kv_table_name} IS '{SCHEMA_VERSION}'"); + client.execute(sql.as_str(), &[]).await.map_err(|e| { + let msg = format!("Failed to set schema version: {e}"); + io::Error::new(io::ErrorKind::Other, msg) + })?; + } else if version_res < SCHEMA_VERSION { + migrations::migrate_schema(&client, &kv_table_name, version_res, SCHEMA_VERSION) + .await?; + } else if version_res > SCHEMA_VERSION { + let msg = format!( + "Failed to open database: incompatible schema version {version_res}. Expected: {SCHEMA_VERSION}" + ); + return Err(io::Error::new(io::ErrorKind::Other, msg)); + } + + // Create composite index for paginated listing. + let sql = format!( + "CREATE INDEX IF NOT EXISTS idx_{kv_table_name}_paginated ON {kv_table_name} (primary_namespace, secondary_namespace, sort_order DESC, key ASC)" + ); + client.execute(sql.as_str(), &[]).await.map_err(|e| { + let msg = format!("Failed to create index on table {kv_table_name}: {e}"); + io::Error::new(io::ErrorKind::Other, msg) + })?; + + // Initialize next_sort_order from the max existing value. + let sql = format!("SELECT COALESCE(MAX(sort_order), 0) FROM {kv_table_name}"); + let row = client.query_one(sql.as_str(), &[]).await.map_err(|e| { + let msg = format!("Failed to read max sort_order from {kv_table_name}: {e}"); + io::Error::new(io::ErrorKind::Other, msg) + })?; + let max_sort_order: i64 = row.get(0); + let next_sort_order = AtomicI64::new(max_sort_order + 1); + + let client = tokio::sync::Mutex::new(client); + let write_version_locks = Mutex::new(HashMap::new()); + Ok(Self { client, config, kv_table_name, tls, write_version_locks, next_sort_order }) + } + + async fn create_database_if_not_exists( + config: &Config, db_name: &str, tls: &PgTlsConnector, + ) -> io::Result<()> { + // Connect without a dbname (to the default database) so we can create the target. + let mut config = config.clone(); + config.dbname("postgres"); + + let client = Self::make_config_connection(&config, tls).await?; + + let row = client + .query_opt("SELECT 1 FROM pg_database WHERE datname = $1", &[&db_name]) + .await + .map_err(|e| { + let msg = format!("Failed to check for database {db_name}: {e}"); + io::Error::new(io::ErrorKind::Other, msg) + })?; + + if row.is_none() { + let sql = format!("CREATE DATABASE {db_name}"); + client.execute(&sql, &[]).await.map_err(|e| { + let msg = format!("Failed to create database {db_name}: {e}"); + io::Error::new(io::ErrorKind::Other, msg) + })?; + log::info!("Created database {db_name}"); + } + + Ok(()) + } + + async fn make_config_connection(config: &Config, tls: &PgTlsConnector) -> io::Result { + let err_map = |e| { + let msg = format!("Failed to connect to PostgreSQL: {e}"); + io::Error::new(io::ErrorKind::Other, msg) + }; + + match tls { + PgTlsConnector::Plain => { + let (client, connection) = config.connect(NoTls).await.map_err(err_map)?; + tokio::spawn(async move { + if let Err(e) = connection.await { + log::error!("PostgreSQL connection error: {e}"); + } + }); + Ok(client) + }, + PgTlsConnector::NativeTls(tls_connector) => { + let (client, connection) = + config.connect(tls_connector.clone()).await.map_err(err_map)?; + tokio::spawn(async move { + if let Err(e) = connection.await { + log::error!("PostgreSQL connection error: {e}"); + } + }); + Ok(client) + }, + } + } + + async fn ensure_connected( + &self, client: &mut tokio::sync::MutexGuard<'_, Client>, + ) -> io::Result<()> { + if client.is_closed() || client.check_connection().await.is_err() { + log::debug!("Reconnecting to PostgreSQL database"); + let new_client = Self::make_config_connection(&self.config, &self.tls).await?; + **client = new_client; + } + Ok(()) + } + + fn get_inner_lock_ref(&self, locking_key: String) -> Arc> { + let mut outer_lock = self.write_version_locks.lock().unwrap(); + Arc::clone(&outer_lock.entry(locking_key).or_default()) + } + + async fn read_internal( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, + ) -> io::Result> { + check_namespace_key_validity(primary_namespace, secondary_namespace, Some(key), "read")?; + + let mut locked_client = self.client.lock().await; + self.ensure_connected(&mut locked_client).await?; + let sql = format!( + "SELECT value FROM {} WHERE primary_namespace=$1 AND secondary_namespace=$2 AND key=$3", + self.kv_table_name + ); + + let row = locked_client + .query_opt(sql.as_str(), &[&primary_namespace, &secondary_namespace, &key]) + .await + .map_err(|e| { + let msg = format!( + "Failed to read from key {}/{}/{}: {}", + PrintableString(primary_namespace), + PrintableString(secondary_namespace), + PrintableString(key), + e + ); + io::Error::new(io::ErrorKind::Other, msg) + })?; + + match row { + Some(row) => { + let value: Vec = row.get(0); + Ok(value) + }, + None => { + let msg = format!( + "Failed to read as key could not be found: {}/{}/{}", + PrintableString(primary_namespace), + PrintableString(secondary_namespace), + PrintableString(key), + ); + Err(io::Error::new(io::ErrorKind::NotFound, msg)) + }, + } + } + + async fn write_internal( + &self, inner_lock_ref: Arc>, locking_key: String, version: u64, + primary_namespace: &str, secondary_namespace: &str, key: &str, buf: Vec, + ) -> io::Result<()> { + check_namespace_key_validity(primary_namespace, secondary_namespace, Some(key), "write")?; + + self.execute_locked_write(inner_lock_ref, locking_key, version, async move || { + let mut locked_client = self.client.lock().await; + self.ensure_connected(&mut locked_client).await?; + + let sort_order = self.next_sort_order.fetch_add(1, Ordering::Relaxed); + + let sql = format!( + "INSERT INTO {} (primary_namespace, secondary_namespace, key, value, sort_order) \ + VALUES ($1, $2, $3, $4, $5) \ + ON CONFLICT (primary_namespace, secondary_namespace, key) DO UPDATE SET value = EXCLUDED.value", + self.kv_table_name + ); + + locked_client + .execute( + sql.as_str(), + &[ + &primary_namespace, + &secondary_namespace, + &key, + &buf, + &sort_order, + ], + ) + .await + .map(|_| ()) + .map_err(|e| { + let msg = format!( + "Failed to write to key {}/{}/{}: {}", + PrintableString(primary_namespace), + PrintableString(secondary_namespace), + PrintableString(key), + e + ); + io::Error::new(io::ErrorKind::Other, msg) + }) + }) + .await + } + + async fn remove_internal( + &self, inner_lock_ref: Arc>, locking_key: String, version: u64, + primary_namespace: &str, secondary_namespace: &str, key: &str, + ) -> io::Result<()> { + check_namespace_key_validity(primary_namespace, secondary_namespace, Some(key), "remove")?; + + self.execute_locked_write(inner_lock_ref, locking_key, version, async move || { + let mut locked_client = self.client.lock().await; + self.ensure_connected(&mut locked_client).await?; + + let sql = format!( + "DELETE FROM {} WHERE primary_namespace=$1 AND secondary_namespace=$2 AND key=$3", + self.kv_table_name + ); + + locked_client + .execute(sql.as_str(), &[&primary_namespace, &secondary_namespace, &key]) + .await + .map_err(|e| { + let msg = format!( + "Failed to delete key {}/{}/{}: {}", + PrintableString(primary_namespace), + PrintableString(secondary_namespace), + PrintableString(key), + e + ); + io::Error::new(io::ErrorKind::Other, msg) + })?; + Ok(()) + }) + .await + } + + async fn list_internal( + &self, primary_namespace: &str, secondary_namespace: &str, + ) -> io::Result> { + check_namespace_key_validity(primary_namespace, secondary_namespace, None, "list")?; + + let mut locked_client = self.client.lock().await; + self.ensure_connected(&mut locked_client).await?; + + let sql = format!( + "SELECT key FROM {} WHERE primary_namespace=$1 AND secondary_namespace=$2", + self.kv_table_name + ); + + let rows = locked_client + .query(sql.as_str(), &[&primary_namespace, &secondary_namespace]) + .await + .map_err(|e| { + let msg = format!("Failed to retrieve queried rows: {e}"); + io::Error::new(io::ErrorKind::Other, msg) + })?; + + let keys: Vec = rows.iter().map(|row| row.get(0)).collect(); + Ok(keys) + } + + async fn list_paginated_internal( + &self, primary_namespace: &str, secondary_namespace: &str, page_token: Option, + ) -> io::Result { + check_namespace_key_validity( + primary_namespace, + secondary_namespace, + None, + "list_paginated", + )?; + + let mut locked_client = self.client.lock().await; + self.ensure_connected(&mut locked_client).await?; + + // Fetch one extra row beyond PAGE_SIZE to determine whether a next page exists. + let fetch_limit = (PAGE_SIZE + 1) as i64; + + let mut entries: Vec<(String, i64)> = match page_token { + Some(ref token) => { + let token_sort_order: i64 = token.as_str().parse().map_err(|_| { + let token_str = token.as_str(); + let msg = format!("Invalid page token: {token_str}"); + io::Error::new(io::ErrorKind::InvalidInput, msg) + })?; + let sql = format!( + "SELECT key, sort_order FROM {} \ + WHERE primary_namespace=$1 \ + AND secondary_namespace=$2 \ + AND sort_order < $3 \ + ORDER BY sort_order DESC, key ASC \ + LIMIT $4", + self.kv_table_name + ); + + let rows = locked_client + .query( + sql.as_str(), + &[ + &primary_namespace, + &secondary_namespace, + &token_sort_order, + &fetch_limit, + ], + ) + .await + .map_err(|e| { + let msg = format!("Failed to retrieve queried rows: {e}"); + io::Error::new(io::ErrorKind::Other, msg) + })?; + + rows.iter().map(|row| (row.get(0), row.get(1))).collect() + }, + None => { + let sql = format!( + "SELECT key, sort_order FROM {} \ + WHERE primary_namespace=$1 \ + AND secondary_namespace=$2 \ + ORDER BY sort_order DESC, key ASC \ + LIMIT $3", + self.kv_table_name + ); + + let rows = locked_client + .query(sql.as_str(), &[&primary_namespace, &secondary_namespace, &fetch_limit]) + .await + .map_err(|e| { + let msg = format!("Failed to retrieve queried rows: {e}"); + io::Error::new(io::ErrorKind::Other, msg) + })?; + + rows.into_iter().map(|row| (row.get(0), row.get(1))).collect() + }, + }; + + let has_more = entries.len() > PAGE_SIZE; + entries.truncate(PAGE_SIZE); + + let next_page_token = if has_more { + let (_, last_sort_order) = *entries.last().expect("must be non-empty"); + Some(PageToken::new(last_sort_order.to_string())) + } else { + None + }; + + let keys = entries.into_iter().map(|(k, _)| k).collect(); + Ok(PaginatedListResponse { keys, next_page_token }) + } + + async fn execute_locked_write>, FN: FnOnce() -> F>( + &self, inner_lock_ref: Arc>, locking_key: String, version: u64, + callback: FN, + ) -> Result<(), io::Error> { + let res = { + let mut last_written_version = inner_lock_ref.lock().await; + + // Check if we already have a newer version written/removed. This is used in async contexts to realize eventual + // consistency. + let is_stale_version = version <= *last_written_version; + + // If the version is not stale, we execute the callback. Otherwise, we can and must skip writing. + if is_stale_version { + Ok(()) + } else { + callback().await.map(|_| { + *last_written_version = version; + }) + } + }; + + self.clean_locks(&inner_lock_ref, locking_key); + + res + } + + fn clean_locks(&self, inner_lock_ref: &Arc>, locking_key: String) { + // If there are no arcs in use elsewhere, this means that there are no in-flight writes. We can remove the map + // entry to prevent leaking memory. The two arcs that are expected are the one in the map and the one held here + // in inner_lock_ref. The outer lock is obtained first, to avoid a new arc being cloned after we've already + // counted. + let mut outer_lock = self.write_version_locks.lock().unwrap(); + + let strong_count = Arc::strong_count(inner_lock_ref); + debug_assert!(strong_count >= 2, "Unexpected PostgresStore strong count"); + + if strong_count == 2 { + outer_lock.remove(&locking_key); + } + } +} + +/// TLS configuration for PostgreSQL connections. +#[derive(Debug, Clone)] +pub struct PostgresTlsConfig { + /// PEM-encoded CA certificate. If `None`, the system's default root certificates are used. + pub certificate_pem: Option, +} + +#[derive(Clone)] +enum PgTlsConnector { + Plain, + NativeTls(MakeTlsConnector), +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::io::test_utils::{do_read_write_remove_list_persist, do_test_store}; + + fn test_connection_string() -> String { + std::env::var("TEST_POSTGRES_URL") + .unwrap_or_else(|_| "host=localhost user=postgres password=postgres".to_string()) + } + + fn create_test_store(table_name: &str) -> PostgresStore { + PostgresStore::new(test_connection_string(), None, Some(table_name.to_string()), None) + .unwrap() + } + + fn cleanup_store(store: &PostgresStore) { + if let Some(ref runtime) = store.internal_runtime { + let kv_table = store.inner.kv_table_name.clone(); + let inner = Arc::clone(&store.inner); + let _ = tokio::task::block_in_place(|| { + runtime.block_on(async { + let client = inner.client.lock().await; + let _ = client.execute(&format!("DROP TABLE IF EXISTS {kv_table}"), &[]).await; + }) + }); + } + } + + #[test] + fn read_write_remove_list_persist() { + let store = create_test_store("test_rwrl"); + do_read_write_remove_list_persist(&store); + cleanup_store(&store); + } + + #[test] + fn test_postgres_store() { + let store_0 = create_test_store("test_pg_store_0"); + let store_1 = create_test_store("test_pg_store_1"); + do_test_store(&store_0, &store_1); + cleanup_store(&store_0); + cleanup_store(&store_1); + } + + #[test] + fn test_postgres_store_auto_reconnect() { + let store = create_test_store("test_pg_reconnect"); + + let ns = "test_ns"; + let sub = "test_sub"; + + // Write a value before disconnecting. + KVStoreSync::write(&store, ns, sub, "key_a", vec![1u8; 8]).unwrap(); + + // Terminate the backend connection to simulate a dropped connection. + if let Some(ref runtime) = store.internal_runtime { + let inner = Arc::clone(&store.inner); + tokio::task::block_in_place(|| { + runtime.block_on(async { + let client = inner.client.lock().await; + let _ = + client.execute("SELECT pg_terminate_backend(pg_backend_pid())", &[]).await; + }) + }); + } + + // Read should auto-reconnect and return the previously written value. + let data = KVStoreSync::read(&store, ns, sub, "key_a").unwrap(); + assert_eq!(data, vec![1u8; 8]); + + // Write should also work after reconnect. + KVStoreSync::write(&store, ns, sub, "key_b", vec![2u8; 8]).unwrap(); + let data = KVStoreSync::read(&store, ns, sub, "key_b").unwrap(); + assert_eq!(data, vec![2u8; 8]); + + cleanup_store(&store); + } + + #[test] + fn test_postgres_store_paginated_listing() { + let store = create_test_store("test_pg_paginated"); + + let primary_namespace = "test_ns"; + let secondary_namespace = "test_sub"; + let num_entries = 225; + + for i in 0..num_entries { + let key = format!("key_{:04}", i); + let data = vec![i as u8; 32]; + KVStoreSync::write(&store, primary_namespace, secondary_namespace, &key, data).unwrap(); + } + + // Paginate through all entries and collect them + let mut all_keys = Vec::new(); + let mut page_token = None; + let mut page_count = 0; + + loop { + let response = PaginatedKVStoreSync::list_paginated( + &store, + primary_namespace, + secondary_namespace, + page_token, + ) + .unwrap(); + + all_keys.extend(response.keys.clone()); + page_count += 1; + + match response.next_page_token { + Some(token) => page_token = Some(token), + None => break, + } + } + + // Verify we got exactly the right number of entries + assert_eq!(all_keys.len(), num_entries); + + // Verify correct number of pages (225 entries at 50 per page = 5 pages) + assert_eq!(page_count, 5); + + // Verify no duplicates + let mut unique_keys = all_keys.clone(); + unique_keys.sort(); + unique_keys.dedup(); + assert_eq!(unique_keys.len(), num_entries); + + // Verify ordering: newest first (highest sort_order first). + assert_eq!(all_keys[0], format!("key_{:04}", num_entries - 1)); + assert_eq!(all_keys[num_entries - 1], "key_0000"); + + cleanup_store(&store); + } + + #[test] + fn test_postgres_store_paginated_update_preserves_order() { + let store = create_test_store("test_pg_paginated_update"); + + let primary_namespace = "test_ns"; + let secondary_namespace = "test_sub"; + + KVStoreSync::write(&store, primary_namespace, secondary_namespace, "first", vec![1u8; 8]) + .unwrap(); + KVStoreSync::write(&store, primary_namespace, secondary_namespace, "second", vec![2u8; 8]) + .unwrap(); + KVStoreSync::write(&store, primary_namespace, secondary_namespace, "third", vec![3u8; 8]) + .unwrap(); + + // Update the first entry + KVStoreSync::write(&store, primary_namespace, secondary_namespace, "first", vec![99u8; 8]) + .unwrap(); + + // Paginated listing should still show "first" with its original creation order + let response = PaginatedKVStoreSync::list_paginated( + &store, + primary_namespace, + secondary_namespace, + None, + ) + .unwrap(); + + // Newest first: third, second, first + assert_eq!(response.keys, vec!["third", "second", "first"]); + + // Verify the updated value was persisted + let data = + KVStoreSync::read(&store, primary_namespace, secondary_namespace, "first").unwrap(); + assert_eq!(data, vec![99u8; 8]); + + cleanup_store(&store); + } + + #[test] + fn test_postgres_store_paginated_empty_namespace() { + let store = create_test_store("test_pg_paginated_empty"); + + // Paginating an empty or unknown namespace returns an empty result with no token. + let response = + PaginatedKVStoreSync::list_paginated(&store, "nonexistent", "ns", None).unwrap(); + assert!(response.keys.is_empty()); + assert!(response.next_page_token.is_none()); + + cleanup_store(&store); + } + + #[test] + fn test_postgres_store_paginated_namespace_isolation() { + let store = create_test_store("test_pg_paginated_isolation"); + + KVStoreSync::write(&store, "ns_a", "sub", "key_1", vec![1u8; 8]).unwrap(); + KVStoreSync::write(&store, "ns_a", "sub", "key_2", vec![2u8; 8]).unwrap(); + KVStoreSync::write(&store, "ns_b", "sub", "key_3", vec![3u8; 8]).unwrap(); + KVStoreSync::write(&store, "ns_a", "other", "key_4", vec![4u8; 8]).unwrap(); + + // ns_a/sub should only contain key_1 and key_2 (newest first). + let response = PaginatedKVStoreSync::list_paginated(&store, "ns_a", "sub", None).unwrap(); + assert_eq!(response.keys, vec!["key_2", "key_1"]); + assert!(response.next_page_token.is_none()); + + // ns_b/sub should only contain key_3. + let response = PaginatedKVStoreSync::list_paginated(&store, "ns_b", "sub", None).unwrap(); + assert_eq!(response.keys, vec!["key_3"]); + + // ns_a/other should only contain key_4. + let response = PaginatedKVStoreSync::list_paginated(&store, "ns_a", "other", None).unwrap(); + assert_eq!(response.keys, vec!["key_4"]); + + cleanup_store(&store); + } + + #[test] + fn test_postgres_store_paginated_removal() { + let store = create_test_store("test_pg_paginated_removal"); + + let ns = "test_ns"; + let sub = "test_sub"; + + KVStoreSync::write(&store, ns, sub, "a", vec![1u8; 8]).unwrap(); + KVStoreSync::write(&store, ns, sub, "b", vec![2u8; 8]).unwrap(); + KVStoreSync::write(&store, ns, sub, "c", vec![3u8; 8]).unwrap(); + + KVStoreSync::remove(&store, ns, sub, "b", false).unwrap(); + + let response = PaginatedKVStoreSync::list_paginated(&store, ns, sub, None).unwrap(); + assert_eq!(response.keys, vec!["c", "a"]); + assert!(response.next_page_token.is_none()); + + cleanup_store(&store); + } + + #[test] + fn test_postgres_store_paginated_exact_page_boundary() { + let store = create_test_store("test_pg_paginated_boundary"); + + let ns = "test_ns"; + let sub = "test_sub"; + + // Write exactly PAGE_SIZE entries (50). + for i in 0..PAGE_SIZE { + let key = format!("key_{:04}", i); + KVStoreSync::write(&store, ns, sub, &key, vec![i as u8; 8]).unwrap(); + } + + // Exactly PAGE_SIZE entries: all returned in one page with no next-page token. + let response = PaginatedKVStoreSync::list_paginated(&store, ns, sub, None).unwrap(); + assert_eq!(response.keys.len(), PAGE_SIZE); + assert!(response.next_page_token.is_none()); + + // Add one more entry (PAGE_SIZE + 1 total). First page should now have a token. + KVStoreSync::write(&store, ns, sub, "key_extra", vec![0u8; 8]).unwrap(); + let response = PaginatedKVStoreSync::list_paginated(&store, ns, sub, None).unwrap(); + assert_eq!(response.keys.len(), PAGE_SIZE); + assert!(response.next_page_token.is_some()); + + // Second page should have exactly 1 entry and no token. + let response = + PaginatedKVStoreSync::list_paginated(&store, ns, sub, response.next_page_token) + .unwrap(); + assert_eq!(response.keys.len(), 1); + assert!(response.next_page_token.is_none()); + + cleanup_store(&store); + } + + #[test] + fn test_postgres_store_paginated_fewer_than_page_size() { + let store = create_test_store("test_pg_paginated_few"); + + let ns = "test_ns"; + let sub = "test_sub"; + + // Write fewer entries than PAGE_SIZE. + for i in 0..5 { + let key = format!("key_{i}"); + KVStoreSync::write(&store, ns, sub, &key, vec![i as u8; 8]).unwrap(); + } + + let response = PaginatedKVStoreSync::list_paginated(&store, ns, sub, None).unwrap(); + assert_eq!(response.keys.len(), 5); + // Fewer than PAGE_SIZE means no next page. + assert!(response.next_page_token.is_none()); + // Newest first. + assert_eq!(response.keys, vec!["key_4", "key_3", "key_2", "key_1", "key_0"]); + + cleanup_store(&store); + } + + #[test] + fn test_postgres_store_write_version_persists_across_restart() { + let table_name = "test_pg_write_version_restart"; + let primary_namespace = "test_ns"; + let secondary_namespace = "test_sub"; + + { + let store = create_test_store(table_name); + + KVStoreSync::write( + &store, + primary_namespace, + secondary_namespace, + "key_a", + vec![1u8; 8], + ) + .unwrap(); + KVStoreSync::write( + &store, + primary_namespace, + secondary_namespace, + "key_b", + vec![2u8; 8], + ) + .unwrap(); + + // Don't clean up since we want to reopen + } + + // Open a new store instance on the same database table and write more + { + let store = create_test_store(table_name); + + KVStoreSync::write( + &store, + primary_namespace, + secondary_namespace, + "key_c", + vec![3u8; 8], + ) + .unwrap(); + + // Paginated listing should show newest first: key_c, key_b, key_a + let response = PaginatedKVStoreSync::list_paginated( + &store, + primary_namespace, + secondary_namespace, + None, + ) + .unwrap(); + + assert_eq!(response.keys, vec!["key_c", "key_b", "key_a"]); + + cleanup_store(&store); + } + } + + #[test] + fn test_tls_config_none_builds_plain_connector() { + let connector = PostgresStore::build_tls_connector(None).unwrap(); + assert!(matches!(connector, PgTlsConnector::Plain)); + } + + #[test] + fn test_tls_config_system_certs_builds_native_tls_connector() { + let config = Some(PostgresTlsConfig { certificate_pem: None }); + let connector = PostgresStore::build_tls_connector(config).unwrap(); + assert!(matches!(connector, PgTlsConnector::NativeTls(_))); + } + + #[test] + fn test_tls_config_invalid_pem_returns_error() { + let config = + Some(PostgresTlsConfig { certificate_pem: Some("not-a-valid-pem".to_string()) }); + let result = PostgresStore::build_tls_connector(config); + assert!(result.is_err()); + } +} diff --git a/tests/integration_tests_postgres.rs b/tests/integration_tests_postgres.rs new file mode 100644 index 000000000..0c7d46b8e --- /dev/null +++ b/tests/integration_tests_postgres.rs @@ -0,0 +1,154 @@ +// This file is Copyright its original authors, visible in version control history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license , at your option. You may not use this file except in +// accordance with one or both of these licenses. + +#![cfg(feature = "postgres")] + +mod common; + +use ldk_node::entropy::NodeEntropy; +use ldk_node::Builder; +use rand::RngCore; + +fn test_connection_string() -> String { + std::env::var("TEST_POSTGRES_URL") + .unwrap_or_else(|_| "host=localhost user=postgres password=postgres".to_string()) +} + +async fn drop_table(table_name: &str) { + let connection_string = format!("{} dbname=ldk_node", test_connection_string()); + let (client, connection) = + tokio_postgres::connect(&connection_string, tokio_postgres::NoTls).await.unwrap(); + tokio::spawn(connection); + let _ = client.execute(&format!("DROP TABLE IF EXISTS {table_name}"), &[]).await; +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn channel_full_cycle_with_postgres_store() { + drop_table("channel_cycle_a").await; + drop_table("channel_cycle_b").await; + + let (bitcoind, electrsd) = common::setup_bitcoind_and_electrsd(); + println!("== Node A =="); + let esplora_url = format!("http://{}", electrsd.esplora_url.as_ref().unwrap()); + let config_a = common::random_config(true); + let mut builder_a = Builder::from_config(config_a.node_config); + builder_a.set_chain_source_esplora(esplora_url.clone(), None); + let node_a = builder_a + .build_with_postgres_store( + config_a.node_entropy, + test_connection_string(), + None, + Some("channel_cycle_a".to_string()), + None, + ) + .unwrap(); + node_a.start().unwrap(); + + println!("\n== Node B =="); + let config_b = common::random_config(true); + let mut builder_b = Builder::from_config(config_b.node_config); + builder_b.set_chain_source_esplora(esplora_url.clone(), None); + let node_b = builder_b + .build_with_postgres_store( + config_b.node_entropy, + test_connection_string(), + None, + Some("channel_cycle_b".to_string()), + None, + ) + .unwrap(); + node_b.start().unwrap(); + + common::do_channel_full_cycle( + node_a, + node_b, + &bitcoind.client, + &electrsd.client, + false, + true, + false, + ) + .await; + + drop_table("channel_cycle_a").await; + drop_table("channel_cycle_b").await; +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn postgres_node_restart() { + drop_table("restart_test").await; + + let (bitcoind, electrsd) = common::setup_bitcoind_and_electrsd(); + let esplora_url = format!("http://{}", electrsd.esplora_url.as_ref().unwrap()); + let connection_string = test_connection_string(); + + let storage_path = common::random_storage_path().to_str().unwrap().to_owned(); + let mut seed_bytes = [42u8; 64]; + rand::rng().fill_bytes(&mut seed_bytes); + let node_entropy = NodeEntropy::from_seed_bytes(seed_bytes); + + // Setup initial node and fund it. + let (expected_balance_sats, expected_node_id) = { + let mut builder = Builder::new(); + builder.set_network(bitcoin::Network::Regtest); + builder.set_storage_dir_path(storage_path.clone()); + builder.set_chain_source_esplora(esplora_url.clone(), None); + let node = builder + .build_with_postgres_store( + node_entropy, + connection_string.clone(), + None, + Some("restart_test".to_string()), + None, + ) + .unwrap(); + + node.start().unwrap(); + let addr = node.onchain_payment().new_address().unwrap(); + common::premine_and_distribute_funds( + &bitcoind.client, + &electrsd.client, + vec![addr], + bitcoin::Amount::from_sat(100_000), + ) + .await; + node.sync_wallets().unwrap(); + + let balance = node.list_balances().spendable_onchain_balance_sats; + assert!(balance > 0); + let node_id = node.node_id(); + + node.stop().unwrap(); + (balance, node_id) + }; + + // Verify node can be restarted from PostgreSQL backend. + let mut builder = Builder::new(); + builder.set_network(bitcoin::Network::Regtest); + builder.set_storage_dir_path(storage_path); + builder.set_chain_source_esplora(esplora_url, None); + + let node = builder + .build_with_postgres_store( + node_entropy, + connection_string.clone(), + None, + Some("restart_test".to_string()), + None, + ) + .unwrap(); + + node.start().unwrap(); + node.sync_wallets().unwrap(); + + assert_eq!(expected_node_id, node.node_id()); + assert_eq!(expected_balance_sats, node.list_balances().spendable_onchain_balance_sats); + + node.stop().unwrap(); + + drop_table("restart_test").await; +}