diff options
author | stale <redkugelblitzin@gmail.com> | 2022-05-03 20:22:43 -0300 |
---|---|---|
committer | stale <redkugelblitzin@gmail.com> | 2022-05-03 20:22:43 -0300 |
commit | 88cc92f64dfc4a241410f67b20dec4680461a344 (patch) | |
tree | 1fe938cc6013c8f8692881cd1803a27fac3a26a9 /src/tls_stuff.rs | |
parent | 2f9687126ecb538f40b57bd6963129b406786175 (diff) |
this is terrible
Diffstat (limited to 'src/tls_stuff.rs')
-rw-r--r-- | src/tls_stuff.rs | 159 |
1 files changed, 159 insertions, 0 deletions
diff --git a/src/tls_stuff.rs b/src/tls_stuff.rs new file mode 100644 index 0000000..83c0489 --- /dev/null +++ b/src/tls_stuff.rs @@ -0,0 +1,159 @@ +//! Simple HTTPS echo service based on hyper-rustls +//! +//! First parameter is the mandatory port to use. +//! Certificate and private key are hardcoded to sample files. +//! hyper will automatically use HTTP/2 if a client starts talking HTTP/2, +//! otherwise HTTP/1.1 will be used. +use core::task::{Context, Poll}; +use futures_util::ready; +use hyper::server::accept::Accept; +use hyper::server::conn::{AddrIncoming, AddrStream}; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::vec::Vec; +use std::{fs, io}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_rustls::rustls::{self, ServerConfig}; + +fn error(err: String) -> io::Error { + io::Error::new(io::ErrorKind::Other, err) +} + +pub enum State { + Handshaking(tokio_rustls::Accept<AddrStream>), + Streaming(tokio_rustls::server::TlsStream<AddrStream>), +} + +// tokio_rustls::server::TlsStream doesn't expose constructor methods, +// so we have to TlsAcceptor::accept and handshake to have access to it +// TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first +pub struct TlsStream { + pub state: State, +} + +impl TlsStream { + pub fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream { + let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream); + TlsStream { + state: State::Handshaking(accept), + } + } +} + +impl AsyncRead for TlsStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &mut ReadBuf, + ) -> Poll<io::Result<()>> { + let pin = self.get_mut(); + match pin.state { + State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) { + Ok(mut stream) => { + let result = Pin::new(&mut stream).poll_read(cx, buf); + pin.state = State::Streaming(stream); + result + } + Err(err) => Poll::Ready(Err(err)), + }, + State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for TlsStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + let pin = self.get_mut(); + match pin.state { + State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) { + Ok(mut stream) => { + let result = Pin::new(&mut stream).poll_write(cx, buf); + pin.state = State::Streaming(stream); + result + } + Err(err) => Poll::Ready(Err(err)), + }, + State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + match self.state { + State::Handshaking(_) => Poll::Ready(Ok(())), + State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx), + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + match self.state { + State::Handshaking(_) => Poll::Ready(Ok(())), + State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx), + } + } +} + +pub struct TlsAcceptor { + config: Arc<ServerConfig>, + incoming: AddrIncoming, +} + +impl TlsAcceptor { + pub fn new(config: Arc<ServerConfig>, incoming: AddrIncoming) -> TlsAcceptor { + TlsAcceptor { config, incoming } + } +} + +impl Accept for TlsAcceptor { + type Conn = TlsStream; + type Error = io::Error; + + fn poll_accept( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Option<Result<Self::Conn, Self::Error>>> { + let pin = self.get_mut(); + match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) { + Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))), + Some(Err(e)) => Poll::Ready(Some(Err(e))), + None => Poll::Ready(None), + } + } +} + +// Load public certificate from file. +pub fn load_certs(filename: &str) -> io::Result<Vec<rustls::Certificate>> { + // Open certificate file. + let certfile = fs::File::open(filename) + .map_err(|e| error(format!("failed to open {}: {}", filename, e)))?; + let mut reader = io::BufReader::new(certfile); + + // Load and return certificate. + let certs = rustls_pemfile::certs(&mut reader) + .map_err(|_| error("failed to load certificate".into()))?; + Ok(certs + .into_iter() + .map(rustls::Certificate) + .collect()) +} + +// Load private key from file. +pub fn load_private_key(filename: &str) -> io::Result<rustls::PrivateKey> { + // Open keyfile. + let keyfile = fs::File::open(filename) + .map_err(|e| error(format!("failed to open {}: {}", filename, e)))?; + let mut reader = io::BufReader::new(keyfile); + + // Load and return a single private key. + let keys = rustls_pemfile::rsa_private_keys(&mut reader) + .map_err(|_| error("failed to load private key".into()))?; + if keys.len() != 1 { + return Err(error("expected a single private key".into())); + } + + Ok(rustls::PrivateKey(keys[0].clone())) +} |