use std::collections::HashMap; use std::fmt::Debug; use async_std::channel; use async_std::io::prelude::{ReadExt, WriteExt}; use async_std::sync::{Arc, RwLock}; use futures::future::Future; use log::{trace, info, warn}; use libpso::crypto::{PSOCipher, NullCipher, CipherError}; use libpso::PacketParseError; use crate::common::serverstate::ClientId; use crate::common::serverstate::{RecvServerPacket, SendServerPacket, ServerState, OnConnect}; #[derive(Debug)] pub enum NetworkError { CouldNotSend, CipherError(CipherError), PacketParseError(PacketParseError), IOError(std::io::Error), DataNotReady, ClientDisconnected, } impl From for NetworkError { fn from(err: CipherError) -> NetworkError { NetworkError::CipherError(err) } } impl From for NetworkError { fn from(err: std::io::Error) -> NetworkError { NetworkError::IOError(err) } } impl From for NetworkError { fn from(err: PacketParseError) -> NetworkError { NetworkError::PacketParseError(err) } } pub struct PacketReceiver { socket: async_std::net::TcpStream, cipher: C, recv_buffer: Vec, incoming_data: Vec, } impl PacketReceiver { pub fn new(socket: async_std::net::TcpStream, cipher: C) -> PacketReceiver { PacketReceiver { socket, cipher, recv_buffer: Vec::new(), incoming_data: Vec::new(), } } async fn fill_recv_buffer(&mut self) -> Result<(), NetworkError> { let mut data = [0u8; 0x8000]; let mut socket = self.socket.clone(); let len = socket.read(&mut data).await?; if len == 0 { return Err(NetworkError::ClientDisconnected); } self.recv_buffer.extend_from_slice(&data[..len]); let mut dec_buf = { //let mut cipher = self.cipher.lock().await; let block_chunk_len = self.recv_buffer.len() / self.cipher.block_size() * self.cipher.block_size(); let buf = self.recv_buffer.drain(..block_chunk_len).collect(); self.cipher.decrypt(&buf)? }; self.incoming_data.append(&mut dec_buf); Ok(()) } pub async fn recv_pkts(&mut self) -> Result, NetworkError> { self.fill_recv_buffer().await?; let mut result = Vec::new(); loop { if self.incoming_data.len() < 2 { break; } let pkt_size = u16::from_le_bytes([self.incoming_data[0], self.incoming_data[1]]) as usize; let mut pkt_len = pkt_size; while pkt_len % self.cipher.block_size() != 0 { pkt_len += 1; } if pkt_len > self.incoming_data.len() { break; } let pkt_data = self.incoming_data.drain(..pkt_len).collect::>(); trace!("[recv buf] {:?}", pkt_data); let pkt = match R::from_bytes(&pkt_data[..pkt_size]) { Ok(p) => p, Err(err) => { warn!("error RecvServerPacket::from_bytes: {:?}", err); continue }, }; result.push(pkt); } Ok(result) } } async fn recv_loop(mut state: STATE, socket: async_std::net::TcpStream, client_id: ClientId, cipher: C, clients: Arc>>>) where STATE: ServerState + Send, S: SendServerPacket + Debug + Send, R: RecvServerPacket + Debug + Send, C: PSOCipher + Send, E: std::fmt::Debug + Send, { let mut pkt_receiver = PacketReceiver::new(socket, cipher); loop { match pkt_receiver.recv_pkts::().await { Ok(pkts) => { for pkt in pkts { info!("[recv from {:?}] {:#?}", client_id, pkt); match state.handle(client_id, pkt).await { Ok(response) => { for resp in response { clients .read() .await .get(&resp.0) .unwrap() .send(resp.1) .await .unwrap(); } }, Err(err) => { warn!("[client recv {:?}] error {:?} ", client_id, err); } } } }, Err(err) => { match err { NetworkError::ClientDisconnected => { info!("[client recv {:?}] disconnected", client_id); for pkt in state.on_disconnect(client_id).await.unwrap() { clients .read() .await .get(&pkt.0) .unwrap() .send(pkt.1) .await .unwrap(); } clients .write() .await .remove(&client_id); break; } _ => { warn!("[client {:?} recv error] {:?}", client_id, err); } } } } } } async fn send_pkt(socket: &mut async_std::net::TcpStream, cipher: &mut C, pkt: &S) -> Result<(), NetworkError> where S: SendServerPacket + std::fmt::Debug, C: PSOCipher, { let buf = pkt.as_bytes(); trace!("[send buf] {:?}", buf); let cbuf = cipher.encrypt(&buf)?; socket.write_all(&cbuf).await?; Ok(()) } async fn send_loop(mut socket: async_std::net::TcpStream, client_id: ClientId, mut cipher: C, packet_queue: channel::Receiver) where S: SendServerPacket + std::fmt::Debug, C: PSOCipher, { loop { match packet_queue.recv().await { Ok(pkt) => { info!("[send to {:?}] {:#?}", client_id, pkt); if let Err(err) = send_pkt(&mut socket, &mut cipher, &pkt).await { warn!("error sending pkt {:#?} to {:?} {:?}", pkt, client_id, err); } }, Err(err) => { info!("send to {:?} failed: {:?}", client_id, err); break; } } } } pub async fn run_server(mut state: STATE, port: u16) where STATE: ServerState + Send + 'static, S: SendServerPacket + std::fmt::Debug + Send + 'static, R: RecvServerPacket + std::fmt::Debug + Send, C: PSOCipher + Send + 'static, E: std::fmt::Debug + Send, { let listener = async_std::net::TcpListener::bind(&std::net::SocketAddr::from((std::net::Ipv4Addr::new(0,0,0,0), port))).await.unwrap(); let mut id = 0; let clients = Arc::new(RwLock::new(HashMap::new())); loop { let (mut socket, addr) = listener.accept().await.unwrap(); id += 1; let client_id = crate::common::serverstate::ClientId(id); info!("new client {:?} {:?} {:?}", client_id, socket, addr); let (client_tx, client_rx) = async_std::channel::unbounded(); clients .write() .await .insert(client_id, client_tx.clone()); let mut cipher_in: Option = None; let mut cipher_out: Option = None; for action in state.on_connect(client_id).await.unwrap() { match action { OnConnect::Cipher(cin, cout) => { cipher_in = Some(cin); cipher_out = Some(cout); }, OnConnect::Packet(pkt) => { send_pkt(&mut socket, &mut NullCipher {}, &pkt).await.unwrap(); } } } let rstate = state.clone(); let rsocket = socket.clone(); let rclients = clients.clone(); async_std::task::spawn(async move { recv_loop(rstate, rsocket, client_id, cipher_in.unwrap(), rclients).await }); async_std::task::spawn(async move { send_loop(socket, client_id, cipher_out.unwrap(), client_rx).await }); } }