summaryrefslogtreecommitdiff
path: root/src/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/main.rs')
-rw-r--r--src/main.rs92
1 files changed, 61 insertions, 31 deletions
diff --git a/src/main.rs b/src/main.rs
index d42ad9f..1367c60 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -7,10 +7,12 @@ use std::{
};
mod minesweeper;
+mod tls_stuff;
use minesweeper::*;
-//use std::convert::{ TryFrom, TryInto };
+use tls_stuff::*;
use hyper::{ Method, StatusCode, Body, Request, Response, Server };
+use hyper::server::conn::{ AddrStream, AddrIncoming };
use hyper::service::{make_service_fn, service_fn};
use tokio::sync::{
RwLock,
@@ -21,9 +23,9 @@ type HtmlResult = Result<Response<Body>, Response<Body>>;
use futures_channel::mpsc::{unbounded, UnboundedSender};
use futures_util::{future, pin_mut, stream::TryStreamExt, StreamExt, sink::SinkExt};
-use tokio::net::{TcpListener, TcpStream};
use tokio::fs;
-use tokio_tungstenite::tungstenite::protocol::Message;
+use hyper_tungstenite::{ tungstenite::protocol::Message, HyperWebsocket };
+use tokio_rustls::{ rustls, Accept, server::TlsStream };
type Tx = UnboundedSender<Message>;
type MovReqTx = mpsc::UnboundedSender<MetaMove>;
@@ -41,23 +43,51 @@ const PAGE_RELPATH: &str = "./page.html";
const FONT_FILE_FUCKIT: &[u8] = include_bytes!("./VT323-Regular.ttf");
#[tokio::main]
-async fn main() {
+async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let sequential_id = Arc::new(AtomicUsize::new(0));
let peers = PeerMap::new(RwLock::new(HashMap::new()));
let peer_info: PeerInfo = (peers.clone(), sequential_id.clone());
- let http_addr = SocketAddr::from(([0, 0, 0, 0], 31235));
- println!("Http on {}", http_addr);
+ let addr = SocketAddr::from(([0, 0, 0, 0], 31235));
+
+ // Build TLS configuration.
+ let tls_cfg = {
+ // Load public certificate.
+ let certs = load_certs("cert.pem")?;
+ // Load private key.
+ let key = load_private_key("cert.rsa")?;
+ // Do not use client certificate authentication.
+ let mut cfg = rustls::ServerConfig::builder()
+ .with_safe_defaults()
+ .with_no_client_auth()
+ .with_single_cert(certs, key)
+ .map_err(|e| error(format!("{}", e)))?;
+ // Configure ALPN to accept HTTP/2, HTTP/1.1 in that order.
+ cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
+ Arc::new(cfg)
+ };
+
+ // Create a TCP listener via tokio.
+ let incoming = AddrIncoming::bind(&addr)?;
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
let game_t = tokio::spawn(gameloop(cmd_rx, peers.clone()));
- // need to await this one at some point
- let conn_l = tokio::spawn(conn_listener(peer_info.clone(), cmd_tx.clone()));
- let http_serv = make_service_fn(|_| {
+ let serv = make_service_fn(|socket: &tls_stuff::TlsStream| {
+ let addr = {
+ if let State::Streaming(ref s) = &socket.state {
+ s.get_ref().0.remote_addr()
+ } else {
+ std::net::SocketAddr::new(std::net::Ipv4Addr::new(0,0,0,0).into(),0)
+ }
+ };
+ let peer_info = peer_info.clone();
+ let cmd_tx = cmd_tx.clone();
async move {
Ok::<_, Infallible>(service_fn(move |req: Request<Body>| {
+ let peer_info = peer_info.clone();
+ let cmd_tx = cmd_tx.clone();
async move {
- Ok::<_,Infallible>(match handle_http_req(req).await {
+ Ok::<_,Infallible>(match handle_req(req, peer_info.clone(), cmd_tx.clone(), addr).await {
Ok(r) => r,
Err(r) => r,
})
@@ -66,12 +96,15 @@ async fn main() {
}
});
- let server = Server::bind(&http_addr)
- .serve(http_serv)
+ // Run the future, keep going until an error occurs.
+ println!("Starting to serve on https://{}.", addr);
+ let server = Server::builder(TlsAcceptor::new(tls_cfg, incoming))
+ .serve(serv)
.with_graceful_shutdown(shutdown_signal());
if let Err(e) = server.await {
eprint!("server error: {}", e);
}
+ Ok(())
}
// If a move is made, broadcast new board, else just send current board
@@ -111,32 +144,17 @@ async fn gameloop(mut move_rx: mpsc::UnboundedReceiver<MetaMove>, peers: PeerMap
}
}
-async fn conn_listener(peer_info: PeerInfo, cmd_tx: MovReqTx) {
- let ws_addr = SocketAddr::from(([0, 0, 0, 0], 31236));
- let ws_socket = TcpListener::bind(&ws_addr).await;
- let ws_listener = ws_socket.expect("Failed to bind");
- // Let's spawn the handling of each connection in a separate task.
- println!("Websocket on {}", ws_addr);
- while let Ok((stream, addr)) = ws_listener.accept().await {
- tokio::spawn(peer_connection(peer_info.clone(), cmd_tx.clone(), stream, addr));
- }
-}
-
-async fn peer_connection(peer_info: PeerInfo, cmd_tx: MovReqTx, raw_stream: TcpStream, addr: SocketAddr) {
+async fn peer_connection(peer_info: PeerInfo, cmd_tx: MovReqTx, socket: HyperWebsocket, addr: SocketAddr) {
+ let socket = socket.await.unwrap(); // FIXME error handling
println!("Incoming TCP connection from: {}", addr);
- let ws_stream = tokio_tungstenite::accept_async(raw_stream)
- .await
- .expect("Error during the websocket handshake occurred");
- println!("WebSocket connection established: {}", addr);
-
let peer_map = peer_info.0;
let peer_seqid = peer_info.1.fetch_add(1, atomic::Ordering::AcqRel);
let mut peer_name = "unknown".to_string();
let (tx, rx) = unbounded();
- let (outgoing, mut incoming) = ws_stream.split();
+ let (outgoing, mut incoming) = socket.split();
let process_incoming = async {
while let Ok(cmd) = incoming.try_next().await {
@@ -249,7 +267,15 @@ async fn peer_connection(peer_info: PeerInfo, cmd_tx: MovReqTx, raw_stream: TcpS
peer_map.write().await.remove(&addr);
}
-async fn handle_http_req(request: Request<Body>) -> HtmlResult {
+async fn handle_req(mut request: Request<Body>, peer_info: PeerInfo, cmd_tx: MovReqTx, addr: SocketAddr) -> HtmlResult {
+ if hyper_tungstenite::is_upgrade_request(&request) {
+ let (resp, wsocket) = hyper_tungstenite::upgrade(&mut request, None).expect("couldn't upgrade to websocket");
+ tokio::spawn(async move {
+ peer_connection(peer_info.clone(), cmd_tx.clone(), wsocket, addr).await;
+ });
+ return Ok(resp);
+ }
+
let page = fs::read_to_string(PAGE_RELPATH).await.unwrap();
let mut uri_path = request.uri().path().split('/').skip(1);
let actual_path = uri_path.next();
@@ -289,6 +315,10 @@ fn errpage<T: Error>(e: T) -> Response<Body> {
.unwrap()
}
+fn error(err: String) -> std::io::Error {
+ std::io::Error::new(std::io::ErrorKind::Other, err)
+}
+
async fn shutdown_signal() {
tokio::signal::ctrl_c()
.await