use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::Arc;
use hyper::{header, upgrade, StatusCode, Body, Request, Response, Server, server::conn::AddrStream};
use hyper::service::{make_service_fn, service_fn};
use tokio_tungstenite::WebSocketStream;
use tungstenite::{handshake, Error};
use futures::stream::StreamExt;
use lazy_static::lazy_static;
use log::{error, info};
use tokio::sync::RwLock;
use protocol::State;
use crate::handler::{ClientHandler, ClientManager};
use crate::wsserver::handle_client;
pub mod wsserver;
pub mod handler;
pub mod timer;
async fn handle_request(mut request: Request<Body>, remote_addr: SocketAddr, mgr: ClientManager) -> Result<Response<Body>, Infallible> {
match (request.uri().path(), request.headers().contains_key(header::UPGRADE)) {
//if the request is ws_echo and the request headers contains an Upgrade key
("/ws", true) => {
//assume request is a handshake, so create the handshake response
let response =
match handshake::server::create_response_with_body(&request, || Body::empty()) {
Ok(response) => {
//in case the handshake response creation succeeds,
//spawn a task to handle the websocket connection
tokio::spawn(async move {
//using the hyper feature of upgrading a connection
match upgrade::on(&mut request).await {
//if successfully upgraded
Ok(upgraded) => {
//create a websocket stream from the upgraded object
let ws_stream = WebSocketStream::from_raw_socket(
//pass the upgraded object
//as the base layer stream of the Websocket
upgraded,
tokio_tungstenite::tungstenite::protocol::Role::Server,
None,
).await;
//we can split the stream into a sink and a stream
let (ws_write, ws_read) = ws_stream.split();
let (tx, rx) = tokio::sync::mpsc::channel(128);
let client = ClientHandler {
tx,
};
// Acquire the write lock in a small scope, so it's dropped as quickly as possible
{
mgr.clients.write().await.insert(remote_addr, client);
}
//forward the stream to the sink to achieve echo
match handle_client(mgr.clone(), remote_addr, rx, ws_write, ws_read).await {
Ok(_) => {},
Err(e) => error!("error on WS connection {}: {}", remote_addr, e),
};
// clean up values left over
{
mgr.clients.write().await.remove(&remote_addr);
mgr.usernames.write().await.remove(&remote_addr);
}
},
Err(e) => {
error!("error upgrading connection from {} to WS: {}", remote_addr, e);
}
}
});
//return the response to the handshake request
response
},
Err(e) => {
//probably the handshake request is not up to spec for websocket
error!("error creating websocket response to {}: {}", remote_addr, e);
let mut res = Response::new(Body::from(format!("Failed to create websocket: {}", e)));
*res.status_mut() = StatusCode::BAD_REQUEST;
return Ok(res);
}
};
Ok::<_, Infallible>(response)
},
("/ws", false) => {
Ok(Response::builder().status(400).body(Body::from("Connection-Upgrade header missing")).unwrap())
},
(url@_, false) => {
// typical HTTP file request
// TODO
Ok(Response::new(Body::empty()))
},
(_, true) => {
// http upgrade on non-/ws endpoint
Ok(Response::builder().status(400).body(Body::from("Incorrect WebSocket endpoint")).unwrap())
}
}
}
lazy_static! {
static ref cmgr: ClientManager = ClientManager {
clients: Arc::new(RwLock::new(Default::default())),
usernames: Arc::new(RwLock::new(Default::default())),
};
}
#[tokio::main]
async fn main() {
simple_logger::init_with_env().expect("Unable to start logging service");
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
info!("Listening on {} for HTTP/WebSocket connections", addr);
let make_svc = make_service_fn(|conn: &AddrStream| {
let remote_addr = conn.remote_addr();
async move {
Ok::<_, Infallible>(service_fn({
move |request: Request<Body>| {
handle_request(request, remote_addr, cmgr.clone())
}
}))
}
});
let server = Server::bind(&addr).serve(make_svc);
if let Err(e) = server.await {
error!("error in server thread: {}", e);
}
}