diff --git a/libsql/src/database/builder.rs b/libsql/src/database/builder.rs index 6479bde48a..101b0bed72 100644 --- a/libsql/src/database/builder.rs +++ b/libsql/src/database/builder.rs @@ -107,6 +107,7 @@ impl Builder<()> { connector: None, read_your_writes: true, remote_writes: false, + push_batch_size: 0, }, } } @@ -524,6 +525,7 @@ cfg_sync! { connector: Option, remote_writes: bool, read_your_writes: bool, + push_batch_size: u32, } impl Builder { @@ -543,6 +545,11 @@ cfg_sync! { self } + pub fn set_push_batch_size(mut self, v: u32) -> Builder { + self.inner.push_batch_size = v; + self + } + /// Provide a custom http connector that will be used to create http connections. pub fn connector(mut self, connector: C) -> Builder where @@ -570,6 +577,7 @@ cfg_sync! { connector, remote_writes, read_your_writes, + push_batch_size, } = self.inner; let path = path.to_str().ok_or(crate::Error::InvalidUTF8Path)?.to_owned(); @@ -596,6 +604,10 @@ cfg_sync! { ) .await?; + if push_batch_size > 0 { + db.sync_ctx.as_ref().unwrap().lock().await.set_push_batch_size(push_batch_size); + } + Ok(Database { db_type: DbType::Offline { db, diff --git a/libsql/src/sync.rs b/libsql/src/sync.rs index 5fcad3aa11..1680587b03 100644 --- a/libsql/src/sync.rs +++ b/libsql/src/sync.rs @@ -19,6 +19,7 @@ pub mod transaction; const METADATA_VERSION: u32 = 0; const DEFAULT_MAX_RETRIES: usize = 5; +const DEFAULT_PUSH_BATCH_SIZE: u32 = 128; #[derive(thiserror::Error, Debug)] #[non_exhaustive] @@ -74,6 +75,7 @@ pub struct SyncContext { sync_url: String, auth_token: Option, max_retries: usize, + push_batch_size: u32, /// The current durable generation. durable_generation: u32, /// Represents the max_frame_no from the server. @@ -102,6 +104,7 @@ impl SyncContext { sync_url, auth_token, max_retries: DEFAULT_MAX_RETRIES, + push_batch_size: DEFAULT_PUSH_BATCH_SIZE, client, durable_generation: 1, durable_frame_num: 0, @@ -117,6 +120,10 @@ impl SyncContext { Ok(me) } + pub fn set_push_batch_size(&mut self, push_batch_size: u32) { + self.push_batch_size = push_batch_size; + } + #[tracing::instrument(skip(self))] pub(crate) async fn pull_one_frame( &mut self, @@ -134,25 +141,26 @@ impl SyncContext { self.pull_with_retry(uri, self.max_retries).await } - #[tracing::instrument(skip(self, frame))] - pub(crate) async fn push_one_frame( + #[tracing::instrument(skip(self, frames))] + pub(crate) async fn push_frames( &mut self, - frame: Bytes, + frames: Bytes, generation: u32, frame_no: u32, + frames_count: u32, ) -> Result { let uri = format!( "{}/sync/{}/{}/{}", self.sync_url, generation, frame_no, - frame_no + 1 + frame_no + frames_count ); tracing::debug!("pushing frame"); - let (generation, durable_frame_num) = self.push_with_retry(uri, frame, self.max_retries).await?; + let (generation, durable_frame_num) = self.push_with_retry(uri, frames, self.max_retries).await?; - if durable_frame_num > frame_no { + if durable_frame_num > frame_no + frames_count - 1 { tracing::error!( "server returned durable_frame_num larger than what we sent: sent={}, got={}", frame_no, @@ -162,7 +170,7 @@ impl SyncContext { return Err(SyncError::InvalidPushFrameNoHigh(frame_no, durable_frame_num).into()); } - if durable_frame_num < frame_no { + if durable_frame_num < frame_no + frames_count - 1 { // Update our knowledge of where the server is at frame wise. self.durable_frame_num = durable_frame_num; @@ -186,7 +194,7 @@ impl SyncContext { Ok(durable_frame_num) } - async fn push_with_retry(&self, uri: String, frame: Bytes, max_retries: usize) -> Result<(u32, u32)> { + async fn push_with_retry(&self, uri: String, body: Bytes, max_retries: usize) -> Result<(u32, u32)> { let mut nr_retries = 0; loop { let mut req = http::Request::post(uri.clone()); @@ -200,7 +208,7 @@ impl SyncContext { None => {} } - let req = req.body(frame.clone().into()).expect("valid body"); + let req = req.body(body.clone().into()).expect("valid body"); let res = self .client @@ -537,19 +545,28 @@ async fn try_push( let mut frame_no = start_frame_no; while frame_no <= end_frame_no { - let frame = conn.wal_get_frame(frame_no, page_size)?; + let batch_size = sync_ctx.push_batch_size.min(end_frame_no - frame_no + 1); + let mut frames = conn.wal_get_frame(frame_no, page_size)?; + if batch_size > 1 { + frames.reserve((batch_size - 1) as usize * frames.len()); + } + for idx in 1..batch_size { + let frame = conn.wal_get_frame(frame_no + idx, page_size)?; + frames.extend_from_slice(frame.as_ref()) + } // The server returns its maximum frame number. To avoid resending // frames the server already knows about, we need to update the // frame number to the one returned by the server. let max_frame_no = sync_ctx - .push_one_frame(frame.freeze(), generation, frame_no) + .push_frames(frames.freeze(), generation, frame_no, batch_size) .await?; if max_frame_no > frame_no { - frame_no = max_frame_no; + frame_no = max_frame_no + 1; + } else { + frame_no += batch_size; } - frame_no += 1; } sync_ctx.write_metadata().await?; diff --git a/libsql/src/sync/test.rs b/libsql/src/sync/test.rs index 141abd05a7..84f7c4980c 100644 --- a/libsql/src/sync/test.rs +++ b/libsql/src/sync/test.rs @@ -28,7 +28,7 @@ async fn test_sync_context_push_frame() { let mut sync_ctx = sync_ctx; // Push a frame and verify the response - let durable_frame = sync_ctx.push_one_frame(frame, 1, 0).await.unwrap(); + let durable_frame = sync_ctx.push_frames(frame, 1, 0, 1).await.unwrap(); sync_ctx.write_metadata().await.unwrap(); assert_eq!(durable_frame, 0); // First frame should return max_frame_no = 0 @@ -56,7 +56,7 @@ async fn test_sync_context_with_auth() { let frame = Bytes::from("test frame with auth"); let mut sync_ctx = sync_ctx; - let durable_frame = sync_ctx.push_one_frame(frame, 1, 0).await.unwrap(); + let durable_frame = sync_ctx.push_frames(frame, 1, 0, 1).await.unwrap(); sync_ctx.write_metadata().await.unwrap(); assert_eq!(durable_frame, 0); assert_eq!(server.frame_count(), 1); @@ -82,7 +82,7 @@ async fn test_sync_context_multiple_frames() { // Push multiple frames and verify incrementing frame numbers for i in 0..3 { let frame = Bytes::from(format!("frame data {}", i)); - let durable_frame = sync_ctx.push_one_frame(frame, 1, i).await.unwrap(); + let durable_frame = sync_ctx.push_frames(frame, 1, i, 1).await.unwrap(); sync_ctx.write_metadata().await.unwrap(); assert_eq!(durable_frame, i); assert_eq!(sync_ctx.durable_frame_num(), i); @@ -108,7 +108,7 @@ async fn test_sync_context_corrupted_metadata() { let mut sync_ctx = sync_ctx; let frame = Bytes::from("test frame data"); - let durable_frame = sync_ctx.push_one_frame(frame, 1, 0).await.unwrap(); + let durable_frame = sync_ctx.push_frames(frame, 1, 0, 1).await.unwrap(); sync_ctx.write_metadata().await.unwrap(); assert_eq!(durable_frame, 0); assert_eq!(server.frame_count(), 1); @@ -152,7 +152,7 @@ async fn test_sync_restarts_with_lower_max_frame_no() { let mut sync_ctx = sync_ctx; let frame = Bytes::from("test frame data"); - let durable_frame = sync_ctx.push_one_frame(frame.clone(), 1, 0).await.unwrap(); + let durable_frame = sync_ctx.push_frames(frame.clone(), 1, 0, 1).await.unwrap(); sync_ctx.write_metadata().await.unwrap(); assert_eq!(durable_frame, 0); assert_eq!(server.frame_count(), 1); @@ -180,14 +180,14 @@ async fn test_sync_restarts_with_lower_max_frame_no() { // This push should fail because we are ahead of the server and thus should get an invalid // frame no error. sync_ctx - .push_one_frame(frame.clone(), 1, frame_no) + .push_frames(frame.clone(), 1, frame_no, 1) .await .unwrap_err(); let frame_no = sync_ctx.durable_frame_num() + 1; // This then should work because when the last one failed it updated our state of the server // durable_frame_num and we should then start writing from there. - sync_ctx.push_one_frame(frame, 1, frame_no).await.unwrap(); + sync_ctx.push_frames(frame, 1, frame_no, 1).await.unwrap(); } #[tokio::test] @@ -215,7 +215,7 @@ async fn test_sync_context_retry_on_error() { server.return_error.store(true, Ordering::SeqCst); // First attempt should fail but retry - let result = sync_ctx.push_one_frame(frame.clone(), 1, 0).await; + let result = sync_ctx.push_frames(frame.clone(), 1, 0, 1).await; assert!(result.is_err()); // Advance time to trigger retries faster @@ -228,7 +228,7 @@ async fn test_sync_context_retry_on_error() { server.return_error.store(false, Ordering::SeqCst); // Next attempt should succeed - let durable_frame = sync_ctx.push_one_frame(frame, 1, 0).await.unwrap(); + let durable_frame = sync_ctx.push_frames(frame, 1, 0, 1).await.unwrap(); sync_ctx.write_metadata().await.unwrap(); assert_eq!(durable_frame, 0); assert_eq!(server.frame_count(), 1);