Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

decode_coefs experiment #1325

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
77 changes: 62 additions & 15 deletions src/recon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ use crate::src::tables::dav1d_lo_ctx_offsets;
use crate::src::tables::dav1d_skip_ctx;
use crate::src::tables::dav1d_tx_type_class;
use crate::src::tables::dav1d_tx_types_per_set;
use crate::src::tables::dav1d_txfm_dimension;
use crate::src::tables::dav1d_txfm_dimensions;
use crate::src::tables::dav1d_txfm_size;
use crate::src::tables::dav1d_txtp_from_uvmode;
use crate::src::tables::TxfmInfo;
use crate::src::wedge::dav1d_ii_masks;
Expand Down Expand Up @@ -263,15 +265,15 @@ impl_MergeInt!(u32, u16);
impl_MergeInt!(u64, u32);
impl_MergeInt!(u128, u64);

#[inline]
fn get_skip_ctx(
t_dim: &TxfmInfo,
#[inline(always)]
fn get_skip_ctx<const TX: usize>(
bs: BlockSize,
a: &[u8],
l: &[u8],
chroma: bool,
layout: Rav1dPixelLayout,
) -> InRange<u8, 0, { 13 - 1 }> {
let t_dim = dav1d_txfm_dimension::<TX>();
kkysen marked this conversation as resolved.
Show resolved Hide resolved
let b_dim = bs.dimensions();
let skip_ctx = if chroma {
let ss_ver = layout == Rav1dPixelLayout::I420;
Expand Down Expand Up @@ -341,8 +343,9 @@ fn get_skip_ctx(
InRange::new(skip_ctx).unwrap()
}

#[inline]
fn get_dc_sign_ctx(tx: TxfmSize, a: &[u8], l: &[u8]) -> c_uint {
#[inline(always)]
fn get_dc_sign_ctx<const TX: usize>(a: &[u8], l: &[u8]) -> c_uint {
let tx = dav1d_txfm_size::<TX>();
lqd marked this conversation as resolved.
Show resolved Hide resolved
let mask = 0xc0c0c0c0c0c0c0c0 as u64;
let mul = 0x101010101010101 as u64;

Expand Down Expand Up @@ -492,7 +495,6 @@ fn get_lo_ctx(
let offset;
match ctx_offsets {
Some(ctx_offsets) => {
level(2, 1); // Bounds check all at once.
mag = level(0, 1) + level(1, 0);
debug_assert_matches!(tx_class, TxClass::TwoD);
mag += level(1, 1);
Expand All @@ -502,7 +504,6 @@ fn get_lo_ctx(
}
None => {
debug_assert_matches!(tx_class, TxClass::H | TxClass::V);
level(1, 4); // Bounds check all at once.
mag = level(0, 1) + level(1, 0);
mag += level(0, 2);
*hi_mag = mag;
Expand All @@ -518,6 +519,8 @@ fn get_lo_ctx(
}
}

#[rustfmt::skip]
#[inline(always)]
fn decode_coefs<BD: BitDepth>(
f: &Rav1dFrameData,
ts: usize,
Expand All @@ -535,22 +538,66 @@ fn decode_coefs<BD: BitDepth>(
txtp: &mut TxfmType,
res_ctx: &mut u8,
) -> c_int {
// We make the `TxfmSize` a const so the optimizer sees we don't need memory reads to access the
// `TxfmInfo` dimensions.
use TxfmSize::*;
match tx {
S4x4 => decode_coefs_inner::<BD, { S4x4 as _ }>(f, ts, ts_c, dbg_block_info, scratch, t_cf, a, l, tx, bs, b, plane, cf, txtp, res_ctx),
S8x8 => decode_coefs_inner::<BD, { S8x8 as _ }>(f, ts, ts_c, dbg_block_info, scratch, t_cf, a, l, tx, bs, b, plane, cf, txtp, res_ctx),
S16x16 => decode_coefs_inner::<BD, { S16x16 as _ }>(f, ts, ts_c, dbg_block_info, scratch, t_cf, a, l, tx, bs, b, plane, cf, txtp, res_ctx),
S32x32 => decode_coefs_inner::<BD, { S32x32 as _ }>(f, ts, ts_c, dbg_block_info, scratch, t_cf, a, l, tx, bs, b, plane, cf, txtp, res_ctx),
S64x64 => decode_coefs_inner::<BD, { S64x64 as _ }>(f, ts, ts_c, dbg_block_info, scratch, t_cf, a, l, tx, bs, b, plane, cf, txtp, res_ctx),
R4x8 => decode_coefs_inner::<BD, { R4x8 as _ }>(f, ts, ts_c, dbg_block_info, scratch, t_cf, a, l, tx, bs, b, plane, cf, txtp, res_ctx),
R8x4 => decode_coefs_inner::<BD, { R8x4 as _ }>(f, ts, ts_c, dbg_block_info, scratch, t_cf, a, l, tx, bs, b, plane, cf, txtp, res_ctx),
R8x16 => decode_coefs_inner::<BD, { R8x16 as _ }>(f, ts, ts_c, dbg_block_info, scratch, t_cf, a, l, tx, bs, b, plane, cf, txtp, res_ctx),
R16x8 => decode_coefs_inner::<BD, { R16x8 as _ }>(f, ts, ts_c, dbg_block_info, scratch, t_cf, a, l, tx, bs, b, plane, cf, txtp, res_ctx),
R16x32 => decode_coefs_inner::<BD, { R16x32 as _ }>(f, ts, ts_c, dbg_block_info, scratch, t_cf, a, l, tx, bs, b, plane, cf, txtp, res_ctx),
R32x16 => decode_coefs_inner::<BD, { R32x16 as _ }>(f, ts, ts_c, dbg_block_info, scratch, t_cf, a, l, tx, bs, b, plane, cf, txtp, res_ctx),
R32x64 => decode_coefs_inner::<BD, { R32x64 as _ }>(f, ts, ts_c, dbg_block_info, scratch, t_cf, a, l, tx, bs, b, plane, cf, txtp, res_ctx),
R64x32 => decode_coefs_inner::<BD, { R64x32 as _ }>(f, ts, ts_c, dbg_block_info, scratch, t_cf, a, l, tx, bs, b, plane, cf, txtp, res_ctx),
R4x16 => decode_coefs_inner::<BD, { R4x16 as _ }>(f, ts, ts_c, dbg_block_info, scratch, t_cf, a, l, tx, bs, b, plane, cf, txtp, res_ctx),
R16x4 => decode_coefs_inner::<BD, { R16x4 as _ }>(f, ts, ts_c, dbg_block_info, scratch, t_cf, a, l, tx, bs, b, plane, cf, txtp, res_ctx),
R8x32 => decode_coefs_inner::<BD, { R8x32 as _ }>(f, ts, ts_c, dbg_block_info, scratch, t_cf, a, l, tx, bs, b, plane, cf, txtp, res_ctx),
R32x8 => decode_coefs_inner::<BD, { R32x8 as _ }>(f, ts, ts_c, dbg_block_info, scratch, t_cf, a, l, tx, bs, b, plane, cf, txtp, res_ctx),
R16x64 => decode_coefs_inner::<BD, { R16x64 as _ }>(f, ts, ts_c, dbg_block_info, scratch, t_cf, a, l, tx, bs, b, plane, cf, txtp, res_ctx),
R64x16 => decode_coefs_inner::<BD, { R64x16 as _ }>(f, ts, ts_c, dbg_block_info, scratch, t_cf, a, l, tx, bs, b, plane, cf, txtp, res_ctx),
}
}

#[inline(never)]
fn decode_coefs_inner<BD: BitDepth, const TX: usize>(
f: &Rav1dFrameData,
ts: usize,
ts_c: &mut Rav1dTileStateContext,
dbg_block_info: bool,
scratch: &mut TaskContextScratch,
t_cf: &mut Cf,
a: &mut [u8],
l: &mut [u8],
tx: TxfmSize,
bs: BlockSize,
b: &Av1Block,
plane: usize,
cf: CfSelect,
txtp: &mut TxfmType,
res_ctx: &mut u8,
) -> c_int {
let t_dim = const { dav1d_txfm_dimension::<TX>() };
let dc_sign_ctx;
let dc_sign;
let mut dc_dq;
let ts = &f.ts[ts];
let chroma = plane != 0;
let frame_hdr = &***f.frame_hdr.as_ref().unwrap();
let lossless = frame_hdr.segmentation.lossless[b.seg_id.get()];
let t_dim = &dav1d_txfm_dimensions[tx as usize];
let dbg = dbg_block_info && plane != 0 && false;

if dbg {
println!("Start: r={}", ts_c.msac.rng);
}

// does this block have any non-zero coefficients
let sctx = get_skip_ctx(t_dim, bs, a, l, chroma, f.cur.p.layout);
let sctx = get_skip_ctx::<TX>(bs, a, l, chroma, f.cur.p.layout);
let all_skip = rav1d_msac_decode_bool_adapt(
&mut ts_c.msac,
&mut ts_c.cdf.coef.skip[t_dim.ctx as usize][sctx.get() as usize],
Expand Down Expand Up @@ -578,7 +625,7 @@ fn decode_coefs<BD: BitDepth>(
Inter(_) if t_dim.max >= TxfmSize::S64x64 as _ => DCT_DCT,
Intra(intra) if chroma => dav1d_txtp_from_uvmode[intra.uv_mode as usize],
// inferred from either the luma txtp (inter) or a LUT (intra)
Inter(_) if chroma => get_uv_inter_txtp(t_dim, *txtp),
Inter(_) if chroma => get_uv_inter_txtp(&t_dim, *txtp),
// In libaom, lossless is checked by a literal qidx == 0, but not all
// such blocks are actually lossless. The remainder gets an implicit
// transform type (for luma)
Expand Down Expand Up @@ -760,7 +807,7 @@ fn decode_coefs<BD: BitDepth>(
let mut rc;
let mut dc_tok;

#[inline]
#[inline(always)]
fn decode_coefs_class<const TX_CLASS: usize, BD: BitDepth>(
ts_c: &mut Rav1dTileStateContext,
t_dim: &TxfmInfo,
Expand Down Expand Up @@ -1013,13 +1060,13 @@ fn decode_coefs<BD: BitDepth>(
let cf = &mut cf;
(rc, dc_tok) = match tx_class {
TxClass::TwoD => decode_coefs_class::<{ TxClass::TwoD as _ }, BD>(
ts_c, t_dim, chroma, scratch, eob, tx, dbg, cf,
ts_c, &t_dim, chroma, scratch, eob, tx, dbg, cf,
),
TxClass::H => decode_coefs_class::<{ TxClass::H as _ }, BD>(
ts_c, t_dim, chroma, scratch, eob, tx, dbg, cf,
ts_c, &t_dim, chroma, scratch, eob, tx, dbg, cf,
),
TxClass::V => decode_coefs_class::<{ TxClass::V as _ }, BD>(
ts_c, t_dim, chroma, scratch, eob, tx, dbg, cf,
ts_c, &t_dim, chroma, scratch, eob, tx, dbg, cf,
),
};
} else {
Expand Down Expand Up @@ -1084,7 +1131,7 @@ fn decode_coefs<BD: BitDepth>(
None => Ac::NoQm,
});
} else {
dc_sign_ctx = get_dc_sign_ctx(tx, a, l) as c_int;
dc_sign_ctx = get_dc_sign_ctx::<TX>(a, l) as c_int;
let dc_sign_cdf = &mut ts_c.cdf.coef.dc_sign[chroma][dc_sign_ctx as usize];
dc_sign = rav1d_msac_decode_bool_adapt(&mut ts_c.msac, dc_sign_cdf) as c_int;
if dbg {
Expand Down
76 changes: 55 additions & 21 deletions src/tables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,19 @@ impl BlockSize {
}
}

pub static dav1d_txfm_dimensions: [TxfmInfo; TxfmSize::COUNT] = {
pub const fn dav1d_txfm_size<const TX: usize>() -> TxfmSize {
let Some(size) = TxfmSize::from_repr(TX) else {
unsafe {
std::hint::unreachable_unchecked();
lqd marked this conversation as resolved.
Show resolved Hide resolved
}
};
size
}

pub const fn dav1d_txfm_dimension<const TX: usize>() -> TxfmInfo {
use TxfmSize::*;
[
TxfmInfo {
match dav1d_txfm_size::<TX>() {
S4x4 => TxfmInfo {
w: 1,
h: 1,
lw: 0,
Expand All @@ -205,7 +214,7 @@ pub static dav1d_txfm_dimensions: [TxfmInfo; TxfmSize::COUNT] = {
sub: DefaultValue::DEFAULT,
ctx: 0,
},
TxfmInfo {
S8x8 => TxfmInfo {
w: 2,
h: 2,
lw: 1,
Expand All @@ -215,7 +224,7 @@ pub static dav1d_txfm_dimensions: [TxfmInfo; TxfmSize::COUNT] = {
sub: S4x4,
ctx: 1,
},
TxfmInfo {
S16x16 => TxfmInfo {
w: 4,
h: 4,
lw: 2,
Expand All @@ -225,7 +234,7 @@ pub static dav1d_txfm_dimensions: [TxfmInfo; TxfmSize::COUNT] = {
sub: S8x8,
ctx: 2,
},
TxfmInfo {
S32x32 => TxfmInfo {
w: 8,
h: 8,
lw: 3,
Expand All @@ -235,7 +244,7 @@ pub static dav1d_txfm_dimensions: [TxfmInfo; TxfmSize::COUNT] = {
sub: S16x16,
ctx: 3,
},
TxfmInfo {
S64x64 => TxfmInfo {
w: 16,
h: 16,
lw: 4,
Expand All @@ -245,7 +254,7 @@ pub static dav1d_txfm_dimensions: [TxfmInfo; TxfmSize::COUNT] = {
sub: S32x32,
ctx: 4,
},
TxfmInfo {
R4x8 => TxfmInfo {
w: 1,
h: 2,
lw: 0,
Expand All @@ -255,7 +264,7 @@ pub static dav1d_txfm_dimensions: [TxfmInfo; TxfmSize::COUNT] = {
sub: S4x4,
ctx: 1,
},
TxfmInfo {
R8x4 => TxfmInfo {
w: 2,
h: 1,
lw: 1,
Expand All @@ -265,7 +274,7 @@ pub static dav1d_txfm_dimensions: [TxfmInfo; TxfmSize::COUNT] = {
sub: S4x4,
ctx: 1,
},
TxfmInfo {
R8x16 => TxfmInfo {
w: 2,
h: 4,
lw: 1,
Expand All @@ -275,7 +284,7 @@ pub static dav1d_txfm_dimensions: [TxfmInfo; TxfmSize::COUNT] = {
sub: S8x8,
ctx: 2,
},
TxfmInfo {
R16x8 => TxfmInfo {
w: 4,
h: 2,
lw: 2,
Expand All @@ -285,7 +294,7 @@ pub static dav1d_txfm_dimensions: [TxfmInfo; TxfmSize::COUNT] = {
sub: S8x8,
ctx: 2,
},
TxfmInfo {
R16x32 => TxfmInfo {
w: 4,
h: 8,
lw: 2,
Expand All @@ -295,7 +304,7 @@ pub static dav1d_txfm_dimensions: [TxfmInfo; TxfmSize::COUNT] = {
sub: S16x16,
ctx: 3,
},
TxfmInfo {
R32x16 => TxfmInfo {
w: 8,
h: 4,
lw: 3,
Expand All @@ -305,7 +314,7 @@ pub static dav1d_txfm_dimensions: [TxfmInfo; TxfmSize::COUNT] = {
sub: S16x16,
ctx: 3,
},
TxfmInfo {
R32x64 => TxfmInfo {
w: 8,
h: 16,
lw: 3,
Expand All @@ -315,7 +324,7 @@ pub static dav1d_txfm_dimensions: [TxfmInfo; TxfmSize::COUNT] = {
sub: S32x32,
ctx: 4,
},
TxfmInfo {
R64x32 => TxfmInfo {
w: 16,
h: 8,
lw: 4,
Expand All @@ -325,7 +334,7 @@ pub static dav1d_txfm_dimensions: [TxfmInfo; TxfmSize::COUNT] = {
sub: S32x32,
ctx: 4,
},
TxfmInfo {
R4x16 => TxfmInfo {
w: 1,
h: 4,
lw: 0,
Expand All @@ -335,7 +344,7 @@ pub static dav1d_txfm_dimensions: [TxfmInfo; TxfmSize::COUNT] = {
sub: R4x8,
ctx: 1,
},
TxfmInfo {
R16x4 => TxfmInfo {
w: 4,
h: 1,
lw: 2,
Expand All @@ -345,7 +354,7 @@ pub static dav1d_txfm_dimensions: [TxfmInfo; TxfmSize::COUNT] = {
sub: R8x4,
ctx: 1,
},
TxfmInfo {
R8x32 => TxfmInfo {
w: 2,
h: 8,
lw: 1,
Expand All @@ -355,7 +364,7 @@ pub static dav1d_txfm_dimensions: [TxfmInfo; TxfmSize::COUNT] = {
sub: R8x16,
ctx: 2,
},
TxfmInfo {
R32x8 => TxfmInfo {
w: 8,
h: 2,
lw: 3,
Expand All @@ -365,7 +374,7 @@ pub static dav1d_txfm_dimensions: [TxfmInfo; TxfmSize::COUNT] = {
sub: R16x8,
ctx: 2,
},
TxfmInfo {
R16x64 => TxfmInfo {
w: 4,
h: 16,
lw: 2,
Expand All @@ -375,7 +384,7 @@ pub static dav1d_txfm_dimensions: [TxfmInfo; TxfmSize::COUNT] = {
sub: R16x32,
ctx: 3,
},
TxfmInfo {
R64x16 => TxfmInfo {
w: 16,
h: 4,
lw: 4,
Expand All @@ -385,6 +394,31 @@ pub static dav1d_txfm_dimensions: [TxfmInfo; TxfmSize::COUNT] = {
sub: R32x16,
ctx: 3,
},
}
}

pub static dav1d_txfm_dimensions: [TxfmInfo; TxfmSize::COUNT] = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be a static? Can it be const so we don't have to duplicate it in the const switch above? I don't see anywhere that really relies on being able to take the address of these elements (i.e. they don't seem to be used from ASM?)

Copy link
Author

@lqd lqd Jul 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hadn't seen a need for it to remain a static, but preferred to leave this out of the experiment to have your opinion about the const generics first.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could also use enum_map! to use the above fn to generate this table (and make the key typed).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did try enum_map! but then my measurements seemed slightly worse. It also required implement DefaultValue for TxfmInfo that I'm not sure can be correct, but then again it won't be used by the enum map so it may be fine?

Since the measurements seemed slightly worsened (and I haven't had the time t check the assembly), but still within noise most of the time, I didn't add this change to this PR. I can definitely push it to another branch if you'd like to test it yourself.

I also tried completely removing the static and using the (modified) const fn for the runtime values outside of decode_coefs, but that also seemed to yield worse numbers, so I've left it out of this PR. I did push a commit removing the duplication by using the new function though.

use TxfmSize::*;
[
dav1d_txfm_dimension::<{ S4x4 as _ }>(),
dav1d_txfm_dimension::<{ S8x8 as _ }>(),
dav1d_txfm_dimension::<{ S16x16 as _ }>(),
dav1d_txfm_dimension::<{ S32x32 as _ }>(),
dav1d_txfm_dimension::<{ S64x64 as _ }>(),
dav1d_txfm_dimension::<{ R4x8 as _ }>(),
dav1d_txfm_dimension::<{ R8x4 as _ }>(),
dav1d_txfm_dimension::<{ R8x16 as _ }>(),
dav1d_txfm_dimension::<{ R16x8 as _ }>(),
dav1d_txfm_dimension::<{ R16x32 as _ }>(),
dav1d_txfm_dimension::<{ R32x16 as _ }>(),
dav1d_txfm_dimension::<{ R32x64 as _ }>(),
dav1d_txfm_dimension::<{ R64x32 as _ }>(),
dav1d_txfm_dimension::<{ R4x16 as _ }>(),
dav1d_txfm_dimension::<{ R16x4 as _ }>(),
dav1d_txfm_dimension::<{ R8x32 as _ }>(),
dav1d_txfm_dimension::<{ R32x8 as _ }>(),
dav1d_txfm_dimension::<{ R16x64 as _ }>(),
dav1d_txfm_dimension::<{ R64x16 as _ }>(),
]
};

Expand Down
Loading