scuffle_http/backend/h3/webtransport.rs
1//! WebTransport session management for HTTP/3.
2//!
3//! This module provides types for handling WebTransport sessions over HTTP/3.
4//! WebTransport allows bidirectional streams and datagrams to be established over QUIC.
5
6use std::sync::Arc;
7use std::{fmt, io};
8
9use bytes::Bytes;
10use h3::quic::StreamErrorIncoming;
11use h3_webtransport::server::{AcceptedBi as H3AcceptedBi, WebTransportSession as H3WebTransportSession};
12use h3_webtransport::stream::{BidiStream, RecvStream as WtRecvStream, SendStream as WtSendStream};
13
14/// A WebTransport session handle.
15///
16/// This type provides access to bidirectional and unidirectional streams
17/// for a WebTransport session established over HTTP/3.
18///
19/// The session can be retrieved from the request extensions when handling
20/// a WebTransport CONNECT request.
21///
22/// # Example
23///
24/// ```rust,ignore
25/// # use scuffle_http::{IncomingRequest, Response};
26/// # use scuffle_http::backend::h3::webtransport::WebTransportSession;
27/// async fn handle_webtransport(req: IncomingRequest) -> Result<Response<()>, std::convert::Infallible> {
28/// if let Some(session) = req.extensions().get::<WebTransportSession>() {
29/// // Handle WebTransport session
30/// tokio::spawn({
31/// let session = session.clone();
32/// async move {
33/// while let Ok(Some(accepted)) = session.accept_bi().await {
34/// // Handle bidirectional streams
35/// }
36/// }
37/// });
38///
39/// return Ok(Response::builder()
40/// .status(200)
41/// .body(())
42/// .unwrap());
43/// }
44///
45/// Ok(Response::builder()
46/// .status(404)
47/// .body(())
48/// .unwrap())
49/// }
50/// ```
51#[derive(Clone)]
52pub struct WebTransportSession {
53 session: Arc<H3WebTransportSession<h3_quinn::Connection, Bytes>>,
54}
55
56impl WebTransportSession {
57 /// Create a new WebTransport session from an h3-webtransport session.
58 pub(crate) fn new(session: Arc<H3WebTransportSession<h3_quinn::Connection, Bytes>>) -> Self {
59 Self { session }
60 }
61
62 /// Accept the next incoming bidirectional stream or request.
63 ///
64 /// Returns `None` when the session is closed or no more streams are available.
65 ///
66 /// # Example
67 ///
68 /// ```rust,ignore
69 /// # use scuffle_http::backend::h3::webtransport::{WebTransportSession, AcceptedBi};
70 /// async fn handle_session(session: WebTransportSession) {
71 /// while let Ok(Some(accepted)) = session.accept_bi().await {
72 /// match accepted {
73 /// AcceptedBi::BidiStream(stream) => {
74 /// // Handle raw bidirectional stream
75 /// }
76 /// AcceptedBi::Request(req, stream) => {
77 /// // Handle HTTP request over WebTransport
78 /// }
79 /// }
80 /// }
81 /// }
82 /// ```
83 pub async fn accept_bi(&self) -> Result<Option<AcceptedBi>, h3::error::StreamError> {
84 match self.session.accept_bi().await {
85 Ok(Some(H3AcceptedBi::BidiStream(id, stream))) => {
86 Ok(Some(AcceptedBi::BidiStream(WebTransportBidiStream { stream, _id: id })))
87 }
88 Ok(Some(H3AcceptedBi::Request(req, stream))) => {
89 Ok(Some(AcceptedBi::Request(req, WebTransportRequestStream { stream })))
90 }
91 Ok(None) => Ok(None),
92 Err(e) => Err(e),
93 }
94 }
95
96 /// Accept the next incoming unidirectional stream.
97 ///
98 /// Returns `None` when the session is closed or no more streams are available.
99 pub async fn accept_uni(
100 &self,
101 ) -> Result<Option<(WebTransportStreamId, WebTransportRecvStream)>, h3::error::ConnectionError> {
102 self.session
103 .accept_uni()
104 .await
105 .map(|o| o.map(|(id, stream)| (WebTransportStreamId(id), WebTransportRecvStream { stream })))
106 }
107
108 /// Open a new bidirectional stream.
109 ///
110 /// # Example
111 ///
112 /// ```rust,ignore
113 /// # use bytes::Bytes;
114 /// # use scuffle_http::backend::h3::webtransport::WebTransportSession;
115 /// # async fn dummy(session: WebTransportSession) -> Result<(), Box<dyn std::error::Error>> {
116 /// let (mut send, mut recv) = session.open_bi().await?;
117 /// send.write(Bytes::from("Hello")).await?;
118 /// send.finish().await?;
119 /// # Ok(())
120 /// # }
121 /// ```
122 pub async fn open_bi(&self) -> Result<(WebTransportSendStream, WebTransportRecvStream), h3::error::StreamError> {
123 let stream = self.session.open_bi(WebTransportStreamId::next_session_id()).await?;
124 use h3::quic::BidiStream;
125 let (send, recv) = stream.split();
126 Ok((
127 WebTransportSendStream { stream: send },
128 WebTransportRecvStream { stream: recv },
129 ))
130 }
131
132 /// Open a new unidirectional stream.
133 ///
134 /// # Example
135 ///
136 /// ```rust,ignore
137 /// # use bytes::Bytes;
138 /// # use scuffle_http::backend::h3::webtransport::WebTransportSession;
139 /// # async fn dummy(session: WebTransportSession) -> Result<(), Box<dyn std::error::Error>> {
140 /// let mut send = session.open_uni().await?;
141 /// send.write(Bytes::from("Hello")).await?;
142 /// send.finish().await?;
143 /// # Ok(())
144 /// # }
145 /// ```
146 pub async fn open_uni(&self) -> Result<WebTransportSendStream, h3::error::StreamError> {
147 let send = self.session.open_uni(WebTransportStreamId::next_session_id()).await?;
148 Ok(WebTransportSendStream { stream: send })
149 }
150
151 /// Get the session ID for this WebTransport session.
152 pub fn session_id(&self) -> h3_webtransport::SessionId {
153 self.session.session_id()
154 }
155
156 /// Get a datagram sender for sending datagrams over this session.
157 ///
158 /// Datagrams are unreliable and unordered messages.
159 ///
160 /// # Example
161 ///
162 /// ```rust,ignore
163 /// # use bytes::Bytes;
164 /// # use scuffle_http::backend::h3::webtransport::WebTransportSession;
165 /// # async fn dummy(session: WebTransportSession) -> Result<(), h3_datagram::datagram_handler::SendDatagramError> {
166 /// let mut sender = session.datagram_sender();
167 /// sender.send_datagram(Bytes::from("Hello"))?;
168 /// # Ok(())
169 /// # }
170 /// ```
171 pub fn datagram_sender(
172 &self,
173 ) -> h3_datagram::datagram_handler::DatagramSender<
174 <h3_quinn::Connection as h3_datagram::quic_traits::DatagramConnectionExt<Bytes>>::SendDatagramHandler,
175 Bytes,
176 > {
177 self.session.datagram_sender()
178 }
179
180 /// Get a datagram reader for receiving datagrams over this session.
181 ///
182 /// # Example
183 ///
184 /// ```rust,ignore
185 /// # use scuffle_http::backend::h3::webtransport::WebTransportSession;
186 /// # async fn dummy(session: WebTransportSession) {
187 /// let mut reader = session.datagram_reader();
188 /// while let Ok(datagram) = reader.read_datagram().await {
189 /// println!("Received: {} bytes", datagram.payload().len());
190 /// }
191 /// # }
192 /// ```
193 pub fn datagram_reader(
194 &self,
195 ) -> h3_datagram::datagram_handler::DatagramReader<
196 <h3_quinn::Connection as h3_datagram::quic_traits::DatagramConnectionExt<Bytes>>::RecvDatagramHandler,
197 > {
198 self.session.datagram_reader()
199 }
200}
201
202impl fmt::Debug for WebTransportSession {
203 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
204 f.debug_struct("WebTransportSession").finish_non_exhaustive()
205 }
206}
207
208/// An accepted bidirectional stream or request.
209#[derive(Debug)]
210pub enum AcceptedBi {
211 /// A raw bidirectional stream.
212 BidiStream(WebTransportBidiStream),
213 /// An HTTP request over WebTransport.
214 Request(http::Request<()>, WebTransportRequestStream),
215}
216
217/// A bidirectional WebTransport stream.
218pub struct WebTransportBidiStream {
219 stream: BidiStream<h3_quinn::BidiStream<Bytes>, Bytes>,
220 _id: h3_webtransport::SessionId,
221}
222
223impl WebTransportBidiStream {
224 /// Get the inner [`h3_webtransport::stream::BidiStream`].
225 ///
226 /// Can be used to access lower-level functionality.
227 pub fn into_inner(self) -> BidiStream<h3_quinn::BidiStream<Bytes>, Bytes> {
228 self.stream
229 }
230
231 /// Split this stream into separate send and receive halves.
232 ///
233 /// # Example
234 ///
235 /// ```rust,ignore
236 /// # use bytes::Bytes;
237 /// # use h3::quic::StreamErrorIncoming;
238 /// # use scuffle_http::backend::h3::webtransport::WebTransportBidiStream;
239 /// # async fn dummy(bidi_stream: WebTransportBidiStream) -> Result<(), StreamErrorIncoming> {
240 /// let (mut send, mut recv) = bidi_stream.split();
241 /// tokio::spawn(async move {
242 /// while let Ok(Some(data)) = recv.read().await {
243 /// println!("Received: {:?}", data);
244 /// }
245 /// });
246 /// send.write(Bytes::from("Hello")).await?;
247 /// # Ok(())
248 /// # }
249 /// ```
250 pub fn split(self) -> (WebTransportSendStream, WebTransportRecvStream) {
251 use h3::quic::BidiStream;
252 let (send, recv) = self.stream.split();
253 (
254 WebTransportSendStream { stream: send },
255 WebTransportRecvStream { stream: recv },
256 )
257 }
258
259 /// Read data from the receive side of the stream.
260 pub async fn read(&mut self) -> Result<Option<Bytes>, StreamErrorIncoming> {
261 use h3::quic::RecvStream;
262 std::future::poll_fn(|cx| self.stream.poll_data(cx)).await
263 }
264
265 /// Read all remaining data from the receive side until the stream is finished.
266 ///
267 /// This collects all chunks into a single [`Bytes`] object.
268 ///
269 /// Returns an [`io::Error`] if the total size exceeds `max_size` or any [`read`](WebTransportBidiStream::read)
270 /// call errors.
271 ///
272 /// # Example
273 ///
274 /// ```rust,ignore
275 /// # use scuffle_http::backend::h3::webtransport::WebTransportBidiStream;
276 /// # async fn dummy(mut bidi_stream: WebTransportBidiStream) -> Result<(), std::io::Error> {
277 /// let data = bidi_stream.read_to_end(1024 * 1024).await?; // max 1MB
278 /// # Ok(())
279 /// # }
280 /// ```
281 pub async fn read_to_end(&mut self, max_size: usize) -> Result<Bytes, io::Error> {
282 let mut chunks = Vec::new();
283 let mut total_size = 0;
284
285 while let Some(chunk) = self
286 .read()
287 .await
288 .map_err(|e| io::Error::other(format!("stream read error: {}", e)))?
289 {
290 total_size += chunk.len();
291 if total_size > max_size {
292 return Err(io::Error::new(
293 io::ErrorKind::InvalidData,
294 format!(
295 "stream data too large: {} bytes exceeds maximum of {} bytes",
296 total_size, max_size
297 ),
298 ));
299 }
300 chunks.push(chunk);
301 }
302
303 if chunks.is_empty() {
304 Ok(Bytes::new())
305 } else if chunks.len() == 1 {
306 Ok(chunks.into_iter().next().unwrap())
307 } else {
308 let mut combined = bytes::BytesMut::with_capacity(total_size);
309 for chunk in chunks {
310 combined.extend_from_slice(&chunk);
311 }
312 Ok(combined.freeze())
313 }
314 }
315
316 /// Write data to the send side of the stream.
317 pub async fn write(&mut self, data: Bytes) -> Result<(), StreamErrorIncoming> {
318 use bytes::Buf;
319 use h3::quic::{SendStream, SendStreamUnframed};
320
321 std::future::poll_fn(|cx| self.stream.poll_ready(cx)).await?;
322 let mut buf = data;
323 while buf.has_remaining() {
324 let written = std::future::poll_fn(|cx| self.stream.poll_send(cx, &mut buf)).await?;
325 if written == 0 {
326 break;
327 }
328 }
329 Ok(())
330 }
331
332 /// Finish writing to the stream.
333 pub async fn finish(&mut self) -> Result<(), StreamErrorIncoming> {
334 use h3::quic::SendStream;
335 std::future::poll_fn(|cx| self.stream.poll_finish(cx)).await
336 }
337}
338
339impl fmt::Debug for WebTransportBidiStream {
340 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
341 f.debug_struct("WebTransportBidiStream").finish()
342 }
343}
344
345/// A receive-only WebTransport stream.
346pub struct WebTransportRecvStream {
347 stream: WtRecvStream<h3_quinn::RecvStream, Bytes>,
348}
349
350impl WebTransportRecvStream {
351 /// Get the inner [`h3_webtransport::stream::RecvStream`].
352 ///
353 /// Can be used to access lower-level functionality.
354 pub fn into_inner(self) -> WtRecvStream<h3_quinn::RecvStream, Bytes> {
355 self.stream
356 }
357
358 /// Read data from the stream.
359 ///
360 /// Returns `Ok(None)` when the stream is finished.
361 ///
362 /// # Example
363 ///
364 /// ```rust,ignore
365 /// # use scuffle_http::backend::h3::webtransport::WebTransportRecvStream;
366 /// # async fn dummy(mut recv_stream: WebTransportRecvStream) {
367 /// while let Ok(Some(data)) = recv_stream.read().await {
368 /// println!("Received {} bytes", data.len());
369 /// }
370 /// # }
371 /// ```
372 pub async fn read(&mut self) -> Result<Option<Bytes>, StreamErrorIncoming> {
373 use h3::quic::RecvStream;
374 std::future::poll_fn(|cx| self.stream.poll_data(cx)).await
375 }
376
377 /// Read all remaining data from the stream until it's finished.
378 ///
379 /// This collects all chunks into a single [`Bytes`] object.
380 ///
381 /// Returns an [`io::Error`] if the total size exceeds `max_size` or any [`read`](WebTransportRecvStream::read)
382 /// call errors.
383 ///
384 /// # Example
385 ///
386 /// ```rust,ignore
387 /// # use scuffle_http::backend::h3::webtransport::WebTransportRecvStream;
388 /// # async fn dummy(mut recv_stream: WebTransportRecvStream) -> Result<(), std::io::Error> {
389 /// let data = recv_stream.read_to_end(1024 * 1024).await?; // max 1MB
390 /// println!("Received complete message: {} bytes", data.len());
391 /// # Ok(())
392 /// # }
393 /// ```
394 pub async fn read_to_end(&mut self, max_size: usize) -> Result<Bytes, io::Error> {
395 let mut chunks = Vec::new();
396 let mut total_size = 0;
397
398 while let Some(chunk) = self
399 .read()
400 .await
401 .map_err(|e| io::Error::other(format!("stream read error: {}", e)))?
402 {
403 total_size += chunk.len();
404 if total_size > max_size {
405 return Err(io::Error::new(
406 io::ErrorKind::InvalidData,
407 format!(
408 "stream data too large: {} bytes exceeds maximum of {} bytes",
409 total_size, max_size
410 ),
411 ));
412 }
413 chunks.push(chunk);
414 }
415
416 if chunks.is_empty() {
417 Ok(Bytes::new())
418 } else if chunks.len() == 1 {
419 Ok(chunks.into_iter().next().unwrap())
420 } else {
421 // Combine all chunks into a single buffer
422 let mut combined = bytes::BytesMut::with_capacity(total_size);
423 for chunk in chunks {
424 combined.extend_from_slice(&chunk);
425 }
426 Ok(combined.freeze())
427 }
428 }
429
430 /// Stop receiving data on this stream with an error code.
431 pub fn stop_sending(&mut self, error_code: u64) {
432 use h3::quic::RecvStream;
433 self.stream.stop_sending(error_code)
434 }
435}
436
437impl fmt::Debug for WebTransportRecvStream {
438 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
439 f.debug_struct("WebTransportRecvStream").finish_non_exhaustive()
440 }
441}
442
443/// A send-only WebTransport stream.
444pub struct WebTransportSendStream {
445 stream: WtSendStream<h3_quinn::SendStream<Bytes>, Bytes>,
446}
447
448impl WebTransportSendStream {
449 /// Get the inner [`h3_webtransport::stream::SendStream`].
450 ///
451 /// Can be used to access lower-level functionality.
452 pub fn into_inner(self) -> WtSendStream<h3_quinn::SendStream<Bytes>, Bytes> {
453 self.stream
454 }
455
456 /// Write data to the stream.
457 ///
458 /// # Example
459 ///
460 /// ```rust,ignore
461 /// # use bytes::Bytes;
462 /// # use h3::quic::StreamErrorIncoming;
463 /// # use scuffle_http::backend::h3::webtransport::WebTransportSendStream;
464 /// # async fn dummy(mut send_stream: WebTransportSendStream) -> Result<(), StreamErrorIncoming> {
465 /// send_stream.write(Bytes::from("Hello, world!")).await?;
466 /// send_stream.finish().await?;
467 /// # Ok(())
468 /// # }
469 /// ```
470 pub async fn write(&mut self, data: Bytes) -> Result<(), StreamErrorIncoming> {
471 use bytes::Buf;
472 use h3::quic::{SendStream, SendStreamUnframed};
473
474 std::future::poll_fn(|cx| self.stream.poll_ready(cx)).await?;
475 let mut buf = data;
476 while buf.has_remaining() {
477 let written = std::future::poll_fn(|cx| self.stream.poll_send(cx, &mut buf)).await?;
478 if written == 0 {
479 break;
480 }
481 }
482 Ok(())
483 }
484
485 /// Write all data and finish the stream in one operation.
486 ///
487 /// This is a convenience method that writes the data and then finishes the stream.
488 ///
489 /// # Example
490 ///
491 /// ```rust,ignore
492 /// # use bytes::Bytes;
493 /// # use h3::quic::StreamErrorIncoming;
494 /// # use scuffle_http::backend::h3::webtransport::WebTransportSendStream;
495 /// # async fn dummy(mut send_stream: WebTransportSendStream) -> Result<(), StreamErrorIncoming> {
496 /// send_stream.write_all(Bytes::from("Complete message")).await?;
497 /// # Ok(())
498 /// # }
499 /// ```
500 pub async fn write_all(&mut self, data: Bytes) -> Result<(), StreamErrorIncoming> {
501 self.write(data).await?;
502 self.finish().await
503 }
504
505 /// Finish writing to the stream.
506 ///
507 /// This signals that no more data will be sent on this stream.
508 pub async fn finish(&mut self) -> Result<(), StreamErrorIncoming> {
509 use h3::quic::SendStream;
510 std::future::poll_fn(|cx| self.stream.poll_finish(cx)).await
511 }
512
513 /// Reset the stream with an error code.
514 pub fn reset(&mut self, reset_code: u64) {
515 use h3::quic::SendStream;
516 self.stream.reset(reset_code)
517 }
518}
519
520impl fmt::Debug for WebTransportSendStream {
521 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
522 f.debug_struct("WebTransportSendStream").finish_non_exhaustive()
523 }
524}
525
526/// A stream for handling HTTP requests over WebTransport.
527pub struct WebTransportRequestStream {
528 stream: h3::server::RequestStream<h3_quinn::BidiStream<Bytes>, Bytes>,
529}
530
531impl WebTransportRequestStream {
532 /// Get the inner [`h3::server::RequestStream`].
533 ///
534 /// Can be used to access lower-level functionality.
535 pub fn into_inner(self) -> h3::server::RequestStream<h3_quinn::BidiStream<Bytes>, Bytes> {
536 self.stream
537 }
538
539 /// Split this stream into separate send and receive halves.
540 pub fn split(
541 self,
542 ) -> (
543 h3::server::RequestStream<h3_quinn::SendStream<Bytes>, Bytes>,
544 h3::server::RequestStream<h3_quinn::RecvStream, Bytes>,
545 ) {
546 self.stream.split()
547 }
548}
549
550impl fmt::Debug for WebTransportRequestStream {
551 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
552 f.debug_struct("WebTransportRequestStream").finish_non_exhaustive()
553 }
554}
555
556/// A WebTransport stream identifier.
557#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
558pub struct WebTransportStreamId(h3_webtransport::SessionId);
559
560impl WebTransportStreamId {
561 fn next_session_id() -> h3_webtransport::SessionId {
562 use std::sync::atomic::{AtomicU64, Ordering};
563 static COUNTER: AtomicU64 = AtomicU64::new(0);
564 let id = COUNTER.fetch_add(1, Ordering::Relaxed);
565 // SessionId is created from a VarInt-encoded StreamId
566 let varint = h3::proto::varint::VarInt::from_u64(id).expect("valid varint");
567 let stream_id = h3::quic::StreamId::from(varint);
568 h3_webtransport::SessionId::from(stream_id)
569 }
570
571 /// Get the inner session ID.
572 pub fn inner(&self) -> h3_webtransport::SessionId {
573 self.0
574 }
575}
576
577impl From<h3_webtransport::SessionId> for WebTransportStreamId {
578 fn from(id: h3_webtransport::SessionId) -> Self {
579 Self(id)
580 }
581}
582
583impl fmt::Display for WebTransportStreamId {
584 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
585 write!(f, "{:?}", self.0)
586 }
587}