summaryrefslogtreecommitdiff
path: root/src/tls_stuff.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/tls_stuff.rs')
-rw-r--r--src/tls_stuff.rs159
1 files changed, 159 insertions, 0 deletions
diff --git a/src/tls_stuff.rs b/src/tls_stuff.rs
new file mode 100644
index 0000000..83c0489
--- /dev/null
+++ b/src/tls_stuff.rs
@@ -0,0 +1,159 @@
+//! Simple HTTPS echo service based on hyper-rustls
+//!
+//! First parameter is the mandatory port to use.
+//! Certificate and private key are hardcoded to sample files.
+//! hyper will automatically use HTTP/2 if a client starts talking HTTP/2,
+//! otherwise HTTP/1.1 will be used.
+use core::task::{Context, Poll};
+use futures_util::ready;
+use hyper::server::accept::Accept;
+use hyper::server::conn::{AddrIncoming, AddrStream};
+use std::future::Future;
+use std::pin::Pin;
+use std::sync::Arc;
+use std::vec::Vec;
+use std::{fs, io};
+use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
+use tokio_rustls::rustls::{self, ServerConfig};
+
+fn error(err: String) -> io::Error {
+ io::Error::new(io::ErrorKind::Other, err)
+}
+
+pub enum State {
+ Handshaking(tokio_rustls::Accept<AddrStream>),
+ Streaming(tokio_rustls::server::TlsStream<AddrStream>),
+}
+
+// tokio_rustls::server::TlsStream doesn't expose constructor methods,
+// so we have to TlsAcceptor::accept and handshake to have access to it
+// TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first
+pub struct TlsStream {
+ pub state: State,
+}
+
+impl TlsStream {
+ pub fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream {
+ let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
+ TlsStream {
+ state: State::Handshaking(accept),
+ }
+ }
+}
+
+impl AsyncRead for TlsStream {
+ fn poll_read(
+ self: Pin<&mut Self>,
+ cx: &mut Context,
+ buf: &mut ReadBuf,
+ ) -> Poll<io::Result<()>> {
+ let pin = self.get_mut();
+ match pin.state {
+ State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
+ Ok(mut stream) => {
+ let result = Pin::new(&mut stream).poll_read(cx, buf);
+ pin.state = State::Streaming(stream);
+ result
+ }
+ Err(err) => Poll::Ready(Err(err)),
+ },
+ State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
+ }
+ }
+}
+
+impl AsyncWrite for TlsStream {
+ fn poll_write(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<io::Result<usize>> {
+ let pin = self.get_mut();
+ match pin.state {
+ State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
+ Ok(mut stream) => {
+ let result = Pin::new(&mut stream).poll_write(cx, buf);
+ pin.state = State::Streaming(stream);
+ result
+ }
+ Err(err) => Poll::Ready(Err(err)),
+ },
+ State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
+ }
+ }
+
+ fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ match self.state {
+ State::Handshaking(_) => Poll::Ready(Ok(())),
+ State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
+ }
+ }
+
+ fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ match self.state {
+ State::Handshaking(_) => Poll::Ready(Ok(())),
+ State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
+ }
+ }
+}
+
+pub struct TlsAcceptor {
+ config: Arc<ServerConfig>,
+ incoming: AddrIncoming,
+}
+
+impl TlsAcceptor {
+ pub fn new(config: Arc<ServerConfig>, incoming: AddrIncoming) -> TlsAcceptor {
+ TlsAcceptor { config, incoming }
+ }
+}
+
+impl Accept for TlsAcceptor {
+ type Conn = TlsStream;
+ type Error = io::Error;
+
+ fn poll_accept(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
+ let pin = self.get_mut();
+ match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
+ Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))),
+ Some(Err(e)) => Poll::Ready(Some(Err(e))),
+ None => Poll::Ready(None),
+ }
+ }
+}
+
+// Load public certificate from file.
+pub fn load_certs(filename: &str) -> io::Result<Vec<rustls::Certificate>> {
+ // Open certificate file.
+ let certfile = fs::File::open(filename)
+ .map_err(|e| error(format!("failed to open {}: {}", filename, e)))?;
+ let mut reader = io::BufReader::new(certfile);
+
+ // Load and return certificate.
+ let certs = rustls_pemfile::certs(&mut reader)
+ .map_err(|_| error("failed to load certificate".into()))?;
+ Ok(certs
+ .into_iter()
+ .map(rustls::Certificate)
+ .collect())
+}
+
+// Load private key from file.
+pub fn load_private_key(filename: &str) -> io::Result<rustls::PrivateKey> {
+ // Open keyfile.
+ let keyfile = fs::File::open(filename)
+ .map_err(|e| error(format!("failed to open {}: {}", filename, e)))?;
+ let mut reader = io::BufReader::new(keyfile);
+
+ // Load and return a single private key.
+ let keys = rustls_pemfile::rsa_private_keys(&mut reader)
+ .map_err(|_| error("failed to load private key".into()))?;
+ if keys.len() != 1 {
+ return Err(error("expected a single private key".into()));
+ }
+
+ Ok(rustls::PrivateKey(keys[0].clone()))
+}