use log::{trace, info, warn}; use async_std::sync::{Arc, Mutex}; use async_std::io::{Read, Write}; use async_std::io::prelude::{ReadExt, WriteExt}; use async_std::prelude::{StreamExt}; use std::collections::HashMap; use libpso::crypto::{PSOCipher, NullCipher, CipherError}; use libpso::PacketParseError; use crate::common::serverstate::ClientId; use crate::common::serverstate::{RecvServerPacket, SendServerPacket, ServerState, OnConnect}; #[derive(Debug)] pub enum NetworkError { CouldNotSend, CipherError(CipherError), PacketParseError(PacketParseError), IOError(std::io::Error), DataNotReady, ClientDisconnected, } impl From for NetworkError { fn from(err: CipherError) -> NetworkError { NetworkError::CipherError(err) } } impl From for NetworkError { fn from(err: std::io::Error) -> NetworkError { NetworkError::IOError(err) } } impl From for NetworkError { fn from(err: PacketParseError) -> NetworkError { NetworkError::PacketParseError(err) } } struct PacketReceiver { socket: Arc, cipher: Arc>>, recv_buffer: Vec, incoming_data: Vec, } impl PacketReceiver { fn new(socket: Arc, cipher: Arc>>) -> PacketReceiver { PacketReceiver { socket: socket, cipher: cipher, recv_buffer: Vec::new(), incoming_data: Vec::new(), } } async fn fill_recv_buffer(&mut self) -> Result<(), NetworkError> { let mut data = [0u8; 0x8000]; let mut socket = &*self.socket; let len = socket.read(&mut data).await?; if len == 0 { return Err(NetworkError::ClientDisconnected); } self.recv_buffer.extend_from_slice(&mut data[..len]); let mut dec_buf = { let mut cipher = self.cipher.lock().await; let block_chunk_len = self.recv_buffer.len() / cipher.block_size() * cipher.block_size(); let buf = self.recv_buffer.drain(..block_chunk_len).collect(); cipher.decrypt(&buf)? }; self.incoming_data.append(&mut dec_buf); Ok(()) } async fn recv_pkts(&mut self) -> Result, NetworkError> { self.fill_recv_buffer().await?; let mut result = Vec::new(); loop { if self.incoming_data.len() < 2 { break; } let pkt_size = u16::from_le_bytes([self.incoming_data[0], self.incoming_data[1]]) as usize; let mut pkt_len = pkt_size; while pkt_len % self.cipher.lock().await.block_size() != 0 { pkt_len += 1; } if pkt_len > self.incoming_data.len() { break; } let pkt_data = self.incoming_data.drain(..pkt_len).collect::>(); trace!("[recv buf] {:?}", pkt_data); let pkt = match R::from_bytes(&pkt_data[..pkt_size]) { Ok(p) => p, Err(err) => { warn!("error RecvServerPacket::from_bytes: {:?}", err); continue }, }; result.push(pkt); } Ok(result) } } async fn send_pkt(socket: Arc, cipher: Arc>>, pkt: S) -> Result<(), NetworkError> { let buf = pkt.as_bytes(); let cbuf = cipher.lock().await.encrypt(&buf)?; let mut ssock = &*socket; ssock.write_all(&cbuf).await?; Ok(()) } enum ClientAction { NewClient(ClientId, async_std::sync::Sender), Packet(ClientId, R), Disconnect(ClientId), } enum ServerStateAction { Cipher(Box, Box), Packet(S), Disconnect, } async fn server_state_loop(mut state: STATE, server_state_receiver: async_std::sync::Receiver, R>>) where STATE: ServerState + Send + 'static, S: SendServerPacket + std::fmt::Debug + Send + 'static, R: RecvServerPacket + std::fmt::Debug + Send + 'static, E: std::fmt::Debug + Send, { async_std::task::spawn(async move { let mut clients = HashMap::new(); loop { let action = server_state_receiver.recv().await.unwrap(); match action { ClientAction::NewClient(client_id, sender) => { clients.insert(client_id, sender.clone()); for action in state.on_connect(client_id) { match action { OnConnect::Cipher((inc, outc)) => { sender.send(ServerStateAction::Cipher(inc, outc)).await; }, OnConnect::Packet(pkt) => { sender.send(ServerStateAction::Packet(pkt)).await; } } } }, ClientAction::Packet(client_id, pkt) => { let pkts = state.handle(client_id, &pkt); match pkts { Ok(pkts) => { for (client_id, pkt) in pkts { if let Some(client) = clients.get_mut(&client_id) { client.send(ServerStateAction::Packet(pkt)).await; } } }, Err(err) => { warn!("[client {:?} state handler error] {:?}", client_id, err); } } }, ClientAction::Disconnect(client_id) => { let pkts = state.on_disconnect(client_id); for (client_id, pkt) in pkts { if let Some(client) = clients.get_mut(&client_id) { client.send(ServerStateAction::Packet(pkt)).await; } } if let Some(client) = clients.get_mut(&client_id) { client.send(ServerStateAction::Disconnect).await; } } } } }); } async fn client_recv_loop(client_id: ClientId, socket: Arc, cipher: Arc>>, server_sender: async_std::sync::Sender, R>>, client_sender: async_std::sync::Sender>) where S: SendServerPacket + std::fmt::Debug + Send + 'static, R: RecvServerPacket + std::fmt::Debug + Send + 'static, { async_std::task::spawn(async move { server_sender.send(ClientAction::NewClient(client_id, client_sender)).await; let mut pkt_receiver = PacketReceiver::new(socket, cipher); loop { match pkt_receiver.recv_pkts().await { Ok(pkts) => { for pkt in pkts { trace!("[recv from {:?}] {:?}", client_id, pkt); server_sender.send(ClientAction::Packet(client_id, pkt)).await; } }, Err(err) => { match err { NetworkError::ClientDisconnected => { trace!("[client disconnected] {:?}", client_id); server_sender.send(ClientAction::Disconnect(client_id)).await; break; } _ => { warn!("[client {:?} recv error] {:?}", client_id, err); } } } } } }); } async fn client_send_loop(client_id: ClientId, socket: Arc, cipher_in: Arc>>, cipher_out: Arc>>, client_receiver: async_std::sync::Receiver>, ) where S: SendServerPacket + std::fmt::Debug + Send + 'static, { async_std::task::spawn(async move { loop { let action = client_receiver.recv().await.unwrap(); match action { ServerStateAction::Cipher(inc, outc) => { *cipher_in.lock().await = inc; *cipher_out.lock().await = outc; } ServerStateAction::Packet(pkt) => { trace!("[send to {:?}] {:?}", client_id, pkt); if let Err(err) = send_pkt(socket.clone(), cipher_out.clone(), pkt).await { warn!("[client {:?} send error ] {:?}", client_id, err); } }, ServerStateAction::Disconnect => { break; } }; } }); } pub async fn mainloop_async(mut state: STATE, port: u16) where STATE: ServerState + Send + 'static, S: SendServerPacket + std::fmt::Debug + Send + Sync + 'static, R: RecvServerPacket + std::fmt::Debug + Send + Sync + 'static, E: std::fmt::Debug + Send, { let listener = async_std::task::spawn(async move { let listener = async_std::net::TcpListener::bind(&std::net::SocketAddr::from((std::net::Ipv4Addr::new(0,0,0,0), port))).await.unwrap(); let mut id = 1; let (server_state_sender, server_state_receiver) = async_std::sync::channel(1024); server_state_loop(state, server_state_receiver).await; loop { let (sock, addr) = listener.accept().await.unwrap(); let client_id = crate::common::serverstate::ClientId(id); id += 1; info!("new client {:?} {:?} {:?}", client_id, sock, addr); let (client_sender, client_receiver) = async_std::sync::channel(64); let socket = Arc::new(sock); let cipher_in: Arc>> = Arc::new(Mutex::new(Box::new(NullCipher {}))); let cipher_out: Arc>> = Arc::new(Mutex::new(Box::new(NullCipher {}))); client_recv_loop(client_id, socket.clone(), cipher_in.clone(), server_state_sender.clone(), client_sender).await; client_send_loop(client_id, socket.clone(), cipher_in.clone(), cipher_out.clone(), client_receiver).await; } }); listener.await }