diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/main.rs | 92 | ||||
-rw-r--r-- | src/tls_stuff.rs | 159 |
2 files changed, 220 insertions, 31 deletions
diff --git a/src/main.rs b/src/main.rs index d42ad9f..1367c60 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,10 +7,12 @@ use std::{ }; mod minesweeper; +mod tls_stuff; use minesweeper::*; -//use std::convert::{ TryFrom, TryInto }; +use tls_stuff::*; use hyper::{ Method, StatusCode, Body, Request, Response, Server }; +use hyper::server::conn::{ AddrStream, AddrIncoming }; use hyper::service::{make_service_fn, service_fn}; use tokio::sync::{ RwLock, @@ -21,9 +23,9 @@ type HtmlResult = Result<Response<Body>, Response<Body>>; use futures_channel::mpsc::{unbounded, UnboundedSender}; use futures_util::{future, pin_mut, stream::TryStreamExt, StreamExt, sink::SinkExt}; -use tokio::net::{TcpListener, TcpStream}; use tokio::fs; -use tokio_tungstenite::tungstenite::protocol::Message; +use hyper_tungstenite::{ tungstenite::protocol::Message, HyperWebsocket }; +use tokio_rustls::{ rustls, Accept, server::TlsStream }; type Tx = UnboundedSender<Message>; type MovReqTx = mpsc::UnboundedSender<MetaMove>; @@ -41,23 +43,51 @@ const PAGE_RELPATH: &str = "./page.html"; const FONT_FILE_FUCKIT: &[u8] = include_bytes!("./VT323-Regular.ttf"); #[tokio::main] -async fn main() { +async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> { let sequential_id = Arc::new(AtomicUsize::new(0)); let peers = PeerMap::new(RwLock::new(HashMap::new())); let peer_info: PeerInfo = (peers.clone(), sequential_id.clone()); - let http_addr = SocketAddr::from(([0, 0, 0, 0], 31235)); - println!("Http on {}", http_addr); + let addr = SocketAddr::from(([0, 0, 0, 0], 31235)); + + // Build TLS configuration. + let tls_cfg = { + // Load public certificate. + let certs = load_certs("cert.pem")?; + // Load private key. + let key = load_private_key("cert.rsa")?; + // Do not use client certificate authentication. + let mut cfg = rustls::ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(certs, key) + .map_err(|e| error(format!("{}", e)))?; + // Configure ALPN to accept HTTP/2, HTTP/1.1 in that order. + cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + Arc::new(cfg) + }; + + // Create a TCP listener via tokio. + let incoming = AddrIncoming::bind(&addr)?; let (cmd_tx, cmd_rx) = mpsc::unbounded_channel(); let game_t = tokio::spawn(gameloop(cmd_rx, peers.clone())); - // need to await this one at some point - let conn_l = tokio::spawn(conn_listener(peer_info.clone(), cmd_tx.clone())); - let http_serv = make_service_fn(|_| { + let serv = make_service_fn(|socket: &tls_stuff::TlsStream| { + let addr = { + if let State::Streaming(ref s) = &socket.state { + s.get_ref().0.remote_addr() + } else { + std::net::SocketAddr::new(std::net::Ipv4Addr::new(0,0,0,0).into(),0) + } + }; + let peer_info = peer_info.clone(); + let cmd_tx = cmd_tx.clone(); async move { Ok::<_, Infallible>(service_fn(move |req: Request<Body>| { + let peer_info = peer_info.clone(); + let cmd_tx = cmd_tx.clone(); async move { - Ok::<_,Infallible>(match handle_http_req(req).await { + Ok::<_,Infallible>(match handle_req(req, peer_info.clone(), cmd_tx.clone(), addr).await { Ok(r) => r, Err(r) => r, }) @@ -66,12 +96,15 @@ async fn main() { } }); - let server = Server::bind(&http_addr) - .serve(http_serv) + // Run the future, keep going until an error occurs. + println!("Starting to serve on https://{}.", addr); + let server = Server::builder(TlsAcceptor::new(tls_cfg, incoming)) + .serve(serv) .with_graceful_shutdown(shutdown_signal()); if let Err(e) = server.await { eprint!("server error: {}", e); } + Ok(()) } // If a move is made, broadcast new board, else just send current board @@ -111,32 +144,17 @@ async fn gameloop(mut move_rx: mpsc::UnboundedReceiver<MetaMove>, peers: PeerMap } } -async fn conn_listener(peer_info: PeerInfo, cmd_tx: MovReqTx) { - let ws_addr = SocketAddr::from(([0, 0, 0, 0], 31236)); - let ws_socket = TcpListener::bind(&ws_addr).await; - let ws_listener = ws_socket.expect("Failed to bind"); - // Let's spawn the handling of each connection in a separate task. - println!("Websocket on {}", ws_addr); - while let Ok((stream, addr)) = ws_listener.accept().await { - tokio::spawn(peer_connection(peer_info.clone(), cmd_tx.clone(), stream, addr)); - } -} - -async fn peer_connection(peer_info: PeerInfo, cmd_tx: MovReqTx, raw_stream: TcpStream, addr: SocketAddr) { +async fn peer_connection(peer_info: PeerInfo, cmd_tx: MovReqTx, socket: HyperWebsocket, addr: SocketAddr) { + let socket = socket.await.unwrap(); // FIXME error handling println!("Incoming TCP connection from: {}", addr); - let ws_stream = tokio_tungstenite::accept_async(raw_stream) - .await - .expect("Error during the websocket handshake occurred"); - println!("WebSocket connection established: {}", addr); - let peer_map = peer_info.0; let peer_seqid = peer_info.1.fetch_add(1, atomic::Ordering::AcqRel); let mut peer_name = "unknown".to_string(); let (tx, rx) = unbounded(); - let (outgoing, mut incoming) = ws_stream.split(); + let (outgoing, mut incoming) = socket.split(); let process_incoming = async { while let Ok(cmd) = incoming.try_next().await { @@ -249,7 +267,15 @@ async fn peer_connection(peer_info: PeerInfo, cmd_tx: MovReqTx, raw_stream: TcpS peer_map.write().await.remove(&addr); } -async fn handle_http_req(request: Request<Body>) -> HtmlResult { +async fn handle_req(mut request: Request<Body>, peer_info: PeerInfo, cmd_tx: MovReqTx, addr: SocketAddr) -> HtmlResult { + if hyper_tungstenite::is_upgrade_request(&request) { + let (resp, wsocket) = hyper_tungstenite::upgrade(&mut request, None).expect("couldn't upgrade to websocket"); + tokio::spawn(async move { + peer_connection(peer_info.clone(), cmd_tx.clone(), wsocket, addr).await; + }); + return Ok(resp); + } + let page = fs::read_to_string(PAGE_RELPATH).await.unwrap(); let mut uri_path = request.uri().path().split('/').skip(1); let actual_path = uri_path.next(); @@ -289,6 +315,10 @@ fn errpage<T: Error>(e: T) -> Response<Body> { .unwrap() } +fn error(err: String) -> std::io::Error { + std::io::Error::new(std::io::ErrorKind::Other, err) +} + async fn shutdown_signal() { tokio::signal::ctrl_c() .await diff --git a/src/tls_stuff.rs b/src/tls_stuff.rs new file mode 100644 index 0000000..83c0489 --- /dev/null +++ b/src/tls_stuff.rs @@ -0,0 +1,159 @@ +//! Simple HTTPS echo service based on hyper-rustls +//! +//! First parameter is the mandatory port to use. +//! Certificate and private key are hardcoded to sample files. +//! hyper will automatically use HTTP/2 if a client starts talking HTTP/2, +//! otherwise HTTP/1.1 will be used. +use core::task::{Context, Poll}; +use futures_util::ready; +use hyper::server::accept::Accept; +use hyper::server::conn::{AddrIncoming, AddrStream}; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::vec::Vec; +use std::{fs, io}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_rustls::rustls::{self, ServerConfig}; + +fn error(err: String) -> io::Error { + io::Error::new(io::ErrorKind::Other, err) +} + +pub enum State { + Handshaking(tokio_rustls::Accept<AddrStream>), + Streaming(tokio_rustls::server::TlsStream<AddrStream>), +} + +// tokio_rustls::server::TlsStream doesn't expose constructor methods, +// so we have to TlsAcceptor::accept and handshake to have access to it +// TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first +pub struct TlsStream { + pub state: State, +} + +impl TlsStream { + pub fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream { + let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream); + TlsStream { + state: State::Handshaking(accept), + } + } +} + +impl AsyncRead for TlsStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &mut ReadBuf, + ) -> Poll<io::Result<()>> { + let pin = self.get_mut(); + match pin.state { + State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) { + Ok(mut stream) => { + let result = Pin::new(&mut stream).poll_read(cx, buf); + pin.state = State::Streaming(stream); + result + } + Err(err) => Poll::Ready(Err(err)), + }, + State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for TlsStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + let pin = self.get_mut(); + match pin.state { + State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) { + Ok(mut stream) => { + let result = Pin::new(&mut stream).poll_write(cx, buf); + pin.state = State::Streaming(stream); + result + } + Err(err) => Poll::Ready(Err(err)), + }, + State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + match self.state { + State::Handshaking(_) => Poll::Ready(Ok(())), + State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx), + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + match self.state { + State::Handshaking(_) => Poll::Ready(Ok(())), + State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx), + } + } +} + +pub struct TlsAcceptor { + config: Arc<ServerConfig>, + incoming: AddrIncoming, +} + +impl TlsAcceptor { + pub fn new(config: Arc<ServerConfig>, incoming: AddrIncoming) -> TlsAcceptor { + TlsAcceptor { config, incoming } + } +} + +impl Accept for TlsAcceptor { + type Conn = TlsStream; + type Error = io::Error; + + fn poll_accept( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Option<Result<Self::Conn, Self::Error>>> { + let pin = self.get_mut(); + match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) { + Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))), + Some(Err(e)) => Poll::Ready(Some(Err(e))), + None => Poll::Ready(None), + } + } +} + +// Load public certificate from file. +pub fn load_certs(filename: &str) -> io::Result<Vec<rustls::Certificate>> { + // Open certificate file. + let certfile = fs::File::open(filename) + .map_err(|e| error(format!("failed to open {}: {}", filename, e)))?; + let mut reader = io::BufReader::new(certfile); + + // Load and return certificate. + let certs = rustls_pemfile::certs(&mut reader) + .map_err(|_| error("failed to load certificate".into()))?; + Ok(certs + .into_iter() + .map(rustls::Certificate) + .collect()) +} + +// Load private key from file. +pub fn load_private_key(filename: &str) -> io::Result<rustls::PrivateKey> { + // Open keyfile. + let keyfile = fs::File::open(filename) + .map_err(|e| error(format!("failed to open {}: {}", filename, e)))?; + let mut reader = io::BufReader::new(keyfile); + + // Load and return a single private key. + let keys = rustls_pemfile::rsa_private_keys(&mut reader) + .map_err(|_| error("failed to load private key".into()))?; + if keys.len() != 1 { + return Err(error("expected a single private key".into())); + } + + Ok(rustls::PrivateKey(keys[0].clone())) +} |