Skip to content

Commit

Permalink
feat: Expose recover error service (#2159)
Browse files Browse the repository at this point in the history
  • Loading branch information
tottoto authored Jan 23, 2025
1 parent dcaa679 commit 47ed9d3
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 9 deletions.
3 changes: 3 additions & 0 deletions tonic/src/service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ pub use self::layered::{LayerExt, Layered};
pub use self::router::{Routes, RoutesBuilder};
#[cfg(feature = "router")]
pub use axum::{body::Body as AxumBody, Router as AxumRouter};

pub mod recover_error;
pub use self::recover_error::{RecoverError, RecoverErrorLayer};
Original file line number Diff line number Diff line change
@@ -1,24 +1,50 @@
//! Middleware which recovers from error.
use std::{
fmt,
future::Future,
pin::Pin,
task::{ready, Context, Poll},
};

use http::Response;
use pin_project::pin_project;
use tower_layer::Layer;
use tower_service::Service;

use crate::Status;

/// Layer which applies the [`RecoverError`] middleware.
#[derive(Debug, Default, Clone)]
pub struct RecoverErrorLayer {
_priv: (),
}

impl RecoverErrorLayer {
/// Create a new `RecoverErrorLayer`.
pub fn new() -> Self {
Self { _priv: () }
}
}

impl<S> Layer<S> for RecoverErrorLayer {
type Service = RecoverError<S>;

fn layer(&self, inner: S) -> Self::Service {
RecoverError::new(inner)
}
}

/// Middleware that attempts to recover from service errors by turning them into a response built
/// from the `Status`.
#[derive(Debug, Clone)]
pub(crate) struct RecoverError<S> {
pub struct RecoverError<S> {
inner: S,
}

impl<S> RecoverError<S> {
pub(crate) fn new(inner: S) -> Self {
/// Create a new `RecoverError` middleware.
pub fn new(inner: S) -> Self {
Self { inner }
}
}
Expand All @@ -43,12 +69,19 @@ where
}
}

/// Response future for [`RecoverError`].
#[pin_project]
pub(crate) struct ResponseFuture<F> {
pub struct ResponseFuture<F> {
#[pin]
inner: F,
}

impl<F> fmt::Debug for ResponseFuture<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ResponseFuture").finish()
}
}

impl<F, E, ResBody> Future for ResponseFuture<F>
where
F: Future<Output = Result<Response<ResBody>, E>>,
Expand All @@ -74,12 +107,19 @@ where
}
}

/// Response body for [`RecoverError`].
#[pin_project]
pub(crate) struct ResponseBody<B> {
pub struct ResponseBody<B> {
#[pin]
inner: Option<B>,
}

impl<B> fmt::Debug for ResponseBody<B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ResponseBody").finish()
}
}

impl<B> ResponseBody<B> {
fn full(inner: B) -> Self {
Self { inner: Some(inner) }
Expand Down
5 changes: 3 additions & 2 deletions tonic/src/transport/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ pub use incoming::TcpIncoming;
#[cfg(feature = "_tls-any")]
use crate::transport::Error;

use self::service::{ConnectInfoLayer, RecoverError, ServerIo};
use self::service::{ConnectInfoLayer, ServerIo};
use super::service::GrpcTimeout;
use crate::body::Body;
use crate::service::RecoverErrorLayer;
use bytes::Bytes;
use http::{Request, Response};
use http_body_util::BodyExt;
Expand Down Expand Up @@ -1087,7 +1088,7 @@ where
let trace_interceptor = self.trace_interceptor.clone();

let svc = ServiceBuilder::new()
.layer_fn(RecoverError::new)
.layer(RecoverErrorLayer::new())
.option_layer(concurrency_limit.map(ConcurrencyLimitLayer::new))
.layer_fn(|s| GrpcTimeout::new(s, timeout))
.service(svc);
Expand Down
3 changes: 0 additions & 3 deletions tonic/src/transport/server/service/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
mod io;
pub(crate) use self::io::{ConnectInfoLayer, ServerIo};

mod recover_error;
pub(crate) use self::recover_error::RecoverError;

#[cfg(feature = "_tls-any")]
mod tls;
#[cfg(feature = "_tls-any")]
Expand Down

0 comments on commit 47ed9d3

Please sign in to comment.