scuffle_http/backend/h3/
webtransport.rs

1//! WebTransport session management for HTTP/3.
2//!
3//! This module provides types for handling WebTransport sessions over HTTP/3.
4//! WebTransport allows bidirectional streams and datagrams to be established over QUIC.
5
6use std::sync::Arc;
7use std::{fmt, io};
8
9use bytes::Bytes;
10use h3::quic::StreamErrorIncoming;
11use h3_webtransport::server::{AcceptedBi as H3AcceptedBi, WebTransportSession as H3WebTransportSession};
12use h3_webtransport::stream::{BidiStream, RecvStream as WtRecvStream, SendStream as WtSendStream};
13
14/// A WebTransport session handle.
15///
16/// This type provides access to bidirectional and unidirectional streams
17/// for a WebTransport session established over HTTP/3.
18///
19/// The session can be retrieved from the request extensions when handling
20/// a WebTransport CONNECT request.
21///
22/// # Example
23///
24/// ```rust,ignore
25/// # use scuffle_http::{IncomingRequest, Response};
26/// # use scuffle_http::backend::h3::webtransport::WebTransportSession;
27/// async fn handle_webtransport(req: IncomingRequest) -> Result<Response<()>, std::convert::Infallible> {
28///     if let Some(session) = req.extensions().get::<WebTransportSession>() {
29///         // Handle WebTransport session
30///         tokio::spawn({
31///             let session = session.clone();
32///             async move {
33///                 while let Ok(Some(accepted)) = session.accept_bi().await {
34///                     // Handle bidirectional streams
35///                 }
36///             }
37///         });
38///
39///         return Ok(Response::builder()
40///             .status(200)
41///             .body(())
42///             .unwrap());
43///     }
44///
45///     Ok(Response::builder()
46///         .status(404)
47///         .body(())
48///         .unwrap())
49/// }
50/// ```
51#[derive(Clone)]
52pub struct WebTransportSession {
53    session: Arc<H3WebTransportSession<h3_quinn::Connection, Bytes>>,
54}
55
56impl WebTransportSession {
57    /// Create a new WebTransport session from an h3-webtransport session.
58    pub(crate) fn new(session: Arc<H3WebTransportSession<h3_quinn::Connection, Bytes>>) -> Self {
59        Self { session }
60    }
61
62    /// Accept the next incoming bidirectional stream or request.
63    ///
64    /// Returns `None` when the session is closed or no more streams are available.
65    ///
66    /// # Example
67    ///
68    /// ```rust,ignore
69    /// # use scuffle_http::backend::h3::webtransport::{WebTransportSession, AcceptedBi};
70    /// async fn handle_session(session: WebTransportSession) {
71    ///     while let Ok(Some(accepted)) = session.accept_bi().await {
72    ///         match accepted {
73    ///             AcceptedBi::BidiStream(stream) => {
74    ///                 // Handle raw bidirectional stream
75    ///             }
76    ///             AcceptedBi::Request(req, stream) => {
77    ///                 // Handle HTTP request over WebTransport
78    ///             }
79    ///         }
80    ///     }
81    /// }
82    /// ```
83    pub async fn accept_bi(&self) -> Result<Option<AcceptedBi>, h3::error::StreamError> {
84        match self.session.accept_bi().await {
85            Ok(Some(H3AcceptedBi::BidiStream(id, stream))) => {
86                Ok(Some(AcceptedBi::BidiStream(WebTransportBidiStream { stream, _id: id })))
87            }
88            Ok(Some(H3AcceptedBi::Request(req, stream))) => {
89                Ok(Some(AcceptedBi::Request(req, WebTransportRequestStream { stream })))
90            }
91            Ok(None) => Ok(None),
92            Err(e) => Err(e),
93        }
94    }
95
96    /// Accept the next incoming unidirectional stream.
97    ///
98    /// Returns `None` when the session is closed or no more streams are available.
99    pub async fn accept_uni(
100        &self,
101    ) -> Result<Option<(WebTransportStreamId, WebTransportRecvStream)>, h3::error::ConnectionError> {
102        self.session
103            .accept_uni()
104            .await
105            .map(|o| o.map(|(id, stream)| (WebTransportStreamId(id), WebTransportRecvStream { stream })))
106    }
107
108    /// Open a new bidirectional stream.
109    ///
110    /// # Example
111    ///
112    /// ```rust,ignore
113    /// # use bytes::Bytes;
114    /// # use scuffle_http::backend::h3::webtransport::WebTransportSession;
115    /// # async fn dummy(session: WebTransportSession) -> Result<(), Box<dyn std::error::Error>> {
116    /// let (mut send, mut recv) = session.open_bi().await?;
117    /// send.write(Bytes::from("Hello")).await?;
118    /// send.finish().await?;
119    /// # Ok(())
120    /// # }
121    /// ```
122    pub async fn open_bi(&self) -> Result<(WebTransportSendStream, WebTransportRecvStream), h3::error::StreamError> {
123        let stream = self.session.open_bi(WebTransportStreamId::next_session_id()).await?;
124        use h3::quic::BidiStream;
125        let (send, recv) = stream.split();
126        Ok((
127            WebTransportSendStream { stream: send },
128            WebTransportRecvStream { stream: recv },
129        ))
130    }
131
132    /// Open a new unidirectional stream.
133    ///
134    /// # Example
135    ///
136    /// ```rust,ignore
137    /// # use bytes::Bytes;
138    /// # use scuffle_http::backend::h3::webtransport::WebTransportSession;
139    /// # async fn dummy(session: WebTransportSession) -> Result<(), Box<dyn std::error::Error>> {
140    /// let mut send = session.open_uni().await?;
141    /// send.write(Bytes::from("Hello")).await?;
142    /// send.finish().await?;
143    /// # Ok(())
144    /// # }
145    /// ```
146    pub async fn open_uni(&self) -> Result<WebTransportSendStream, h3::error::StreamError> {
147        let send = self.session.open_uni(WebTransportStreamId::next_session_id()).await?;
148        Ok(WebTransportSendStream { stream: send })
149    }
150
151    /// Get the session ID for this WebTransport session.
152    pub fn session_id(&self) -> h3_webtransport::SessionId {
153        self.session.session_id()
154    }
155
156    /// Get a datagram sender for sending datagrams over this session.
157    ///
158    /// Datagrams are unreliable and unordered messages.
159    ///
160    /// # Example
161    ///
162    /// ```rust,ignore
163    /// # use bytes::Bytes;
164    /// # use scuffle_http::backend::h3::webtransport::WebTransportSession;
165    /// # async fn dummy(session: WebTransportSession) -> Result<(), h3_datagram::datagram_handler::SendDatagramError> {
166    /// let mut sender = session.datagram_sender();
167    /// sender.send_datagram(Bytes::from("Hello"))?;
168    /// # Ok(())
169    /// # }
170    /// ```
171    pub fn datagram_sender(
172        &self,
173    ) -> h3_datagram::datagram_handler::DatagramSender<
174        <h3_quinn::Connection as h3_datagram::quic_traits::DatagramConnectionExt<Bytes>>::SendDatagramHandler,
175        Bytes,
176    > {
177        self.session.datagram_sender()
178    }
179
180    /// Get a datagram reader for receiving datagrams over this session.
181    ///
182    /// # Example
183    ///
184    /// ```rust,ignore
185    /// # use scuffle_http::backend::h3::webtransport::WebTransportSession;
186    /// # async fn dummy(session: WebTransportSession) {
187    /// let mut reader = session.datagram_reader();
188    /// while let Ok(datagram) = reader.read_datagram().await {
189    ///     println!("Received: {} bytes", datagram.payload().len());
190    /// }
191    /// # }
192    /// ```
193    pub fn datagram_reader(
194        &self,
195    ) -> h3_datagram::datagram_handler::DatagramReader<
196        <h3_quinn::Connection as h3_datagram::quic_traits::DatagramConnectionExt<Bytes>>::RecvDatagramHandler,
197    > {
198        self.session.datagram_reader()
199    }
200}
201
202impl fmt::Debug for WebTransportSession {
203    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
204        f.debug_struct("WebTransportSession").finish_non_exhaustive()
205    }
206}
207
208/// An accepted bidirectional stream or request.
209#[derive(Debug)]
210pub enum AcceptedBi {
211    /// A raw bidirectional stream.
212    BidiStream(WebTransportBidiStream),
213    /// An HTTP request over WebTransport.
214    Request(http::Request<()>, WebTransportRequestStream),
215}
216
217/// A bidirectional WebTransport stream.
218pub struct WebTransportBidiStream {
219    stream: BidiStream<h3_quinn::BidiStream<Bytes>, Bytes>,
220    _id: h3_webtransport::SessionId,
221}
222
223impl WebTransportBidiStream {
224    /// Get the inner [`h3_webtransport::stream::BidiStream`].
225    ///
226    /// Can be used to access lower-level functionality.
227    pub fn into_inner(self) -> BidiStream<h3_quinn::BidiStream<Bytes>, Bytes> {
228        self.stream
229    }
230
231    /// Split this stream into separate send and receive halves.
232    ///
233    /// # Example
234    ///
235    /// ```rust,ignore
236    /// # use bytes::Bytes;
237    /// # use h3::quic::StreamErrorIncoming;
238    /// # use scuffle_http::backend::h3::webtransport::WebTransportBidiStream;
239    /// # async fn dummy(bidi_stream: WebTransportBidiStream) -> Result<(), StreamErrorIncoming> {
240    /// let (mut send, mut recv) = bidi_stream.split();
241    /// tokio::spawn(async move {
242    ///     while let Ok(Some(data)) = recv.read().await {
243    ///         println!("Received: {:?}", data);
244    ///     }
245    /// });
246    /// send.write(Bytes::from("Hello")).await?;
247    /// # Ok(())
248    /// # }
249    /// ```
250    pub fn split(self) -> (WebTransportSendStream, WebTransportRecvStream) {
251        use h3::quic::BidiStream;
252        let (send, recv) = self.stream.split();
253        (
254            WebTransportSendStream { stream: send },
255            WebTransportRecvStream { stream: recv },
256        )
257    }
258
259    /// Read data from the receive side of the stream.
260    pub async fn read(&mut self) -> Result<Option<Bytes>, StreamErrorIncoming> {
261        use h3::quic::RecvStream;
262        std::future::poll_fn(|cx| self.stream.poll_data(cx)).await
263    }
264
265    /// Read all remaining data from the receive side until the stream is finished.
266    ///
267    /// This collects all chunks into a single [`Bytes`] object.
268    ///
269    /// Returns an [`io::Error`] if the total size exceeds `max_size` or any [`read`](WebTransportBidiStream::read)
270    /// call errors.
271    ///
272    /// # Example
273    ///
274    /// ```rust,ignore
275    /// # use scuffle_http::backend::h3::webtransport::WebTransportBidiStream;
276    /// # async fn dummy(mut bidi_stream: WebTransportBidiStream) -> Result<(), std::io::Error> {
277    /// let data = bidi_stream.read_to_end(1024 * 1024).await?; // max 1MB
278    /// # Ok(())
279    /// # }
280    /// ```
281    pub async fn read_to_end(&mut self, max_size: usize) -> Result<Bytes, io::Error> {
282        let mut chunks = Vec::new();
283        let mut total_size = 0;
284
285        while let Some(chunk) = self
286            .read()
287            .await
288            .map_err(|e| io::Error::other(format!("stream read error: {}", e)))?
289        {
290            total_size += chunk.len();
291            if total_size > max_size {
292                return Err(io::Error::new(
293                    io::ErrorKind::InvalidData,
294                    format!(
295                        "stream data too large: {} bytes exceeds maximum of {} bytes",
296                        total_size, max_size
297                    ),
298                ));
299            }
300            chunks.push(chunk);
301        }
302
303        if chunks.is_empty() {
304            Ok(Bytes::new())
305        } else if chunks.len() == 1 {
306            Ok(chunks.into_iter().next().unwrap())
307        } else {
308            let mut combined = bytes::BytesMut::with_capacity(total_size);
309            for chunk in chunks {
310                combined.extend_from_slice(&chunk);
311            }
312            Ok(combined.freeze())
313        }
314    }
315
316    /// Write data to the send side of the stream.
317    pub async fn write(&mut self, data: Bytes) -> Result<(), StreamErrorIncoming> {
318        use bytes::Buf;
319        use h3::quic::{SendStream, SendStreamUnframed};
320
321        std::future::poll_fn(|cx| self.stream.poll_ready(cx)).await?;
322        let mut buf = data;
323        while buf.has_remaining() {
324            let written = std::future::poll_fn(|cx| self.stream.poll_send(cx, &mut buf)).await?;
325            if written == 0 {
326                break;
327            }
328        }
329        Ok(())
330    }
331
332    /// Finish writing to the stream.
333    pub async fn finish(&mut self) -> Result<(), StreamErrorIncoming> {
334        use h3::quic::SendStream;
335        std::future::poll_fn(|cx| self.stream.poll_finish(cx)).await
336    }
337}
338
339impl fmt::Debug for WebTransportBidiStream {
340    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
341        f.debug_struct("WebTransportBidiStream").finish()
342    }
343}
344
345/// A receive-only WebTransport stream.
346pub struct WebTransportRecvStream {
347    stream: WtRecvStream<h3_quinn::RecvStream, Bytes>,
348}
349
350impl WebTransportRecvStream {
351    /// Get the inner [`h3_webtransport::stream::RecvStream`].
352    ///
353    /// Can be used to access lower-level functionality.
354    pub fn into_inner(self) -> WtRecvStream<h3_quinn::RecvStream, Bytes> {
355        self.stream
356    }
357
358    /// Read data from the stream.
359    ///
360    /// Returns `Ok(None)` when the stream is finished.
361    ///
362    /// # Example
363    ///
364    /// ```rust,ignore
365    /// # use scuffle_http::backend::h3::webtransport::WebTransportRecvStream;
366    /// # async fn dummy(mut recv_stream: WebTransportRecvStream) {
367    /// while let Ok(Some(data)) = recv_stream.read().await {
368    ///     println!("Received {} bytes", data.len());
369    /// }
370    /// # }
371    /// ```
372    pub async fn read(&mut self) -> Result<Option<Bytes>, StreamErrorIncoming> {
373        use h3::quic::RecvStream;
374        std::future::poll_fn(|cx| self.stream.poll_data(cx)).await
375    }
376
377    /// Read all remaining data from the stream until it's finished.
378    ///
379    /// This collects all chunks into a single [`Bytes`] object.
380    ///
381    /// Returns an [`io::Error`] if the total size exceeds `max_size` or any [`read`](WebTransportRecvStream::read)
382    /// call errors.
383    ///
384    /// # Example
385    ///
386    /// ```rust,ignore
387    /// # use scuffle_http::backend::h3::webtransport::WebTransportRecvStream;
388    /// # async fn dummy(mut recv_stream: WebTransportRecvStream) -> Result<(), std::io::Error> {
389    /// let data = recv_stream.read_to_end(1024 * 1024).await?; // max 1MB
390    /// println!("Received complete message: {} bytes", data.len());
391    /// # Ok(())
392    /// # }
393    /// ```
394    pub async fn read_to_end(&mut self, max_size: usize) -> Result<Bytes, io::Error> {
395        let mut chunks = Vec::new();
396        let mut total_size = 0;
397
398        while let Some(chunk) = self
399            .read()
400            .await
401            .map_err(|e| io::Error::other(format!("stream read error: {}", e)))?
402        {
403            total_size += chunk.len();
404            if total_size > max_size {
405                return Err(io::Error::new(
406                    io::ErrorKind::InvalidData,
407                    format!(
408                        "stream data too large: {} bytes exceeds maximum of {} bytes",
409                        total_size, max_size
410                    ),
411                ));
412            }
413            chunks.push(chunk);
414        }
415
416        if chunks.is_empty() {
417            Ok(Bytes::new())
418        } else if chunks.len() == 1 {
419            Ok(chunks.into_iter().next().unwrap())
420        } else {
421            // Combine all chunks into a single buffer
422            let mut combined = bytes::BytesMut::with_capacity(total_size);
423            for chunk in chunks {
424                combined.extend_from_slice(&chunk);
425            }
426            Ok(combined.freeze())
427        }
428    }
429
430    /// Stop receiving data on this stream with an error code.
431    pub fn stop_sending(&mut self, error_code: u64) {
432        use h3::quic::RecvStream;
433        self.stream.stop_sending(error_code)
434    }
435}
436
437impl fmt::Debug for WebTransportRecvStream {
438    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
439        f.debug_struct("WebTransportRecvStream").finish_non_exhaustive()
440    }
441}
442
443/// A send-only WebTransport stream.
444pub struct WebTransportSendStream {
445    stream: WtSendStream<h3_quinn::SendStream<Bytes>, Bytes>,
446}
447
448impl WebTransportSendStream {
449    /// Get the inner [`h3_webtransport::stream::SendStream`].
450    ///
451    /// Can be used to access lower-level functionality.
452    pub fn into_inner(self) -> WtSendStream<h3_quinn::SendStream<Bytes>, Bytes> {
453        self.stream
454    }
455
456    /// Write data to the stream.
457    ///
458    /// # Example
459    ///
460    /// ```rust,ignore
461    /// # use bytes::Bytes;
462    /// # use h3::quic::StreamErrorIncoming;
463    /// # use scuffle_http::backend::h3::webtransport::WebTransportSendStream;
464    /// # async fn dummy(mut send_stream: WebTransportSendStream) -> Result<(), StreamErrorIncoming> {
465    /// send_stream.write(Bytes::from("Hello, world!")).await?;
466    /// send_stream.finish().await?;
467    /// # Ok(())
468    /// # }
469    /// ```
470    pub async fn write(&mut self, data: Bytes) -> Result<(), StreamErrorIncoming> {
471        use bytes::Buf;
472        use h3::quic::{SendStream, SendStreamUnframed};
473
474        std::future::poll_fn(|cx| self.stream.poll_ready(cx)).await?;
475        let mut buf = data;
476        while buf.has_remaining() {
477            let written = std::future::poll_fn(|cx| self.stream.poll_send(cx, &mut buf)).await?;
478            if written == 0 {
479                break;
480            }
481        }
482        Ok(())
483    }
484
485    /// Write all data and finish the stream in one operation.
486    ///
487    /// This is a convenience method that writes the data and then finishes the stream.
488    ///
489    /// # Example
490    ///
491    /// ```rust,ignore
492    /// # use bytes::Bytes;
493    /// # use h3::quic::StreamErrorIncoming;
494    /// # use scuffle_http::backend::h3::webtransport::WebTransportSendStream;
495    /// # async fn dummy(mut send_stream: WebTransportSendStream) -> Result<(), StreamErrorIncoming> {
496    /// send_stream.write_all(Bytes::from("Complete message")).await?;
497    /// # Ok(())
498    /// # }
499    /// ```
500    pub async fn write_all(&mut self, data: Bytes) -> Result<(), StreamErrorIncoming> {
501        self.write(data).await?;
502        self.finish().await
503    }
504
505    /// Finish writing to the stream.
506    ///
507    /// This signals that no more data will be sent on this stream.
508    pub async fn finish(&mut self) -> Result<(), StreamErrorIncoming> {
509        use h3::quic::SendStream;
510        std::future::poll_fn(|cx| self.stream.poll_finish(cx)).await
511    }
512
513    /// Reset the stream with an error code.
514    pub fn reset(&mut self, reset_code: u64) {
515        use h3::quic::SendStream;
516        self.stream.reset(reset_code)
517    }
518}
519
520impl fmt::Debug for WebTransportSendStream {
521    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
522        f.debug_struct("WebTransportSendStream").finish_non_exhaustive()
523    }
524}
525
526/// A stream for handling HTTP requests over WebTransport.
527pub struct WebTransportRequestStream {
528    stream: h3::server::RequestStream<h3_quinn::BidiStream<Bytes>, Bytes>,
529}
530
531impl WebTransportRequestStream {
532    /// Get the inner [`h3::server::RequestStream`].
533    ///
534    /// Can be used to access lower-level functionality.
535    pub fn into_inner(self) -> h3::server::RequestStream<h3_quinn::BidiStream<Bytes>, Bytes> {
536        self.stream
537    }
538
539    /// Split this stream into separate send and receive halves.
540    pub fn split(
541        self,
542    ) -> (
543        h3::server::RequestStream<h3_quinn::SendStream<Bytes>, Bytes>,
544        h3::server::RequestStream<h3_quinn::RecvStream, Bytes>,
545    ) {
546        self.stream.split()
547    }
548}
549
550impl fmt::Debug for WebTransportRequestStream {
551    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
552        f.debug_struct("WebTransportRequestStream").finish_non_exhaustive()
553    }
554}
555
556/// A WebTransport stream identifier.
557#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
558pub struct WebTransportStreamId(h3_webtransport::SessionId);
559
560impl WebTransportStreamId {
561    fn next_session_id() -> h3_webtransport::SessionId {
562        use std::sync::atomic::{AtomicU64, Ordering};
563        static COUNTER: AtomicU64 = AtomicU64::new(0);
564        let id = COUNTER.fetch_add(1, Ordering::Relaxed);
565        // SessionId is created from a VarInt-encoded StreamId
566        let varint = h3::proto::varint::VarInt::from_u64(id).expect("valid varint");
567        let stream_id = h3::quic::StreamId::from(varint);
568        h3_webtransport::SessionId::from(stream_id)
569    }
570
571    /// Get the inner session ID.
572    pub fn inner(&self) -> h3_webtransport::SessionId {
573        self.0
574    }
575}
576
577impl From<h3_webtransport::SessionId> for WebTransportStreamId {
578    fn from(id: h3_webtransport::SessionId) -> Self {
579        Self(id)
580    }
581}
582
583impl fmt::Display for WebTransportStreamId {
584    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
585        write!(f, "{:?}", self.0)
586    }
587}