124 lines
4.9 KiB
Kotlin
124 lines
4.9 KiB
Kotlin
|
|
package no.zyon.shade
|
||
|
|
|
||
|
|
import no.zyon.shade.serialization.SessionStateJson
|
||
|
|
import no.zyon.shade.types.ChainState
|
||
|
|
import no.zyon.shade.types.IdentityKeyPair
|
||
|
|
import no.zyon.shade.types.KeyPair
|
||
|
|
import no.zyon.shade.types.OneTimePreKey
|
||
|
|
import no.zyon.shade.types.SessionState
|
||
|
|
import no.zyon.shade.types.SignedPreKey
|
||
|
|
import org.junit.Assert.assertArrayEquals
|
||
|
|
import org.junit.Assert.assertEquals
|
||
|
|
import org.junit.Assert.assertNotNull
|
||
|
|
import org.junit.Assert.assertNull
|
||
|
|
import org.junit.Test
|
||
|
|
|
||
|
|
/**
|
||
|
|
* Round-trip tests for the at-rest JSON serialization used by
|
||
|
|
* `KeystoreStorage`. The format isn't cross-platform (TS uses its
|
||
|
|
* own shape) — what matters is `serialize → deserialize` preserves
|
||
|
|
* every byte of every key.
|
||
|
|
*/
|
||
|
|
class SessionStateJsonTest {
|
||
|
|
|
||
|
|
private fun bytes(n: Int, fill: Byte): ByteArray = ByteArray(n) { fill }
|
||
|
|
|
||
|
|
@Test
|
||
|
|
fun identityKeyPairRoundTrip() {
|
||
|
|
val k = IdentityKeyPair(
|
||
|
|
signingPublicKey = bytes(32, 0x11),
|
||
|
|
signingPrivateKey = bytes(32, 0x22),
|
||
|
|
dhPublicKey = bytes(32, 0x33),
|
||
|
|
dhPrivateKey = bytes(32, 0x44),
|
||
|
|
)
|
||
|
|
val s = SessionStateJson.serializeIdentityKeyPair(k)
|
||
|
|
val d = SessionStateJson.deserializeIdentityKeyPair(s)
|
||
|
|
assertArrayEquals(k.signingPublicKey, d.signingPublicKey)
|
||
|
|
assertArrayEquals(k.signingPrivateKey, d.signingPrivateKey)
|
||
|
|
assertArrayEquals(k.dhPublicKey, d.dhPublicKey)
|
||
|
|
assertArrayEquals(k.dhPrivateKey, d.dhPrivateKey)
|
||
|
|
}
|
||
|
|
|
||
|
|
@Test
|
||
|
|
fun signedPreKeyRoundTrip() {
|
||
|
|
val k = SignedPreKey(
|
||
|
|
keyId = 42,
|
||
|
|
keyPair = KeyPair(publicKey = bytes(32, 0x55), privateKey = bytes(32, 0x66)),
|
||
|
|
signature = bytes(64, 0x77),
|
||
|
|
timestamp = 1_700_000_000_000L,
|
||
|
|
)
|
||
|
|
val s = SessionStateJson.serializeSignedPreKey(k)
|
||
|
|
val d = SessionStateJson.deserializeSignedPreKey(s)
|
||
|
|
assertEquals(k.keyId, d.keyId)
|
||
|
|
assertArrayEquals(k.keyPair.publicKey, d.keyPair.publicKey)
|
||
|
|
assertArrayEquals(k.keyPair.privateKey, d.keyPair.privateKey)
|
||
|
|
assertArrayEquals(k.signature, d.signature)
|
||
|
|
assertEquals(k.timestamp, d.timestamp)
|
||
|
|
}
|
||
|
|
|
||
|
|
@Test
|
||
|
|
fun oneTimePreKeyRoundTrip() {
|
||
|
|
val k = OneTimePreKey(
|
||
|
|
keyId = 7,
|
||
|
|
keyPair = KeyPair(publicKey = bytes(32, 0x88.toByte()), privateKey = bytes(32, 0x99.toByte())),
|
||
|
|
)
|
||
|
|
val s = SessionStateJson.serializeOneTimePreKey(k)
|
||
|
|
val d = SessionStateJson.deserializeOneTimePreKey(s)
|
||
|
|
assertEquals(k.keyId, d.keyId)
|
||
|
|
assertArrayEquals(k.keyPair.publicKey, d.keyPair.publicKey)
|
||
|
|
assertArrayEquals(k.keyPair.privateKey, d.keyPair.privateKey)
|
||
|
|
}
|
||
|
|
|
||
|
|
@Test
|
||
|
|
fun sessionStateRoundTripFullPopulated() {
|
||
|
|
val state = SessionState(
|
||
|
|
remoteIdentityKey = bytes(32, 0x01),
|
||
|
|
rootKey = bytes(32, 0x02),
|
||
|
|
sendChain = ChainState(chainKey = bytes(32, 0x03), counter = 5),
|
||
|
|
receiveChain = ChainState(chainKey = bytes(32, 0x04), counter = 3),
|
||
|
|
dhSend = KeyPair(publicKey = bytes(32, 0x05), privateKey = bytes(32, 0x06)),
|
||
|
|
dhReceive = bytes(32, 0x07),
|
||
|
|
previousSendCounter = 9,
|
||
|
|
skippedKeys = mutableMapOf(
|
||
|
|
"remote:1" to bytes(32, 0x0A),
|
||
|
|
"remote:2" to bytes(32, 0x0B),
|
||
|
|
),
|
||
|
|
)
|
||
|
|
val s = SessionStateJson.serialize(state)
|
||
|
|
val d = SessionStateJson.deserialize(s)
|
||
|
|
assertArrayEquals(state.remoteIdentityKey, d.remoteIdentityKey)
|
||
|
|
assertArrayEquals(state.rootKey, d.rootKey)
|
||
|
|
assertArrayEquals(state.sendChain.chainKey, d.sendChain.chainKey)
|
||
|
|
assertEquals(state.sendChain.counter, d.sendChain.counter)
|
||
|
|
assertNotNull(d.receiveChain)
|
||
|
|
assertArrayEquals(state.receiveChain!!.chainKey, d.receiveChain!!.chainKey)
|
||
|
|
assertArrayEquals(state.dhSend.publicKey, d.dhSend.publicKey)
|
||
|
|
assertArrayEquals(state.dhSend.privateKey, d.dhSend.privateKey)
|
||
|
|
assertArrayEquals(state.dhReceive, d.dhReceive)
|
||
|
|
assertEquals(state.previousSendCounter, d.previousSendCounter)
|
||
|
|
assertEquals(state.skippedKeys.size, d.skippedKeys.size)
|
||
|
|
for ((k, v) in state.skippedKeys) {
|
||
|
|
assertArrayEquals(v, d.skippedKeys[k])
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
@Test
|
||
|
|
fun sessionStateRoundTripWithNullableFields() {
|
||
|
|
val state = SessionState(
|
||
|
|
remoteIdentityKey = bytes(32, 0x01),
|
||
|
|
rootKey = bytes(32, 0x02),
|
||
|
|
sendChain = ChainState(chainKey = bytes(32, 0x03), counter = 0),
|
||
|
|
receiveChain = null,
|
||
|
|
dhSend = KeyPair(publicKey = bytes(32, 0x05), privateKey = bytes(32, 0x06)),
|
||
|
|
dhReceive = null,
|
||
|
|
previousSendCounter = 0,
|
||
|
|
skippedKeys = mutableMapOf(),
|
||
|
|
)
|
||
|
|
val s = SessionStateJson.serialize(state)
|
||
|
|
val d = SessionStateJson.deserialize(s)
|
||
|
|
assertNull(d.receiveChain)
|
||
|
|
assertNull(d.dhReceive)
|
||
|
|
assertEquals(0, d.skippedKeys.size)
|
||
|
|
}
|
||
|
|
}
|