scuffle_http/
body.rs

1//! Types for working with HTTP bodies.
2
3use std::fmt::Debug;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use bytes::{Buf, Bytes};
8use http_body::Frame;
9
10/// An error that can occur when reading the body of an incoming request.
11#[derive(thiserror::Error, Debug)]
12pub enum IncomingBodyError {
13    /// An error that occurred while reading a hyper body.
14    #[error("hyper error: {0}")]
15    #[cfg(any(feature = "http1", feature = "http2"))]
16    Hyper(#[from] hyper::Error),
17    /// An error that occurred while reading a h3 body.
18    #[error("h3 body error: {0}")]
19    #[cfg(feature = "http3")]
20    H3(#[from] crate::backend::h3::body::H3BodyError),
21}
22
23/// The body of an incoming request.
24///
25/// This enum is used to abstract away the differences between the body types of HTTP/1, HTTP/2 and HTTP/3.
26/// It implements the [`http_body::Body`] trait.
27pub enum IncomingBody {
28    /// The body of an incoming hyper request.
29    #[cfg(any(feature = "http1", feature = "http2"))]
30    Hyper(hyper::body::Incoming),
31    /// The body of an incoming h3 request.
32    #[cfg(feature = "http3")]
33    Quic(crate::backend::h3::body::QuicIncomingBody<h3_quinn::RecvStream>),
34    /// An empty body (used for WebTransport sessions).
35    #[cfg(feature = "webtransport")]
36    Empty,
37}
38
39#[cfg(any(feature = "http1", feature = "http2"))]
40impl From<hyper::body::Incoming> for IncomingBody {
41    fn from(body: hyper::body::Incoming) -> Self {
42        IncomingBody::Hyper(body)
43    }
44}
45
46#[cfg(feature = "http3")]
47impl From<crate::backend::h3::body::QuicIncomingBody<h3_quinn::RecvStream>> for IncomingBody {
48    fn from(body: crate::backend::h3::body::QuicIncomingBody<h3_quinn::RecvStream>) -> Self {
49        IncomingBody::Quic(body)
50    }
51}
52
53impl http_body::Body for IncomingBody {
54    type Data = Bytes;
55    type Error = IncomingBodyError;
56
57    fn is_end_stream(&self) -> bool {
58        match self {
59            #[cfg(any(feature = "http1", feature = "http2"))]
60            IncomingBody::Hyper(body) => body.is_end_stream(),
61            #[cfg(feature = "http3")]
62            IncomingBody::Quic(body) => body.is_end_stream(),
63            #[cfg(feature = "webtransport")]
64            IncomingBody::Empty => true,
65            #[cfg(not(any(feature = "http1", feature = "http2", feature = "http3")))]
66            _ => false,
67        }
68    }
69
70    fn poll_frame(
71        self: std::pin::Pin<&mut Self>,
72        _cx: &mut std::task::Context<'_>,
73    ) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
74        match self.get_mut() {
75            #[cfg(any(feature = "http1", feature = "http2"))]
76            IncomingBody::Hyper(body) => std::pin::Pin::new(body).poll_frame(_cx).map_err(Into::into),
77            #[cfg(feature = "http3")]
78            IncomingBody::Quic(body) => std::pin::Pin::new(body).poll_frame(_cx).map_err(Into::into),
79            #[cfg(feature = "webtransport")]
80            IncomingBody::Empty => std::task::Poll::Ready(None),
81            #[cfg(not(any(feature = "http1", feature = "http2", feature = "http3")))]
82            _ => std::task::Poll::Ready(None),
83        }
84    }
85
86    fn size_hint(&self) -> http_body::SizeHint {
87        match self {
88            #[cfg(any(feature = "http1", feature = "http2"))]
89            IncomingBody::Hyper(body) => body.size_hint(),
90            #[cfg(feature = "http3")]
91            IncomingBody::Quic(body) => body.size_hint(),
92            #[cfg(feature = "webtransport")]
93            IncomingBody::Empty => http_body::SizeHint::with_exact(0),
94            #[cfg(not(any(feature = "http1", feature = "http2", feature = "http3")))]
95            _ => http_body::SizeHint::default(),
96        }
97    }
98}
99
100pin_project_lite::pin_project! {
101    /// A wrapper around an HTTP body that tracks the size of the data that is read from it.
102    pub struct TrackedBody<B, T> {
103        #[pin]
104        body: B,
105        tracker: T,
106    }
107}
108
109impl<B, T> TrackedBody<B, T> {
110    /// Create a new [`TrackedBody`] with the given body and tracker.
111    pub fn new(body: B, tracker: T) -> Self {
112        Self { body, tracker }
113    }
114}
115
116/// An error that can occur when tracking the body of an incoming request.
117#[derive(thiserror::Error)]
118pub enum TrackedBodyError<B, T>
119where
120    B: http_body::Body,
121    T: Tracker,
122{
123    /// An error that occurred while reading the body.
124    #[error("body error: {0}")]
125    Body(B::Error),
126    /// An error that occurred while calling [`Tracker::on_data`].
127    #[error("tracker error: {0}")]
128    Tracker(T::Error),
129}
130
131impl<B, T> Debug for TrackedBodyError<B, T>
132where
133    B: http_body::Body,
134    B::Error: Debug,
135    T: Tracker,
136    T::Error: Debug,
137{
138    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139        match self {
140            TrackedBodyError::Body(err) => f.debug_tuple("TrackedBodyError::Body").field(err).finish(),
141            TrackedBodyError::Tracker(err) => f.debug_tuple("TrackedBodyError::Tracker").field(err).finish(),
142        }
143    }
144}
145
146/// A trait for tracking the size of the data that is read from an HTTP body.
147pub trait Tracker: Send + Sync + 'static {
148    /// The error type that can occur when [`Tracker::on_data`] is called.
149    type Error;
150
151    /// Called when data is read from the body.
152    ///
153    /// The `size` parameter is the size of the data that is remaining to be read from the body.
154    fn on_data(&self, size: usize) -> Result<(), Self::Error> {
155        let _ = size;
156        Ok(())
157    }
158}
159
160impl<B, T> http_body::Body for TrackedBody<B, T>
161where
162    B: http_body::Body,
163    T: Tracker,
164{
165    type Data = B::Data;
166    type Error = TrackedBodyError<B, T>;
167
168    fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
169        let this = self.project();
170
171        match this.body.poll_frame(cx) {
172            Poll::Pending => Poll::Pending,
173            Poll::Ready(frame) => {
174                if let Some(Ok(frame)) = &frame
175                    && let Some(data) = frame.data_ref()
176                    && let Err(err) = this.tracker.on_data(data.remaining())
177                {
178                    return Poll::Ready(Some(Err(TrackedBodyError::Tracker(err))));
179                }
180
181                Poll::Ready(frame.transpose().map_err(TrackedBodyError::Body).transpose())
182            }
183        }
184    }
185
186    fn is_end_stream(&self) -> bool {
187        self.body.is_end_stream()
188    }
189
190    fn size_hint(&self) -> http_body::SizeHint {
191        self.body.size_hint()
192    }
193}
194
195#[cfg(test)]
196#[cfg_attr(all(test, coverage_nightly), coverage(off))]
197mod tests {
198    use std::convert::Infallible;
199
200    use crate::body::TrackedBodyError;
201
202    #[test]
203    fn tracked_body_error_debug() {
204        struct TestTracker;
205
206        impl super::Tracker for TestTracker {
207            type Error = Infallible;
208        }
209
210        struct TestBody;
211
212        impl http_body::Body for TestBody {
213            type Data = bytes::Bytes;
214            type Error = ();
215
216            fn poll_frame(
217                self: std::pin::Pin<&mut Self>,
218                _cx: &mut std::task::Context<'_>,
219            ) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
220                std::task::Poll::Ready(None)
221            }
222        }
223
224        let err = TrackedBodyError::<TestBody, TestTracker>::Body(());
225        assert_eq!(format!("{err:?}"), "TrackedBodyError::Body(())",);
226    }
227}