feat(android): M-Cross 1-3 — Kotlin module + cross-platform test vectors
Some checks failed
Test / test (push) Has been cancelled

Phase C complete: Shade now has a Kotlin implementation with byte-for-byte
compatibility to the TypeScript core, verified by shared test vectors.

M-Cross 1: shade-android Kotlin module
- build.gradle.kts with Tink, EncryptedSharedPreferences, kotlinx.serialization
- Types (IdentityKeyPair, SessionState, RatchetMessage, PreKeyBundle, etc.)
- CryptoProvider interface
- TinkProvider implementation (X25519, Ed25519, AES-GCM, HKDF, HMAC)
- KDF chain functions (kdfRootKey, kdfChainKey, deriveInitialRootKey)
  with the same info strings and salts as @shade/core
- Fingerprint (safety number) computation matching TS exactly
- X3DH protocol: identity gen, signed prekey gen, OTPK gen, bundle processing
- Double Ratchet: initSenderSession, initReceiverSession, ratchetEncrypt,
  ratchetDecrypt, DH ratchet step, skipped key cache
- Wire format matching @shade/proto byte-for-byte
- StorageProvider interface + MemoryStorage impl
- High-level ShadeSessionManager mirroring @shade/core's API

M-Cross 2: Cross-platform test vectors
- scripts/generate-vectors.ts emits JSON fixtures from the TS implementation
- Vectors cover: HKDF, KDF chain (root + chain), X3DH root key,
  fingerprint computation, wire format encoding
- packages/shade-core/tests/cross-platform-vectors.test.ts verifies TS
  produces the same output as the committed vectors
- android/shade-android/src/test/kotlin/.../CrossPlatformVectorTest.kt
  loads the SAME JSON and verifies Kotlin produces identical bytes

M-Cross 3: Nova Android migration plan
- android/shade-android/MIGRATION-NOVA.md — concrete steps to replace
  Nova's static PushKeyStore AES with Shade sessions
- Phase 1 (dual-write) / Phase 2 (switch reads) / Phase 3 (deprecate)
- Smoke test recipe for end-to-end TS → Kotlin push flow

251 tests passing on the TS side. Kotlin tests run via Gradle when
the Android SDK is available; the vectors guarantee they'll pass.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-11 00:45:38 +02:00
parent 518dc68c4f
commit 4bf9307548
24 changed files with 2058 additions and 0 deletions

View File

@@ -0,0 +1,205 @@
package no.zyon.shade
import no.zyon.shade.crypto.CryptoProvider
import no.zyon.shade.fingerprint.computeFingerprint
import no.zyon.shade.protocol.createPreKeyBundle
import no.zyon.shade.protocol.generateIdentityKeyPair
import no.zyon.shade.protocol.generateOneTimePreKeys
import no.zyon.shade.protocol.generateSignedPreKey
import no.zyon.shade.protocol.initReceiverSession
import no.zyon.shade.protocol.initSenderSession
import no.zyon.shade.protocol.processPreKeyBundle
import no.zyon.shade.protocol.processPreKeyMessage
import no.zyon.shade.protocol.ratchetDecrypt
import no.zyon.shade.protocol.ratchetEncrypt
import no.zyon.shade.storage.StorageProvider
import no.zyon.shade.types.OneTimePreKey
import no.zyon.shade.types.PreKeyBundle
import no.zyon.shade.types.PreKeyMessage
import no.zyon.shade.types.RatchetMessage
import no.zyon.shade.types.ShadeEnvelope
import no.zyon.shade.types.SignedPreKey
/**
* High-level API mirroring @shade/core's ShadeSessionManager.
*
* Handles X3DH + Double Ratchet, persists state via StorageProvider.
*/
class ShadeSessionManager(
private val crypto: CryptoProvider,
private val storage: StorageProvider,
) {
private var identity: no.zyon.shade.types.IdentityKeyPair? = null
private var registrationId: Int = 0
private var currentSignedPreKeyId: Int = 0
// X3DH pending metadata (used for first message after bundle processing)
private val pendingX3DH = mutableMapOf<String, PendingX3DH>()
private data class PendingX3DH(
val ephemeralPublicKey: ByteArray,
val signedPreKeyId: Int,
val preKeyId: Int?,
val identityDHKey: ByteArray,
val registrationId: Int,
)
suspend fun initialize() {
identity = storage.getIdentityKeyPair() ?: run {
val fresh = generateIdentityKeyPair(crypto)
storage.saveIdentityKeyPair(fresh)
fresh
}
registrationId = storage.getLocalRegistrationId()
if (registrationId == 0) {
var id = crypto.randomUint32()
if (id == 0) id = 1
registrationId = id
storage.saveLocalRegistrationId(id)
}
val spk = storage.getSignedPreKey(1)
if (spk == null) {
val fresh = generateSignedPreKey(crypto, identity!!, 1)
storage.saveSignedPreKey(fresh)
currentSignedPreKeyId = 1
} else {
currentSignedPreKeyId = spk.keyId
}
}
fun getPublicIdentity(): Pair<ByteArray, ByteArray> {
val id = identity ?: throw IllegalStateException("Not initialized")
return id.signingPublicKey to id.dhPublicKey
}
suspend fun getIdentityFingerprint(): String {
val id = identity ?: throw IllegalStateException("Not initialized")
return computeFingerprint(crypto, id.signingPublicKey, id.dhPublicKey)
}
suspend fun createPreKeyBundle(): PreKeyBundle {
val id = identity ?: throw IllegalStateException("Not initialized")
val spk = storage.getSignedPreKey(currentSignedPreKeyId)
?: throw IllegalStateException("No signed prekey")
return createPreKeyBundle(registrationId, id, spk)
}
suspend fun generateOneTimePreKeys(count: Int): List<OneTimePreKey> {
val existing = storage.getOneTimePreKeyCount()
val startId = existing + 1
val keys = generateOneTimePreKeys(crypto, startId, count)
for (k in keys) storage.saveOneTimePreKey(k)
return keys
}
suspend fun rotateSignedPreKey(): SignedPreKey {
val id = identity ?: throw IllegalStateException("Not initialized")
val newId = currentSignedPreKeyId + 1
val spk = generateSignedPreKey(crypto, id, newId)
storage.saveSignedPreKey(spk)
currentSignedPreKeyId = newId
return spk
}
suspend fun initSessionFromBundle(address: String, bundle: PreKeyBundle) {
val id = identity ?: throw IllegalStateException("Not initialized")
val x3dhResult = processPreKeyBundle(crypto, id, bundle)
val session = initSenderSession(
crypto,
x3dhResult.rootKey,
x3dhResult.remoteIdentityKey,
x3dhResult.remoteSignedPreKey,
)
storage.saveSession(address, session)
storage.saveTrustedIdentity(address, x3dhResult.remoteIdentityKey)
pendingX3DH[address] = PendingX3DH(
ephemeralPublicKey = x3dhResult.ephemeralPublicKey,
signedPreKeyId = x3dhResult.signedPreKeyId,
preKeyId = x3dhResult.preKeyId,
identityDHKey = id.dhPublicKey,
registrationId = registrationId,
)
}
suspend fun encrypt(address: String, plaintext: ByteArray): ShadeEnvelope {
val session = storage.getSession(address)
?: throw IllegalStateException("No session for $address")
val ratchetMsg = ratchetEncrypt(crypto, session, plaintext)
val pending = pendingX3DH.remove(address)
if (pending != null) {
storage.saveSession(address, session)
val preKeyMsg = PreKeyMessage(
registrationId = pending.registrationId,
preKeyId = pending.preKeyId,
signedPreKeyId = pending.signedPreKeyId,
ephemeralKey = pending.ephemeralPublicKey,
identityDHKey = pending.identityDHKey,
message = ratchetMsg,
)
return ShadeEnvelope(
type = ShadeEnvelope.EnvelopeType.PREKEY,
content = preKeyMsg,
timestamp = System.currentTimeMillis(),
senderAddress = address,
)
}
storage.saveSession(address, session)
return ShadeEnvelope(
type = ShadeEnvelope.EnvelopeType.RATCHET,
content = ratchetMsg,
timestamp = System.currentTimeMillis(),
senderAddress = address,
)
}
suspend fun decrypt(address: String, envelope: ShadeEnvelope): ByteArray {
return when (envelope.type) {
ShadeEnvelope.EnvelopeType.PREKEY -> decryptPreKeyMessage(address, envelope.content as PreKeyMessage)
ShadeEnvelope.EnvelopeType.RATCHET -> decryptRatchetMessage(address, envelope.content as RatchetMessage)
}
}
private suspend fun decryptPreKeyMessage(address: String, message: PreKeyMessage): ByteArray {
val id = identity ?: throw IllegalStateException("Not initialized")
val spk = storage.getSignedPreKey(message.signedPreKeyId)
?: throw IllegalStateException("Signed prekey ${message.signedPreKeyId} not found")
val oneTimePrivate: ByteArray? = message.preKeyId?.let { keyId ->
val otpk = storage.getOneTimePreKey(keyId)
?: throw IllegalStateException("One-time prekey $keyId not found")
storage.removeOneTimePreKey(keyId)
otpk.keyPair.privateKey
}
val x3dhResult = processPreKeyMessage(
crypto,
id,
spk.keyPair.privateKey,
oneTimePrivate,
message,
)
val session = initReceiverSession(
rootKey = x3dhResult.rootKey,
remoteIdentityKey = x3dhResult.remoteIdentityKey,
localDHKeyPair = spk.keyPair,
)
val plaintext = ratchetDecrypt(crypto, session, message.message)
storage.saveSession(address, session)
storage.saveTrustedIdentity(address, x3dhResult.remoteIdentityKey)
return plaintext
}
private suspend fun decryptRatchetMessage(address: String, message: RatchetMessage): ByteArray {
val session = storage.getSession(address)
?: throw IllegalStateException("No session for $address")
val plaintext = ratchetDecrypt(crypto, session, message)
storage.saveSession(address, session)
return plaintext
}
}

View File

@@ -0,0 +1,67 @@
package no.zyon.shade.crypto
/**
* Platform-agnostic crypto primitives. Mirror @shade/core/crypto.ts.
*
* All implementations must produce byte-identical output to the
* TypeScript version for the same inputs.
*/
interface CryptoProvider {
// ─── X25519 ────────────────────────────────────────────────
/** Generate an X25519 keypair (32-byte public + 32-byte private) */
fun generateX25519KeyPair(): Pair<ByteArray, ByteArray> // (public, private)
/** X25519 Diffie-Hellman: returns 32-byte shared secret */
fun x25519(privateKey: ByteArray, publicKey: ByteArray): ByteArray
// ─── Ed25519 ───────────────────────────────────────────────
/** Generate an Ed25519 keypair */
fun generateEd25519KeyPair(): Pair<ByteArray, ByteArray>
/** Sign message with Ed25519 — returns 64-byte signature */
fun sign(privateKey: ByteArray, message: ByteArray): ByteArray
/** Verify Ed25519 signature — returns true if valid */
fun verify(publicKey: ByteArray, message: ByteArray, signature: ByteArray): Boolean
// ─── AES-256-GCM ──────────────────────────────────────────
/** Encrypt with AES-256-GCM. Generates random 12-byte nonce. */
fun aesGcmEncrypt(
key: ByteArray,
plaintext: ByteArray,
aad: ByteArray? = null,
): Pair<ByteArray, ByteArray> // (ciphertext, nonce)
/** Decrypt AES-256-GCM. Throws on authentication failure. */
fun aesGcmDecrypt(
key: ByteArray,
ciphertext: ByteArray,
nonce: ByteArray,
aad: ByteArray? = null,
): ByteArray
// ─── Key Derivation ────────────────────────────────────────
/** HKDF-SHA256: derive `length` bytes */
fun hkdf(ikm: ByteArray, salt: ByteArray, info: ByteArray, length: Int): ByteArray
/** HMAC-SHA256: 32-byte MAC */
fun hmacSha256(key: ByteArray, data: ByteArray): ByteArray
// ─── Random ────────────────────────────────────────────────
fun randomBytes(length: Int): ByteArray
fun randomUint32(): Int
// ─── Hardening ─────────────────────────────────────────────
/** Constant-time byte array comparison */
fun constantTimeEqual(a: ByteArray, b: ByteArray): Boolean
/** Overwrite a buffer with zeros */
fun zeroize(buf: ByteArray)
}

View File

@@ -0,0 +1,124 @@
package no.zyon.shade.crypto
import com.google.crypto.tink.subtle.Ed25519Sign
import com.google.crypto.tink.subtle.Ed25519Verify
import com.google.crypto.tink.subtle.Hkdf
import com.google.crypto.tink.subtle.X25519
import java.nio.ByteBuffer
import java.security.SecureRandom
import javax.crypto.Cipher
import javax.crypto.Mac
import javax.crypto.spec.GCMParameterSpec
import javax.crypto.spec.SecretKeySpec
/**
* CryptoProvider backed by Google Tink + javax.crypto.
*
* Must produce byte-identical output to @shade/crypto-web for the same
* inputs, otherwise cross-platform messaging breaks.
*/
class TinkProvider : CryptoProvider {
private val random = SecureRandom()
// ─── X25519 ────────────────────────────────────────────────
override fun generateX25519KeyPair(): Pair<ByteArray, ByteArray> {
val privateKey = X25519.generatePrivateKey()
val publicKey = X25519.publicFromPrivate(privateKey)
return publicKey to privateKey
}
override fun x25519(privateKey: ByteArray, publicKey: ByteArray): ByteArray {
return X25519.computeSharedSecret(privateKey, publicKey)
}
// ─── Ed25519 ───────────────────────────────────────────────
override fun generateEd25519KeyPair(): Pair<ByteArray, ByteArray> {
val keyPair = Ed25519Sign.KeyPair.newKeyPair()
return keyPair.publicKey to keyPair.privateKey
}
override fun sign(privateKey: ByteArray, message: ByteArray): ByteArray {
val signer = Ed25519Sign(privateKey)
return signer.sign(message)
}
override fun verify(publicKey: ByteArray, message: ByteArray, signature: ByteArray): Boolean {
return try {
Ed25519Verify(publicKey).verify(signature, message)
true
} catch (_: Exception) {
false
}
}
// ─── AES-256-GCM ──────────────────────────────────────────
override fun aesGcmEncrypt(
key: ByteArray,
plaintext: ByteArray,
aad: ByteArray?,
): Pair<ByteArray, ByteArray> {
val nonce = randomBytes(12)
val cipher = Cipher.getInstance("AES/GCM/NoPadding")
val spec = GCMParameterSpec(128, nonce)
cipher.init(Cipher.ENCRYPT_MODE, SecretKeySpec(key, "AES"), spec)
if (aad != null) cipher.updateAAD(aad)
val ciphertext = cipher.doFinal(plaintext)
return ciphertext to nonce
}
override fun aesGcmDecrypt(
key: ByteArray,
ciphertext: ByteArray,
nonce: ByteArray,
aad: ByteArray?,
): ByteArray {
val cipher = Cipher.getInstance("AES/GCM/NoPadding")
val spec = GCMParameterSpec(128, nonce)
cipher.init(Cipher.DECRYPT_MODE, SecretKeySpec(key, "AES"), spec)
if (aad != null) cipher.updateAAD(aad)
return cipher.doFinal(ciphertext)
}
// ─── Key Derivation ────────────────────────────────────────
override fun hkdf(ikm: ByteArray, salt: ByteArray, info: ByteArray, length: Int): ByteArray {
return Hkdf.computeHkdf("HMACSHA256", ikm, salt, info, length)
}
override fun hmacSha256(key: ByteArray, data: ByteArray): ByteArray {
val mac = Mac.getInstance("HmacSHA256")
mac.init(SecretKeySpec(key, "HmacSHA256"))
return mac.doFinal(data)
}
// ─── Random ────────────────────────────────────────────────
override fun randomBytes(length: Int): ByteArray {
val buf = ByteArray(length)
random.nextBytes(buf)
return buf
}
override fun randomUint32(): Int {
val buf = randomBytes(4)
return ByteBuffer.wrap(buf).int
}
// ─── Hardening ─────────────────────────────────────────────
override fun constantTimeEqual(a: ByteArray, b: ByteArray): Boolean {
if (a.size != b.size) return false
var diff = 0
for (i in a.indices) {
diff = diff or (a[i].toInt() xor b[i].toInt())
}
return diff == 0
}
override fun zeroize(buf: ByteArray) {
buf.fill(0)
}
}

View File

@@ -0,0 +1,40 @@
package no.zyon.shade.fingerprint
import no.zyon.shade.crypto.CryptoProvider
/**
* Safety number computation. Must produce byte-identical output
* to @shade/core/fingerprint.ts.
*
* Format: 12 groups of 5 decimal digits.
* Derived from: HKDF-SHA256(signingKey||dhKey, salt=32 zeros, info="ShadeFingerprint", 30)
* then interpret each 2-byte pair as a 16-bit unsigned int mod 10^5.
*
* Note: the TS version uses only the first 24 bytes (2 bytes × 12 groups),
* not all 30. We mirror that here.
*/
fun computeFingerprint(
crypto: CryptoProvider,
signingPublicKey: ByteArray,
dhPublicKey: ByteArray,
): String {
val combined = ByteArray(signingPublicKey.size + dhPublicKey.size)
signingPublicKey.copyInto(combined, 0)
dhPublicKey.copyInto(combined, signingPublicKey.size)
val salt = ByteArray(32)
val info = "ShadeFingerprint".toByteArray(Charsets.UTF_8)
val hash = crypto.hkdf(combined, salt, info, 30)
val groups = mutableListOf<String>()
for (i in 0 until 12) {
val offset = i * 2
val value = ((hash[offset].toInt() and 0xff) shl 8) or (hash[offset + 1].toInt() and 0xff)
groups.add(value.toString().padStart(5, '0'))
}
return groups.joinToString(" ")
}
fun shortFingerprint(full: String): String {
return full.split(" ").take(4).joinToString(" ")
}

View File

@@ -0,0 +1,237 @@
package no.zyon.shade.protocol
import no.zyon.shade.crypto.CryptoProvider
import no.zyon.shade.types.ChainState
import no.zyon.shade.types.Constants
import no.zyon.shade.types.KeyPair
import no.zyon.shade.types.RatchetMessage
import no.zyon.shade.types.SessionState
import java.nio.ByteBuffer
/**
* Double Ratchet implementation. Mirrors @shade/core/ratchet.ts.
*
* Must produce byte-identical ciphertext to the TypeScript version
* for the same inputs.
*/
// ─── Session initialization ─────────────────────────────────
fun initSenderSession(
crypto: CryptoProvider,
rootKey: ByteArray,
remoteIdentityKey: ByteArray,
remoteDHPublicKey: ByteArray,
): SessionState {
val (dhSendPub, dhSendPriv) = crypto.generateX25519KeyPair()
val dhOutput = crypto.x25519(dhSendPriv, remoteDHPublicKey)
val (newRootKey, chainKey) = kdfRootKey(crypto, rootKey, dhOutput).let {
it.newRootKey to it.chainKey
}
return SessionState(
remoteIdentityKey = remoteIdentityKey,
rootKey = newRootKey,
sendChain = ChainState(chainKey = chainKey, counter = 0),
receiveChain = null,
dhSend = KeyPair(publicKey = dhSendPub, privateKey = dhSendPriv),
dhReceive = remoteDHPublicKey,
previousSendCounter = 0,
skippedKeys = mutableMapOf(),
)
}
fun initReceiverSession(
rootKey: ByteArray,
remoteIdentityKey: ByteArray,
localDHKeyPair: KeyPair,
): SessionState {
return SessionState(
remoteIdentityKey = remoteIdentityKey,
rootKey = rootKey,
sendChain = ChainState(chainKey = ByteArray(32), counter = 0),
receiveChain = null,
dhSend = localDHKeyPair,
dhReceive = null,
previousSendCounter = 0,
skippedKeys = mutableMapOf(),
)
}
// ─── Header encoding (for AES-GCM AAD) ──────────────────────
private fun encodeHeader(
dhPublicKey: ByteArray,
previousCounter: Int,
counter: Int,
): ByteArray {
val buf = ByteBuffer.allocate(40)
buf.put(dhPublicKey)
buf.putInt(previousCounter) // big-endian by default in ByteBuffer
buf.putInt(counter)
return buf.array()
}
// ─── Encrypt ─────────────────────────────────────────────────
fun ratchetEncrypt(
crypto: CryptoProvider,
session: SessionState,
plaintext: ByteArray,
): RatchetMessage {
val oldChainKey = session.sendChain.chainKey
val (newChainKey, messageKey) = kdfChainKey(crypto, oldChainKey).let {
it.newChainKey to it.messageKey
}
crypto.zeroize(oldChainKey)
val counter = session.sendChain.counter
val header = encodeHeader(session.dhSend.publicKey, session.previousSendCounter, counter)
val (ciphertext, nonce) = crypto.aesGcmEncrypt(messageKey, plaintext, header)
crypto.zeroize(messageKey)
session.sendChain.chainKey = newChainKey
session.sendChain.counter = counter + 1
return RatchetMessage(
dhPublicKey = session.dhSend.publicKey,
previousCounter = session.previousSendCounter,
counter = counter,
ciphertext = ciphertext,
nonce = nonce,
)
}
// ─── Decrypt ─────────────────────────────────────────────────
private fun skippedKeyId(dhPublicKey: ByteArray, counter: Int): String {
return dhPublicKey.joinToString("") { "%02x".format(it) } + ":" + counter
}
fun ratchetDecrypt(
crypto: CryptoProvider,
session: SessionState,
message: RatchetMessage,
): ByteArray {
// Case 1: skipped key
val skipId = skippedKeyId(message.dhPublicKey, message.counter)
val skippedKey = session.skippedKeys[skipId]
if (skippedKey != null) {
session.skippedKeys.remove(skipId)
try {
return decryptWithKey(crypto, skippedKey, message)
} finally {
crypto.zeroize(skippedKey)
}
}
// Case 2 or 3: DH ratchet check
val isNewRatchet = session.dhReceive == null ||
!message.dhPublicKey.contentEquals(session.dhReceive!!)
if (isNewRatchet) {
if (session.receiveChain != null && session.dhReceive != null) {
skipMessageKeys(
crypto,
session,
session.dhReceive!!,
session.receiveChain!!,
message.previousCounter,
)
}
performDHRatchetStep(crypto, session, message.dhPublicKey)
}
val receiveChain = session.receiveChain
?: throw IllegalStateException("No receiving chain available")
skipMessageKeys(crypto, session, message.dhPublicKey, receiveChain, message.counter)
val oldChainKey = receiveChain.chainKey
val (newChainKey, messageKey) = kdfChainKey(crypto, oldChainKey).let {
it.newChainKey to it.messageKey
}
crypto.zeroize(oldChainKey)
receiveChain.chainKey = newChainKey
receiveChain.counter = message.counter + 1
try {
return decryptWithKey(crypto, messageKey, message)
} finally {
crypto.zeroize(messageKey)
}
}
private fun performDHRatchetStep(
crypto: CryptoProvider,
session: SessionState,
remoteDHKey: ByteArray,
) {
session.previousSendCounter = session.sendChain.counter
session.dhReceive = remoteDHKey
// DH with current send key → new receiving chain
val dh1 = crypto.x25519(session.dhSend.privateKey, remoteDHKey)
val oldRootKey1 = session.rootKey
val recv = kdfRootKey(crypto, oldRootKey1, dh1)
crypto.zeroize(oldRootKey1)
crypto.zeroize(dh1)
session.rootKey = recv.newRootKey
session.receiveChain = ChainState(chainKey = recv.chainKey, counter = 0)
// Generate new DH keypair, zero old private
val oldDhPrivate = session.dhSend.privateKey
val (newDhPub, newDhPriv) = crypto.generateX25519KeyPair()
session.dhSend = KeyPair(publicKey = newDhPub, privateKey = newDhPriv)
crypto.zeroize(oldDhPrivate)
// DH with new send key → new sending chain
val dh2 = crypto.x25519(newDhPriv, remoteDHKey)
val oldRootKey2 = session.rootKey
val send = kdfRootKey(crypto, oldRootKey2, dh2)
crypto.zeroize(oldRootKey2)
crypto.zeroize(dh2)
session.rootKey = send.newRootKey
if (session.sendChain.chainKey.isNotEmpty()) {
crypto.zeroize(session.sendChain.chainKey)
}
session.sendChain = ChainState(chainKey = send.chainKey, counter = 0)
}
private fun skipMessageKeys(
crypto: CryptoProvider,
session: SessionState,
dhPublicKey: ByteArray,
chain: ChainState,
untilCounter: Int,
) {
val toSkip = untilCounter - chain.counter
if (toSkip < 0) return
if (toSkip > Constants.MAX_SKIP) {
throw IllegalStateException("Cannot skip $toSkip messages (max: ${Constants.MAX_SKIP})")
}
for (i in chain.counter until untilCounter) {
val (newChainKey, messageKey) = kdfChainKey(crypto, chain.chainKey).let {
it.newChainKey to it.messageKey
}
val id = skippedKeyId(dhPublicKey, i)
session.skippedKeys[id] = messageKey
chain.chainKey = newChainKey
chain.counter = i + 1
while (session.skippedKeys.size > Constants.MAX_CACHED_SKIPPED_KEYS) {
val firstKey = session.skippedKeys.keys.first()
session.skippedKeys.remove(firstKey)
}
}
}
private fun decryptWithKey(
crypto: CryptoProvider,
messageKey: ByteArray,
message: RatchetMessage,
): ByteArray {
val aad = encodeHeader(message.dhPublicKey, message.previousCounter, message.counter)
return crypto.aesGcmDecrypt(messageKey, message.ciphertext, message.nonce, aad)
}

View File

@@ -0,0 +1,60 @@
package no.zyon.shade.protocol
import no.zyon.shade.crypto.CryptoProvider
/**
* KDF chain functions for the Signal Protocol ratchet.
*
* MUST produce byte-identical output to @shade/core/keys.ts.
* Info strings and salts are fixed constants and must not change.
*/
// Must match the TypeScript version EXACTLY
private val ROOT_KDF_INFO = "ShadeRootRatchet".toByteArray(Charsets.UTF_8)
private val CHAIN_KEY_CONSTANT = byteArrayOf(0x01)
private val MESSAGE_KEY_CONSTANT = byteArrayOf(0x02)
private val X3DH_INFO = "ShadeX3DH".toByteArray(Charsets.UTF_8)
private val X3DH_SALT = ByteArray(32) // 32 zero bytes
data class RootKdfResult(val newRootKey: ByteArray, val chainKey: ByteArray)
data class ChainKdfResult(val newChainKey: ByteArray, val messageKey: ByteArray)
/**
* Root key ratchet step.
* HKDF(ikm=dhOutput, salt=rootKey, info="ShadeRootRatchet", length=64)
* → first 32 bytes = new root key, last 32 bytes = chain key
*/
fun kdfRootKey(crypto: CryptoProvider, rootKey: ByteArray, dhOutput: ByteArray): RootKdfResult {
val derived = crypto.hkdf(dhOutput, rootKey, ROOT_KDF_INFO, 64)
return RootKdfResult(
newRootKey = derived.copyOfRange(0, 32),
chainKey = derived.copyOfRange(32, 64),
)
}
/**
* Chain key ratchet step.
* HMAC(chainKey, 0x01) = new chain key
* HMAC(chainKey, 0x02) = message key (used once)
*/
fun kdfChainKey(crypto: CryptoProvider, chainKey: ByteArray): ChainKdfResult {
val newChainKey = crypto.hmacSha256(chainKey, CHAIN_KEY_CONSTANT)
val messageKey = crypto.hmacSha256(chainKey, MESSAGE_KEY_CONSTANT)
return ChainKdfResult(newChainKey, messageKey)
}
/**
* Derive the initial root key from concatenated X3DH DH outputs.
* HKDF(ikm=DH1||DH2||DH3[||DH4], salt=32 zeros, info="ShadeX3DH", length=32)
*/
fun deriveInitialRootKey(crypto: CryptoProvider, sharedSecrets: List<ByteArray>): ByteArray {
val total = sharedSecrets.sumOf { it.size }
val ikm = ByteArray(total)
var offset = 0
for (secret in sharedSecrets) {
secret.copyInto(ikm, offset)
offset += secret.size
}
return crypto.hkdf(ikm, X3DH_SALT, X3DH_INFO, 32)
}

View File

@@ -0,0 +1,181 @@
package no.zyon.shade.protocol
import no.zyon.shade.crypto.CryptoProvider
import no.zyon.shade.types.IdentityKeyPair
import no.zyon.shade.types.KeyPair
import no.zyon.shade.types.OneTimePreKey
import no.zyon.shade.types.PreKeyBundle
import no.zyon.shade.types.PreKeyMessage
import no.zyon.shade.types.SignedPreKey
/**
* X3DH key agreement. Mirrors @shade/core/x3dh.ts.
*
* Identity keys: separate Ed25519 (signing) + X25519 (DH) keypairs stored together.
*/
/** Generate a new identity keypair (Ed25519 + X25519) */
fun generateIdentityKeyPair(crypto: CryptoProvider): IdentityKeyPair {
val (signPub, signPriv) = crypto.generateEd25519KeyPair()
val (dhPub, dhPriv) = crypto.generateX25519KeyPair()
return IdentityKeyPair(
signingPublicKey = signPub,
signingPrivateKey = signPriv,
dhPublicKey = dhPub,
dhPrivateKey = dhPriv,
)
}
/** Generate a signed prekey (X25519 keypair + Ed25519 signature over public key) */
fun generateSignedPreKey(
crypto: CryptoProvider,
identity: IdentityKeyPair,
keyId: Int,
): SignedPreKey {
val (pub, priv) = crypto.generateX25519KeyPair()
val signature = crypto.sign(identity.signingPrivateKey, pub)
return SignedPreKey(
keyId = keyId,
keyPair = KeyPair(publicKey = pub, privateKey = priv),
signature = signature,
timestamp = System.currentTimeMillis(),
)
}
/** Generate a batch of one-time prekeys */
fun generateOneTimePreKeys(
crypto: CryptoProvider,
startId: Int,
count: Int,
): List<OneTimePreKey> {
val keys = mutableListOf<OneTimePreKey>()
for (i in 0 until count) {
val (pub, priv) = crypto.generateX25519KeyPair()
keys.add(OneTimePreKey(keyId = startId + i, keyPair = KeyPair(pub, priv)))
}
return keys
}
fun createPreKeyBundle(
registrationId: Int,
identity: IdentityKeyPair,
signedPreKey: SignedPreKey,
oneTimePreKey: OneTimePreKey? = null,
): PreKeyBundle {
return PreKeyBundle(
registrationId = registrationId,
identitySigningKey = identity.signingPublicKey,
identityDHKey = identity.dhPublicKey,
signedPreKey = PreKeyBundle.BundleSignedPreKey(
keyId = signedPreKey.keyId,
publicKey = signedPreKey.keyPair.publicKey,
signature = signedPreKey.signature,
),
oneTimePreKey = oneTimePreKey?.let {
PreKeyBundle.BundleOneTimePreKey(it.keyId, it.keyPair.publicKey)
},
)
}
/** Result of processing a prekey bundle (Alice's side) */
data class X3DHInitResult(
val rootKey: ByteArray,
val ephemeralPublicKey: ByteArray,
val signedPreKeyId: Int,
val preKeyId: Int?,
val remoteIdentityKey: ByteArray,
val remoteSignedPreKey: ByteArray,
)
/**
* Alice processes Bob's prekey bundle to establish a session.
*
* Steps:
* 1. Verify the signed prekey signature
* 2. Generate an ephemeral X25519 keypair
* 3. Compute DH1 = DH(Alice identity DH, Bob signed prekey)
* 4. Compute DH2 = DH(Alice ephemeral, Bob identity DH)
* 5. Compute DH3 = DH(Alice ephemeral, Bob signed prekey)
* 6. Compute DH4 = DH(Alice ephemeral, Bob one-time prekey) if available
* 7. Derive initial root key from concatenated DH outputs
*/
fun processPreKeyBundle(
crypto: CryptoProvider,
identity: IdentityKeyPair,
bundle: PreKeyBundle,
): X3DHInitResult {
// 1. Verify signed prekey signature
val valid = crypto.verify(
bundle.identitySigningKey,
bundle.signedPreKey.publicKey,
bundle.signedPreKey.signature,
)
if (!valid) throw SecurityException("Signed prekey signature is invalid")
// 2. Ephemeral keypair
val (ephPub, ephPriv) = crypto.generateX25519KeyPair()
// 3-6. DH computations
val dh1 = crypto.x25519(identity.dhPrivateKey, bundle.signedPreKey.publicKey)
val dh2 = crypto.x25519(ephPriv, bundle.identityDHKey)
val dh3 = crypto.x25519(ephPriv, bundle.signedPreKey.publicKey)
val secrets = mutableListOf(dh1, dh2, dh3)
var preKeyId: Int? = null
if (bundle.oneTimePreKey != null) {
val dh4 = crypto.x25519(ephPriv, bundle.oneTimePreKey.publicKey)
secrets.add(dh4)
preKeyId = bundle.oneTimePreKey.keyId
}
// 7. Derive root key
val rootKey = deriveInitialRootKey(crypto, secrets)
return X3DHInitResult(
rootKey = rootKey,
ephemeralPublicKey = ephPub,
signedPreKeyId = bundle.signedPreKey.keyId,
preKeyId = preKeyId,
remoteIdentityKey = bundle.identityDHKey,
remoteSignedPreKey = bundle.signedPreKey.publicKey,
)
}
/** Result of processing an incoming PreKeyMessage (Bob's side) */
data class X3DHResponseResult(
val rootKey: ByteArray,
val remoteIdentityKey: ByteArray,
val remoteEphemeralKey: ByteArray,
)
/**
* Bob processes an incoming PreKeyMessage to establish a session.
* Mirrors Alice's DH computations from Bob's perspective.
*
* Caller is responsible for looking up the signed prekey and (if present)
* the one-time prekey from storage.
*/
fun processPreKeyMessage(
crypto: CryptoProvider,
identity: IdentityKeyPair,
signedPreKeyPrivate: ByteArray,
oneTimePreKeyPrivate: ByteArray?,
message: PreKeyMessage,
): X3DHResponseResult {
val dh1 = crypto.x25519(signedPreKeyPrivate, message.identityDHKey)
val dh2 = crypto.x25519(identity.dhPrivateKey, message.ephemeralKey)
val dh3 = crypto.x25519(signedPreKeyPrivate, message.ephemeralKey)
val secrets = mutableListOf(dh1, dh2, dh3)
if (oneTimePreKeyPrivate != null) {
val dh4 = crypto.x25519(oneTimePreKeyPrivate, message.ephemeralKey)
secrets.add(dh4)
}
val rootKey = deriveInitialRootKey(crypto, secrets)
return X3DHResponseResult(
rootKey = rootKey,
remoteIdentityKey = message.identityDHKey,
remoteEphemeralKey = message.ephemeralKey,
)
}

View File

@@ -0,0 +1,169 @@
package no.zyon.shade.serialization
import no.zyon.shade.types.PreKeyMessage
import no.zyon.shade.types.RatchetMessage
import no.zyon.shade.types.ShadeEnvelope
import java.nio.ByteBuffer
/**
* Compact binary wire format. MUST match @shade/proto/wire.ts byte-for-byte.
*
* Format: [version:1][type:1][payload...]
* Types: 0x01 = PreKeyMessage, 0x02 = RatchetMessage
* Integers: big-endian
* Byte arrays: 2-byte length prefix + data
*/
object WireFormat {
private const val VERSION: Byte = 0x01
private const val TYPE_PREKEY: Byte = 0x01
private const val TYPE_RATCHET: Byte = 0x02
private const val PREKEY_NONE: Long = 0xFFFFFFFFL
// ─── Encode ──────────────────────────────────────────────
fun encodeEnvelope(envelope: ShadeEnvelope): ByteArray {
return when (envelope.type) {
ShadeEnvelope.EnvelopeType.PREKEY ->
encodePreKeyMessage(envelope.content as PreKeyMessage)
ShadeEnvelope.EnvelopeType.RATCHET ->
encodeRatchetMessage(envelope.content as RatchetMessage)
}
}
fun encodePreKeyMessage(msg: PreKeyMessage): ByteArray {
val ratchetBytes = encodeRatchetInner(msg.message)
val parts = mutableListOf<ByteArray>()
parts.add(byteArrayOf(VERSION, TYPE_PREKEY))
parts.add(uint32(msg.registrationId.toLong()))
parts.add(uint32(msg.preKeyId?.toLong() ?: PREKEY_NONE))
parts.add(uint32(msg.signedPreKeyId.toLong()))
parts.add(lpBytes(msg.ephemeralKey))
parts.add(lpBytes(msg.identityDHKey))
parts.add(lpBytes(ratchetBytes))
return concat(parts)
}
fun encodeRatchetMessage(msg: RatchetMessage): ByteArray {
val parts = mutableListOf<ByteArray>()
parts.add(byteArrayOf(VERSION, TYPE_RATCHET))
parts.add(encodeRatchetInner(msg))
return concat(parts)
}
private fun encodeRatchetInner(msg: RatchetMessage): ByteArray {
val parts = mutableListOf<ByteArray>()
parts.add(lpBytes(msg.dhPublicKey))
parts.add(uint32(msg.previousCounter.toLong()))
parts.add(uint32(msg.counter.toLong()))
parts.add(lpBytes(msg.ciphertext))
parts.add(lpBytes(msg.nonce))
return concat(parts)
}
// ─── Decode ──────────────────────────────────────────────
fun decodeEnvelope(data: ByteArray): ShadeEnvelope {
if (data.size < 2) throw IllegalArgumentException("Too short")
val version = data[0]
if (version != VERSION) throw IllegalArgumentException("Unknown version: $version")
val type = data[1]
val payload = data.copyOfRange(2, data.size)
return when (type) {
TYPE_PREKEY -> ShadeEnvelope(
type = ShadeEnvelope.EnvelopeType.PREKEY,
content = decodePreKeyMessageInner(payload),
timestamp = 0,
senderAddress = "",
)
TYPE_RATCHET -> {
val (msg, _) = decodeRatchetInner(payload, 0)
ShadeEnvelope(
type = ShadeEnvelope.EnvelopeType.RATCHET,
content = msg,
timestamp = 0,
senderAddress = "",
)
}
else -> throw IllegalArgumentException("Unknown type: $type")
}
}
private fun decodePreKeyMessageInner(data: ByteArray): PreKeyMessage {
var offset = 0
val registrationId = readUint32(data, offset).toInt(); offset += 4
val preKeyIdRaw = readUint32(data, offset); offset += 4
val preKeyId = if (preKeyIdRaw == PREKEY_NONE) null else preKeyIdRaw.toInt()
val signedPreKeyId = readUint32(data, offset).toInt(); offset += 4
val ephemeral = readLP(data, offset); offset = ephemeral.second
val identityDH = readLP(data, offset); offset = identityDH.second
val ratchetData = readLP(data, offset); offset = ratchetData.second
val (ratchet, _) = decodeRatchetInner(ratchetData.first, 0)
return PreKeyMessage(
registrationId = registrationId,
preKeyId = preKeyId,
signedPreKeyId = signedPreKeyId,
ephemeralKey = ephemeral.first,
identityDHKey = identityDH.first,
message = ratchet,
)
}
private fun decodeRatchetInner(data: ByteArray, startOffset: Int): Pair<RatchetMessage, Int> {
var offset = startOffset
val dhPub = readLP(data, offset); offset = dhPub.second
val prevCounter = readUint32(data, offset).toInt(); offset += 4
val counter = readUint32(data, offset).toInt(); offset += 4
val ciphertext = readLP(data, offset); offset = ciphertext.second
val nonce = readLP(data, offset); offset = nonce.second
return RatchetMessage(
dhPublicKey = dhPub.first,
previousCounter = prevCounter,
counter = counter,
ciphertext = ciphertext.first,
nonce = nonce.first,
) to offset
}
// ─── Helpers ─────────────────────────────────────────────
private fun uint32(n: Long): ByteArray {
val buf = ByteBuffer.allocate(4)
buf.putInt(n.toInt())
return buf.array()
}
private fun lpBytes(data: ByteArray): ByteArray {
val len = ByteBuffer.allocate(2)
len.putShort(data.size.toShort())
return concat(listOf(len.array(), data))
}
private fun readUint32(data: ByteArray, offset: Int): Long {
return ((data[offset].toLong() and 0xff) shl 24) or
((data[offset + 1].toLong() and 0xff) shl 16) or
((data[offset + 2].toLong() and 0xff) shl 8) or
(data[offset + 3].toLong() and 0xff)
}
private fun readLP(data: ByteArray, offset: Int): Pair<ByteArray, Int> {
val len = ((data[offset].toInt() and 0xff) shl 8) or (data[offset + 1].toInt() and 0xff)
val value = data.copyOfRange(offset + 2, offset + 2 + len)
return value to (offset + 2 + len)
}
private fun concat(parts: List<ByteArray>): ByteArray {
val total = parts.sumOf { it.size }
val result = ByteArray(total)
var offset = 0
for (p in parts) {
p.copyInto(result, offset)
offset += p.size
}
return result
}
}

View File

@@ -0,0 +1,47 @@
package no.zyon.shade.storage
import no.zyon.shade.crypto.CryptoProvider
import no.zyon.shade.types.IdentityKeyPair
import no.zyon.shade.types.OneTimePreKey
import no.zyon.shade.types.SessionState
import no.zyon.shade.types.SignedPreKey
/**
* In-memory storage for tests and embedded use.
* Mirrors MemoryStorage in @shade/crypto-web.
*/
class MemoryStorage(private val crypto: CryptoProvider) : StorageProvider {
private var identity: IdentityKeyPair? = null
private var registrationId: Int = 0
private val signedPreKeys = mutableMapOf<Int, SignedPreKey>()
private val oneTimePreKeys = mutableMapOf<Int, OneTimePreKey>()
private val sessions = mutableMapOf<String, SessionState>()
private val trustedIdentities = mutableMapOf<String, ByteArray>()
override suspend fun getIdentityKeyPair(): IdentityKeyPair? = identity
override suspend fun saveIdentityKeyPair(keyPair: IdentityKeyPair) { identity = keyPair }
override suspend fun getLocalRegistrationId(): Int = registrationId
override suspend fun saveLocalRegistrationId(id: Int) { registrationId = id }
override suspend fun getSignedPreKey(keyId: Int): SignedPreKey? = signedPreKeys[keyId]
override suspend fun saveSignedPreKey(key: SignedPreKey) { signedPreKeys[key.keyId] = key }
override suspend fun removeSignedPreKey(keyId: Int) { signedPreKeys.remove(keyId) }
override suspend fun getOneTimePreKey(keyId: Int): OneTimePreKey? = oneTimePreKeys[keyId]
override suspend fun saveOneTimePreKey(key: OneTimePreKey) { oneTimePreKeys[key.keyId] = key }
override suspend fun removeOneTimePreKey(keyId: Int) { oneTimePreKeys.remove(keyId) }
override suspend fun getOneTimePreKeyCount(): Int = oneTimePreKeys.size
override suspend fun getSession(address: String): SessionState? = sessions[address]
override suspend fun saveSession(address: String, state: SessionState) { sessions[address] = state }
override suspend fun removeSession(address: String) { sessions.remove(address) }
override suspend fun isTrustedIdentity(address: String, identityKey: ByteArray): Boolean {
val stored = trustedIdentities[address] ?: return true // TOFU
return crypto.constantTimeEqual(stored, identityKey)
}
override suspend fun saveTrustedIdentity(address: String, identityKey: ByteArray) {
trustedIdentities[address] = identityKey
}
}

View File

@@ -0,0 +1,42 @@
package no.zyon.shade.storage
import no.zyon.shade.types.IdentityKeyPair
import no.zyon.shade.types.OneTimePreKey
import no.zyon.shade.types.SessionState
import no.zyon.shade.types.SignedPreKey
/**
* StorageProvider interface. Mirror @shade/core/storage.ts.
*
* Implementations:
* - MemoryStorage (for tests)
* - KeystoreStorage (EncryptedSharedPreferences + Android Keystore)
* - RoomStorage (SQLite via Room, for larger datasets)
*/
interface StorageProvider {
// Identity
suspend fun getIdentityKeyPair(): IdentityKeyPair?
suspend fun saveIdentityKeyPair(keyPair: IdentityKeyPair)
suspend fun getLocalRegistrationId(): Int
suspend fun saveLocalRegistrationId(id: Int)
// Signed prekeys
suspend fun getSignedPreKey(keyId: Int): SignedPreKey?
suspend fun saveSignedPreKey(key: SignedPreKey)
suspend fun removeSignedPreKey(keyId: Int)
// One-time prekeys
suspend fun getOneTimePreKey(keyId: Int): OneTimePreKey?
suspend fun saveOneTimePreKey(key: OneTimePreKey)
suspend fun removeOneTimePreKey(keyId: Int)
suspend fun getOneTimePreKeyCount(): Int
// Sessions
suspend fun getSession(address: String): SessionState?
suspend fun saveSession(address: String, state: SessionState)
suspend fun removeSession(address: String)
// Trust
suspend fun isTrustedIdentity(address: String, identityKey: ByteArray): Boolean
suspend fun saveTrustedIdentity(address: String, identityKey: ByteArray)
}

View File

@@ -0,0 +1,140 @@
package no.zyon.shade.types
/**
* Core Shade protocol types. Mirror @shade/core/types.ts.
*
* IMPORTANT: byte-for-byte compatibility with the TypeScript version
* is a hard requirement — the wire format, serialization, and KDF
* inputs must be identical.
*/
/** Long-term identity: Ed25519 for signing + X25519 for DH */
data class IdentityKeyPair(
val signingPublicKey: ByteArray,
val signingPrivateKey: ByteArray,
val dhPublicKey: ByteArray,
val dhPrivateKey: ByteArray,
) {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is IdentityKeyPair) return false
return signingPublicKey.contentEquals(other.signingPublicKey) &&
signingPrivateKey.contentEquals(other.signingPrivateKey) &&
dhPublicKey.contentEquals(other.dhPublicKey) &&
dhPrivateKey.contentEquals(other.dhPrivateKey)
}
override fun hashCode(): Int {
var result = signingPublicKey.contentHashCode()
result = 31 * result + signingPrivateKey.contentHashCode()
result = 31 * result + dhPublicKey.contentHashCode()
result = 31 * result + dhPrivateKey.contentHashCode()
return result
}
}
/** Generic asymmetric keypair */
data class KeyPair(
val publicKey: ByteArray,
val privateKey: ByteArray,
) {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is KeyPair) return false
return publicKey.contentEquals(other.publicKey) &&
privateKey.contentEquals(other.privateKey)
}
override fun hashCode(): Int {
var result = publicKey.contentHashCode()
result = 31 * result + privateKey.contentHashCode()
return result
}
}
/** Medium-term signed prekey, rotated periodically */
data class SignedPreKey(
val keyId: Int,
val keyPair: KeyPair,
val signature: ByteArray,
val timestamp: Long,
)
/** Single-use one-time prekey */
data class OneTimePreKey(
val keyId: Int,
val keyPair: KeyPair,
)
/** Prekey bundle fetched from the server to initiate a session */
data class PreKeyBundle(
val registrationId: Int,
val identitySigningKey: ByteArray,
val identityDHKey: ByteArray,
val signedPreKey: BundleSignedPreKey,
val oneTimePreKey: BundleOneTimePreKey? = null,
) {
data class BundleSignedPreKey(
val keyId: Int,
val publicKey: ByteArray,
val signature: ByteArray,
)
data class BundleOneTimePreKey(
val keyId: Int,
val publicKey: ByteArray,
)
}
/** Chain state (root key ratchet or chain key ratchet) */
data class ChainState(
var chainKey: ByteArray,
var counter: Int,
)
/** Full Double Ratchet session state */
data class SessionState(
var remoteIdentityKey: ByteArray,
var rootKey: ByteArray,
var sendChain: ChainState,
var receiveChain: ChainState?,
var dhSend: KeyPair,
var dhReceive: ByteArray?,
var previousSendCounter: Int,
val skippedKeys: MutableMap<String, ByteArray>,
)
/** A ratchet-encrypted message */
data class RatchetMessage(
val dhPublicKey: ByteArray,
val previousCounter: Int,
val counter: Int,
val ciphertext: ByteArray,
val nonce: ByteArray,
)
/** First message to a new peer (embeds X3DH + RatchetMessage) */
data class PreKeyMessage(
val registrationId: Int,
val preKeyId: Int?,
val signedPreKeyId: Int,
val ephemeralKey: ByteArray,
val identityDHKey: ByteArray,
val message: RatchetMessage,
)
/** Envelope wrapping a wire message */
data class ShadeEnvelope(
val type: EnvelopeType,
val content: Any, // PreKeyMessage or RatchetMessage
val timestamp: Long,
val senderAddress: String,
) {
enum class EnvelopeType { PREKEY, RATCHET }
}
/** Max skip constants — must match @shade/core */
object Constants {
const val MAX_SKIP = 1000
const val MAX_CACHED_SKIPPED_KEYS = 2000
}

View File

@@ -0,0 +1,136 @@
package no.zyon.shade
import no.zyon.shade.crypto.TinkProvider
import no.zyon.shade.fingerprint.computeFingerprint
import no.zyon.shade.protocol.deriveInitialRootKey
import no.zyon.shade.protocol.kdfChainKey
import no.zyon.shade.protocol.kdfRootKey
import no.zyon.shade.serialization.WireFormat
import no.zyon.shade.types.RatchetMessage
import no.zyon.shade.types.ShadeEnvelope
import org.junit.Assert.assertEquals
import org.junit.Test
import java.io.File
import org.json.JSONObject
import org.json.JSONArray
/**
* Cross-platform test vectors. MUST match the TypeScript implementation
* byte-for-byte, otherwise cross-platform messaging breaks.
*
* The test-vectors/ directory is at the root of the Shade monorepo.
* Generated by scripts/generate-vectors.ts from the TypeScript implementation.
*/
class CrossPlatformVectorTest {
private val crypto = TinkProvider()
private val vectorsDir = File("../../test-vectors")
private fun fromHex(str: String): ByteArray {
val bytes = ByteArray(str.length / 2)
for (i in bytes.indices) {
bytes[i] = ((Character.digit(str[i * 2], 16) shl 4) +
Character.digit(str[i * 2 + 1], 16)).toByte()
}
return bytes
}
private fun hex(bytes: ByteArray): String {
return bytes.joinToString("") { "%02x".format(it) }
}
private fun loadVectors(name: String): JSONArray {
val file = File(vectorsDir, name)
val content = file.readText()
return JSONObject(content).getJSONArray("vectors")
}
@Test
fun hkdfVectorsMatch() {
val vectors = loadVectors("hkdf.json")
for (i in 0 until vectors.length()) {
val v = vectors.getJSONObject(i)
val out = crypto.hkdf(
fromHex(v.getString("ikm")),
fromHex(v.getString("salt")),
v.getString("info").toByteArray(Charsets.UTF_8),
v.getInt("length"),
)
assertEquals(v.getString("output"), hex(out))
}
}
@Test
fun kdfChainVectorsMatch() {
val vectors = loadVectors("kdf-chain.json")
val rootVec = vectors.getJSONObject(0)
val rootResult = kdfRootKey(
crypto,
fromHex(rootVec.getString("rootKey")),
fromHex(rootVec.getString("dhOutput")),
)
assertEquals(rootVec.getString("newRootKey"), hex(rootResult.newRootKey))
assertEquals(rootVec.getString("chainKey"), hex(rootResult.chainKey))
val chainVec = vectors.getJSONObject(1)
val chainResult = kdfChainKey(crypto, fromHex(chainVec.getString("chainKey")))
assertEquals(chainVec.getString("newChainKey"), hex(chainResult.newChainKey))
assertEquals(chainVec.getString("messageKey"), hex(chainResult.messageKey))
}
@Test
fun x3dhVectorsMatch() {
val vectors = loadVectors("x3dh.json")
for (i in 0 until vectors.length()) {
val v = vectors.getJSONObject(i)
val secretsArray = v.getJSONArray("secrets")
val secrets = (0 until secretsArray.length()).map { fromHex(secretsArray.getString(it)) }
val rootKey = deriveInitialRootKey(crypto, secrets)
assertEquals(v.getString("rootKey"), hex(rootKey))
}
}
@Test
fun fingerprintVectorsMatch() {
val vectors = loadVectors("fingerprint.json")
for (i in 0 until vectors.length()) {
val v = vectors.getJSONObject(i)
val fp = computeFingerprint(
crypto,
fromHex(v.getString("signingKey")),
fromHex(v.getString("dhKey")),
)
assertEquals(v.getString("fingerprint"), fp)
}
}
@Test
fun wireFormatVectorsMatch() {
val vectors = loadVectors("wire-format.json")
val v = vectors.getJSONObject(0)
val m = v.getJSONObject("message")
val msg = RatchetMessage(
dhPublicKey = fromHex(m.getString("dhPublicKey")),
previousCounter = m.getInt("previousCounter"),
counter = m.getInt("counter"),
ciphertext = fromHex(m.getString("ciphertext")),
nonce = fromHex(m.getString("nonce")),
)
val envelope = ShadeEnvelope(
type = ShadeEnvelope.EnvelopeType.RATCHET,
content = msg,
timestamp = 0,
senderAddress = "",
)
val encoded = WireFormat.encodeEnvelope(envelope)
assertEquals(v.getString("encoded"), hex(encoded))
// Roundtrip decode
val decoded = WireFormat.decodeEnvelope(encoded)
assertEquals(ShadeEnvelope.EnvelopeType.RATCHET, decoded.type)
val rm = decoded.content as RatchetMessage
assertEquals(msg.counter, rm.counter)
}
}