package no.zyon.shade import no.zyon.shade.serialization.SessionStateJson import no.zyon.shade.types.ChainState import no.zyon.shade.types.IdentityKeyPair import no.zyon.shade.types.KeyPair import no.zyon.shade.types.OneTimePreKey import no.zyon.shade.types.SessionState import no.zyon.shade.types.SignedPreKey import org.junit.Assert.assertArrayEquals import org.junit.Assert.assertEquals import org.junit.Assert.assertNotNull import org.junit.Assert.assertNull import org.junit.Test /** * Round-trip tests for the at-rest JSON serialization used by * `KeystoreStorage`. The format isn't cross-platform (TS uses its * own shape) — what matters is `serialize → deserialize` preserves * every byte of every key. */ class SessionStateJsonTest { private fun bytes(n: Int, fill: Byte): ByteArray = ByteArray(n) { fill } @Test fun identityKeyPairRoundTrip() { val k = IdentityKeyPair( signingPublicKey = bytes(32, 0x11), signingPrivateKey = bytes(32, 0x22), dhPublicKey = bytes(32, 0x33), dhPrivateKey = bytes(32, 0x44), ) val s = SessionStateJson.serializeIdentityKeyPair(k) val d = SessionStateJson.deserializeIdentityKeyPair(s) assertArrayEquals(k.signingPublicKey, d.signingPublicKey) assertArrayEquals(k.signingPrivateKey, d.signingPrivateKey) assertArrayEquals(k.dhPublicKey, d.dhPublicKey) assertArrayEquals(k.dhPrivateKey, d.dhPrivateKey) } @Test fun signedPreKeyRoundTrip() { val k = SignedPreKey( keyId = 42, keyPair = KeyPair(publicKey = bytes(32, 0x55), privateKey = bytes(32, 0x66)), signature = bytes(64, 0x77), timestamp = 1_700_000_000_000L, ) val s = SessionStateJson.serializeSignedPreKey(k) val d = SessionStateJson.deserializeSignedPreKey(s) assertEquals(k.keyId, d.keyId) assertArrayEquals(k.keyPair.publicKey, d.keyPair.publicKey) assertArrayEquals(k.keyPair.privateKey, d.keyPair.privateKey) assertArrayEquals(k.signature, d.signature) assertEquals(k.timestamp, d.timestamp) } @Test fun oneTimePreKeyRoundTrip() { val k = OneTimePreKey( keyId = 7, keyPair = KeyPair(publicKey = bytes(32, 0x88.toByte()), privateKey = bytes(32, 0x99.toByte())), ) val s = SessionStateJson.serializeOneTimePreKey(k) val d = SessionStateJson.deserializeOneTimePreKey(s) assertEquals(k.keyId, d.keyId) assertArrayEquals(k.keyPair.publicKey, d.keyPair.publicKey) assertArrayEquals(k.keyPair.privateKey, d.keyPair.privateKey) } @Test fun sessionStateRoundTripFullPopulated() { val state = SessionState( remoteIdentityKey = bytes(32, 0x01), rootKey = bytes(32, 0x02), sendChain = ChainState(chainKey = bytes(32, 0x03), counter = 5), receiveChain = ChainState(chainKey = bytes(32, 0x04), counter = 3), dhSend = KeyPair(publicKey = bytes(32, 0x05), privateKey = bytes(32, 0x06)), dhReceive = bytes(32, 0x07), previousSendCounter = 9, skippedKeys = mutableMapOf( "remote:1" to bytes(32, 0x0A), "remote:2" to bytes(32, 0x0B), ), ) val s = SessionStateJson.serialize(state) val d = SessionStateJson.deserialize(s) assertArrayEquals(state.remoteIdentityKey, d.remoteIdentityKey) assertArrayEquals(state.rootKey, d.rootKey) assertArrayEquals(state.sendChain.chainKey, d.sendChain.chainKey) assertEquals(state.sendChain.counter, d.sendChain.counter) assertNotNull(d.receiveChain) assertArrayEquals(state.receiveChain!!.chainKey, d.receiveChain!!.chainKey) assertArrayEquals(state.dhSend.publicKey, d.dhSend.publicKey) assertArrayEquals(state.dhSend.privateKey, d.dhSend.privateKey) assertArrayEquals(state.dhReceive, d.dhReceive) assertEquals(state.previousSendCounter, d.previousSendCounter) assertEquals(state.skippedKeys.size, d.skippedKeys.size) for ((k, v) in state.skippedKeys) { assertArrayEquals(v, d.skippedKeys[k]) } } @Test fun sessionStateRoundTripWithNullableFields() { val state = SessionState( remoteIdentityKey = bytes(32, 0x01), rootKey = bytes(32, 0x02), sendChain = ChainState(chainKey = bytes(32, 0x03), counter = 0), receiveChain = null, dhSend = KeyPair(publicKey = bytes(32, 0x05), privateKey = bytes(32, 0x06)), dhReceive = null, previousSendCounter = 0, skippedKeys = mutableMapOf(), ) val s = SessionStateJson.serialize(state) val d = SessionStateJson.deserialize(s) assertNull(d.receiveChain) assertNull(d.dhReceive) assertEquals(0, d.skippedKeys.size) } }