Skip to content

Commit

Permalink
Merge pull request #1950 from tursodatabase/batch_push
Browse files Browse the repository at this point in the history
Push up to 128 frames in sync
  • Loading branch information
penberg authored Feb 11, 2025
2 parents 7aad3e6 + ac6f1bc commit 224b57e
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 22 deletions.
12 changes: 12 additions & 0 deletions libsql/src/database/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ impl Builder<()> {
connector: None,
read_your_writes: true,
remote_writes: false,
push_batch_size: 0,
},
}
}
Expand Down Expand Up @@ -524,6 +525,7 @@ cfg_sync! {
connector: Option<crate::util::ConnectorService>,
remote_writes: bool,
read_your_writes: bool,
push_batch_size: u32,
}

impl Builder<SyncedDatabase> {
Expand All @@ -543,6 +545,11 @@ cfg_sync! {
self
}

pub fn set_push_batch_size(mut self, v: u32) -> Builder<SyncedDatabase> {
self.inner.push_batch_size = v;
self
}

/// Provide a custom http connector that will be used to create http connections.
pub fn connector<C>(mut self, connector: C) -> Builder<SyncedDatabase>
where
Expand Down Expand Up @@ -570,6 +577,7 @@ cfg_sync! {
connector,
remote_writes,
read_your_writes,
push_batch_size,
} = self.inner;

let path = path.to_str().ok_or(crate::Error::InvalidUTF8Path)?.to_owned();
Expand All @@ -596,6 +604,10 @@ cfg_sync! {
)
.await?;

if push_batch_size > 0 {
db.sync_ctx.as_ref().unwrap().lock().await.set_push_batch_size(push_batch_size);
}

Ok(Database {
db_type: DbType::Offline {
db,
Expand Down
43 changes: 30 additions & 13 deletions libsql/src/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub mod transaction;
const METADATA_VERSION: u32 = 0;

const DEFAULT_MAX_RETRIES: usize = 5;
const DEFAULT_PUSH_BATCH_SIZE: u32 = 128;

#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
Expand Down Expand Up @@ -74,6 +75,7 @@ pub struct SyncContext {
sync_url: String,
auth_token: Option<HeaderValue>,
max_retries: usize,
push_batch_size: u32,
/// The current durable generation.
durable_generation: u32,
/// Represents the max_frame_no from the server.
Expand Down Expand Up @@ -102,6 +104,7 @@ impl SyncContext {
sync_url,
auth_token,
max_retries: DEFAULT_MAX_RETRIES,
push_batch_size: DEFAULT_PUSH_BATCH_SIZE,
client,
durable_generation: 1,
durable_frame_num: 0,
Expand All @@ -117,6 +120,10 @@ impl SyncContext {
Ok(me)
}

pub fn set_push_batch_size(&mut self, push_batch_size: u32) {
self.push_batch_size = push_batch_size;
}

#[tracing::instrument(skip(self))]
pub(crate) async fn pull_one_frame(
&mut self,
Expand All @@ -134,25 +141,26 @@ impl SyncContext {
self.pull_with_retry(uri, self.max_retries).await
}

#[tracing::instrument(skip(self, frame))]
pub(crate) async fn push_one_frame(
#[tracing::instrument(skip(self, frames))]
pub(crate) async fn push_frames(
&mut self,
frame: Bytes,
frames: Bytes,
generation: u32,
frame_no: u32,
frames_count: u32,
) -> Result<u32> {
let uri = format!(
"{}/sync/{}/{}/{}",
self.sync_url,
generation,
frame_no,
frame_no + 1
frame_no + frames_count
);
tracing::debug!("pushing frame");

let (generation, durable_frame_num) = self.push_with_retry(uri, frame, self.max_retries).await?;
let (generation, durable_frame_num) = self.push_with_retry(uri, frames, self.max_retries).await?;

if durable_frame_num > frame_no {
if durable_frame_num > frame_no + frames_count - 1 {
tracing::error!(
"server returned durable_frame_num larger than what we sent: sent={}, got={}",
frame_no,
Expand All @@ -162,7 +170,7 @@ impl SyncContext {
return Err(SyncError::InvalidPushFrameNoHigh(frame_no, durable_frame_num).into());
}

if durable_frame_num < frame_no {
if durable_frame_num < frame_no + frames_count - 1 {
// Update our knowledge of where the server is at frame wise.
self.durable_frame_num = durable_frame_num;

Expand All @@ -186,7 +194,7 @@ impl SyncContext {
Ok(durable_frame_num)
}

async fn push_with_retry(&self, uri: String, frame: Bytes, max_retries: usize) -> Result<(u32, u32)> {
async fn push_with_retry(&self, uri: String, body: Bytes, max_retries: usize) -> Result<(u32, u32)> {
let mut nr_retries = 0;
loop {
let mut req = http::Request::post(uri.clone());
Expand All @@ -200,7 +208,7 @@ impl SyncContext {
None => {}
}

let req = req.body(frame.clone().into()).expect("valid body");
let req = req.body(body.clone().into()).expect("valid body");

let res = self
.client
Expand Down Expand Up @@ -537,19 +545,28 @@ async fn try_push(

let mut frame_no = start_frame_no;
while frame_no <= end_frame_no {
let frame = conn.wal_get_frame(frame_no, page_size)?;
let batch_size = sync_ctx.push_batch_size.min(end_frame_no - frame_no + 1);
let mut frames = conn.wal_get_frame(frame_no, page_size)?;
if batch_size > 1 {
frames.reserve((batch_size - 1) as usize * frames.len());
}
for idx in 1..batch_size {
let frame = conn.wal_get_frame(frame_no + idx, page_size)?;
frames.extend_from_slice(frame.as_ref())
}

// The server returns its maximum frame number. To avoid resending
// frames the server already knows about, we need to update the
// frame number to the one returned by the server.
let max_frame_no = sync_ctx
.push_one_frame(frame.freeze(), generation, frame_no)
.push_frames(frames.freeze(), generation, frame_no, batch_size)
.await?;

if max_frame_no > frame_no {
frame_no = max_frame_no;
frame_no = max_frame_no + 1;
} else {
frame_no += batch_size;
}
frame_no += 1;
}

sync_ctx.write_metadata().await?;
Expand Down
18 changes: 9 additions & 9 deletions libsql/src/sync/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ async fn test_sync_context_push_frame() {
let mut sync_ctx = sync_ctx;

// Push a frame and verify the response
let durable_frame = sync_ctx.push_one_frame(frame, 1, 0).await.unwrap();
let durable_frame = sync_ctx.push_frames(frame, 1, 0, 1).await.unwrap();
sync_ctx.write_metadata().await.unwrap();
assert_eq!(durable_frame, 0); // First frame should return max_frame_no = 0

Expand Down Expand Up @@ -56,7 +56,7 @@ async fn test_sync_context_with_auth() {
let frame = Bytes::from("test frame with auth");
let mut sync_ctx = sync_ctx;

let durable_frame = sync_ctx.push_one_frame(frame, 1, 0).await.unwrap();
let durable_frame = sync_ctx.push_frames(frame, 1, 0, 1).await.unwrap();
sync_ctx.write_metadata().await.unwrap();
assert_eq!(durable_frame, 0);
assert_eq!(server.frame_count(), 1);
Expand All @@ -82,7 +82,7 @@ async fn test_sync_context_multiple_frames() {
// Push multiple frames and verify incrementing frame numbers
for i in 0..3 {
let frame = Bytes::from(format!("frame data {}", i));
let durable_frame = sync_ctx.push_one_frame(frame, 1, i).await.unwrap();
let durable_frame = sync_ctx.push_frames(frame, 1, i, 1).await.unwrap();
sync_ctx.write_metadata().await.unwrap();
assert_eq!(durable_frame, i);
assert_eq!(sync_ctx.durable_frame_num(), i);
Expand All @@ -108,7 +108,7 @@ async fn test_sync_context_corrupted_metadata() {

let mut sync_ctx = sync_ctx;
let frame = Bytes::from("test frame data");
let durable_frame = sync_ctx.push_one_frame(frame, 1, 0).await.unwrap();
let durable_frame = sync_ctx.push_frames(frame, 1, 0, 1).await.unwrap();
sync_ctx.write_metadata().await.unwrap();
assert_eq!(durable_frame, 0);
assert_eq!(server.frame_count(), 1);
Expand Down Expand Up @@ -152,7 +152,7 @@ async fn test_sync_restarts_with_lower_max_frame_no() {

let mut sync_ctx = sync_ctx;
let frame = Bytes::from("test frame data");
let durable_frame = sync_ctx.push_one_frame(frame.clone(), 1, 0).await.unwrap();
let durable_frame = sync_ctx.push_frames(frame.clone(), 1, 0, 1).await.unwrap();
sync_ctx.write_metadata().await.unwrap();
assert_eq!(durable_frame, 0);
assert_eq!(server.frame_count(), 1);
Expand Down Expand Up @@ -180,14 +180,14 @@ async fn test_sync_restarts_with_lower_max_frame_no() {
// This push should fail because we are ahead of the server and thus should get an invalid
// frame no error.
sync_ctx
.push_one_frame(frame.clone(), 1, frame_no)
.push_frames(frame.clone(), 1, frame_no, 1)
.await
.unwrap_err();

let frame_no = sync_ctx.durable_frame_num() + 1;
// This then should work because when the last one failed it updated our state of the server
// durable_frame_num and we should then start writing from there.
sync_ctx.push_one_frame(frame, 1, frame_no).await.unwrap();
sync_ctx.push_frames(frame, 1, frame_no, 1).await.unwrap();
}

#[tokio::test]
Expand Down Expand Up @@ -215,7 +215,7 @@ async fn test_sync_context_retry_on_error() {
server.return_error.store(true, Ordering::SeqCst);

// First attempt should fail but retry
let result = sync_ctx.push_one_frame(frame.clone(), 1, 0).await;
let result = sync_ctx.push_frames(frame.clone(), 1, 0, 1).await;
assert!(result.is_err());

// Advance time to trigger retries faster
Expand All @@ -228,7 +228,7 @@ async fn test_sync_context_retry_on_error() {
server.return_error.store(false, Ordering::SeqCst);

// Next attempt should succeed
let durable_frame = sync_ctx.push_one_frame(frame, 1, 0).await.unwrap();
let durable_frame = sync_ctx.push_frames(frame, 1, 0, 1).await.unwrap();
sync_ctx.write_metadata().await.unwrap();
assert_eq!(durable_frame, 0);
assert_eq!(server.frame_count(), 1);
Expand Down

0 comments on commit 224b57e

Please sign in to comment.