Author: Kevin Schoon [me@kevinschoon.com]
Hash: 35c8c191423e5ff426b9f82ca89f2fb3de28718c
Timestamp: Mon, 30 Sep 2024 10:47:05 +0000 (2 weeks ago)

+130 -169 +/-6 browse
Reduce duplicated code for STARTTLS connections
1diff --git a/.gitignore b/.gitignore
2index 25c055c..6042ec5 100644
3--- a/.gitignore
4+++ b/.gitignore
5 @@ -1,2 +1,4 @@
6 target
7 mail
8+ key.pem
9+ cert.pem
10 diff --git a/maitred.toml b/maitred.toml
11index 46d1033..d09b5f6 100644
12--- a/maitred.toml
13+++ b/maitred.toml
14 @@ -7,6 +7,10 @@ address = "0.0.0.0:2525"
15 # Enable HAProxy's PROXY Protocol
16 proxy_protocol = false
17
18+ [tls]
19+ certificate = "cert.pem"
20+ key = "key.pem"
21+
22 [dkim]
23 enabled = false
24
25 diff --git a/maitred/src/server.rs b/maitred/src/server.rs
26index d641090..d036169 100644
27--- a/maitred/src/server.rs
28+++ b/maitred/src/server.rs
29 @@ -15,7 +15,6 @@ use mail_auth::Resolver;
30 use mail_parser::Message;
31 use proxy_header::{ParseConfig, ProxyHeader};
32 use smtp_proto::Request;
33- use tokio::io::BufStream;
34 use tokio::net::TcpListener;
35 use tokio::sync::mpsc::Sender;
36 use tokio::sync::Mutex;
37 @@ -88,6 +87,13 @@ impl From<&Session> for Envelope {
38 }
39 }
40
41+ /// Action for controlling a TCP session
42+ pub(crate) enum Action {
43+ Continue,
44+ Shutdown,
45+ TlsUpgrade,
46+ }
47+
48 /// Server implements everything that is required to run an SMTP server by
49 /// binding to the configured address and processing individual TCP connections
50 /// as they are received.
51 @@ -207,102 +213,97 @@ impl Server {
52 .with_single_cert(certs, private_key)?)
53 }
54
55- // TODO: Eliminate duplicated code
56- async fn serve_tls<T>(
57+ /// drive the session forward
58+ async fn next<T>(
59 &self,
60- stream: &mut BufStream<T>,
61- mut session: Session,
62- msg_queue: Arc<Injector<Envelope>>,
63- pipelining: bool,
64- send_greeting: bool,
65- ) -> Result<(), ServerError>
66+ framed: &mut Framed<T, Transport>,
67+ session: &mut Session,
68+ queue: Arc<Injector<Envelope>>,
69+ tls_active: bool,
70+ ) -> Result<Action, ServerError>
71 where
72 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + std::marker::Unpin,
73 {
74- let acceptor = TlsAcceptor::from(Arc::new(self.rustls_config().await?));
75- let tls_stream = acceptor.accept(stream).await?;
76-
77- let mut framed = Framed::new(tls_stream, Transport::default().pipelining(pipelining));
78-
79- if send_greeting {
80- let greeting = session.greeting();
81- // send inital server greeting
82- framed.send(greeting).await?;
83- }
84-
85- let mut shutdown = false;
86-
87- 'outer: while !shutdown {
88- let frame = timeout(self.global_timeout, framed.next()).await;
89- match frame {
90- Ok(Some(Ok(Command::Requests(commands)))) => {
91- shutdown = is_quit(commands.as_slice());
92- for command in commands {
93- match session.process(&command).await {
94- Ok(responses) => {
95- for response in responses {
96- framed.send(response).await?;
97- }
98- }
99- Err(e) => {
100- tracing::warn!("Client error: {:?}", e);
101- let fatal = e.is_fatal();
102- framed.send(e).await?;
103- if fatal {
104- break 'outer;
105- }
106- }
107- }
108- }
109- }
110- Ok(Some(Ok(Command::Payload(payload)))) => {
111- match session.handle_data(&payload).await {
112+ match timeout(self.global_timeout, framed.next()).await {
113+ Ok(Some(Ok(Command::Requests(commands)))) => {
114+ let shutdown = is_quit(commands.as_slice());
115+ let starttls = is_starttls(commands.as_slice());
116+ for command in commands {
117+ match session.process(&command).await {
118 Ok(responses) => {
119 for response in responses {
120 framed.send(response).await?;
121 }
122- msg_queue.push(Envelope::from(&session));
123 }
124- Err(response) => {
125- tracing::warn!("Error handling message payload: {:?}", response);
126-
127- framed.send(response).await?;
128+ Err(e) => {
129+ tracing::warn!("Client error: {:?}", e);
130+ let fatal = e.is_fatal();
131+ framed.send(e).await?;
132+ if fatal {
133+ return Ok(Action::Shutdown);
134+ } else {
135+ return Ok(Action::Continue);
136+ }
137 }
138 }
139 }
140- Ok(Some(Err(err))) => {
141- tracing::warn!("Client Error: {}", err);
142- let response = match err {
143- crate::transport::TransportError::PipelineNotEnabled => {
144- crate::smtp_response!(500, 0, 0, 0, "Pipelining is not enabled")
145- }
146- crate::transport::TransportError::Smtp(e) => {
147- crate::session::smtp_error_to_response(e)
148- }
149- // IO Errors considered fatal for the entire session
150- crate::transport::TransportError::Io(e) => return Err(ServerError::Io(e)),
151- };
152- framed.send(response).await?;
153+ if starttls {
154+ if tls_active {
155+ tracing::warn!(
156+ "Client attempted to upgrade to TLS but they already have TLS"
157+ );
158+ framed.send(crate::session::tls_already_active()).await?;
159+ return Ok(Action::Continue);
160+ }
161+ Ok(Action::TlsUpgrade)
162+ } else if shutdown {
163+ Ok(Action::Shutdown)
164+ } else {
165+ Ok(Action::Continue)
166 }
167- Ok(None) => {
168- tracing::info!("Client connection closing");
169- break 'outer;
170+ }
171+ Ok(Some(Ok(Command::Payload(payload)))) => match session.handle_data(&payload).await {
172+ Ok(responses) => {
173+ for response in responses {
174+ framed.send(response).await?;
175+ }
176+ queue.push(Envelope::from(&session.clone()));
177+ Ok(Action::Continue)
178 }
179- Err(timeout) => {
180- tracing::warn!("Client connection exceeded: {:?}", self.global_timeout);
181- framed
182- .send(crate::session::timeout(&timeout.to_string()))
183- .await?;
184- return Err(ServerError::Timeout(self.global_timeout.as_secs()));
185+ Err(response) => {
186+ tracing::warn!("Error handling message payload: {:?}", response);
187+ framed.send(response).await?;
188+ Ok(Action::Continue)
189 }
190+ },
191+ Ok(Some(Err(err))) => {
192+ tracing::warn!("Client Error: {}", err);
193+ let response = match err {
194+ crate::transport::TransportError::PipelineNotEnabled => {
195+ crate::smtp_response!(500, 0, 0, 0, "Pipelining is not enabled")
196+ }
197+ crate::transport::TransportError::Smtp(e) => {
198+ crate::session::smtp_error_to_response(e)
199+ }
200+ // IO Errors considered fatal for the entire session
201+ crate::transport::TransportError::Io(e) => return Err(ServerError::Io(e)),
202+ };
203+ framed.send(response).await?;
204+ Ok(Action::Continue)
205+ }
206+ Ok(None) => Ok(Action::Shutdown),
207+ Err(e) => {
208+ tracing::warn!("Client connection exceeded: {:?}", self.global_timeout);
209+ framed.send(crate::session::timeout(&e.to_string())).await?;
210+ Err(ServerError::Timeout(self.global_timeout.as_secs()))
211 }
212 }
213- Ok(())
214 }
215
216+ /// Serve a plain SMTP connection that may be upgradable to TLS.
217 async fn serve_plain<T>(
218 &self,
219- stream: BufStream<T>,
220+ stream: &mut T,
221 msg_queue: Arc<Injector<Envelope>>,
222 pipelining: bool,
223 remote_addr: SocketAddr,
224 @@ -322,94 +323,37 @@ impl Server {
225
226 let greeting = session.greeting();
227
228- let mut framed = Framed::new(stream, Transport::default().pipelining(pipelining));
229+ let transport = Transport::default().pipelining(pipelining);
230+
231+ let mut framed = Framed::new(&mut *stream, transport.clone());
232
233- // send inital server greeting
234 framed.send(greeting).await?;
235
236- let mut shutdown = false;
237-
238- 'outer: while !shutdown {
239- let frame = timeout(self.global_timeout, framed.next()).await;
240- match frame {
241- Ok(Some(Ok(Command::Requests(commands)))) => {
242- shutdown = is_quit(commands.as_slice());
243- let starttls = is_starttls(commands.as_slice());
244- for command in commands {
245- match session.process(&command).await {
246- Ok(responses) => {
247- for response in responses {
248- framed.send(response).await?;
249- }
250- }
251- Err(e) => {
252- tracing::warn!("Client error: {:?}", e);
253- let fatal = e.is_fatal();
254- framed.send(e).await?;
255- if fatal {
256- break 'outer;
257- }
258- if starttls {
259- continue 'outer;
260- }
261- }
262- }
263- }
264- if starttls {
265- tracing::info!("Upgrading client connection with STARTTLS");
266- return self
267- .serve_tls(
268- framed.get_mut(),
269- session.clone().with_options(self.options.clone().unwrap()),
270- msg_queue.clone(),
271- pipelining,
272- false,
273- )
274- .await;
275- }
276- }
277- Ok(Some(Ok(Command::Payload(payload)))) => {
278- match session.handle_data(&payload).await {
279- Ok(responses) => {
280- for response in responses {
281- framed.send(response).await?;
282- }
283- msg_queue.push(Envelope::from(&session));
284- }
285- Err(response) => {
286- tracing::warn!("Error handling message payload: {:?}", response);
287- framed.send(response).await?;
288+ loop {
289+ match self
290+ .next(&mut framed, &mut session, msg_queue.clone(), false)
291+ .await?
292+ {
293+ Action::Continue => {}
294+ Action::Shutdown => return Ok(()),
295+ Action::TlsUpgrade => {
296+ let acceptor = TlsAcceptor::from(Arc::new(self.rustls_config().await?));
297+ let tls_stream = acceptor.accept(&mut *stream).await?;
298+ let mut tls_framed =
299+ Framed::new(tls_stream, transport.clone());
300+ loop {
301+ match self
302+ .next(&mut tls_framed, &mut session, msg_queue.clone(), true)
303+ .await?
304+ {
305+ Action::Continue => {}
306+ Action::Shutdown => return Ok(()),
307+ Action::TlsUpgrade => unreachable!(),
308 }
309 }
310 }
311- Ok(Some(Err(err))) => {
312- tracing::warn!("Client Error: {}", err);
313- let response = match err {
314- crate::transport::TransportError::PipelineNotEnabled => {
315- crate::smtp_response!(500, 0, 0, 0, "Pipelining is not enabled")
316- }
317- crate::transport::TransportError::Smtp(e) => {
318- crate::session::smtp_error_to_response(e)
319- }
320- // IO Errors considered fatal for the entire session
321- crate::transport::TransportError::Io(e) => return Err(ServerError::Io(e)),
322- };
323- framed.send(response).await?;
324- }
325- Ok(None) => {
326- tracing::info!("Client connection closing");
327- break 'outer;
328- }
329- Err(timeout) => {
330- tracing::warn!("Client connection exceeded: {:?}", self.global_timeout);
331- framed
332- .send(crate::session::timeout(&timeout.to_string()))
333- .await?;
334- return Err(ServerError::Timeout(self.global_timeout.as_secs()));
335- }
336 }
337 }
338- Ok(())
339 }
340
341 async fn spawn_workers(&mut self, global_queue: Arc<Injector<Envelope>>) {
342 @@ -473,7 +417,7 @@ impl Server {
343 let global_queue = Arc::new(Injector::<Envelope>::new());
344 self.spawn_workers(global_queue.clone()).await;
345 loop {
346- let (socket, addr) = listener.accept().await.unwrap();
347+ let (mut socket, addr) = listener.accept().await.unwrap();
348 let local_addr = socket.local_addr()?;
349 tracing::info!("Accepted connection on: {:?} from: {:?}", local_addr, addr);
350 // pass the proxied address if proxy protocol is enabled
351 @@ -502,12 +446,7 @@ impl Server {
352 .is_some_and(|opts| opts.capabilities & smtp_proto::EXT_PIPELINING != 0)
353 || self.options.is_none();
354 match self
355- .serve_plain(
356- BufStream::new(socket),
357- global_queue.clone(),
358- pipelining,
359- addr,
360- )
361+ .serve_plain(&mut socket, global_queue.clone(), pipelining, addr)
362 .await
363 {
364 Ok(_) => {
365 @@ -585,7 +524,7 @@ mod test {
366
367 #[tokio::test]
368 async fn test_server() {
369- let stream = FakeStream {
370+ let mut stream = FakeStream {
371 buffer: vec![
372 "HELO example.org\r\n".into(),
373 "MAIL FROM: <fuu@bar.com>\r\n".into(),
374 @@ -602,7 +541,7 @@ mod test {
375 let global_queue = Arc::new(Injector::<Envelope>::new());
376 server
377 .serve_plain(
378- BufStream::new(stream),
379+ &mut stream,
380 global_queue.clone(),
381 false,
382 SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 25)),
383 @@ -619,7 +558,7 @@ mod test {
384
385 #[tokio::test]
386 async fn test_server_pipelined() {
387- let stream = FakeStream {
388+ let mut stream = FakeStream {
389 buffer: vec![
390 "HELO example.org\r\n".into(),
391 "MAIL FROM: <fuu@bar.com>\r\n".into(),
392 @@ -634,7 +573,7 @@ mod test {
393 let global_queue = Arc::new(Injector::<Envelope>::new());
394 server
395 .serve_plain(
396- BufStream::new(stream),
397+ &mut stream,
398 global_queue.clone(),
399 false,
400 SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 25)),
401 diff --git a/maitred/src/session.rs b/maitred/src/session.rs
402index b1e3cdc..8ec2e7c 100644
403--- a/maitred/src/session.rs
404+++ b/maitred/src/session.rs
405 @@ -66,6 +66,10 @@ pub fn timeout(message: &str) -> Response<String> {
406 smtp_response!(421, 4, 4, 2, format!("Timeout exceeded: {}", message))
407 }
408
409+ pub fn tls_already_active() -> Response<String> {
410+ smtp_response!(400, 0, 0, 0, "TLS is already active")
411+ }
412+
413 pub fn smtp_error_to_response(e: smtp_proto::Error) -> Response<String> {
414 match e {
415 smtp_proto::Error::NeedsMoreData { bytes_left: _ } => {
416 diff --git a/maitred/src/transport.rs b/maitred/src/transport.rs
417index e8d6e59..3a7650b 100644
418--- a/maitred/src/transport.rs
419+++ b/maitred/src/transport.rs
420 @@ -103,7 +103,6 @@ impl Display for Command {
421 }
422
423 /// Line oriented transport
424- /// TODO: BDAT
425 /// TODO: BINARYMIME
426 #[derive(Default)]
427 pub(crate) struct Transport {
428 @@ -112,6 +111,16 @@ pub(crate) struct Transport {
429 pipelining: bool,
430 }
431
432+ impl Clone for Transport {
433+ fn clone(&self) -> Self {
434+ Transport {
435+ receiver: None,
436+ buf: Vec::new(),
437+ pipelining: self.pipelining,
438+ }
439+ }
440+ }
441+
442 impl Transport {
443 /// If the transport should allow piplining commands
444 pub fn pipelining(mut self, enabled: bool) -> Self {
445 @@ -141,7 +150,6 @@ impl Decoder for Transport {
446 type Error = TransportError;
447
448 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
449-
450 tracing::trace!("{}", String::from_utf8_lossy(src));
451
452 if src.is_empty() {
453 diff --git a/scripts/gen_certs.sh b/scripts/gen_certs.sh
454new file mode 100755
455index 0000000..8eef9fd
456--- /dev/null
457+++ b/scripts/gen_certs.sh
458 @@ -0,0 +1,4 @@
459+ #!/bin/sh
460+
461+ openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem \
462+ -days 365 -nodes -subj '/CN=localhost'