use std::pin::Pin;
use futures::future::Future;
use log::{trace, info, warn};
use async_std::sync::{Arc, Mutex};
use async_std::io::prelude::{ReadExt, WriteExt};
use std::collections::HashMap;

use libpso::crypto::{PSOCipher, NullCipher, CipherError};
use libpso::PacketParseError;
use crate::common::serverstate::ClientId;
use crate::common::serverstate::{RecvServerPacket, SendServerPacket, ServerState, OnConnect};


#[derive(Debug)]
pub enum NetworkError {
    CouldNotSend,
    CipherError(CipherError),
    PacketParseError(PacketParseError),
    IOError(std::io::Error),
    DataNotReady,
    ClientDisconnected,
}

impl From<CipherError> for NetworkError {
    fn from(err: CipherError) -> NetworkError {
        NetworkError::CipherError(err)
    }
}

impl From<std::io::Error> for NetworkError {
    fn from(err: std::io::Error) -> NetworkError {
        NetworkError::IOError(err)
    }
}

impl From<PacketParseError> for NetworkError {
    fn from(err: PacketParseError) -> NetworkError {
        NetworkError::PacketParseError(err)
    }
}

struct PacketReceiver {
    socket: Arc<async_std::net::TcpStream>,
    cipher: Arc<Mutex<Box<dyn PSOCipher + Send>>>,
    recv_buffer: Vec<u8>,
    incoming_data: Vec<u8>,
}

impl PacketReceiver {
    fn new(socket: Arc<async_std::net::TcpStream>, cipher: Arc<Mutex<Box<dyn PSOCipher + Send>>>) -> PacketReceiver {
        PacketReceiver {
            socket: socket,
            cipher: cipher,
            recv_buffer: Vec::new(),
            incoming_data: Vec::new(),
        }
    }

    async fn fill_recv_buffer(&mut self) -> Result<(), NetworkError> {
        let mut data = [0u8; 0x8000];

        let mut socket = &*self.socket;
        let len = socket.read(&mut data).await?;
        if len == 0 {
            return Err(NetworkError::ClientDisconnected);
        }

        self.recv_buffer.extend_from_slice(&mut data[..len]);

        let mut dec_buf = {
            let mut cipher = self.cipher.lock().await;
            let block_chunk_len = self.recv_buffer.len() / cipher.block_size() * cipher.block_size();
            let buf = self.recv_buffer.drain(..block_chunk_len).collect();
            cipher.decrypt(&buf)?
        };
        self.incoming_data.append(&mut dec_buf);

        Ok(())
    }

    async fn recv_pkts<R: RecvServerPacket + Send + std::fmt::Debug>(&mut self) -> Result<Vec<R>, NetworkError> {
        self.fill_recv_buffer().await?;

        let mut result = Vec::new();
        loop {
            if self.incoming_data.len() < 2 {
                break;
            }
            let pkt_size = u16::from_le_bytes([self.incoming_data[0], self.incoming_data[1]]) as usize;
            let mut pkt_len = pkt_size;
            while pkt_len % self.cipher.lock().await.block_size() != 0 {
                pkt_len += 1;
            }

            if pkt_len > self.incoming_data.len() {
                break;
            }

            let pkt_data = self.incoming_data.drain(..pkt_len).collect::<Vec<_>>();

            trace!("[recv buf] {:?}", pkt_data);
            let pkt = match R::from_bytes(&pkt_data[..pkt_size]) {
                Ok(p) => p,
                Err(err) => {
                    warn!("error RecvServerPacket::from_bytes: {:?}", err);
                    continue
                },
            };

            result.push(pkt);
        }

        Ok(result)
    }
}

async fn send_pkt<S: SendServerPacket + Send + std::fmt::Debug>(socket: Arc<async_std::net::TcpStream>,
                                                                cipher: Arc<Mutex<Box<dyn PSOCipher + Send>>>, pkt: S)
                                                                -> Result<(), NetworkError>
{
    let buf = pkt.as_bytes();
    //println!("sndbuf: {:?}", buf);
    let cbuf = cipher.lock().await.encrypt(&buf)?;
    let mut ssock = &*socket;
    ssock.write_all(&cbuf).await?;
    Ok(())
}


enum ClientAction<S, R> {
    NewClient(ClientId, async_std::sync::Sender<S>),
    Packet(ClientId, R),
    Disconnect(ClientId),
}

enum ServerStateAction<S> {
    Cipher(Box<dyn PSOCipher + Send + Sync>, Box<dyn PSOCipher + Send + Sync>),
    Packet(S),
    Disconnect,
}

fn client_recv_loop<S, R>(client_id: ClientId,
                                socket: Arc<async_std::net::TcpStream>,
                                cipher: Arc<Mutex<Box<dyn PSOCipher + Send>>>,
                                server_sender: async_std::sync::Sender<ClientAction<ServerStateAction<S>, R>>,
                                client_sender: async_std::sync::Sender<ServerStateAction<S>>)
where
    S: SendServerPacket + std::fmt::Debug + Send + 'static,
    R: RecvServerPacket + std::fmt::Debug + Send + 'static,
{
    async_std::task::spawn(async move {
        server_sender.send(ClientAction::NewClient(client_id, client_sender)).await;
        let mut pkt_receiver = PacketReceiver::new(socket, cipher);

        loop {
            match pkt_receiver.recv_pkts().await {
                Ok(pkts) => {
                    for pkt in pkts {
                        trace!("[recv from {:?}] {:?}", client_id, pkt);
                        server_sender.send(ClientAction::Packet(client_id, pkt)).await;
                    }
                },
                Err(err) => {
                    match err {
                        NetworkError::ClientDisconnected => {
                            trace!("[client disconnected] {:?}", client_id);
                            server_sender.send(ClientAction::Disconnect(client_id)).await;
                            break;
                        }
                        _ => {
                            warn!("[client {:?} recv error] {:?}", client_id, err);
                        }
                    }
                }
            }
        }
    });
}

fn client_send_loop<S>(client_id: ClientId,
                             socket: Arc<async_std::net::TcpStream>,
                             cipher_in: Arc<Mutex<Box<dyn PSOCipher + Send>>>,
                             cipher_out: Arc<Mutex<Box<dyn PSOCipher + Send>>>,
                             client_receiver: async_std::sync::Receiver<ServerStateAction<S>>)
where
    S: SendServerPacket + std::fmt::Debug + Send + 'static,
{
    async_std::task::spawn(async move {
        loop {
            let action = client_receiver.recv().await.unwrap();
            match action {
                ServerStateAction::Cipher(inc, outc) => {
                    *cipher_in.lock().await = inc;
                    *cipher_out.lock().await = outc;
                }
                ServerStateAction::Packet(pkt) => {
                    trace!("[send to {:?}] {:?}", client_id, pkt);
                    if let Err(err) = send_pkt(socket.clone(), cipher_out.clone(), pkt).await {
                        warn!("[client {:?} send error ] {:?}", client_id, err);
                    }
                },
                ServerStateAction::Disconnect => {
                    break;
                }
            };
        }
    });
}

fn state_client_loop<STATE, S, R, E>(state: Arc<Mutex<STATE>>,
                                           server_state_receiver: async_std::sync::Receiver<ClientAction<ServerStateAction<S>, R>>) where
    STATE: ServerState<SendPacket=S, RecvPacket=R, PacketError=E> + Send + 'static,
    S: SendServerPacket + std::fmt::Debug + Send + 'static,
    R: RecvServerPacket + std::fmt::Debug + Send + 'static,
    E: std::fmt::Debug + Send,
{
    async_std::task::spawn(async move {
        let mut clients = HashMap::new();

        loop {
            let action = server_state_receiver.recv().await.unwrap();
            let mut state = state.lock().await;

            match action {
                ClientAction::NewClient(client_id, sender) => {
                    let actions = state.on_connect(client_id).await;
                    match actions {
                        Ok(actions) => {
                            for action in actions {
                                match action {
                                    OnConnect::Cipher((inc, outc)) => {
                                        sender.send(ServerStateAction::Cipher(inc, outc)).await;
                                    },
                                    OnConnect::Packet(pkt) => {
                                        sender.send(ServerStateAction::Packet(pkt)).await;
                                    }
                                }
                            }
                        },
                        Err(err) => {
                            warn!("[client {:?} state on_connect error] {:?}", client_id, err);
                        }
                    }
                    clients.insert(client_id, sender);
                },
                ClientAction::Packet(client_id, pkt) => {
                    let pkts = state.handle(client_id, &pkt).await;
                    match pkts {
                        Ok(pkts) => {
                            for (client_id, pkt) in pkts {
                                if let Some(client) = clients.get_mut(&client_id) {
                                    client.send(ServerStateAction::Packet(pkt)).await;
                                }
                            }
                        },
                        Err(err) => {
                            warn!("[client {:?} state handler error] {:?}", client_id, err);
                        }
                    }
                },
                ClientAction::Disconnect(client_id) => {
                    let pkts = state.on_disconnect(client_id).await;
                    match pkts {
                        Ok(pkts) => {
                            for (client_id, pkt) in pkts {
                                if let Some(client) = clients.get_mut(&client_id) {
                                    client.send(ServerStateAction::Packet(pkt)).await;
                                }
                            }

                            if let Some(client) = clients.get_mut(&client_id) {
                                client.send(ServerStateAction::Disconnect).await;
                            }
                        }
                        Err(err) => {
                            warn!("[client {:?} state on_disconnect error] {:?}", client_id, err);
                        }
                    }
                }
            }
        }
    });
}


pub fn client_accept_mainloop<STATE, S, R, E>(state: Arc<Mutex<STATE>>, client_port: u16) -> Pin<Box<dyn Future<Output = ()>>>
where
    STATE: ServerState<SendPacket=S, RecvPacket=R, PacketError=E> + Send + 'static,
    S: SendServerPacket + std::fmt::Debug + Send + Sync + 'static,
    R: RecvServerPacket + std::fmt::Debug + Send + Sync + 'static,
    E: std::fmt::Debug + Send,
{
    Box::pin(async_std::task::spawn(async move {
        let listener = async_std::net::TcpListener::bind(&std::net::SocketAddr::from((std::net::Ipv4Addr::new(0,0,0,0), client_port))).await.unwrap();
        let mut id = 0;

        let (server_state_sender, server_state_receiver) = async_std::sync::channel(1024);
        state_client_loop(state, server_state_receiver);

        loop {
            let (sock, addr) = listener.accept().await.unwrap();
            id += 1;
            let client_id = crate::common::serverstate::ClientId(id);

            info!("new client {:?} {:?} {:?}", client_id, sock, addr);

            let (client_sender, client_receiver) = async_std::sync::channel(64);
            let socket = Arc::new(sock);
            let cipher_in: Arc<Mutex<Box<dyn PSOCipher + Send>>> = Arc::new(Mutex::new(Box::new(NullCipher {})));
            let cipher_out: Arc<Mutex<Box<dyn PSOCipher + Send>>> = Arc::new(Mutex::new(Box::new(NullCipher {})));

            client_recv_loop(client_id, socket.clone(), cipher_in.clone(), server_state_sender.clone(), client_sender);
            client_send_loop(client_id, socket.clone(), cipher_in.clone(), cipher_out.clone(), client_receiver);
        }
    }))
}