From c7ea9c03e96a9f7d410dd13dfa23a949e50db75e Mon Sep 17 00:00:00 2001
From: jake <jake@sharnoth.com>
Date: Sat, 18 Nov 2023 23:27:14 -0700
Subject: [PATCH] add PSOPacketData derive to enums

---
 psopacket/src/lib.rs | 84 +++++++++++++++++++++++++++++++++++++-------
 1 file changed, 71 insertions(+), 13 deletions(-)

diff --git a/psopacket/src/lib.rs b/psopacket/src/lib.rs
index ef5a7a9..ff7c63c 100644
--- a/psopacket/src/lib.rs
+++ b/psopacket/src/lib.rs
@@ -647,19 +647,7 @@ pub fn pso_message(attr: TokenStream, item: TokenStream) -> TokenStream {
     q.into()
 }
 
-#[proc_macro_derive(PSOPacketData)]
-pub fn pso_packet_data(input: TokenStream) -> TokenStream {
-    let derive = parse_macro_input!(input as DeriveInput);
-
-    let name = derive.ident;
-
-    let fields = if let syn::Data::Struct(strct) = derive.data {
-        strct.fields
-    }
-    else {
-        return syn::Error::new(name.span(), "PSOPacketData only works on structs").to_compile_error().into();
-    };
-
+fn pso_packet_data_struct(name: syn::Ident, fields: syn::Fields) -> TokenStream {
     let attrs = match get_struct_fields(fields.iter()) {
         Ok(a) => a,
         Err(err) => return err
@@ -695,3 +683,73 @@ pub fn pso_packet_data(input: TokenStream) -> TokenStream {
 
     q.into()
 }
+
+fn pso_packet_data_enum<'a>(name: syn::Ident, repr_type: syn::Ident, variants: impl Iterator<Item = &'a syn::Variant> + Clone) -> TokenStream {
+    let value_to_variant = variants
+        .clone()
+        .enumerate()
+        .map(|(i, variant)| {
+            quote! {
+                #i => #name::#variant,
+            }
+        })
+        .collect::<Vec<_>>();
+
+    let variant_to_value = variants
+        .enumerate()
+        .map(|(i, variant)| {
+            quote! {
+                #name::#variant => #repr_type::to_le_bytes(#i as #repr_type).to_vec(),
+            }
+        })
+        .collect::<Vec<_>>();
+    let impl_pso_data_packet = quote! {
+        impl PSOPacketData for #name {
+            fn from_bytes<R: std::io::Read + std::io::Seek>(mut cur: &mut R) -> Result<Self, PacketParseError> {
+                let mut buf = #repr_type::default().to_le_bytes();
+                cur.read_exact(&mut buf).unwrap();
+                let value = #repr_type::from_le_bytes(buf);
+
+                Ok(match value as usize {
+                    #(#value_to_variant)*
+                    _ => return Err(PacketParseError::InvalidValue)
+                })
+            }
+
+            fn as_bytes(&self) -> Vec<#repr_type> {
+                match self {
+                    #(#variant_to_value)*
+                }
+            }
+        }
+    };
+
+    impl_pso_data_packet.into()
+}
+
+#[proc_macro_derive(PSOPacketData)]
+pub fn pso_packet_data(input: TokenStream) -> TokenStream {
+    let derive = parse_macro_input!(input as DeriveInput);
+
+    let name = derive.ident;
+
+    if let syn::Data::Struct(strct) = derive.data {
+        pso_packet_data_struct(name, strct.fields)
+    }
+    else if let syn::Data::Enum(enm) = derive.data {
+        let repr_type = derive.attrs.iter().fold(None, |mut repr_type, attr| {
+            if attr.path().is_ident("repr") {
+                attr.parse_nested_meta(|meta| {
+                    repr_type = Some(meta.path.get_ident().cloned().unwrap());
+                    Ok(())
+                }).unwrap();
+            }
+            repr_type
+        });
+
+        pso_packet_data_enum(name, repr_type.unwrap(), enm.variants.iter())
+    }
+    else {
+        syn::Error::new(name.span(), "PSOPacketData only works on structs and enums").to_compile_error().into()
+    }
+}