Skip to content

Commit

Permalink
upgrade zstd-rs, simplify code
Browse files Browse the repository at this point in the history
  • Loading branch information
phiresky committed Jul 17, 2022
1 parent cbce68c commit bee9be1
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 75 deletions.
16 changes: 9 additions & 7 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ required-features = ["benchmark"]
crate-type = ["cdylib"]

[dependencies]
zstd = {version = "0.6.0", git = "https://github.com/phiresky/zstd-rs", branch = "master"}
zstd-safe = {version = "3.0.0", git = "https://github.com/phiresky/zstd-rs", branch = "master"}
zstd = {version = "0.11.2", features = ["experimental"]}
#zstd = {version = "0.5.3", path="../zstd-rs"}
#zstd = {version = "=0.5.4"}
anyhow = "1.0.44"
Expand Down
67 changes: 31 additions & 36 deletions src/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,19 @@ pub(crate) fn zstd_compress_fn<'a>(
} else {
ctx.get(arg_is_compact).context("is_compact argument")?
};
let out = Vec::new();
use zstd::stream::write::Encoder;

let dict = if ctx.len() <= arg_dict {
None
let encoder = if ctx.len() <= arg_dict {
Encoder::new(out, level)
} else {
match ctx.get_raw(arg_dict) {
ValueRef::Integer(-1) => None,
ValueRef::Null => None,
ValueRef::Blob(d) => Some(Arc::new(wrap_encoder_dict(d.to_vec(), level))),
ValueRef::Integer(_) => Some(
encoder_dict_from_ctx(ctx, arg_dict, level)
ValueRef::Integer(-1) | ValueRef::Null => Encoder::new(out, level),
ValueRef::Blob(d) => Encoder::with_dictionary(out, level, d),
//Some(Arc::new(wrap_encoder_dict(d.to_vec(), level))),
ValueRef::Integer(_) => Encoder::with_prepared_dictionary(
out,
&*encoder_dict_from_ctx(ctx, arg_dict, level)
.context("loading dictionary from int")?,
),
other => anyhow::bail!(
Expand All @@ -69,33 +72,26 @@ pub(crate) fn zstd_compress_fn<'a>(
),
}
};

let res = {
let out = Vec::new();
let mut encoder = match &dict {
Some(dict) => zstd::stream::write::Encoder::with_prepared_dictionary(out, dict),
None => zstd::stream::write::Encoder::new(out, level),
}
.context("creating zstd encoder")?;
/* encoder
.get_operation_mut()
.context
.set_pledged_src_size(input_value.len() as u64)
.context("pledge")?;*/
if compact {
encoder
.include_checksum(false)
.context("disable checksums")?;
encoder.include_contentsize(false).context("cs")?;
encoder.include_dictid(false).context("did")?;
encoder.include_magicbytes(false).context("did")?;
}
let mut encoder = encoder.context("creating zstd encoder")?;

/* encoder
.get_operation_mut()
.context
.set_pledged_src_size(input_value.len() as u64)
.context("pledge")?;*/
if compact {
encoder
.write_all(input_value)
.context("writing data to zstd encoder")?;
encoder.finish().context("finishing zstd stream")?
};
drop(dict); // to make sure the dict is still in scope because of https://github.com/gyscos/zstd-rs/issues/55
.include_checksum(false)
.context("disable checksums")?;
encoder.include_contentsize(false).context("cs")?;
encoder.include_dictid(false).context("did")?;
encoder.include_magicbytes(false).context("did")?;
}
encoder
.write_all(input_value)
.context("writing data to zstd encoder")?;
let res = encoder.finish().context("finishing zstd stream")?;

Ok(ToSqlOutput::Owned(Value::Blob(res)))
}

Expand Down Expand Up @@ -135,9 +131,8 @@ pub(crate) fn zstd_decompress_fn<'a>(
None
} else {
match ctx.get_raw(arg_dict) {
ValueRef::Integer(-1) => None,
ValueRef::Null => None,
ValueRef::Blob(d) => Some(Arc::new(wrap_decoder_dict(d.to_vec()))),
ValueRef::Integer(-1) | ValueRef::Null => None,
ValueRef::Blob(d) => Some(Arc::new(DecoderDictionary::copy(d))),
ValueRef::Integer(_) => {
Some(decoder_dict_from_ctx(ctx, arg_dict).context("load dict")?)
}
Expand Down
36 changes: 6 additions & 30 deletions src/dict_management.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,42 +5,18 @@ use std::time::Duration;

use zstd::dict::{DecoderDictionary, EncoderDictionary};

type OwnedEncoderDict<'a> = owning_ref::OwningHandle<Vec<u8>, Box<EncoderDictionary<'a>>>;

// zstd-rs only exposes zstd_safe::create_cdict_by_reference, not zstd_safe::create_cdict
// so we need to keep a reference to the vector ourselves
// is there a better way?
pub fn wrap_encoder_dict(dict_raw: Vec<u8>, level: i32) -> OwnedEncoderDict<'static> {
owning_ref::OwningHandle::new_with_fn(dict_raw, |d| {
Box::new(EncoderDictionary::new(
unsafe { d.as_ref() }.unwrap(),
level,
))
})
}

type OwnedDecoderDict<'a> = owning_ref::OwningHandle<Vec<u8>, Box<DecoderDictionary<'a>>>;

// zstd-rs only exposes zstd_safe::create_cdict_by_reference, not zstd_safe::create_cdict
// so we need to keep a reference to the vector ourselves
// is there a better way?
pub fn wrap_decoder_dict(dict_raw: Vec<u8>) -> OwnedDecoderDict<'static> {
owning_ref::OwningHandle::new_with_fn(dict_raw, |d| {
Box::new(DecoderDictionary::new(unsafe { &*d }))
})
}
// TODO: the rust interface currently requires a level when preparing a dictionary, but the zstd interface (ZSTD_CCtx_loadDictionary) does not.
// TODO: Using LruCache here isn't very smart
pub fn encoder_dict_from_ctx<'a>(
ctx: &'a Context,
arg_index: usize,
level: i32,
) -> anyhow::Result<Arc<OwnedEncoderDict<'static>>> {
) -> anyhow::Result<Arc<EncoderDictionary<'static>>> {
use lru_time_cache::LruCache;
// we cache the instantiated encoder dictionaries keyed by (DbConnection, dict_id, compression_level)
// DbConnection would ideally be db.path() because it's the same for multiple connections to the same db, but that would be less robust (e.g. in-memory databases)
lazy_static::lazy_static! {
static ref DICTS: RwLock<LruCache<(usize, i32, i32), Arc<OwnedEncoderDict<'static>>>> = RwLock::new(LruCache::with_expiry_duration(Duration::from_secs(10)));
static ref DICTS: RwLock<LruCache<(usize, i32, i32), Arc<EncoderDictionary<'static>>>> = RwLock::new(LruCache::with_expiry_duration(Duration::from_secs(10)));
}
let id: i32 = ctx.get(arg_index)?;
let db = unsafe { ctx.get_connection()? }; // SAFETY: This might be unsafe depending on how the connection is used. See https://github.com/rusqlite/rusqlite/issues/643#issuecomment-640181213
Expand All @@ -63,7 +39,7 @@ pub fn encoder_dict_from_ctx<'a>(
|r| r.get(0),
)
.with_context(|| format!("getting dict with id={} from _zstd_dicts", id))?;
let dict = wrap_encoder_dict(dict_raw, level);
let dict = EncoderDictionary::copy(&dict_raw, level);
Arc::new(dict)
}),
lru_time_cache::Entry::Occupied(o) => o.into_mut(),
Expand All @@ -75,12 +51,12 @@ pub fn encoder_dict_from_ctx<'a>(
pub fn decoder_dict_from_ctx<'a>(
ctx: &'a Context,
arg_index: usize,
) -> anyhow::Result<Arc<OwnedDecoderDict<'static>>> {
) -> anyhow::Result<Arc<DecoderDictionary<'static>>> {
use lru_time_cache::LruCache;
// we cache the instantiated decoder dictionaries keyed by (DbConnection, dict_id)
// DbConnection would ideally be db.path() because it's the same for multiple connections to the same db, but that would be less robust (e.g. in-memory databases)
lazy_static::lazy_static! {
static ref DICTS: RwLock<LruCache<(usize, i32), Arc<OwnedDecoderDict<'static>>>> = RwLock::new(LruCache::with_expiry_duration(Duration::from_secs(10)));
static ref DICTS: RwLock<LruCache<(usize, i32), Arc<DecoderDictionary<'static>>>> = RwLock::new(LruCache::with_expiry_duration(Duration::from_secs(10)));
}
let id: i32 = ctx.get(arg_index)?;
let db = unsafe { ctx.get_connection()? }; // SAFETY: This might be unsafe depending on how the connection is used. See https://github.com/rusqlite/rusqlite/issues/643#issuecomment-640181213
Expand All @@ -101,7 +77,7 @@ pub fn decoder_dict_from_ctx<'a>(
|r| r.get(0),
)
.with_context(|| format!("getting dict with id={} from _zstd_dicts", id))?;
let dict = wrap_decoder_dict(dict_raw);
let dict = DecoderDictionary::copy(&dict_raw);
Arc::new(dict)
}),
lru_time_cache::Entry::Occupied(o) => o.into_mut(),
Expand Down

0 comments on commit bee9be1

Please sign in to comment.