Commit
+95 -52 +/-3 browse
1 | diff --git a/maitred/src/pipeline.rs b/maitred/src/pipeline.rs |
2 | index e07791f..39c837d 100644 |
3 | --- a/maitred/src/pipeline.rs |
4 | +++ b/maitred/src/pipeline.rs |
5 | @@ -1,9 +1,10 @@ |
6 | use std::result::Result as StdResult; |
7 | |
8 | - use crate::session::Result as SessionResult; |
9 | - use crate::Request; |
10 | + use crate::session::{Result as SessionResult, Session}; |
11 | + use crate::transport::Response; |
12 | + use crate::{smtp_err, smtp_ok, Request, SmtpResponse}; |
13 | |
14 | - pub type Result = StdResult<Option<Vec<SessionResult>>, Vec<SessionResult>>; |
15 | + pub type Result = Vec<Response<String>>; |
16 | pub type Transaction = (Request<String>, SessionResult); |
17 | |
18 | /// Pipeline chunks session request/responses into logical groups returning |
19 | @@ -11,9 +12,16 @@ pub type Transaction = (Request<String>, SessionResult); |
20 | #[derive(Default)] |
21 | pub struct Pipeline { |
22 | history: Vec<Transaction>, |
23 | + disable: bool, |
24 | } |
25 | |
26 | impl Pipeline { |
27 | + /// disable pipelining and return each each transaction transparently |
28 | + pub fn disable(mut self) -> Self { |
29 | + self.disable = true; |
30 | + self |
31 | + } |
32 | + |
33 | /// Checks if the pipeline is within a data transaction (if the previous |
34 | /// command was DATA/BDAT). |
35 | fn within_tx(&self) -> bool { |
36 | @@ -50,15 +58,27 @@ impl Pipeline { |
37 | .last() |
38 | .expect("to results called without history"); |
39 | if last_command.1.is_ok() && mail_from_ok && rcpt_to_ok_count > 0 { |
40 | - Ok(Some(self.history.iter().map(|tx| tx.1.clone()).collect())) |
41 | + self.history |
42 | + .iter() |
43 | + .map(|tx| tx.1.clone().unwrap_or_else(|e| e)) |
44 | + .collect() |
45 | } else if !mail_from_ok { |
46 | self.history.pop(); |
47 | - Err(self.history.iter().map(|tx| tx.1.clone()).collect()) |
48 | + self.history |
49 | + .iter() |
50 | + .map(|tx| tx.1.clone().unwrap_or_else(|e| e)) |
51 | + .collect() |
52 | } else if !rcpt_to_ok_count <= 0 { |
53 | self.history.pop(); |
54 | - Err(self.history.iter().map(|tx| tx.1.clone()).collect()) |
55 | + self.history |
56 | + .iter() |
57 | + .map(|tx| tx.1.clone().unwrap_or_else(|e| e)) |
58 | + .collect() |
59 | } else { |
60 | - Err(self.history.iter().map(|tx| tx.1.clone()).collect()) |
61 | + self.history |
62 | + .iter() |
63 | + .map(|tx| tx.1.clone().unwrap_or_else(|e| e)) |
64 | + .collect() |
65 | } |
66 | } |
67 | |
68 | @@ -90,18 +110,18 @@ impl Pipeline { |
69 | match req { |
70 | Request::Ehlo { host: _ } => { |
71 | self.history.clear(); |
72 | - Ok(Some(vec![res.clone()])) |
73 | + vec![res.clone().unwrap_or_else(|e| e)] |
74 | } |
75 | Request::Lhlo { host: _ } => { |
76 | self.history.clear(); |
77 | - Ok(Some(vec![res.clone()])) |
78 | + vec![res.clone().unwrap_or_else(|e| e)] |
79 | } |
80 | Request::Helo { host: _ } => { |
81 | self.history.clear(); |
82 | - Ok(Some(vec![res.clone()])) |
83 | + vec![res.clone().unwrap_or_else(|e| e)] |
84 | } |
85 | - Request::Mail { from: _ } => Ok(None), |
86 | - Request::Rcpt { to: _ } => Ok(None), |
87 | + Request::Mail { from: _ } => vec![], |
88 | + Request::Rcpt { to: _ } => vec![], |
89 | Request::Bdat { |
90 | chunk_size: _, |
91 | is_last: _, |
92 | @@ -111,17 +131,17 @@ impl Pipeline { |
93 | self.history.clear(); |
94 | chunk |
95 | } else { |
96 | - Ok(None) |
97 | + vec![] |
98 | } |
99 | } |
100 | Request::Auth { |
101 | mechanism: _, |
102 | initial_response: _, |
103 | } => todo!(), |
104 | - Request::Noop { value: _ } => Ok(Some(vec![res.clone()])), |
105 | + Request::Noop { value: _ } => vec![res.clone().unwrap_or_else(|e| e)], |
106 | Request::Vrfy { value: _ } => todo!(), |
107 | Request::Expn { value: _ } => todo!(), |
108 | - Request::Help { value: _ } => Ok(Some(vec![res.clone()])), |
109 | + Request::Help { value: _ } => vec![res.clone().unwrap_or_else(|e| e)], |
110 | Request::Etrn { name: _ } => todo!(), |
111 | Request::Atrn { domains: _ } => todo!(), |
112 | Request::Burl { uri: _, is_last: _ } => todo!(), |
113 | @@ -132,14 +152,14 @@ impl Pipeline { |
114 | self.history.clear(); |
115 | chunk |
116 | } else { |
117 | - Ok(None) |
118 | + vec![] |
119 | } |
120 | } |
121 | Request::Rset => { |
122 | self.history.clear(); |
123 | - Ok(Some(vec![res.clone()])) |
124 | + vec![res.clone().unwrap_or_else(|e| e)] |
125 | } |
126 | - Request::Quit => Ok(Some(vec![res.clone()])), |
127 | + Request::Quit => vec![res.clone().unwrap_or_else(|e| e)], |
128 | } |
129 | } |
130 | } |
131 | @@ -153,14 +173,17 @@ mod test { |
132 | #[test] |
133 | pub fn test_pipeline_basic() { |
134 | let mut pipeline = Pipeline::default(); |
135 | - assert!(pipeline |
136 | - .process( |
137 | - &Request::Helo { |
138 | - host: "example.org".to_string(), |
139 | - }, |
140 | - &smtp_ok!(200, 0, 0, 0, "OK") |
141 | - ) |
142 | - .is_ok_and(|responses| responses.is_some_and(|responses| responses.len() == 1))); |
143 | + assert!( |
144 | + pipeline |
145 | + .process( |
146 | + &Request::Helo { |
147 | + host: "example.org".to_string(), |
148 | + }, |
149 | + &smtp_ok!(200, 0, 0, 0, "OK") |
150 | + ) |
151 | + .len() |
152 | + == 1 |
153 | + ); |
154 | // batchable commands out of order |
155 | assert!(pipeline |
156 | .process( |
157 | @@ -172,7 +195,7 @@ mod test { |
158 | }, |
159 | &smtp_ok!(200, 0, 0, 0, "OK: baz@qux.com") |
160 | ) |
161 | - .is_ok_and(|responses| responses.is_none())); |
162 | + .is_empty()); |
163 | assert!(pipeline |
164 | .process( |
165 | &Request::Mail { |
166 | @@ -183,20 +206,14 @@ mod test { |
167 | }, |
168 | &smtp_ok!(200, 0, 0, 0, "OK: fuu@bar.com") |
169 | ) |
170 | - .is_ok_and(|responses| responses.is_none())); |
171 | + .is_empty()); |
172 | + |
173 | // initialize a data request |
174 | assert!(pipeline |
175 | .process(&Request::Data {}, &smtp_ok!(200, 0, 0, 0, "OK")) |
176 | - .is_ok_and(|responses| responses.is_none())); |
177 | + .is_empty()); |
178 | // simulate the end of a request |
179 | let result = pipeline.process(&Request::Data {}, &smtp_ok!(200, 0, 0, 0, "OK")); |
180 | - assert!( |
181 | - result.is_ok_and(|responses| responses.is_some_and(|responses| { |
182 | - responses.len() == 3 |
183 | - && responses[0].is_ok() |
184 | - && responses[1].is_ok() |
185 | - && responses[2].is_ok() |
186 | - })) |
187 | - ); |
188 | + assert!(result.len() == 3); |
189 | } |
190 | } |
191 | diff --git a/maitred/src/server.rs b/maitred/src/server.rs |
192 | index 0afa22d..9ef55f4 100644 |
193 | --- a/maitred/src/server.rs |
194 | +++ b/maitred/src/server.rs |
195 | @@ -1,15 +1,16 @@ |
196 | use std::time::Duration; |
197 | |
198 | + use bytes::Bytes; |
199 | use futures::SinkExt; |
200 | use smtp_proto::Request; |
201 | use tokio::{net::TcpListener, time::timeout}; |
202 | - use tokio_stream::StreamExt; |
203 | + use tokio_stream::{self as stream, StreamExt}; |
204 | use tokio_util::codec::Framed; |
205 | |
206 | use crate::error::Error; |
207 | use crate::pipeline::Pipeline; |
208 | use crate::session::{Options as SessionOptions, Session}; |
209 | - use crate::transport::Transport; |
210 | + use crate::transport::{Response, Transport}; |
211 | |
212 | const DEFAULT_LISTEN_ADDR: &str = "127.0.0.1:2525"; |
213 | const DEFAULT_GREETING: &str = "Maitred ESMTP Server"; |
214 | @@ -36,7 +37,33 @@ const DEFAULT_MAXIMUM_SIZE: u64 = 5_000_000; |
215 | // 250 CHUNKING |
216 | |
217 | const DEFAULT_CAPABILITIES: u32 = |
218 | - smtp_proto::EXT_SIZE + smtp_proto::EXT_ENHANCED_STATUS_CODES + smtp_proto::EXT_PIPELINING; |
219 | + smtp_proto::EXT_SIZE | smtp_proto::EXT_ENHANCED_STATUS_CODES | smtp_proto::EXT_PIPELINING; |
220 | + |
221 | + struct ConditionalPipeline<'a> { |
222 | + pub opts: &'a SessionOptions, |
223 | + pub session: &'a mut Session, |
224 | + pub pipeline: &'a mut Pipeline, |
225 | + } |
226 | + |
227 | + impl ConditionalPipeline<'_> { |
228 | + pub fn apply(&mut self, req: &Request<String>, data: Option<&Bytes>) -> Vec<Response<String>> { |
229 | + let response = self.session.process(self.opts, req, data); |
230 | + if self.opts.capabilities & smtp_proto::EXT_PIPELINING != 0 { |
231 | + self.pipeline.process(req, &response) |
232 | + } else { |
233 | + match self.session.process(self.opts, req, data) { |
234 | + Ok(response) => { |
235 | + tracing::debug!("Client response: {:?}", response); |
236 | + vec![response] |
237 | + } |
238 | + Err(response) => { |
239 | + tracing::warn!("Client error: {:?}", response); |
240 | + vec![response] |
241 | + } |
242 | + } |
243 | + } |
244 | + } |
245 | + } |
246 | |
247 | #[derive(Clone)] |
248 | struct Configuration { |
249 | @@ -125,6 +152,11 @@ impl Server { |
250 | T: tokio::io::AsyncRead + tokio::io::AsyncWrite + std::marker::Unpin, |
251 | { |
252 | let mut session = Session::default(); |
253 | + let mut pipelined = ConditionalPipeline { |
254 | + opts: &self.config.session_opts(), |
255 | + session: &mut session, |
256 | + pipeline: &mut Pipeline::default(), |
257 | + }; |
258 | // send inital server greeting |
259 | framed |
260 | .send(crate::session::greeting( |
261 | @@ -144,16 +176,10 @@ impl Server { |
262 | if matches!(command.0, Request::Quit) { |
263 | finished = true; |
264 | } |
265 | - match session.process(&opts, &command.0, &command.1) { |
266 | - Ok(resp) => { |
267 | - tracing::debug!("Returning response: {:?}", resp); |
268 | - framed.send(resp).await?; |
269 | - } |
270 | - Err(err) => { |
271 | - tracing::warn!("Client error: {:?}", err); |
272 | - framed.send(err).await?; |
273 | - } |
274 | - }; |
275 | + let responses = pipelined.apply(&command.0, command.1.as_ref()); |
276 | + for response in responses { |
277 | + framed.send(response).await?; |
278 | + } |
279 | if finished { |
280 | break 'outer; |
281 | } |
282 | diff --git a/maitred/src/session.rs b/maitred/src/session.rs |
283 | index 5498515..59f0c4e 100644 |
284 | --- a/maitred/src/session.rs |
285 | +++ b/maitred/src/session.rs |
286 | @@ -76,7 +76,7 @@ impl Session { |
287 | &mut self, |
288 | opts: &Options, |
289 | req: &Request<String>, |
290 | - data: &Option<Bytes>, |
291 | + data: Option<&Bytes>, |
292 | ) -> Result { |
293 | match req { |
294 | Request::Ehlo { host } => { |
295 | @@ -245,7 +245,7 @@ mod test { |
296 | fn process_all(session: &mut Session, opts: &Options, commands: &[TestCase]) { |
297 | commands.iter().enumerate().for_each(|(i, command)| { |
298 | println!("Running command {}/{}", i, commands.len()); |
299 | - let response = session.process(opts, &command.request, &command.payload); |
300 | + let response = session.process(opts, &command.request, command.payload.as_ref()); |
301 | println!("Response: {:?}", response); |
302 | match response { |
303 | Ok(actual_response) => { |