use std::collections::HashMap;
use std::fmt::Debug;
use async_std::channel;
use async_std::io::prelude::{ReadExt, WriteExt};
use async_std::sync::{Arc, RwLock};
use futures::future::Future;
use log::{trace, info, warn};

use libpso::crypto::{PSOCipher, NullCipher, CipherError};
use libpso::PacketParseError;
use crate::common::serverstate::ClientId;
use crate::common::serverstate::{RecvServerPacket, SendServerPacket, ServerState, OnConnect};


#[derive(Debug)]
pub enum NetworkError {
    CouldNotSend,
    CipherError(CipherError),
    PacketParseError(PacketParseError),
    IOError(std::io::Error),
    DataNotReady,
    ClientDisconnected,
}

impl From<CipherError> for NetworkError {
    fn from(err: CipherError) -> NetworkError {
        NetworkError::CipherError(err)
    }
}

impl From<std::io::Error> for NetworkError {
    fn from(err: std::io::Error) -> NetworkError {
        NetworkError::IOError(err)
    }
}

impl From<PacketParseError> for NetworkError {
    fn from(err: PacketParseError) -> NetworkError {
        NetworkError::PacketParseError(err)
    }
}

pub struct PacketReceiver<C: PSOCipher> {
    socket: async_std::net::TcpStream,
    cipher: C,
    recv_buffer: Vec<u8>,
    incoming_data: Vec<u8>,
}

impl<C: PSOCipher> PacketReceiver<C> {
    pub fn new(socket: async_std::net::TcpStream, cipher: C) -> PacketReceiver<C> {
        PacketReceiver {
            socket,
            cipher,
            recv_buffer: Vec::new(),
            incoming_data: Vec::new(),
        }
    }

    async fn fill_recv_buffer(&mut self) -> Result<(), NetworkError> {
        let mut data = [0u8; 0x8000];

        let mut socket = self.socket.clone();
        let len = socket.read(&mut data).await?;
        if len == 0 {
            return Err(NetworkError::ClientDisconnected);
        }

        self.recv_buffer.extend_from_slice(&data[..len]);

        let mut dec_buf = {
            //let mut cipher = self.cipher.lock().await;
            let block_chunk_len = self.recv_buffer.len() / self.cipher.block_size() * self.cipher.block_size();
            let buf = self.recv_buffer.drain(..block_chunk_len).collect();
            self.cipher.decrypt(&buf)?
        };
        self.incoming_data.append(&mut dec_buf);

        Ok(())
    }

    pub async fn recv_pkts<R: RecvServerPacket + std::fmt::Debug>(&mut self) -> Result<Vec<R>, NetworkError> {
        self.fill_recv_buffer().await?;

        let mut result = Vec::new();
        loop {
            if self.incoming_data.len() < 2 {
                break;
            }
            let pkt_size = u16::from_le_bytes([self.incoming_data[0], self.incoming_data[1]]) as usize;
            let mut pkt_len = pkt_size;
            while pkt_len % self.cipher.block_size() != 0 {
                pkt_len += 1;
            }

            if pkt_len > self.incoming_data.len() {
                break;
            }

            let pkt_data = self.incoming_data.drain(..pkt_len).collect::<Vec<_>>();

            trace!("[recv buf] {:?}", pkt_data);
            let pkt = match R::from_bytes(&pkt_data[..pkt_size]) {
                Ok(p) => p,
                Err(err) => {
                    warn!("error RecvServerPacket::from_bytes: {:?}", err);
                    continue
                },
            };

            result.push(pkt);
        }

        Ok(result)
    }
}

async fn recv_loop<STATE, S, R, C, E>(mut state: STATE,
                                      socket: async_std::net::TcpStream,
                                      client_id: ClientId,
                                      cipher: C,
                                      clients: Arc<RwLock<HashMap<ClientId, channel::Sender<S>>>>)
where
    STATE: ServerState<SendPacket=S, RecvPacket=R, Cipher=C, PacketError=E> + Send,
    S: SendServerPacket + Debug + Send,
    R: RecvServerPacket + Debug + Send,
    C: PSOCipher + Send,
    E: std::fmt::Debug + Send,
{
    let mut pkt_receiver = PacketReceiver::new(socket, cipher);
    loop {
        match pkt_receiver.recv_pkts::<R>().await {
            Ok(pkts) => {
                for pkt in pkts {
                    info!("[recv from {:?}] {:#?}", client_id, pkt);
                    match state.handle(client_id, pkt).await {
                        Ok(response) => {
                            for resp in response {
                                clients
                                    .read()
                                    .await
                                    .get(&resp.0)
                                    .unwrap()
                                    .send(resp.1)
                                    .await
                                    .unwrap();
                            }
                        },
                        Err(err) => {
                            warn!("[client recv {:?}] error {:?} ", client_id, err);
                        }
                    }
                }
            },
            Err(err) => {
                match err {
                    NetworkError::ClientDisconnected => {
                        info!("[client recv {:?}] disconnected", client_id);
                        for pkt in state.on_disconnect(client_id).await.unwrap() {
                            clients
                                .read()
                                .await
                                .get(&pkt.0)
                                .unwrap()
                                .send(pkt.1)
                                .await
                                .unwrap();
                        }
                        clients
                            .write()
                            .await
                            .remove(&client_id);
                        break;
                    }
                    _ => {
                        warn!("[client {:?} recv error] {:?}", client_id, err);
                    }
                }
            }
        }
    }
}


async fn send_pkt<S, C>(socket: &mut async_std::net::TcpStream,
                        cipher: &mut C,
                        pkt: &S)
                        -> Result<(), NetworkError>
where
    S: SendServerPacket + std::fmt::Debug,
    C: PSOCipher,
{
    let buf = pkt.as_bytes();
    trace!("[send buf] {:?}", buf);
    let cbuf = cipher.encrypt(&buf)?;
    socket.write_all(&cbuf).await?;
    Ok(())
}

async fn send_loop<S, C>(mut socket: async_std::net::TcpStream, client_id: ClientId, mut cipher: C,  packet_queue: channel::Receiver<S>)
where
    S: SendServerPacket + std::fmt::Debug,
    C: PSOCipher,
{
    loop {
        match packet_queue.recv().await {
            Ok(pkt) => {
                if let Err(err) = send_pkt(&mut socket, &mut cipher, &pkt).await {
                    warn!("error sending pkt {:#?} to {:?} {:?}", pkt, client_id, err);
                }
            },
            Err(err) => {
                info!("send to {:?} failed: {:?}", client_id, err);
                break;
            }
        }
    }
}

pub async fn run_server<STATE, S, R, C, E>(mut state: STATE, port: u16)
where
    STATE: ServerState<SendPacket=S, RecvPacket=R, Cipher=C, PacketError=E> + Send + 'static,
    S: SendServerPacket + std::fmt::Debug + Send + 'static,
    R: RecvServerPacket + std::fmt::Debug + Send,
    C: PSOCipher + Send + 'static,
    E: std::fmt::Debug + Send,
{
    let listener = async_std::net::TcpListener::bind(&std::net::SocketAddr::from((std::net::Ipv4Addr::new(0,0,0,0), port))).await.unwrap();
    let mut id = 0;

    let clients = Arc::new(RwLock::new(HashMap::new()));

    loop {
        let (mut socket, addr) = listener.accept().await.unwrap();
        id += 1;

        let client_id = crate::common::serverstate::ClientId(id);
        info!("new client {:?} {:?} {:?}", client_id, socket, addr);

        let (client_tx, client_rx) = async_std::channel::unbounded();

        clients
            .write()
            .await
            .insert(client_id, client_tx.clone());

        let mut cipher_in: Option<C> = None;
        let mut cipher_out: Option<C> = None;

        for action in state.on_connect(client_id).await.unwrap() {
            match action {
                OnConnect::Cipher(cin, cout) => {
                    cipher_in = Some(cin);
                    cipher_out = Some(cout);
                },
                OnConnect::Packet(pkt) => {
                    send_pkt(&mut socket, &mut NullCipher {}, &pkt).await.unwrap();
                }
            }
        }

        let rstate = state.clone();
        let rsocket = socket.clone();
        let rclients = clients.clone();
        async_std::task::spawn(async move {
            recv_loop(rstate, rsocket, client_id, cipher_in.unwrap(), rclients).await
        });

        async_std::task::spawn(async move {
            send_loop(socket, client_id, cipher_out.unwrap(), client_rx).await
        });
    }
}