scuffle_http/backend/
h3.rs

1//! HTTP3 backend.
2use std::fmt::Debug;
3use std::io;
4use std::net::SocketAddr;
5use std::sync::Arc;
6use std::time::Duration;
7
8use body::QuicIncomingBody;
9use scuffle_context::ContextFutExt;
10#[cfg(feature = "tracing")]
11use tracing::Instrument;
12use utils::copy_response_body;
13#[cfg(feature = "webtransport")]
14use {h3::ext::Protocol, h3_webtransport as h3wt};
15
16use crate::error::HttpError;
17use crate::service::{HttpService, HttpServiceFactory};
18
19pub mod body;
20mod utils;
21#[cfg(feature = "webtransport")]
22pub mod webtransport;
23
24/// A backend that handles incoming HTTP3 connections.
25///
26/// This is used internally by the [`HttpServer`](crate::server::HttpServer) but can be used directly if preferred.
27///
28/// Call [`run`](Http3Backend::run) to start the server.
29#[derive(bon::Builder, Debug, Clone)]
30pub struct Http3Backend<F> {
31    /// The [`scuffle_context::Context`] this server will live by.
32    #[builder(default = scuffle_context::Context::global())]
33    ctx: scuffle_context::Context,
34    /// The number of worker tasks to spawn for each server backend.
35    #[builder(default = 1)]
36    worker_tasks: usize,
37    /// The service factory that will be used to create new services.
38    service_factory: F,
39    /// The address to bind to.
40    ///
41    /// Use `[::]` for a dual-stack listener.
42    /// For example, use `[::]:80` to bind to port 80 on both IPv4 and IPv6.
43    bind: SocketAddr,
44    /// Enable WebTransport support.
45    #[builder(default = false)]
46    #[cfg(feature = "webtransport")]
47    enable_webtransport: bool,
48    #[builder(default = 1, setters(vis = "", name = max_webtransport_sessions_internal))]
49    #[cfg(feature = "webtransport")]
50    max_webtransport_sessions: u64,
51    /// rustls config.
52    ///
53    /// Use this field to set the server into TLS mode.
54    /// It will only accept TLS connections when this is set.
55    rustls_config: tokio_rustls::rustls::ServerConfig,
56}
57
58#[cfg(feature = "webtransport")]
59impl<F, S> Http3BackendBuilder<F, S>
60where
61    S: http3_backend_builder::State,
62    S::MaxWebtransportSessions: http3_backend_builder::IsUnset,
63    S::EnableWebtransport: http3_backend_builder::IsSet,
64{
65    /// Set the maximum number of concurrent WebTransport sessions.
66    ///
67    /// Corresponds to [h3::server::Builder::max_webtransport_sessions].
68    ///
69    /// Default is 1 when WebTransport is enabled.
70    pub fn max_webtransport_sessions(
71        self,
72        max_webtransport_sessions: u64,
73    ) -> Http3BackendBuilder<F, http3_backend_builder::SetMaxWebtransportSessions<S>> {
74        self.max_webtransport_sessions_internal(max_webtransport_sessions)
75    }
76}
77
78impl<F> Http3Backend<F>
79where
80    F: HttpServiceFactory + Clone + Send + 'static,
81    F::Error: std::error::Error + Send,
82    F::Service: Clone + Send + 'static,
83    <F::Service as HttpService>::Error: std::error::Error + Send + Sync,
84    <F::Service as HttpService>::ResBody: Send,
85    <<F::Service as HttpService>::ResBody as http_body::Body>::Data: Send,
86    <<F::Service as HttpService>::ResBody as http_body::Body>::Error: std::error::Error + Send + Sync,
87{
88    /// Run the HTTP3 server
89    ///
90    /// This function will bind to the address specified in `bind`, listen for incoming connections and handle requests.
91    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(bind = %self.bind)))]
92    pub async fn run(mut self) -> Result<(), HttpError<F>> {
93        #[cfg(feature = "tracing")]
94        tracing::debug!("starting server");
95
96        // not quite sure why this is necessary but it is
97        self.rustls_config.max_early_data_size = u32::MAX;
98        let crypto = h3_quinn::quinn::crypto::rustls::QuicServerConfig::try_from(self.rustls_config)?;
99        let mut server_config = h3_quinn::quinn::ServerConfig::with_crypto(Arc::new(crypto));
100        let mut transport_config = quinn::TransportConfig::default();
101        transport_config.keep_alive_interval(Some(Duration::from_secs(2)));
102        server_config.transport = Arc::new(transport_config);
103
104        // Bind the UDP socket
105        let socket = std::net::UdpSocket::bind(self.bind)?;
106
107        // Runtime for the quinn endpoint
108        let runtime = h3_quinn::quinn::default_runtime().ok_or_else(|| io::Error::other("no async runtime found"))?;
109
110        // Create a child context for the workers so we can shut them down if one of them fails without shutting down the main context
111        let (worker_ctx, worker_handler) = self.ctx.new_child();
112
113        let workers = (0..self.worker_tasks).map(|_n| {
114            let ctx = worker_ctx.clone();
115            let service_factory = self.service_factory.clone();
116            let server_config = server_config.clone();
117            let socket = socket.try_clone().expect("failed to clone socket");
118            let runtime = Arc::clone(&runtime);
119
120            let worker_fut = async move {
121                let endpoint = h3_quinn::quinn::Endpoint::new(
122                    h3_quinn::quinn::EndpointConfig::default(),
123                    Some(server_config),
124                    socket,
125                    runtime,
126                )?;
127
128                #[cfg(feature = "tracing")]
129                tracing::trace!("waiting for connections");
130
131                while let Some(Some(new_conn)) = endpoint.accept().with_context(&ctx).await {
132                    let mut service_factory = service_factory.clone();
133                    let ctx = ctx.clone();
134
135                    tokio::spawn(async move {
136                        let _res: Result<_, HttpError<F>> = async move {
137                            let Some(conn) = new_conn.with_context(&ctx).await.transpose()? else {
138                                #[cfg(feature = "tracing")]
139                                tracing::trace!("context done while accepting connection");
140                                return Ok(());
141                            };
142                            let addr = conn.remote_address();
143                            let client_certs = conn
144                                .peer_identity()
145                                .and_then(|any| any.downcast::<Vec<tokio_rustls::rustls::pki_types::CertificateDer>>().ok());
146
147                            #[cfg(feature = "tracing")]
148                            tracing::debug!(addr = %addr, "accepted quic connection");
149
150                            let connection_fut = async move {
151                                #[cfg(not(feature = "webtransport"))]
152                                let h3_conn_builder = h3::server::builder();
153                                #[cfg(feature = "webtransport")]
154                                let mut h3_conn_builder = h3::server::builder();
155
156                                #[cfg(feature = "webtransport")]
157                                if self.enable_webtransport {
158                                    h3_conn_builder
159                                        .enable_webtransport(true)
160                                        .enable_extended_connect(true)
161                                        .enable_datagram(true)
162                                        .max_webtransport_sessions(self.max_webtransport_sessions)
163                                        .send_grease(true);
164                                }
165
166                                let Some(mut h3_conn) = h3_conn_builder
167                                    .build(h3_quinn::Connection::new(conn))
168                                    .with_context(&ctx)
169                                    .await
170                                    .transpose()?
171                                else {
172                                    #[cfg(feature = "tracing")]
173                                    tracing::trace!("context done while establishing connection");
174                                    return Ok(());
175                                };
176
177                                let mut extra_extensions = http::Extensions::new();
178                                extra_extensions.insert(crate::extensions::ClientAddr(addr));
179                                if let Some(certs) = client_certs {
180                                    extra_extensions.insert(crate::extensions::ClientIdentity(Arc::new(*certs)));
181                                }
182
183                                // make a new service for this connection
184                                let http_service = service_factory
185                                    .new_service(addr)
186                                    .await
187                                    .map_err(|e| HttpError::ServiceFactoryError(e))?;
188
189                                loop {
190                                    match h3_conn.accept().with_context(&ctx).await {
191                                        Some(Ok(Some(resolver))) => {
192                                            // Resolve the request
193                                            let (req, stream) = match resolver.resolve_request().await {
194                                                Ok(r) => r,
195                                                Err(_err) => {
196                                                    #[cfg(feature = "tracing")]
197                                                    tracing::warn!("error on accept: {}", _err);
198                                                    continue;
199                                                }
200                                            };
201
202                                            #[cfg(feature = "tracing")]
203                                            tracing::debug!(method = %req.method(), uri = %req.uri(), "received request");
204
205                                            // Check if this is a WebTransport CONNECT request
206                                            #[cfg(feature = "webtransport")]
207                                            if self.enable_webtransport
208                                                && req.extensions().get::<Protocol>() == Some(&Protocol::WEB_TRANSPORT)
209                                                && req.method() == http::Method::CONNECT
210                                            {
211                                                #[cfg(feature = "tracing")]
212                                                tracing::debug!("starting WebTransport session");
213
214                                                // Store the original request for passing to the service
215                                                let (parts, _) = req.into_parts();
216
217                                                // Accept the WebTransport session
218                                                let session = match h3wt::server::WebTransportSession::accept(
219                                                    http::Request::from_parts(parts.clone(), ()),
220                                                    stream,
221                                                    h3_conn,
222                                                )
223                                                .await
224                                                {
225                                                    Ok(session) => session,
226                                                    Err(_err) => {
227                                                        #[cfg(feature = "tracing")]
228                                                        tracing::warn!(err = %_err, "failed to accept WebTransport session");
229                                                        break;
230                                                    }
231                                                };
232
233                                                let wt_session =
234                                                    webtransport::WebTransportSession::new(std::sync::Arc::new(session));
235
236                                                // Create an empty body for the WebTransport request
237                                                // Since WebTransport operates on streams, not the request body
238                                                let empty_body = crate::body::IncomingBody::Empty;
239
240                                                // Reconstruct the request with the session in extensions
241                                                let mut wt_req = http::Request::from_parts(parts, empty_body);
242                                                wt_req.extensions_mut().insert(wt_session); // Call the service with the WebTransport request
243                                                tokio::spawn({
244                                                    let ctx = ctx.clone();
245                                                    let mut http_service = http_service.clone();
246                                                    async move {
247                                                        let _res: Result<_, HttpError<F>> = async move {
248                                                            let _resp = http_service
249                                                                .call(wt_req)
250                                                                .await
251                                                                .map_err(|e| HttpError::ServiceError(e))?;
252
253                                                            #[cfg(feature = "tracing")]
254                                                            tracing::debug!("WebTransport session handler completed");
255
256                                                            Ok(())
257                                                        }
258                                                        .await;
259
260                                                        #[cfg(feature = "tracing")]
261                                                        if let Err(e) = _res {
262                                                            tracing::warn!(err = %e, "WebTransport session handler error");
263                                                        }
264
265                                                        drop(ctx);
266                                                    }
267                                                });
268
269                                                break;
270                                            }
271                                            let (mut send, recv) = stream.split();
272
273                                            let size_hint = req
274                                                .headers()
275                                                .get(http::header::CONTENT_LENGTH)
276                                                .and_then(|len| len.to_str().ok().and_then(|x| x.parse().ok()));
277                                            let body = QuicIncomingBody::new(recv, size_hint);
278                                            let mut req = req.map(|_| crate::body::IncomingBody::from(body));
279
280                                            req.extensions_mut().extend(extra_extensions.clone());
281
282                                            tokio::spawn({
283                                                let ctx = ctx.clone();
284                                                let mut http_service = http_service.clone();
285                                                async move {
286                                                    let _res: Result<_, HttpError<F>> = async move {
287                                                        let resp = http_service
288                                                            .call(req)
289                                                            .await
290                                                            .map_err(|e| HttpError::ServiceError(e))?;
291                                                        let (parts, body) = resp.into_parts();
292
293                                                        send.send_response(http::Response::from_parts(parts, ())).await?;
294                                                        copy_response_body(send, body).await?;
295
296                                                        Ok(())
297                                                    }
298                                                    .await;
299
300                                                    #[cfg(feature = "tracing")]
301                                                    if let Err(e) = _res {
302                                                        tracing::warn!(err = %e, "error handling request");
303                                                    }
304
305                                                    // This moves the context into the async block because it is dropped here
306                                                    drop(ctx);
307                                                }
308                                            });
309                                        }
310                                        // indicating no more streams to be received
311                                        Some(Ok(None)) => {
312                                            break;
313                                        }
314                                        Some(Err(err)) => return Err(err.into()),
315                                        // context is done
316                                        None => {
317                                            #[cfg(feature = "tracing")]
318                                            tracing::trace!("context done, stopping connection loop");
319                                            break;
320                                        }
321                                    }
322                                }
323
324                                #[cfg(feature = "tracing")]
325                                tracing::trace!("connection closed");
326
327                                Ok(())
328                            };
329
330                            #[cfg(feature = "tracing")]
331                            let connection_fut = connection_fut.instrument(tracing::trace_span!("connection", addr = %addr));
332
333                            connection_fut.await
334                        }
335                        .await;
336
337                        #[cfg(feature = "tracing")]
338                        if let Err(err) = _res {
339                            tracing::warn!(err = %err, "error handling connection");
340                        }
341                    });
342                }
343
344                // shut down gracefully
345                // wait for connections to be closed before exiting
346                endpoint.wait_idle().await;
347
348                Ok::<_, crate::error::HttpError<F>>(())
349            };
350
351            #[cfg(feature = "tracing")]
352            let worker_fut = worker_fut.instrument(tracing::trace_span!("worker", n = _n));
353
354            tokio::spawn(worker_fut)
355        });
356
357        if let Err(_e) = futures::future::try_join_all(workers).await {
358            #[cfg(feature = "tracing")]
359            tracing::error!(err = %_e, "error running workers");
360        }
361
362        drop(worker_ctx);
363        worker_handler.shutdown().await;
364
365        #[cfg(feature = "tracing")]
366        tracing::debug!("all workers finished");
367
368        Ok(())
369    }
370}