scuffle_http/backend/h3.rs
1//! HTTP3 backend.
2use std::fmt::Debug;
3use std::io;
4use std::net::SocketAddr;
5use std::sync::Arc;
6use std::time::Duration;
7
8use body::QuicIncomingBody;
9use scuffle_context::ContextFutExt;
10#[cfg(feature = "tracing")]
11use tracing::Instrument;
12use utils::copy_response_body;
13#[cfg(feature = "webtransport")]
14use {h3::ext::Protocol, h3_webtransport as h3wt};
15
16use crate::error::HttpError;
17use crate::service::{HttpService, HttpServiceFactory};
18
19pub mod body;
20mod utils;
21#[cfg(feature = "webtransport")]
22pub mod webtransport;
23
24/// A backend that handles incoming HTTP3 connections.
25///
26/// This is used internally by the [`HttpServer`](crate::server::HttpServer) but can be used directly if preferred.
27///
28/// Call [`run`](Http3Backend::run) to start the server.
29#[derive(bon::Builder, Debug, Clone)]
30pub struct Http3Backend<F> {
31 /// The [`scuffle_context::Context`] this server will live by.
32 #[builder(default = scuffle_context::Context::global())]
33 ctx: scuffle_context::Context,
34 /// The number of worker tasks to spawn for each server backend.
35 #[builder(default = 1)]
36 worker_tasks: usize,
37 /// The service factory that will be used to create new services.
38 service_factory: F,
39 /// The address to bind to.
40 ///
41 /// Use `[::]` for a dual-stack listener.
42 /// For example, use `[::]:80` to bind to port 80 on both IPv4 and IPv6.
43 bind: SocketAddr,
44 /// Enable WebTransport support.
45 #[builder(default = false)]
46 #[cfg(feature = "webtransport")]
47 enable_webtransport: bool,
48 #[builder(default = 1, setters(vis = "", name = max_webtransport_sessions_internal))]
49 #[cfg(feature = "webtransport")]
50 max_webtransport_sessions: u64,
51 /// rustls config.
52 ///
53 /// Use this field to set the server into TLS mode.
54 /// It will only accept TLS connections when this is set.
55 rustls_config: tokio_rustls::rustls::ServerConfig,
56}
57
58#[cfg(feature = "webtransport")]
59impl<F, S> Http3BackendBuilder<F, S>
60where
61 S: http3_backend_builder::State,
62 S::MaxWebtransportSessions: http3_backend_builder::IsUnset,
63 S::EnableWebtransport: http3_backend_builder::IsSet,
64{
65 /// Set the maximum number of concurrent WebTransport sessions.
66 ///
67 /// Corresponds to [h3::server::Builder::max_webtransport_sessions].
68 ///
69 /// Default is 1 when WebTransport is enabled.
70 pub fn max_webtransport_sessions(
71 self,
72 max_webtransport_sessions: u64,
73 ) -> Http3BackendBuilder<F, http3_backend_builder::SetMaxWebtransportSessions<S>> {
74 self.max_webtransport_sessions_internal(max_webtransport_sessions)
75 }
76}
77
78impl<F> Http3Backend<F>
79where
80 F: HttpServiceFactory + Clone + Send + 'static,
81 F::Error: std::error::Error + Send,
82 F::Service: Clone + Send + 'static,
83 <F::Service as HttpService>::Error: std::error::Error + Send + Sync,
84 <F::Service as HttpService>::ResBody: Send,
85 <<F::Service as HttpService>::ResBody as http_body::Body>::Data: Send,
86 <<F::Service as HttpService>::ResBody as http_body::Body>::Error: std::error::Error + Send + Sync,
87{
88 /// Run the HTTP3 server
89 ///
90 /// This function will bind to the address specified in `bind`, listen for incoming connections and handle requests.
91 #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(bind = %self.bind)))]
92 pub async fn run(mut self) -> Result<(), HttpError<F>> {
93 #[cfg(feature = "tracing")]
94 tracing::debug!("starting server");
95
96 // not quite sure why this is necessary but it is
97 self.rustls_config.max_early_data_size = u32::MAX;
98 let crypto = h3_quinn::quinn::crypto::rustls::QuicServerConfig::try_from(self.rustls_config)?;
99 let mut server_config = h3_quinn::quinn::ServerConfig::with_crypto(Arc::new(crypto));
100 let mut transport_config = quinn::TransportConfig::default();
101 transport_config.keep_alive_interval(Some(Duration::from_secs(2)));
102 server_config.transport = Arc::new(transport_config);
103
104 // Bind the UDP socket
105 let socket = std::net::UdpSocket::bind(self.bind)?;
106
107 // Runtime for the quinn endpoint
108 let runtime = h3_quinn::quinn::default_runtime().ok_or_else(|| io::Error::other("no async runtime found"))?;
109
110 // Create a child context for the workers so we can shut them down if one of them fails without shutting down the main context
111 let (worker_ctx, worker_handler) = self.ctx.new_child();
112
113 let workers = (0..self.worker_tasks).map(|_n| {
114 let ctx = worker_ctx.clone();
115 let service_factory = self.service_factory.clone();
116 let server_config = server_config.clone();
117 let socket = socket.try_clone().expect("failed to clone socket");
118 let runtime = Arc::clone(&runtime);
119
120 let worker_fut = async move {
121 let endpoint = h3_quinn::quinn::Endpoint::new(
122 h3_quinn::quinn::EndpointConfig::default(),
123 Some(server_config),
124 socket,
125 runtime,
126 )?;
127
128 #[cfg(feature = "tracing")]
129 tracing::trace!("waiting for connections");
130
131 while let Some(Some(new_conn)) = endpoint.accept().with_context(&ctx).await {
132 let mut service_factory = service_factory.clone();
133 let ctx = ctx.clone();
134
135 tokio::spawn(async move {
136 let _res: Result<_, HttpError<F>> = async move {
137 let Some(conn) = new_conn.with_context(&ctx).await.transpose()? else {
138 #[cfg(feature = "tracing")]
139 tracing::trace!("context done while accepting connection");
140 return Ok(());
141 };
142 let addr = conn.remote_address();
143 let client_certs = conn
144 .peer_identity()
145 .and_then(|any| any.downcast::<Vec<tokio_rustls::rustls::pki_types::CertificateDer>>().ok());
146
147 #[cfg(feature = "tracing")]
148 tracing::debug!(addr = %addr, "accepted quic connection");
149
150 let connection_fut = async move {
151 #[cfg(not(feature = "webtransport"))]
152 let h3_conn_builder = h3::server::builder();
153 #[cfg(feature = "webtransport")]
154 let mut h3_conn_builder = h3::server::builder();
155
156 #[cfg(feature = "webtransport")]
157 if self.enable_webtransport {
158 h3_conn_builder
159 .enable_webtransport(true)
160 .enable_extended_connect(true)
161 .enable_datagram(true)
162 .max_webtransport_sessions(self.max_webtransport_sessions)
163 .send_grease(true);
164 }
165
166 let Some(mut h3_conn) = h3_conn_builder
167 .build(h3_quinn::Connection::new(conn))
168 .with_context(&ctx)
169 .await
170 .transpose()?
171 else {
172 #[cfg(feature = "tracing")]
173 tracing::trace!("context done while establishing connection");
174 return Ok(());
175 };
176
177 let mut extra_extensions = http::Extensions::new();
178 extra_extensions.insert(crate::extensions::ClientAddr(addr));
179 if let Some(certs) = client_certs {
180 extra_extensions.insert(crate::extensions::ClientIdentity(Arc::new(*certs)));
181 }
182
183 // make a new service for this connection
184 let http_service = service_factory
185 .new_service(addr)
186 .await
187 .map_err(|e| HttpError::ServiceFactoryError(e))?;
188
189 loop {
190 match h3_conn.accept().with_context(&ctx).await {
191 Some(Ok(Some(resolver))) => {
192 // Resolve the request
193 let (req, stream) = match resolver.resolve_request().await {
194 Ok(r) => r,
195 Err(_err) => {
196 #[cfg(feature = "tracing")]
197 tracing::warn!("error on accept: {}", _err);
198 continue;
199 }
200 };
201
202 #[cfg(feature = "tracing")]
203 tracing::debug!(method = %req.method(), uri = %req.uri(), "received request");
204
205 // Check if this is a WebTransport CONNECT request
206 #[cfg(feature = "webtransport")]
207 if self.enable_webtransport
208 && req.extensions().get::<Protocol>() == Some(&Protocol::WEB_TRANSPORT)
209 && req.method() == http::Method::CONNECT
210 {
211 #[cfg(feature = "tracing")]
212 tracing::debug!("starting WebTransport session");
213
214 // Store the original request for passing to the service
215 let (parts, _) = req.into_parts();
216
217 // Accept the WebTransport session
218 let session = match h3wt::server::WebTransportSession::accept(
219 http::Request::from_parts(parts.clone(), ()),
220 stream,
221 h3_conn,
222 )
223 .await
224 {
225 Ok(session) => session,
226 Err(_err) => {
227 #[cfg(feature = "tracing")]
228 tracing::warn!(err = %_err, "failed to accept WebTransport session");
229 break;
230 }
231 };
232
233 let wt_session =
234 webtransport::WebTransportSession::new(std::sync::Arc::new(session));
235
236 // Create an empty body for the WebTransport request
237 // Since WebTransport operates on streams, not the request body
238 let empty_body = crate::body::IncomingBody::Empty;
239
240 // Reconstruct the request with the session in extensions
241 let mut wt_req = http::Request::from_parts(parts, empty_body);
242 wt_req.extensions_mut().insert(wt_session); // Call the service with the WebTransport request
243 tokio::spawn({
244 let ctx = ctx.clone();
245 let mut http_service = http_service.clone();
246 async move {
247 let _res: Result<_, HttpError<F>> = async move {
248 let _resp = http_service
249 .call(wt_req)
250 .await
251 .map_err(|e| HttpError::ServiceError(e))?;
252
253 #[cfg(feature = "tracing")]
254 tracing::debug!("WebTransport session handler completed");
255
256 Ok(())
257 }
258 .await;
259
260 #[cfg(feature = "tracing")]
261 if let Err(e) = _res {
262 tracing::warn!(err = %e, "WebTransport session handler error");
263 }
264
265 drop(ctx);
266 }
267 });
268
269 break;
270 }
271 let (mut send, recv) = stream.split();
272
273 let size_hint = req
274 .headers()
275 .get(http::header::CONTENT_LENGTH)
276 .and_then(|len| len.to_str().ok().and_then(|x| x.parse().ok()));
277 let body = QuicIncomingBody::new(recv, size_hint);
278 let mut req = req.map(|_| crate::body::IncomingBody::from(body));
279
280 req.extensions_mut().extend(extra_extensions.clone());
281
282 tokio::spawn({
283 let ctx = ctx.clone();
284 let mut http_service = http_service.clone();
285 async move {
286 let _res: Result<_, HttpError<F>> = async move {
287 let resp = http_service
288 .call(req)
289 .await
290 .map_err(|e| HttpError::ServiceError(e))?;
291 let (parts, body) = resp.into_parts();
292
293 send.send_response(http::Response::from_parts(parts, ())).await?;
294 copy_response_body(send, body).await?;
295
296 Ok(())
297 }
298 .await;
299
300 #[cfg(feature = "tracing")]
301 if let Err(e) = _res {
302 tracing::warn!(err = %e, "error handling request");
303 }
304
305 // This moves the context into the async block because it is dropped here
306 drop(ctx);
307 }
308 });
309 }
310 // indicating no more streams to be received
311 Some(Ok(None)) => {
312 break;
313 }
314 Some(Err(err)) => return Err(err.into()),
315 // context is done
316 None => {
317 #[cfg(feature = "tracing")]
318 tracing::trace!("context done, stopping connection loop");
319 break;
320 }
321 }
322 }
323
324 #[cfg(feature = "tracing")]
325 tracing::trace!("connection closed");
326
327 Ok(())
328 };
329
330 #[cfg(feature = "tracing")]
331 let connection_fut = connection_fut.instrument(tracing::trace_span!("connection", addr = %addr));
332
333 connection_fut.await
334 }
335 .await;
336
337 #[cfg(feature = "tracing")]
338 if let Err(err) = _res {
339 tracing::warn!(err = %err, "error handling connection");
340 }
341 });
342 }
343
344 // shut down gracefully
345 // wait for connections to be closed before exiting
346 endpoint.wait_idle().await;
347
348 Ok::<_, crate::error::HttpError<F>>(())
349 };
350
351 #[cfg(feature = "tracing")]
352 let worker_fut = worker_fut.instrument(tracing::trace_span!("worker", n = _n));
353
354 tokio::spawn(worker_fut)
355 });
356
357 if let Err(_e) = futures::future::try_join_all(workers).await {
358 #[cfg(feature = "tracing")]
359 tracing::error!(err = %_e, "error running workers");
360 }
361
362 drop(worker_ctx);
363 worker_handler.shutdown().await;
364
365 #[cfg(feature = "tracing")]
366 tracing::debug!("all workers finished");
367
368 Ok(())
369 }
370}