diff --git a/src/login/character.rs b/src/login/character.rs index 379d106..44a4198 100644 --- a/src/login/character.rs +++ b/src/login/character.rs @@ -1,6 +1,7 @@ use std::net; use std::sync::Arc; use std::io::Read; +use std::collections::HashMap; use rand::{Rng, RngCore}; use bcrypt::{DEFAULT_COST, hash, verify}; @@ -29,6 +30,7 @@ pub const CHARACTER_PORT: u16 = 12001; #[derive(Debug)] pub enum CharacterError { + ClientNotFound(ClientId), } #[derive(Debug)] @@ -123,18 +125,29 @@ fn generate_param_data(path: &str) -> (ParamDataHeader, Vec) { }, buffer) } -// TODO: rip these client-specific vars into a HashMap -pub struct CharacterServerState { - //shared_state: SharedLoginState, - data_access: DA, - param_header: ParamDataHeader, - +struct ClientState { param_index: usize, - param_data: Arc>, user: Option, characters: Option<[Option; 4]>, guildcard_data_buffer: Option>, +} + +impl ClientState { + fn new() -> ClientState { + ClientState { + param_index: 0, + user: None, + characters: None, + guildcard_data_buffer: None, + } + } +} +pub struct CharacterServerState { + data_access: DA, + param_header: ParamDataHeader, + param_data: Vec, + clients: HashMap, } impl CharacterServerState { @@ -145,32 +158,32 @@ impl CharacterServerState { //shared_state: shared_state, data_access: data_access, param_header: param_header, - - param_index: 0, - param_data: Arc::new(param_data), - user: None, - characters: None, - guildcard_data_buffer: None, + param_data: param_data, + clients: HashMap::new(), } } - fn validate_login(&mut self, pkt: &Login) -> Vec { - match get_login_status(&self.data_access, pkt) { + fn validate_login(&mut self, id: ClientId, pkt: &Login) -> Result, CharacterError> { + let client = self.clients.get_mut(&id).ok_or(CharacterError::ClientNotFound(id))?; + Ok(match get_login_status(&self.data_access, pkt) { Ok(user) => { let mut response = LoginResponse::by_status(AccountStatus::Ok, pkt.security_data); response.guildcard = user.guildcard.map_or(0, |gc| gc) as u32; response.team_id = user.team_id.map_or(0, |ti| ti) as u32; - self.user = Some(user); + client.user = Some(user); vec![SendCharacterPacket::LoginResponse(response)] }, Err(err) => { vec![SendCharacterPacket::LoginResponse(LoginResponse::by_status(err, pkt.security_data))] } - } + }) } - fn get_settings(&mut self) -> Vec { - let user = self.user.as_ref().unwrap(); + fn get_settings(&mut self, id: ClientId) -> Result, CharacterError> { + let client = self.clients.get_mut(&id).ok_or(CharacterError::ClientNotFound(id))?; + let user = client.user.as_ref().unwrap(); + + // TODO: this should error (data should be added on account creation, why did I copy this silly sylv logic?) let settings = match self.data_access.get_user_settings_by_user(&user) { Some(settings) => settings, None => self.data_access.create_user_settings_by_user(&user), @@ -180,16 +193,17 @@ impl CharacterServerState { settings.settings.joystick_config, 0, 0); let pkt = SendCharacterPacket::SendKeyAndTeamSettings(pkt); - vec![pkt] + Ok(vec![pkt]) } - fn char_select(&mut self, select: &CharSelect) -> Vec { - if self.characters.is_none() { - self.characters = Some(self.data_access.get_characters_by_user(self.user.as_ref().unwrap())); + fn char_select(&mut self, id: ClientId, select: &CharSelect) -> Result, CharacterError> { + let client = self.clients.get_mut(&id).ok_or(CharacterError::ClientNotFound(id))?; + if client.characters.is_none() { + client.characters = Some(self.data_access.get_characters_by_user(client.user.as_ref().unwrap())); } - let chars = self.characters.as_ref().unwrap(); - if let Some(char) = &chars[select.slot as usize] { + let chars = client.characters.as_ref().unwrap(); + Ok(if let Some(char) = &chars[select.slot as usize] { vec![SendCharacterPacket::CharacterPreview(CharacterPreview { flag: 0, slot: select.slot, @@ -202,7 +216,7 @@ impl CharacterServerState { slot: select.slot, code: 2, })] - } + }) } fn validate_checksum(&mut self) -> Vec { @@ -212,31 +226,32 @@ impl CharacterServerState { })] } - fn guildcard_data_header(&mut self) -> Vec { - let guildcard_data = self.data_access.get_guild_card_data_by_user(self.user.as_ref().unwrap()); + fn guildcard_data_header(&mut self, id: ClientId) -> Result, CharacterError> { + let client = self.clients.get_mut(&id).ok_or(CharacterError::ClientNotFound(id))?; + let guildcard_data = self.data_access.get_guild_card_data_by_user(client.user.as_ref().unwrap()); let bytes = guildcard_data.guildcard.as_bytes(); let mut crc = crc32::Digest::new(crc32::IEEE); crc.write(&bytes[..]); - self.guildcard_data_buffer = Some(bytes.to_vec()); + client.guildcard_data_buffer = Some(bytes.to_vec()); - vec![SendCharacterPacket::GuildcardDataHeader(GuildcardDataHeader::new(bytes.len(), crc.sum32()))] + Ok(vec![SendCharacterPacket::GuildcardDataHeader(GuildcardDataHeader::new(bytes.len(), crc.sum32()))]) } - fn guildcard_data_chunk(&mut self, chunk: u32, again: u32) -> Vec { - if again != 0 { + fn guildcard_data_chunk(&mut self, id: ClientId, chunk: u32, again: u32) -> Result, CharacterError> { + let client = self.clients.get_mut(&id).ok_or(CharacterError::ClientNotFound(id))?; + Ok(if again != 0 { let start = chunk as usize * GUILD_CARD_CHUNK_SIZE; - let len = std::cmp::min(GUILD_CARD_CHUNK_SIZE, self.guildcard_data_buffer.as_ref().unwrap().len() as usize - start); + let len = std::cmp::min(GUILD_CARD_CHUNK_SIZE, client.guildcard_data_buffer.as_ref().unwrap().len() as usize - start); let end = start + len; let mut buf = [0u8; GUILD_CARD_CHUNK_SIZE as usize]; - buf[..len as usize].copy_from_slice(&self.guildcard_data_buffer.as_ref().unwrap()[start..end]); + buf[..len as usize].copy_from_slice(&client.guildcard_data_buffer.as_ref().unwrap()[start..end]); vec![SendCharacterPacket::GuildcardDataChunk(GuildcardDataChunk::new(chunk, buf, len))] } else { Vec::new() - } - + }) } } @@ -247,6 +262,8 @@ impl ServerState for CharacterServerState { type PacketError = CharacterError; fn on_connect(&mut self, id: ClientId) -> Vec> { + self.clients.insert(id, ClientState::new()); + let mut rng = rand::thread_rng(); let mut server_key = [0u8; 48]; @@ -260,32 +277,34 @@ impl ServerState for CharacterServerState { ] } - fn handle(&mut self, id: ClientId, pkt: &RecvCharacterPacket) -> Box> { - match pkt { + fn handle(&mut self, id: ClientId, pkt: &RecvCharacterPacket) + -> Result>, CharacterError> { + Ok(match pkt { RecvCharacterPacket::Login(login) => { - Box::new(self.validate_login(login).into_iter().map(move |pkt| (id, pkt))) + Box::new(self.validate_login(id, login)?.into_iter().map(move |pkt| (id, pkt))) }, RecvCharacterPacket::RequestSettings(_req) => { - Box::new(self.get_settings().into_iter().map(move |pkt| (id, pkt))) + Box::new(self.get_settings(id)?.into_iter().map(move |pkt| (id, pkt))) }, RecvCharacterPacket::CharSelect(sel) => { - Box::new(self.char_select(sel).into_iter().map(move |pkt| (id, pkt))) + Box::new(self.char_select(id, sel)?.into_iter().map(move |pkt| (id, pkt))) }, RecvCharacterPacket::Checksum(_checksum) => { Box::new(self.validate_checksum().into_iter().map(move |pkt| (id, pkt))) }, RecvCharacterPacket::GuildcardDataRequest(_request) => { - Box::new(self.guildcard_data_header().into_iter().map(move |pkt| (id, pkt))) + Box::new(self.guildcard_data_header(id)?.into_iter().map(move |pkt| (id, pkt))) }, RecvCharacterPacket::GuildcardDataChunkRequest(request) => { - Box::new(self.guildcard_data_chunk(request.chunk, request.again).into_iter().map(move |pkt| (id, pkt))) + Box::new(self.guildcard_data_chunk(id, request.chunk, request.again)?.into_iter().map(move |pkt| (id, pkt))) }, RecvCharacterPacket::ParamDataRequest(_request) => { Box::new(vec![SendCharacterPacket::ParamDataHeader(self.param_header.clone())].into_iter().map(move |pkt| (id, pkt))) }, RecvCharacterPacket::ParamDataChunkRequest(_request) => { - let chunk = self.param_index; - self.param_index += 1; + let client = self.clients.get_mut(&id).ok_or(CharacterError::ClientNotFound(id))?; + let chunk = client.param_index; + client.param_index += 1; let start = chunk * 0x6800; let end = std::cmp::min((chunk+1)*0x6800, self.param_data.len()); @@ -301,7 +320,7 @@ impl ServerState for CharacterServerState { } )].into_iter().map(move |pkt| (id, pkt))) } - } + }) } } @@ -328,7 +347,8 @@ mod test { } let mut server = CharacterServerState::new(TestData {}); - server.user = Some(UserAccount { + let mut clientstate = ClientState::new(); + clientstate.user = Some(UserAccount { id: 1, username: "testuser".to_owned(), password: bcrypt::hash("mypassword", 5).unwrap(), @@ -338,8 +358,11 @@ mod test { muted_until: SystemTime::now(), created_at: SystemTime::now(), }); + server.clients.insert(ClientId(5), clientstate); - let send = server.handle(ClientId(5), &RecvCharacterPacket::RequestSettings(RequestSettings {flag: 0})).collect::>(); + let send = server.handle(ClientId(5), &RecvCharacterPacket::RequestSettings(RequestSettings {flag: 0})) + .unwrap() + .collect::>(); assert!(send.len() == 1); assert!(send[0].0 == ClientId(5)); @@ -356,7 +379,7 @@ mod test { let send = server.handle(ClientId(1), &RecvCharacterPacket::Checksum(Checksum {flag: 0, checksum: 1234, padding: 0, - })).collect::>(); + })).unwrap().collect::>(); assert!(send.len() == 1); let bytes = send[0].1.as_bytes();