libp2p_request_response/
cbor.rs

1// Copyright 2023 Protocol Labs
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21/// A request-response behaviour using [`cbor4ii::serde`] for serializing and
22/// deserializing the messages.
23///
24/// # Example
25///
26/// ```
27/// # use libp2p_request_response::{cbor, ProtocolSupport, self as request_response};
28/// # use libp2p_swarm::StreamProtocol;
29/// #[derive(Debug, serde::Serialize, serde::Deserialize)]
30/// struct GreetRequest {
31///     name: String,
32/// }
33///
34/// #[derive(Debug, serde::Serialize, serde::Deserialize)]
35/// struct GreetResponse {
36///     message: String,
37/// }
38///
39/// let behaviour = cbor::Behaviour::<GreetRequest, GreetResponse>::new(
40///     [(StreamProtocol::new("/my-cbor-protocol"), ProtocolSupport::Full)],
41///     request_response::Config::default()
42/// );
43/// ```
44pub type Behaviour<Req, Resp> = crate::Behaviour<codec::Codec<Req, Resp>>;
45
46mod codec {
47    use async_trait::async_trait;
48    use cbor4ii::core::error::DecodeError;
49    use futures::prelude::*;
50    use libp2p_swarm::StreamProtocol;
51    use serde::{de::DeserializeOwned, Serialize};
52    use std::{collections::TryReserveError, convert::Infallible, io, marker::PhantomData};
53
54    /// Max request size in bytes
55    const REQUEST_SIZE_MAXIMUM: u64 = 1024 * 1024;
56    /// Max response size in bytes
57    const RESPONSE_SIZE_MAXIMUM: u64 = 10 * 1024 * 1024;
58
59    pub struct Codec<Req, Resp> {
60        phantom: PhantomData<(Req, Resp)>,
61    }
62
63    impl<Req, Resp> Default for Codec<Req, Resp> {
64        fn default() -> Self {
65            Codec {
66                phantom: PhantomData,
67            }
68        }
69    }
70
71    impl<Req, Resp> Clone for Codec<Req, Resp> {
72        fn clone(&self) -> Self {
73            Self::default()
74        }
75    }
76
77    #[async_trait]
78    impl<Req, Resp> crate::Codec for Codec<Req, Resp>
79    where
80        Req: Send + Serialize + DeserializeOwned,
81        Resp: Send + Serialize + DeserializeOwned,
82    {
83        type Protocol = StreamProtocol;
84        type Request = Req;
85        type Response = Resp;
86
87        async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Req>
88        where
89            T: AsyncRead + Unpin + Send,
90        {
91            let mut vec = Vec::new();
92
93            io.take(REQUEST_SIZE_MAXIMUM).read_to_end(&mut vec).await?;
94
95            cbor4ii::serde::from_slice(vec.as_slice()).map_err(decode_into_io_error)
96        }
97
98        async fn read_response<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Resp>
99        where
100            T: AsyncRead + Unpin + Send,
101        {
102            let mut vec = Vec::new();
103
104            io.take(RESPONSE_SIZE_MAXIMUM).read_to_end(&mut vec).await?;
105
106            cbor4ii::serde::from_slice(vec.as_slice()).map_err(decode_into_io_error)
107        }
108
109        async fn write_request<T>(
110            &mut self,
111            _: &Self::Protocol,
112            io: &mut T,
113            req: Self::Request,
114        ) -> io::Result<()>
115        where
116            T: AsyncWrite + Unpin + Send,
117        {
118            let data: Vec<u8> =
119                cbor4ii::serde::to_vec(Vec::new(), &req).map_err(encode_into_io_error)?;
120
121            io.write_all(data.as_ref()).await?;
122
123            Ok(())
124        }
125
126        async fn write_response<T>(
127            &mut self,
128            _: &Self::Protocol,
129            io: &mut T,
130            resp: Self::Response,
131        ) -> io::Result<()>
132        where
133            T: AsyncWrite + Unpin + Send,
134        {
135            let data: Vec<u8> =
136                cbor4ii::serde::to_vec(Vec::new(), &resp).map_err(encode_into_io_error)?;
137
138            io.write_all(data.as_ref()).await?;
139
140            Ok(())
141        }
142    }
143
144    fn decode_into_io_error(err: cbor4ii::serde::DecodeError<Infallible>) -> io::Error {
145        match err {
146            cbor4ii::serde::DecodeError::Core(DecodeError::Read(e)) => {
147                io::Error::new(io::ErrorKind::Other, e)
148            }
149            cbor4ii::serde::DecodeError::Core(e @ DecodeError::Unsupported { .. }) => {
150                io::Error::new(io::ErrorKind::Unsupported, e)
151            }
152            cbor4ii::serde::DecodeError::Core(e @ DecodeError::Eof { .. }) => {
153                io::Error::new(io::ErrorKind::UnexpectedEof, e)
154            }
155            cbor4ii::serde::DecodeError::Core(e) => io::Error::new(io::ErrorKind::InvalidData, e),
156            cbor4ii::serde::DecodeError::Custom(e) => {
157                io::Error::new(io::ErrorKind::Other, e.to_string())
158            }
159        }
160    }
161
162    fn encode_into_io_error(err: cbor4ii::serde::EncodeError<TryReserveError>) -> io::Error {
163        io::Error::new(io::ErrorKind::Other, err)
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use crate::cbor::codec::Codec;
170    use crate::Codec as _;
171    use futures::AsyncWriteExt;
172    use futures_ringbuf::Endpoint;
173    use libp2p_swarm::StreamProtocol;
174    use serde::{Deserialize, Serialize};
175
176    #[async_std::test]
177    async fn test_codec() {
178        let expected_request = TestRequest {
179            payload: "test_payload".to_string(),
180        };
181        let expected_response = TestResponse {
182            payload: "test_payload".to_string(),
183        };
184        let protocol = StreamProtocol::new("/test_cbor/1");
185        let mut codec = Codec::default();
186
187        let (mut a, mut b) = Endpoint::pair(124, 124);
188        codec
189            .write_request(&protocol, &mut a, expected_request.clone())
190            .await
191            .expect("Should write request");
192        a.close().await.unwrap();
193
194        let actual_request = codec
195            .read_request(&protocol, &mut b)
196            .await
197            .expect("Should read request");
198        b.close().await.unwrap();
199
200        assert_eq!(actual_request, expected_request);
201
202        let (mut a, mut b) = Endpoint::pair(124, 124);
203        codec
204            .write_response(&protocol, &mut a, expected_response.clone())
205            .await
206            .expect("Should write response");
207        a.close().await.unwrap();
208
209        let actual_response = codec
210            .read_response(&protocol, &mut b)
211            .await
212            .expect("Should read response");
213        b.close().await.unwrap();
214
215        assert_eq!(actual_response, expected_response);
216    }
217
218    #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
219    struct TestRequest {
220        payload: String,
221    }
222
223    #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
224    struct TestResponse {
225        payload: String,
226    }
227}