libp2p_request_response/
cbor.rs
1pub type Behaviour<Req, Resp> = crate::Behaviour<codec::Codec<Req, Resp>>;
45
46mod codec {
47 use async_trait::async_trait;
48 use cbor4ii::core::error::DecodeError;
49 use futures::prelude::*;
50 use libp2p_swarm::StreamProtocol;
51 use serde::{de::DeserializeOwned, Serialize};
52 use std::{collections::TryReserveError, convert::Infallible, io, marker::PhantomData};
53
54 const REQUEST_SIZE_MAXIMUM: u64 = 1024 * 1024;
56 const RESPONSE_SIZE_MAXIMUM: u64 = 10 * 1024 * 1024;
58
59 pub struct Codec<Req, Resp> {
60 phantom: PhantomData<(Req, Resp)>,
61 }
62
63 impl<Req, Resp> Default for Codec<Req, Resp> {
64 fn default() -> Self {
65 Codec {
66 phantom: PhantomData,
67 }
68 }
69 }
70
71 impl<Req, Resp> Clone for Codec<Req, Resp> {
72 fn clone(&self) -> Self {
73 Self::default()
74 }
75 }
76
77 #[async_trait]
78 impl<Req, Resp> crate::Codec for Codec<Req, Resp>
79 where
80 Req: Send + Serialize + DeserializeOwned,
81 Resp: Send + Serialize + DeserializeOwned,
82 {
83 type Protocol = StreamProtocol;
84 type Request = Req;
85 type Response = Resp;
86
87 async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Req>
88 where
89 T: AsyncRead + Unpin + Send,
90 {
91 let mut vec = Vec::new();
92
93 io.take(REQUEST_SIZE_MAXIMUM).read_to_end(&mut vec).await?;
94
95 cbor4ii::serde::from_slice(vec.as_slice()).map_err(decode_into_io_error)
96 }
97
98 async fn read_response<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Resp>
99 where
100 T: AsyncRead + Unpin + Send,
101 {
102 let mut vec = Vec::new();
103
104 io.take(RESPONSE_SIZE_MAXIMUM).read_to_end(&mut vec).await?;
105
106 cbor4ii::serde::from_slice(vec.as_slice()).map_err(decode_into_io_error)
107 }
108
109 async fn write_request<T>(
110 &mut self,
111 _: &Self::Protocol,
112 io: &mut T,
113 req: Self::Request,
114 ) -> io::Result<()>
115 where
116 T: AsyncWrite + Unpin + Send,
117 {
118 let data: Vec<u8> =
119 cbor4ii::serde::to_vec(Vec::new(), &req).map_err(encode_into_io_error)?;
120
121 io.write_all(data.as_ref()).await?;
122
123 Ok(())
124 }
125
126 async fn write_response<T>(
127 &mut self,
128 _: &Self::Protocol,
129 io: &mut T,
130 resp: Self::Response,
131 ) -> io::Result<()>
132 where
133 T: AsyncWrite + Unpin + Send,
134 {
135 let data: Vec<u8> =
136 cbor4ii::serde::to_vec(Vec::new(), &resp).map_err(encode_into_io_error)?;
137
138 io.write_all(data.as_ref()).await?;
139
140 Ok(())
141 }
142 }
143
144 fn decode_into_io_error(err: cbor4ii::serde::DecodeError<Infallible>) -> io::Error {
145 match err {
146 cbor4ii::serde::DecodeError::Core(DecodeError::Read(e)) => {
147 io::Error::new(io::ErrorKind::Other, e)
148 }
149 cbor4ii::serde::DecodeError::Core(e @ DecodeError::Unsupported { .. }) => {
150 io::Error::new(io::ErrorKind::Unsupported, e)
151 }
152 cbor4ii::serde::DecodeError::Core(e @ DecodeError::Eof { .. }) => {
153 io::Error::new(io::ErrorKind::UnexpectedEof, e)
154 }
155 cbor4ii::serde::DecodeError::Core(e) => io::Error::new(io::ErrorKind::InvalidData, e),
156 cbor4ii::serde::DecodeError::Custom(e) => {
157 io::Error::new(io::ErrorKind::Other, e.to_string())
158 }
159 }
160 }
161
162 fn encode_into_io_error(err: cbor4ii::serde::EncodeError<TryReserveError>) -> io::Error {
163 io::Error::new(io::ErrorKind::Other, err)
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use crate::cbor::codec::Codec;
170 use crate::Codec as _;
171 use futures::AsyncWriteExt;
172 use futures_ringbuf::Endpoint;
173 use libp2p_swarm::StreamProtocol;
174 use serde::{Deserialize, Serialize};
175
176 #[async_std::test]
177 async fn test_codec() {
178 let expected_request = TestRequest {
179 payload: "test_payload".to_string(),
180 };
181 let expected_response = TestResponse {
182 payload: "test_payload".to_string(),
183 };
184 let protocol = StreamProtocol::new("/test_cbor/1");
185 let mut codec = Codec::default();
186
187 let (mut a, mut b) = Endpoint::pair(124, 124);
188 codec
189 .write_request(&protocol, &mut a, expected_request.clone())
190 .await
191 .expect("Should write request");
192 a.close().await.unwrap();
193
194 let actual_request = codec
195 .read_request(&protocol, &mut b)
196 .await
197 .expect("Should read request");
198 b.close().await.unwrap();
199
200 assert_eq!(actual_request, expected_request);
201
202 let (mut a, mut b) = Endpoint::pair(124, 124);
203 codec
204 .write_response(&protocol, &mut a, expected_response.clone())
205 .await
206 .expect("Should write response");
207 a.close().await.unwrap();
208
209 let actual_response = codec
210 .read_response(&protocol, &mut b)
211 .await
212 .expect("Should read response");
213 b.close().await.unwrap();
214
215 assert_eq!(actual_response, expected_response);
216 }
217
218 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
219 struct TestRequest {
220 payload: String,
221 }
222
223 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
224 struct TestResponse {
225 payload: String,
226 }
227}