Commit
+130 -169 +/-6 browse
1 | diff --git a/.gitignore b/.gitignore |
2 | index 25c055c..6042ec5 100644 |
3 | --- a/.gitignore |
4 | +++ b/.gitignore |
5 | @@ -1,2 +1,4 @@ |
6 | target |
7 | |
8 | + key.pem |
9 | + cert.pem |
10 | diff --git a/maitred.toml b/maitred.toml |
11 | index 46d1033..d09b5f6 100644 |
12 | --- a/maitred.toml |
13 | +++ b/maitred.toml |
14 | @@ -7,6 +7,10 @@ address = "0.0.0.0:2525" |
15 | # Enable HAProxy's PROXY Protocol |
16 | proxy_protocol = false |
17 | |
18 | + [tls] |
19 | + certificate = "cert.pem" |
20 | + key = "key.pem" |
21 | + |
22 | [dkim] |
23 | enabled = false |
24 | |
25 | diff --git a/maitred/src/server.rs b/maitred/src/server.rs |
26 | index d641090..d036169 100644 |
27 | --- a/maitred/src/server.rs |
28 | +++ b/maitred/src/server.rs |
29 | @@ -15,7 +15,6 @@ use mail_auth::Resolver; |
30 | use mail_parser::Message; |
31 | use proxy_header::{ParseConfig, ProxyHeader}; |
32 | use smtp_proto::Request; |
33 | - use tokio::io::BufStream; |
34 | use tokio::net::TcpListener; |
35 | use tokio::sync::mpsc::Sender; |
36 | use tokio::sync::Mutex; |
37 | @@ -88,6 +87,13 @@ impl From<&Session> for Envelope { |
38 | } |
39 | } |
40 | |
41 | + /// Action for controlling a TCP session |
42 | + pub(crate) enum Action { |
43 | + Continue, |
44 | + Shutdown, |
45 | + TlsUpgrade, |
46 | + } |
47 | + |
48 | /// Server implements everything that is required to run an SMTP server by |
49 | /// binding to the configured address and processing individual TCP connections |
50 | /// as they are received. |
51 | @@ -207,102 +213,97 @@ impl Server { |
52 | .with_single_cert(certs, private_key)?) |
53 | } |
54 | |
55 | - // TODO: Eliminate duplicated code |
56 | - async fn serve_tls<T>( |
57 | + /// drive the session forward |
58 | + async fn next<T>( |
59 | &self, |
60 | - stream: &mut BufStream<T>, |
61 | - mut session: Session, |
62 | - msg_queue: Arc<Injector<Envelope>>, |
63 | - pipelining: bool, |
64 | - send_greeting: bool, |
65 | - ) -> Result<(), ServerError> |
66 | + framed: &mut Framed<T, Transport>, |
67 | + session: &mut Session, |
68 | + queue: Arc<Injector<Envelope>>, |
69 | + tls_active: bool, |
70 | + ) -> Result<Action, ServerError> |
71 | where |
72 | T: tokio::io::AsyncRead + tokio::io::AsyncWrite + std::marker::Unpin, |
73 | { |
74 | - let acceptor = TlsAcceptor::from(Arc::new(self.rustls_config().await?)); |
75 | - let tls_stream = acceptor.accept(stream).await?; |
76 | - |
77 | - let mut framed = Framed::new(tls_stream, Transport::default().pipelining(pipelining)); |
78 | - |
79 | - if send_greeting { |
80 | - let greeting = session.greeting(); |
81 | - // send inital server greeting |
82 | - framed.send(greeting).await?; |
83 | - } |
84 | - |
85 | - let mut shutdown = false; |
86 | - |
87 | - 'outer: while !shutdown { |
88 | - let frame = timeout(self.global_timeout, framed.next()).await; |
89 | - match frame { |
90 | - Ok(Some(Ok(Command::Requests(commands)))) => { |
91 | - shutdown = is_quit(commands.as_slice()); |
92 | - for command in commands { |
93 | - match session.process(&command).await { |
94 | - Ok(responses) => { |
95 | - for response in responses { |
96 | - framed.send(response).await?; |
97 | - } |
98 | - } |
99 | - Err(e) => { |
100 | - tracing::warn!("Client error: {:?}", e); |
101 | - let fatal = e.is_fatal(); |
102 | - framed.send(e).await?; |
103 | - if fatal { |
104 | - break 'outer; |
105 | - } |
106 | - } |
107 | - } |
108 | - } |
109 | - } |
110 | - Ok(Some(Ok(Command::Payload(payload)))) => { |
111 | - match session.handle_data(&payload).await { |
112 | + match timeout(self.global_timeout, framed.next()).await { |
113 | + Ok(Some(Ok(Command::Requests(commands)))) => { |
114 | + let shutdown = is_quit(commands.as_slice()); |
115 | + let starttls = is_starttls(commands.as_slice()); |
116 | + for command in commands { |
117 | + match session.process(&command).await { |
118 | Ok(responses) => { |
119 | for response in responses { |
120 | framed.send(response).await?; |
121 | } |
122 | - msg_queue.push(Envelope::from(&session)); |
123 | } |
124 | - Err(response) => { |
125 | - tracing::warn!("Error handling message payload: {:?}", response); |
126 | - |
127 | - framed.send(response).await?; |
128 | + Err(e) => { |
129 | + tracing::warn!("Client error: {:?}", e); |
130 | + let fatal = e.is_fatal(); |
131 | + framed.send(e).await?; |
132 | + if fatal { |
133 | + return Ok(Action::Shutdown); |
134 | + } else { |
135 | + return Ok(Action::Continue); |
136 | + } |
137 | } |
138 | } |
139 | } |
140 | - Ok(Some(Err(err))) => { |
141 | - tracing::warn!("Client Error: {}", err); |
142 | - let response = match err { |
143 | - crate::transport::TransportError::PipelineNotEnabled => { |
144 | - crate::smtp_response!(500, 0, 0, 0, "Pipelining is not enabled") |
145 | - } |
146 | - crate::transport::TransportError::Smtp(e) => { |
147 | - crate::session::smtp_error_to_response(e) |
148 | - } |
149 | - // IO Errors considered fatal for the entire session |
150 | - crate::transport::TransportError::Io(e) => return Err(ServerError::Io(e)), |
151 | - }; |
152 | - framed.send(response).await?; |
153 | + if starttls { |
154 | + if tls_active { |
155 | + tracing::warn!( |
156 | + "Client attempted to upgrade to TLS but they already have TLS" |
157 | + ); |
158 | + framed.send(crate::session::tls_already_active()).await?; |
159 | + return Ok(Action::Continue); |
160 | + } |
161 | + Ok(Action::TlsUpgrade) |
162 | + } else if shutdown { |
163 | + Ok(Action::Shutdown) |
164 | + } else { |
165 | + Ok(Action::Continue) |
166 | } |
167 | - Ok(None) => { |
168 | - tracing::info!("Client connection closing"); |
169 | - break 'outer; |
170 | + } |
171 | + Ok(Some(Ok(Command::Payload(payload)))) => match session.handle_data(&payload).await { |
172 | + Ok(responses) => { |
173 | + for response in responses { |
174 | + framed.send(response).await?; |
175 | + } |
176 | + queue.push(Envelope::from(&session.clone())); |
177 | + Ok(Action::Continue) |
178 | } |
179 | - Err(timeout) => { |
180 | - tracing::warn!("Client connection exceeded: {:?}", self.global_timeout); |
181 | - framed |
182 | - .send(crate::session::timeout(&timeout.to_string())) |
183 | - .await?; |
184 | - return Err(ServerError::Timeout(self.global_timeout.as_secs())); |
185 | + Err(response) => { |
186 | + tracing::warn!("Error handling message payload: {:?}", response); |
187 | + framed.send(response).await?; |
188 | + Ok(Action::Continue) |
189 | } |
190 | + }, |
191 | + Ok(Some(Err(err))) => { |
192 | + tracing::warn!("Client Error: {}", err); |
193 | + let response = match err { |
194 | + crate::transport::TransportError::PipelineNotEnabled => { |
195 | + crate::smtp_response!(500, 0, 0, 0, "Pipelining is not enabled") |
196 | + } |
197 | + crate::transport::TransportError::Smtp(e) => { |
198 | + crate::session::smtp_error_to_response(e) |
199 | + } |
200 | + // IO Errors considered fatal for the entire session |
201 | + crate::transport::TransportError::Io(e) => return Err(ServerError::Io(e)), |
202 | + }; |
203 | + framed.send(response).await?; |
204 | + Ok(Action::Continue) |
205 | + } |
206 | + Ok(None) => Ok(Action::Shutdown), |
207 | + Err(e) => { |
208 | + tracing::warn!("Client connection exceeded: {:?}", self.global_timeout); |
209 | + framed.send(crate::session::timeout(&e.to_string())).await?; |
210 | + Err(ServerError::Timeout(self.global_timeout.as_secs())) |
211 | } |
212 | } |
213 | - Ok(()) |
214 | } |
215 | |
216 | + /// Serve a plain SMTP connection that may be upgradable to TLS. |
217 | async fn serve_plain<T>( |
218 | &self, |
219 | - stream: BufStream<T>, |
220 | + stream: &mut T, |
221 | msg_queue: Arc<Injector<Envelope>>, |
222 | pipelining: bool, |
223 | remote_addr: SocketAddr, |
224 | @@ -322,94 +323,37 @@ impl Server { |
225 | |
226 | let greeting = session.greeting(); |
227 | |
228 | - let mut framed = Framed::new(stream, Transport::default().pipelining(pipelining)); |
229 | + let transport = Transport::default().pipelining(pipelining); |
230 | + |
231 | + let mut framed = Framed::new(&mut *stream, transport.clone()); |
232 | |
233 | - // send inital server greeting |
234 | framed.send(greeting).await?; |
235 | |
236 | - let mut shutdown = false; |
237 | - |
238 | - 'outer: while !shutdown { |
239 | - let frame = timeout(self.global_timeout, framed.next()).await; |
240 | - match frame { |
241 | - Ok(Some(Ok(Command::Requests(commands)))) => { |
242 | - shutdown = is_quit(commands.as_slice()); |
243 | - let starttls = is_starttls(commands.as_slice()); |
244 | - for command in commands { |
245 | - match session.process(&command).await { |
246 | - Ok(responses) => { |
247 | - for response in responses { |
248 | - framed.send(response).await?; |
249 | - } |
250 | - } |
251 | - Err(e) => { |
252 | - tracing::warn!("Client error: {:?}", e); |
253 | - let fatal = e.is_fatal(); |
254 | - framed.send(e).await?; |
255 | - if fatal { |
256 | - break 'outer; |
257 | - } |
258 | - if starttls { |
259 | - continue 'outer; |
260 | - } |
261 | - } |
262 | - } |
263 | - } |
264 | - if starttls { |
265 | - tracing::info!("Upgrading client connection with STARTTLS"); |
266 | - return self |
267 | - .serve_tls( |
268 | - framed.get_mut(), |
269 | - session.clone().with_options(self.options.clone().unwrap()), |
270 | - msg_queue.clone(), |
271 | - pipelining, |
272 | - false, |
273 | - ) |
274 | - .await; |
275 | - } |
276 | - } |
277 | - Ok(Some(Ok(Command::Payload(payload)))) => { |
278 | - match session.handle_data(&payload).await { |
279 | - Ok(responses) => { |
280 | - for response in responses { |
281 | - framed.send(response).await?; |
282 | - } |
283 | - msg_queue.push(Envelope::from(&session)); |
284 | - } |
285 | - Err(response) => { |
286 | - tracing::warn!("Error handling message payload: {:?}", response); |
287 | - framed.send(response).await?; |
288 | + loop { |
289 | + match self |
290 | + .next(&mut framed, &mut session, msg_queue.clone(), false) |
291 | + .await? |
292 | + { |
293 | + Action::Continue => {} |
294 | + Action::Shutdown => return Ok(()), |
295 | + Action::TlsUpgrade => { |
296 | + let acceptor = TlsAcceptor::from(Arc::new(self.rustls_config().await?)); |
297 | + let tls_stream = acceptor.accept(&mut *stream).await?; |
298 | + let mut tls_framed = |
299 | + Framed::new(tls_stream, transport.clone()); |
300 | + loop { |
301 | + match self |
302 | + .next(&mut tls_framed, &mut session, msg_queue.clone(), true) |
303 | + .await? |
304 | + { |
305 | + Action::Continue => {} |
306 | + Action::Shutdown => return Ok(()), |
307 | + Action::TlsUpgrade => unreachable!(), |
308 | } |
309 | } |
310 | } |
311 | - Ok(Some(Err(err))) => { |
312 | - tracing::warn!("Client Error: {}", err); |
313 | - let response = match err { |
314 | - crate::transport::TransportError::PipelineNotEnabled => { |
315 | - crate::smtp_response!(500, 0, 0, 0, "Pipelining is not enabled") |
316 | - } |
317 | - crate::transport::TransportError::Smtp(e) => { |
318 | - crate::session::smtp_error_to_response(e) |
319 | - } |
320 | - // IO Errors considered fatal for the entire session |
321 | - crate::transport::TransportError::Io(e) => return Err(ServerError::Io(e)), |
322 | - }; |
323 | - framed.send(response).await?; |
324 | - } |
325 | - Ok(None) => { |
326 | - tracing::info!("Client connection closing"); |
327 | - break 'outer; |
328 | - } |
329 | - Err(timeout) => { |
330 | - tracing::warn!("Client connection exceeded: {:?}", self.global_timeout); |
331 | - framed |
332 | - .send(crate::session::timeout(&timeout.to_string())) |
333 | - .await?; |
334 | - return Err(ServerError::Timeout(self.global_timeout.as_secs())); |
335 | - } |
336 | } |
337 | } |
338 | - Ok(()) |
339 | } |
340 | |
341 | async fn spawn_workers(&mut self, global_queue: Arc<Injector<Envelope>>) { |
342 | @@ -473,7 +417,7 @@ impl Server { |
343 | let global_queue = Arc::new(Injector::<Envelope>::new()); |
344 | self.spawn_workers(global_queue.clone()).await; |
345 | loop { |
346 | - let (socket, addr) = listener.accept().await.unwrap(); |
347 | + let (mut socket, addr) = listener.accept().await.unwrap(); |
348 | let local_addr = socket.local_addr()?; |
349 | tracing::info!("Accepted connection on: {:?} from: {:?}", local_addr, addr); |
350 | // pass the proxied address if proxy protocol is enabled |
351 | @@ -502,12 +446,7 @@ impl Server { |
352 | .is_some_and(|opts| opts.capabilities & smtp_proto::EXT_PIPELINING != 0) |
353 | || self.options.is_none(); |
354 | match self |
355 | - .serve_plain( |
356 | - BufStream::new(socket), |
357 | - global_queue.clone(), |
358 | - pipelining, |
359 | - addr, |
360 | - ) |
361 | + .serve_plain(&mut socket, global_queue.clone(), pipelining, addr) |
362 | .await |
363 | { |
364 | Ok(_) => { |
365 | @@ -585,7 +524,7 @@ mod test { |
366 | |
367 | #[tokio::test] |
368 | async fn test_server() { |
369 | - let stream = FakeStream { |
370 | + let mut stream = FakeStream { |
371 | buffer: vec![ |
372 | "HELO example.org\r\n".into(), |
373 | "MAIL FROM: <fuu@bar.com>\r\n".into(), |
374 | @@ -602,7 +541,7 @@ mod test { |
375 | let global_queue = Arc::new(Injector::<Envelope>::new()); |
376 | server |
377 | .serve_plain( |
378 | - BufStream::new(stream), |
379 | + &mut stream, |
380 | global_queue.clone(), |
381 | false, |
382 | SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 25)), |
383 | @@ -619,7 +558,7 @@ mod test { |
384 | |
385 | #[tokio::test] |
386 | async fn test_server_pipelined() { |
387 | - let stream = FakeStream { |
388 | + let mut stream = FakeStream { |
389 | buffer: vec![ |
390 | "HELO example.org\r\n".into(), |
391 | "MAIL FROM: <fuu@bar.com>\r\n".into(), |
392 | @@ -634,7 +573,7 @@ mod test { |
393 | let global_queue = Arc::new(Injector::<Envelope>::new()); |
394 | server |
395 | .serve_plain( |
396 | - BufStream::new(stream), |
397 | + &mut stream, |
398 | global_queue.clone(), |
399 | false, |
400 | SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 25)), |
401 | diff --git a/maitred/src/session.rs b/maitred/src/session.rs |
402 | index b1e3cdc..8ec2e7c 100644 |
403 | --- a/maitred/src/session.rs |
404 | +++ b/maitred/src/session.rs |
405 | @@ -66,6 +66,10 @@ pub fn timeout(message: &str) -> Response<String> { |
406 | smtp_response!(421, 4, 4, 2, format!("Timeout exceeded: {}", message)) |
407 | } |
408 | |
409 | + pub fn tls_already_active() -> Response<String> { |
410 | + smtp_response!(400, 0, 0, 0, "TLS is already active") |
411 | + } |
412 | + |
413 | pub fn smtp_error_to_response(e: smtp_proto::Error) -> Response<String> { |
414 | match e { |
415 | smtp_proto::Error::NeedsMoreData { bytes_left: _ } => { |
416 | diff --git a/maitred/src/transport.rs b/maitred/src/transport.rs |
417 | index e8d6e59..3a7650b 100644 |
418 | --- a/maitred/src/transport.rs |
419 | +++ b/maitred/src/transport.rs |
420 | @@ -103,7 +103,6 @@ impl Display for Command { |
421 | } |
422 | |
423 | /// Line oriented transport |
424 | - /// TODO: BDAT |
425 | /// TODO: BINARYMIME |
426 | #[derive(Default)] |
427 | pub(crate) struct Transport { |
428 | @@ -112,6 +111,16 @@ pub(crate) struct Transport { |
429 | pipelining: bool, |
430 | } |
431 | |
432 | + impl Clone for Transport { |
433 | + fn clone(&self) -> Self { |
434 | + Transport { |
435 | + receiver: None, |
436 | + buf: Vec::new(), |
437 | + pipelining: self.pipelining, |
438 | + } |
439 | + } |
440 | + } |
441 | + |
442 | impl Transport { |
443 | /// If the transport should allow piplining commands |
444 | pub fn pipelining(mut self, enabled: bool) -> Self { |
445 | @@ -141,7 +150,6 @@ impl Decoder for Transport { |
446 | type Error = TransportError; |
447 | |
448 | fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> { |
449 | - |
450 | tracing::trace!("{}", String::from_utf8_lossy(src)); |
451 | |
452 | if src.is_empty() { |
453 | diff --git a/scripts/gen_certs.sh b/scripts/gen_certs.sh |
454 | new file mode 100755 |
455 | index 0000000..8eef9fd |
456 | --- /dev/null |
457 | +++ b/scripts/gen_certs.sh |
458 | @@ -0,0 +1,4 @@ |
459 | + #!/bin/sh |
460 | + |
461 | + openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem \ |
462 | + -days 365 -nodes -subj '/CN=localhost' |