Commit
Author: Kevin Schoon [me@kevinschoon.com]
Hash: 9bfcd33b6d0faede182b16faf296712c0b2eb7f4
Timestamp: Wed, 14 Aug 2024 22:08:35 +0000 (4 months ago)

+159 -89 +/-6 browse
make session async
1diff --git a/Cargo.lock b/Cargo.lock
2index b378456..f34642f 100644
3--- a/Cargo.lock
4+++ b/Cargo.lock
5 @@ -36,6 +36,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
6 checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f"
7
8 [[package]]
9+ name = "async-trait"
10+ version = "0.1.81"
11+ source = "registry+https://github.com/rust-lang/crates.io-index"
12+ checksum = "6e0c28dcc82d7c8ead5cb13beb15405b57b8546e93215673ff8ca0349a028107"
13+ dependencies = [
14+ "proc-macro2",
15+ "quote",
16+ "syn",
17+ ]
18+
19+ [[package]]
20 name = "autocfg"
21 version = "1.3.0"
22 source = "registry+https://github.com/rust-lang/crates.io-index"
23 @@ -270,6 +281,7 @@ dependencies = [
24 name = "maitred"
25 version = "0.1.0"
26 dependencies = [
27+ "async-trait",
28 "bytes",
29 "email_address",
30 "futures",
31 diff --git a/maitred/Cargo.toml b/maitred/Cargo.toml
32index f2a1adb..587470b 100644
33--- a/maitred/Cargo.toml
34+++ b/maitred/Cargo.toml
35 @@ -4,6 +4,7 @@ version = "0.1.0"
36 edition = "2021"
37
38 [dependencies]
39+ async-trait = "0.1.81"
40 bytes = "1.6.1"
41 email_address = "0.2.9"
42 futures = "0.3.30"
43 diff --git a/maitred/src/expand.rs b/maitred/src/expand.rs
44index 0cfcaa7..7a75201 100644
45--- a/maitred/src/expand.rs
46+++ b/maitred/src/expand.rs
47 @@ -1,5 +1,6 @@
48 use std::result::Result as StdResult;
49
50+ use async_trait::async_trait;
51 use email_address::EmailAddress;
52
53 /// Result type containing any of the associated e-mail addresses with the
54 @@ -21,9 +22,10 @@ pub enum Error {
55 /// addresses within the list if it exists. NOTE: That this function should
56 /// only be called with proper authentication otherwise it could be used to
57 /// harvest e-mail addresses.
58+ #[async_trait]
59 pub trait Expansion {
60 /// Expand the group into an array of members
61- fn expand(&self, name: &str) -> Result;
62+ async fn expand(&self, name: &str) -> Result;
63 }
64
65 /// Helper wrapper implementing the Expansion trait
66 @@ -43,11 +45,12 @@ pub struct Func<F>(pub F)
67 where
68 F: Fn(&str) -> Result;
69
70+ #[async_trait]
71 impl<F> Expansion for Func<F>
72 where
73- F: Fn(&str) -> Result,
74+ F: Fn(&str) -> Result + Sync,
75 {
76- fn expand(&self, name: &str) -> Result {
77+ async fn expand(&self, name: &str) -> Result {
78 let f = &self.0;
79 f(name)
80 }
81 diff --git a/maitred/src/server.rs b/maitred/src/server.rs
82index 4b63e82..a78b7ec 100644
83--- a/maitred/src/server.rs
84+++ b/maitred/src/server.rs
85 @@ -4,6 +4,7 @@ use std::time::Duration;
86 use bytes::Bytes;
87 use futures::SinkExt;
88 use smtp_proto::Request;
89+ use tokio::sync::Mutex;
90 use tokio::{net::TcpListener, time::timeout};
91 use tokio_stream::StreamExt;
92 use tokio_util::codec::Framed;
93 @@ -11,7 +12,7 @@ use tokio_util::codec::Framed;
94 use crate::error::Error;
95 use crate::pipeline::Pipeline;
96 use crate::session::Session;
97- use crate::transport::Transport;
98+ use crate::transport::{Response, Transport};
99 use crate::Chunk;
100
101 /// The default port the server will listen on if none was specified in it's
102 @@ -24,14 +25,15 @@ const DEFAULT_GLOBAL_TIMEOUT_SECS: u64 = 300;
103
104 /// Apply pipelining if running in extended mode and configured to support it
105 struct ConditionalPipeline<'a> {
106- pub session: &'a mut Session,
107+ pub session: &'a Mutex<Session>,
108 pub pipeline: &'a mut Pipeline,
109 }
110
111 impl ConditionalPipeline<'_> {
112- pub fn apply(&mut self, req: &Request<String>, data: Option<&Bytes>) -> Chunk {
113- let response = self.session.process(req, data);
114- if self.session.has_capability(smtp_proto::EXT_PIPELINING) && self.session.is_extended() {
115+ pub async fn apply(&mut self, req: &Request<String>, data: Option<&Bytes>) -> Chunk {
116+ let mut session = self.session.lock().await;
117+ let response = session.process(req, data).await;
118+ if session.has_capability(smtp_proto::EXT_PIPELINING) && session.is_extended() {
119 self.pipeline.process(req, &response)
120 } else {
121 match response {
122 @@ -100,22 +102,15 @@ impl Server {
123 self
124 }
125
126- async fn process<T>(&self, mut framed: Framed<T, Transport>) -> Result<Session, Error>
127+ async fn process<T>(
128+ &self,
129+ mut framed: Framed<T, Transport>,
130+ pipeline: &mut ConditionalPipeline<'_>,
131+ greeting: Response<String>,
132+ ) -> Result<(), Error>
133 where
134 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + std::marker::Unpin,
135 {
136- let mut session = Session::default();
137- if let Some(opts) = &self.options {
138- session = session.with_options(opts.clone());
139- }
140-
141- let greeting = session.greeting();
142-
143- let mut pipelined = ConditionalPipeline {
144- session: &mut session,
145- pipeline: &mut Pipeline::default(),
146- };
147-
148 // send inital server greeting
149 framed.send(greeting).await?;
150
151 @@ -130,7 +125,8 @@ impl Server {
152 if matches!(command.0, Request::Quit) {
153 finished = true;
154 }
155- let responses = pipelined.apply(&command.0, command.1.as_ref());
156+ let responses =
157+ pipeline.apply(&command.0, command.1.as_ref()).await;
158 for response in responses.0.into_iter() {
159 framed.send(response).await?;
160 }
161 @@ -157,7 +153,7 @@ impl Server {
162 }
163 }
164 tracing::info!("Connection closed");
165- Ok(session)
166+ Ok(())
167 }
168
169 pub async fn listen(&self) -> Result<(), Error> {
170 @@ -168,7 +164,21 @@ impl Server {
171 let addr = socket.local_addr()?;
172 tracing::info!("Accepted connection on: {:?}", addr);
173 let framed = Framed::new(socket, Transport::default());
174- if let Err(err) = self.process(framed).await {
175+ let mut session = Session::default();
176+ if let Some(opts) = &self.options {
177+ session = session.with_options(opts.clone());
178+ }
179+
180+ let greeting = session.greeting();
181+
182+ let session = Mutex::new(session);
183+
184+ let mut pipelined = ConditionalPipeline {
185+ session: &session,
186+ pipeline: &mut Pipeline::default(),
187+ };
188+
189+ if let Err(err) = self.process(framed, &mut pipelined, greeting).await {
190 tracing::warn!("Client encountered an error: {:?}", err);
191 }
192 }
193 @@ -249,11 +259,24 @@ mod test {
194 };
195 let server = Server::new("example.org");
196 let framed = Framed::new(stream, Transport::default());
197- let session = server.process(framed).await.unwrap();
198+ let session = Session::default();
199+ let greeting = session.greeting();
200+ let session = Mutex::new(session);
201+
202+ let mut pipelined = ConditionalPipeline {
203+ session: &session,
204+ pipeline: &mut Pipeline::default(),
205+ };
206+ server
207+ .process(framed, &mut pipelined, greeting)
208+ .await
209+ .unwrap();
210+ let session = session.lock().await;
211 assert!(session
212 .mail_from
213+ .as_ref()
214 .is_some_and(|mail_from| mail_from.email() == "fuu@bar.com"));
215- assert!(session.rcpt_to.is_some_and(|rcpts| rcpts
216+ assert!(session.rcpt_to.as_ref().is_some_and(|rcpts| rcpts
217 .first()
218 .is_some_and(|rcpt_to| rcpt_to.email() == "baz@qux.com")));
219 }
220 diff --git a/maitred/src/session.rs b/maitred/src/session.rs
221index 03d6fce..17ebb65 100644
222--- a/maitred/src/session.rs
223+++ b/maitred/src/session.rs
224 @@ -1,9 +1,11 @@
225 use std::rc::Rc;
226 use std::result::Result as StdResult;
227 use std::str::FromStr;
228+ use std::sync::Arc;
229
230 use bytes::Bytes;
231 use email_address::EmailAddress;
232+
233 use mail_parser::MessageParser;
234 use smtp_proto::{EhloResponse, Request, Response as SmtpResponse};
235 use url::Host;
236 @@ -70,8 +72,8 @@ pub struct Options {
237 pub capabilities: u32,
238 pub help_banner: String,
239 pub greeting: String,
240- pub list_expansion: Option<Rc<dyn Expansion>>,
241- pub verification: Option<Rc<dyn Verify>>,
242+ pub list_expansion: Option<Arc<dyn Expansion>>,
243+ pub verification: Option<Arc<dyn Verify>>,
244 }
245
246 impl Default for Options {
247 @@ -113,7 +115,7 @@ impl Options {
248 where
249 T: crate::expand::Expansion + 'static,
250 {
251- self.list_expansion = Some(Rc::new(expansion));
252+ self.list_expansion = Some(Arc::new(expansion));
253 self
254 }
255
256 @@ -121,7 +123,7 @@ impl Options {
257 where
258 T: crate::verify::Verify + 'static,
259 {
260- self.verification = Some(Rc::new(verification));
261+ self.verification = Some(Arc::new(verification));
262 self
263 }
264 }
265 @@ -219,7 +221,7 @@ impl Session {
266 /// indicate that the process is starting and the second one contains the
267 /// parsed bytes from the transfer.
268 /// FIXME: Not at all reasonable yet
269- pub fn process(&mut self, req: &Request<String>, data: Option<&Bytes>) -> Result {
270+ pub async fn process(&mut self, req: &Request<String>, data: Option<&Bytes>) -> Result {
271 match req {
272 Request::Ehlo { host } => {
273 self.hostname =
274 @@ -321,7 +323,7 @@ impl Session {
275 let address = EmailAddress::from_str(value.as_str()).map_err(|e| {
276 smtp_chunk!(500, 0, 0, 0, format!("cannot parse: {} {}", value, e))
277 })?;
278- match verifier.verify(&address) {
279+ match verifier.verify(&address).await {
280 Ok(_) => {
281 smtp_chunk_ok!(250, 0, 0, 0, "OK".to_string())
282 }
283 @@ -333,7 +335,7 @@ impl Session {
284 }
285 Request::Expn { value } => {
286 if let Some(expn) = &self.opts.list_expansion {
287- match expn.expand(value) {
288+ match expn.expand(value).await {
289 Ok(addresses) => {
290 let mut result = vec![smtp_response!(250, 0, 0, 0, "OK")];
291 result.extend(
292 @@ -420,9 +422,13 @@ impl Session {
293
294 #[cfg(test)]
295 mod test {
296- use super::*;
297+ use std::sync::Arc;
298
299+ use futures::stream::{self, StreamExt};
300 use smtp_proto::{MailFrom, RcptTo};
301+ use tokio::sync::Mutex;
302+
303+ use super::*;
304
305 const EXAMPLE_HOSTNAME: &str = "example.org";
306
307 @@ -433,10 +439,13 @@ mod test {
308 }
309
310 /// process all commands returning their response
311- fn process_all(session: &mut Session, commands: &[TestCase]) {
312- commands.iter().enumerate().for_each(|(i, command)| {
313+ async fn process_all(session: &Mutex<Session>, commands: &[TestCase]) {
314+ let stream = stream::iter(commands);
315+ stream.enumerate().for_each(|(i, command)| {
316+ async move {
317+ let mut session = session.lock().await;
318 println!("Running command {}/{}", i, commands.len());
319- let response = session.process(&command.request, command.payload.as_ref());
320+ let response = session.process(&command.request, command.payload.as_ref()).await;
321 println!("Response: {:?}", response);
322 match response {
323 Ok(actual_response) => {
324 @@ -472,12 +481,13 @@ mod test {
325 },
326 }
327 }
328+ };
329 }
330- })
331+ }).await;
332 }
333
334- #[test]
335- fn test_hello_quit() {
336+ #[tokio::test]
337+ async fn test_hello_quit() {
338 let requests = &[
339 TestCase {
340 request: Request::Helo {
341 @@ -492,16 +502,18 @@ mod test {
342 expected: smtp_chunk_ok!(221, 0, 0, 0, String::from("Ciao!")),
343 },
344 ];
345- let mut session = Session::default();
346- process_all(&mut session, requests);
347+ let session = Mutex::new(Session::default());
348+ process_all(&session, requests).await;
349+ let session = session.lock().await;
350 // session should contain both requests
351 assert!(session
352 .hostname
353+ .as_ref()
354 .is_some_and(|hostname| hostname.to_string() == EXAMPLE_HOSTNAME));
355 }
356
357- #[test]
358- fn test_command_with_no_hello() {
359+ #[tokio::test]
360+ async fn test_command_with_no_hello() {
361 let requests = &[TestCase {
362 request: Request::Mail {
363 from: MailFrom {
364 @@ -512,13 +524,15 @@ mod test {
365 payload: None,
366 expected: smtp_chunk_err!(500, 0, 0, 0, String::from("It's polite to say EHLO first")),
367 }];
368- let mut session = Session::default()
369- .with_options(Options::default().our_hostname(EXAMPLE_HOSTNAME).into());
370- process_all(&mut session, requests);
371+ let session = Mutex::new(
372+ Session::default()
373+ .with_options(Options::default().our_hostname(EXAMPLE_HOSTNAME).into()),
374+ );
375+ process_all(&session, requests).await;
376 }
377
378- #[test]
379- fn test_expand() {
380+ #[tokio::test]
381+ async fn test_expand() {
382 let requests = &[
383 TestCase {
384 request: Request::Helo {
385 @@ -544,26 +558,30 @@ mod test {
386 expected: smtp_chunk_ok!(221, 0, 0, 0, String::from("Ciao!")),
387 },
388 ];
389- let mut session = Session::default().with_options(
390- Options::default()
391- .list_expansion(crate::expand::Func(|name: &str| {
392- assert!(name == "mailing-list");
393- Ok(vec![
394- EmailAddress::new_unchecked("Fuu <fuu@bar.com>"),
395- EmailAddress::new_unchecked("Baz <baz@qux.com>"),
396- ])
397- }))
398- .into(),
399+ let session = Mutex::new(
400+ Session::default().with_options(
401+ Options::default()
402+ .list_expansion(crate::expand::Func(|name: &str| {
403+ assert!(name == "mailing-list");
404+ Ok(vec![
405+ EmailAddress::new_unchecked("Fuu <fuu@bar.com>"),
406+ EmailAddress::new_unchecked("Baz <baz@qux.com>"),
407+ ])
408+ }))
409+ .into(),
410+ ),
411 );
412- process_all(&mut session, requests);
413+ process_all(&session, requests).await;
414 // session should contain both requests
415+ let session = session.lock().await;
416 assert!(session
417 .hostname
418+ .as_ref()
419 .is_some_and(|hostname| hostname.to_string() == EXAMPLE_HOSTNAME));
420 }
421
422- #[test]
423- fn test_verify() {
424+ #[tokio::test]
425+ async fn test_verify() {
426 let requests = &[
427 TestCase {
428 request: Request::Helo {
429 @@ -585,23 +603,27 @@ mod test {
430 expected: smtp_chunk_ok!(221, 0, 0, 0, String::from("Ciao!")),
431 },
432 ];
433- let mut session = Session::default().with_options(
434- Options::default()
435- .verification(crate::verify::Func(|addr: &EmailAddress| {
436- assert!(addr.email() == "bar@baz.com");
437- Ok(())
438- }))
439- .into(),
440+ let session = Mutex::new(
441+ Session::default().with_options(
442+ Options::default()
443+ .verification(crate::verify::Func(|addr: &EmailAddress| {
444+ assert!(addr.email() == "bar@baz.com");
445+ Ok(())
446+ }))
447+ .into(),
448+ ),
449 );
450- process_all(&mut session, requests);
451+ process_all(&session, requests).await;
452 // session should contain both requests
453+ let session = session.lock().await;
454 assert!(session
455 .hostname
456+ .as_ref()
457 .is_some_and(|hostname| hostname.to_string() == EXAMPLE_HOSTNAME));
458 }
459
460- #[test]
461- fn test_non_ascii_characters() {
462+ #[tokio::test]
463+ async fn test_non_ascii_characters() {
464 let mut expected_ehlo_response = EhloResponse::new(String::from("Hello example.org"));
465 expected_ehlo_response.capabilities = DEFAULT_CAPABILITIES;
466 expected_ehlo_response.size = DEFAULT_MAXIMUM_MESSAGE_SIZE as usize;
467 @@ -670,17 +692,19 @@ mod test {
468 expected: smtp_chunk_ok!(250, 0, 0, 0, "OK"),
469 },
470 ];
471- let mut session = Session::default().with_options(
472- Options::default()
473- .our_hostname(EXAMPLE_HOSTNAME)
474- .capabilities(DEFAULT_CAPABILITIES)
475- .into(),
476+ let session = Mutex::new(
477+ Session::default().with_options(
478+ Options::default()
479+ .our_hostname(EXAMPLE_HOSTNAME)
480+ .capabilities(DEFAULT_CAPABILITIES)
481+ .into(),
482+ ),
483 );
484- process_all(&mut session, requests);
485+ process_all(&session, requests).await;
486 }
487
488- #[test]
489- fn test_email_with_body() {
490+ #[tokio::test]
491+ async fn test_email_with_body() {
492 let requests = &[
493 TestCase {
494 request: Request::Helo {
495 @@ -736,17 +760,21 @@ transport rather than the session.
496 expected: smtp_chunk_ok!(250, 0, 0, 0, "OK"),
497 },
498 ];
499- let mut session = Session::default()
500- .with_options(Options::default().our_hostname(EXAMPLE_HOSTNAME).into());
501- process_all(&mut session, requests);
502+ let session = Mutex::new(
503+ Session::default()
504+ .with_options(Options::default().our_hostname(EXAMPLE_HOSTNAME).into()),
505+ );
506+ process_all(&session, requests).await;
507+ let session = session.lock().await;
508 assert!(session
509 .mail_from
510+ .as_ref()
511 .is_some_and(|mail_from| mail_from.email() == "fuu@example.org"));
512- assert!(session.rcpt_to.is_some_and(|rcpts| rcpts
513+ assert!(session.rcpt_to.as_ref().is_some_and(|rcpts| rcpts
514 .first()
515 .is_some_and(|rcpt_to| rcpt_to.email() == "bar@example.org")));
516- assert!(session.body.is_some_and(|body| {
517- let message = MessageParser::new().parse(&body).unwrap();
518+ assert!(session.body.as_ref().is_some_and(|body| {
519+ let message = MessageParser::new().parse(body).unwrap();
520 message
521 .subject()
522 .is_some_and(|subject| subject == "Hello World")
523 diff --git a/maitred/src/verify.rs b/maitred/src/verify.rs
524index 7aa2742..e4efb1c 100644
525--- a/maitred/src/verify.rs
526+++ b/maitred/src/verify.rs
527 @@ -1,5 +1,6 @@
528 use std::result::Result as StdResult;
529
530+ use async_trait::async_trait;
531 use email_address::EmailAddress;
532
533 /// Result indicating the VRFY command was successful and the user was
534 @@ -26,9 +27,10 @@ pub enum Error {
535
536 /// Verify that the given e-mail address exists on the server. Servers may
537 /// choose to implement nothing or not use this option at all if desired.
538+ #[async_trait]
539 pub trait Verify {
540 /// Verify the e-mail address on the server
541- fn verify(&self, address: &EmailAddress) -> Result;
542+ async fn verify(&self, address: &EmailAddress) -> Result;
543 }
544
545 /// Helper wrapper implementing the Verify trait.
546 @@ -36,11 +38,12 @@ pub struct Func<F>(pub F)
547 where
548 F: Fn(&EmailAddress) -> Result;
549
550+ #[async_trait]
551 impl<F> Verify for Func<F>
552 where
553- F: Fn(&EmailAddress) -> Result,
554+ F: Fn(&EmailAddress) -> Result + Sync,
555 {
556- fn verify(&self, address: &EmailAddress) -> Result {
557+ async fn verify(&self, address: &EmailAddress) -> Result {
558 let f = &self.0;
559 f(address)
560 }