Author: Kevin Schoon [me@kevinschoon.com]
Hash: ab392816c58221cb19dbf9892e944c5f523fa7dc
Timestamp: Mon, 29 Jul 2024 21:35:21 +0000 (2 months ago)

+151 -18 +/-3 browse
factor tcp stream out for better testing
1diff --git a/maitred/src/server.rs b/maitred/src/server.rs
2index f1ee7e7..3b68bb3 100644
3--- a/maitred/src/server.rs
4+++ b/maitred/src/server.rs
5 @@ -1,21 +1,27 @@
6+ use std::time::{Duration, Instant};
7+
8 use futures::SinkExt;
9 use smtp_proto::Request;
10- use tokio::net::{TcpListener, TcpStream};
11+ use tokio::net::TcpListener;
12 use tokio_stream::StreamExt;
13 use tokio_util::codec::Framed;
14
15 use crate::error::Error;
16- use crate::session::Session;
17+ use crate::session::{Options as SessionOptions, Session};
18 use crate::transport::Transport;
19
20 const DEFAULT_LISTEN_ADDR: &str = "127.0.0.1:2525";
21 const DEFAULT_GREETING: &str = "Maitred ESMTP Server";
22+ // Maximum amount of time the server will wait for a command before closing
23+ // the connection.
24+ const DEFAULT_GLOBAL_TIMEOUT_SECS: u64 = 300;
25
26 #[derive(Clone)]
27 struct Configuration {
28 pub address: String,
29 pub hostname: String,
30 pub greeting: String,
31+ pub global_timeout: Duration,
32 }
33
34 impl Default for Configuration {
35 @@ -24,6 +30,7 @@ impl Default for Configuration {
36 address: DEFAULT_LISTEN_ADDR.to_string(),
37 hostname: String::default(),
38 greeting: DEFAULT_GREETING.to_string(),
39+ global_timeout: Duration::from_secs(DEFAULT_GLOBAL_TIMEOUT_SECS),
40 }
41 }
42 }
43 @@ -43,6 +50,12 @@ impl Server {
44 }
45 }
46
47+ /// Greeting message returned from the server upon initial connection.
48+ pub fn with_greeting(mut self, greeting: &str) -> Self {
49+ self.config.greeting = greeting.to_string();
50+ self
51+ }
52+
53 /// Listener address for the SMTP server to bind to listen for incoming
54 /// connections.
55 pub fn with_address(mut self, address: &str) -> Self {
56 @@ -50,15 +63,24 @@ impl Server {
57 self
58 }
59
60- async fn process(&self, stream: TcpStream) -> Result<(), Error> {
61- let peer = stream.peer_addr()?;
62- tracing::info!("Processing new TCP connection from {:?}", peer);
63- let transport = Transport::default();
64- let mut framed = Framed::new(stream, transport);
65+ /// Set the maximum amount of time the server will wait for another command
66+ /// before closing the connection. RFC states the suggested time is 5m.
67+ pub fn with_timeout(mut self, timeout: Duration) -> Self {
68+ self.config.global_timeout = timeout;
69+ self
70+ }
71+
72+ async fn process<T>(
73+ &self,
74+ mut framed: Framed<T, Transport>,
75+ opts: crate::session::Options,
76+ ) -> Result<Session, Error>
77+ where
78+ T: tokio::io::AsyncRead + tokio::io::AsyncWrite + std::marker::Unpin,
79+ {
80 let mut session = Session::default();
81- let session_opts = crate::session::Options {
82- hostname: self.config.hostname.clone(),
83- };
84+ let start_time = Instant::now();
85+ let mut n_commands = 0;
86 // send inital server greeting
87 framed
88 .send(crate::session::greeting(
89 @@ -73,8 +95,7 @@ impl Server {
90 if matches!(command.0, Request::Quit) {
91 finished = true;
92 }
93-
94- match session.process(&session_opts, &command.0, command.1) {
95+ match session.process(&opts, &command.0, command.1) {
96 Ok(resp) => {
97 tracing::debug!("Returning response: {:?}", resp);
98 framed.send(resp).await?;
99 @@ -86,7 +107,8 @@ impl Server {
100 };
101 }
102 Err(err) => {
103- tracing::warn!("Socket closed with error: {:?}", err)
104+ tracing::warn!("Socket closed with error: {:?}", err);
105+ return Err(err);
106 }
107 };
108
109 @@ -95,7 +117,7 @@ impl Server {
110 }
111 }
112 tracing::info!("Connection closed");
113- Ok(())
114+ Ok(session)
115 }
116
117 pub async fn listen(&self) -> Result<(), Error> {
118 @@ -103,10 +125,120 @@ impl Server {
119 tracing::info!("Mail server listening @ {}", self.config.address);
120 loop {
121 let (socket, _) = listener.accept().await.unwrap();
122- // TODO: timeout
123- if let Err(err) = self.process(socket).await {
124+ let framed = Framed::new(socket, Transport::default());
125+ if let Err(err) = self
126+ .process(
127+ framed,
128+ SessionOptions {
129+ hostname: self.config.hostname.clone(),
130+ },
131+ )
132+ .await
133+ {
134 tracing::warn!("Client encountered an error: {:?}", err);
135 }
136 }
137 }
138 }
139+
140+ #[cfg(test)]
141+ mod test {
142+
143+ use super::*;
144+
145+ use std::io;
146+ use std::pin::Pin;
147+ use std::task::{Context, Poll};
148+ use tokio::io::{AsyncRead, AsyncWrite};
149+
150+ /// Fake TCP stream for testing purposes with "framed" line oriented
151+ /// requests to feed to the session processor.
152+ #[derive(Default)]
153+ struct FakeStream {
154+ buffer: Vec<Vec<u8>>,
155+ chunk: usize,
156+ }
157+
158+ impl AsyncRead for FakeStream {
159+ fn poll_read(
160+ self: Pin<&mut Self>,
161+ _cx: &mut Context<'_>,
162+ buf: &mut tokio::io::ReadBuf<'_>,
163+ ) -> Poll<io::Result<()>> {
164+ let inner = self.get_mut();
165+ let index = inner.chunk;
166+ if let Some(chunk) = inner.buffer.get(index) {
167+ inner.chunk = index + 1;
168+ println!("Client wrote: {:?}", String::from_utf8_lossy(chunk));
169+ buf.put_slice(chunk.as_slice());
170+ std::task::Poll::Ready(Ok(()))
171+ } else {
172+ Poll::Ready(Ok(()))
173+ }
174+ }
175+ }
176+
177+ impl AsyncWrite for FakeStream {
178+ fn poll_write(
179+ self: Pin<&mut Self>,
180+ _cx: &mut Context<'_>,
181+ buf: &[u8],
182+ ) -> Poll<Result<usize, io::Error>> {
183+ println!("Server responded: {:?}", String::from_utf8_lossy(buf));
184+ Poll::Ready(Ok(buf.len()))
185+ }
186+
187+ fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
188+ Poll::Ready(Ok(()))
189+ }
190+
191+ fn poll_shutdown(
192+ self: Pin<&mut Self>,
193+ _cx: &mut Context<'_>,
194+ ) -> Poll<Result<(), io::Error>> {
195+ todo!()
196+ }
197+ }
198+
199+ #[tokio::test]
200+ async fn test_server_process() {
201+ let stream = FakeStream {
202+ buffer: vec![
203+ "HELO example.org\r\n".into(),
204+ "MAIL FROM: <fuu@bar.com>\r\n".into(),
205+ "RCPT TO: <baz@qux.com>\r\n".into(),
206+ "DATA\r\n".into(),
207+ "Subject: Hello World\r\n.\r\n".into(),
208+ "QUIT\r\n".into(),
209+ ],
210+ ..Default::default()
211+ };
212+ let server = Server::new("example.org");
213+ let framed = Framed::new(stream, Transport::default());
214+ let session = server
215+ .process(
216+ framed,
217+ crate::session::Options {
218+ hostname: "localhost".to_string(),
219+ },
220+ )
221+ .await
222+ .unwrap();
223+ assert!(session.history.len() == 6);
224+ assert!(matches!(
225+ session.history.first().unwrap(),
226+ Request::Helo { host: _ }
227+ ));
228+ assert!(matches!(
229+ session.history.get(1).unwrap(),
230+ Request::Mail { from: _ }
231+ ));
232+ assert!(matches!(
233+ session.history.get(2).unwrap(),
234+ Request::Rcpt { to: _ }
235+ ));
236+ assert!(matches!(session.history.get(3).unwrap(), Request::Data {}));
237+ assert!(matches!(session.history.get(4).unwrap(), Request::Data {}));
238+ assert!(matches!(session.history.get(5).unwrap(), Request::Quit {}));
239+ }
240+ }
241 diff --git a/maitred/src/session.rs b/maitred/src/session.rs
242index af82e3b..7d06d19 100644
243--- a/maitred/src/session.rs
244+++ b/maitred/src/session.rs
245 @@ -1,4 +1,5 @@
246 use std::result::Result as StdResult;
247+ use std::time::{Duration, Instant};
248
249 use bytes::Bytes;
250 use mail_parser::{Addr, Message, MessageParser};
251 @@ -23,11 +24,11 @@ pub fn greeting(hostname: &str, greeting: &str) -> Response<String> {
252 }
253
254 /// Runtime options that influence server behavior
255- #[derive(Default)]
256 pub(crate) struct Options {
257 pub hostname: String,
258 }
259
260+
261 /// Stateful connection that coresponds to a single SMTP session
262 #[derive(Default)]
263 pub(crate) struct Session {
264 @@ -40,7 +41,6 @@ pub(crate) struct Session {
265 /// rcpt address
266 pub rcpt_to: Option<Address>,
267 pub hostname: Option<Host>,
268-
269 /// If an active data transfer is taking place
270 data_transfer: Option<DataTransfer>,
271 }
272 diff --git a/maitred/src/transport.rs b/maitred/src/transport.rs
273index f0f1425..fa59793 100644
274--- a/maitred/src/transport.rs
275+++ b/maitred/src/transport.rs
276 @@ -25,6 +25,7 @@ pub(crate) enum Receiver {
277 }
278
279 /// Command from the client with an optional attached payload.
280+ #[derive(Debug)]
281 pub(crate) struct Command(pub Request<String>, pub Option<Bytes>);
282
283 impl Display for Command {