1use std::fmt::Debug;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use bytes::{Buf, Bytes};
8use http_body::Frame;
9
10#[derive(thiserror::Error, Debug)]
12pub enum IncomingBodyError {
13 #[error("hyper error: {0}")]
15 #[cfg(any(feature = "http1", feature = "http2"))]
16 Hyper(#[from] hyper::Error),
17 #[error("h3 body error: {0}")]
19 #[cfg(feature = "http3")]
20 H3(#[from] crate::backend::h3::body::H3BodyError),
21}
22
23pub enum IncomingBody {
28 #[cfg(any(feature = "http1", feature = "http2"))]
30 Hyper(hyper::body::Incoming),
31 #[cfg(feature = "http3")]
33 Quic(crate::backend::h3::body::QuicIncomingBody<h3_quinn::RecvStream>),
34 #[cfg(feature = "webtransport")]
36 Empty,
37}
38
39#[cfg(any(feature = "http1", feature = "http2"))]
40impl From<hyper::body::Incoming> for IncomingBody {
41 fn from(body: hyper::body::Incoming) -> Self {
42 IncomingBody::Hyper(body)
43 }
44}
45
46#[cfg(feature = "http3")]
47impl From<crate::backend::h3::body::QuicIncomingBody<h3_quinn::RecvStream>> for IncomingBody {
48 fn from(body: crate::backend::h3::body::QuicIncomingBody<h3_quinn::RecvStream>) -> Self {
49 IncomingBody::Quic(body)
50 }
51}
52
53impl http_body::Body for IncomingBody {
54 type Data = Bytes;
55 type Error = IncomingBodyError;
56
57 fn is_end_stream(&self) -> bool {
58 match self {
59 #[cfg(any(feature = "http1", feature = "http2"))]
60 IncomingBody::Hyper(body) => body.is_end_stream(),
61 #[cfg(feature = "http3")]
62 IncomingBody::Quic(body) => body.is_end_stream(),
63 #[cfg(feature = "webtransport")]
64 IncomingBody::Empty => true,
65 #[cfg(not(any(feature = "http1", feature = "http2", feature = "http3")))]
66 _ => false,
67 }
68 }
69
70 fn poll_frame(
71 self: std::pin::Pin<&mut Self>,
72 _cx: &mut std::task::Context<'_>,
73 ) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
74 match self.get_mut() {
75 #[cfg(any(feature = "http1", feature = "http2"))]
76 IncomingBody::Hyper(body) => std::pin::Pin::new(body).poll_frame(_cx).map_err(Into::into),
77 #[cfg(feature = "http3")]
78 IncomingBody::Quic(body) => std::pin::Pin::new(body).poll_frame(_cx).map_err(Into::into),
79 #[cfg(feature = "webtransport")]
80 IncomingBody::Empty => std::task::Poll::Ready(None),
81 #[cfg(not(any(feature = "http1", feature = "http2", feature = "http3")))]
82 _ => std::task::Poll::Ready(None),
83 }
84 }
85
86 fn size_hint(&self) -> http_body::SizeHint {
87 match self {
88 #[cfg(any(feature = "http1", feature = "http2"))]
89 IncomingBody::Hyper(body) => body.size_hint(),
90 #[cfg(feature = "http3")]
91 IncomingBody::Quic(body) => body.size_hint(),
92 #[cfg(feature = "webtransport")]
93 IncomingBody::Empty => http_body::SizeHint::with_exact(0),
94 #[cfg(not(any(feature = "http1", feature = "http2", feature = "http3")))]
95 _ => http_body::SizeHint::default(),
96 }
97 }
98}
99
100pin_project_lite::pin_project! {
101 pub struct TrackedBody<B, T> {
103 #[pin]
104 body: B,
105 tracker: T,
106 }
107}
108
109impl<B, T> TrackedBody<B, T> {
110 pub fn new(body: B, tracker: T) -> Self {
112 Self { body, tracker }
113 }
114}
115
116#[derive(thiserror::Error)]
118pub enum TrackedBodyError<B, T>
119where
120 B: http_body::Body,
121 T: Tracker,
122{
123 #[error("body error: {0}")]
125 Body(B::Error),
126 #[error("tracker error: {0}")]
128 Tracker(T::Error),
129}
130
131impl<B, T> Debug for TrackedBodyError<B, T>
132where
133 B: http_body::Body,
134 B::Error: Debug,
135 T: Tracker,
136 T::Error: Debug,
137{
138 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139 match self {
140 TrackedBodyError::Body(err) => f.debug_tuple("TrackedBodyError::Body").field(err).finish(),
141 TrackedBodyError::Tracker(err) => f.debug_tuple("TrackedBodyError::Tracker").field(err).finish(),
142 }
143 }
144}
145
146pub trait Tracker: Send + Sync + 'static {
148 type Error;
150
151 fn on_data(&self, size: usize) -> Result<(), Self::Error> {
155 let _ = size;
156 Ok(())
157 }
158}
159
160impl<B, T> http_body::Body for TrackedBody<B, T>
161where
162 B: http_body::Body,
163 T: Tracker,
164{
165 type Data = B::Data;
166 type Error = TrackedBodyError<B, T>;
167
168 fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
169 let this = self.project();
170
171 match this.body.poll_frame(cx) {
172 Poll::Pending => Poll::Pending,
173 Poll::Ready(frame) => {
174 if let Some(Ok(frame)) = &frame
175 && let Some(data) = frame.data_ref()
176 && let Err(err) = this.tracker.on_data(data.remaining())
177 {
178 return Poll::Ready(Some(Err(TrackedBodyError::Tracker(err))));
179 }
180
181 Poll::Ready(frame.transpose().map_err(TrackedBodyError::Body).transpose())
182 }
183 }
184 }
185
186 fn is_end_stream(&self) -> bool {
187 self.body.is_end_stream()
188 }
189
190 fn size_hint(&self) -> http_body::SizeHint {
191 self.body.size_hint()
192 }
193}
194
195#[cfg(test)]
196#[cfg_attr(all(test, coverage_nightly), coverage(off))]
197mod tests {
198 use std::convert::Infallible;
199
200 use crate::body::TrackedBodyError;
201
202 #[test]
203 fn tracked_body_error_debug() {
204 struct TestTracker;
205
206 impl super::Tracker for TestTracker {
207 type Error = Infallible;
208 }
209
210 struct TestBody;
211
212 impl http_body::Body for TestBody {
213 type Data = bytes::Bytes;
214 type Error = ();
215
216 fn poll_frame(
217 self: std::pin::Pin<&mut Self>,
218 _cx: &mut std::task::Context<'_>,
219 ) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
220 std::task::Poll::Ready(None)
221 }
222 }
223
224 let err = TrackedBodyError::<TestBody, TestTracker>::Body(());
225 assert_eq!(format!("{err:?}"), "TrackedBodyError::Body(())",);
226 }
227}