206 lines
7.8 KiB
Kotlin
206 lines
7.8 KiB
Kotlin
|
|
package no.zyon.shade
|
||
|
|
|
||
|
|
import no.zyon.shade.crypto.CryptoProvider
|
||
|
|
import no.zyon.shade.fingerprint.computeFingerprint
|
||
|
|
import no.zyon.shade.protocol.createPreKeyBundle
|
||
|
|
import no.zyon.shade.protocol.generateIdentityKeyPair
|
||
|
|
import no.zyon.shade.protocol.generateOneTimePreKeys
|
||
|
|
import no.zyon.shade.protocol.generateSignedPreKey
|
||
|
|
import no.zyon.shade.protocol.initReceiverSession
|
||
|
|
import no.zyon.shade.protocol.initSenderSession
|
||
|
|
import no.zyon.shade.protocol.processPreKeyBundle
|
||
|
|
import no.zyon.shade.protocol.processPreKeyMessage
|
||
|
|
import no.zyon.shade.protocol.ratchetDecrypt
|
||
|
|
import no.zyon.shade.protocol.ratchetEncrypt
|
||
|
|
import no.zyon.shade.storage.StorageProvider
|
||
|
|
import no.zyon.shade.types.OneTimePreKey
|
||
|
|
import no.zyon.shade.types.PreKeyBundle
|
||
|
|
import no.zyon.shade.types.PreKeyMessage
|
||
|
|
import no.zyon.shade.types.RatchetMessage
|
||
|
|
import no.zyon.shade.types.ShadeEnvelope
|
||
|
|
import no.zyon.shade.types.SignedPreKey
|
||
|
|
|
||
|
|
/**
|
||
|
|
* High-level API mirroring @shade/core's ShadeSessionManager.
|
||
|
|
*
|
||
|
|
* Handles X3DH + Double Ratchet, persists state via StorageProvider.
|
||
|
|
*/
|
||
|
|
class ShadeSessionManager(
|
||
|
|
private val crypto: CryptoProvider,
|
||
|
|
private val storage: StorageProvider,
|
||
|
|
) {
|
||
|
|
private var identity: no.zyon.shade.types.IdentityKeyPair? = null
|
||
|
|
private var registrationId: Int = 0
|
||
|
|
private var currentSignedPreKeyId: Int = 0
|
||
|
|
|
||
|
|
// X3DH pending metadata (used for first message after bundle processing)
|
||
|
|
private val pendingX3DH = mutableMapOf<String, PendingX3DH>()
|
||
|
|
|
||
|
|
private data class PendingX3DH(
|
||
|
|
val ephemeralPublicKey: ByteArray,
|
||
|
|
val signedPreKeyId: Int,
|
||
|
|
val preKeyId: Int?,
|
||
|
|
val identityDHKey: ByteArray,
|
||
|
|
val registrationId: Int,
|
||
|
|
)
|
||
|
|
|
||
|
|
suspend fun initialize() {
|
||
|
|
identity = storage.getIdentityKeyPair() ?: run {
|
||
|
|
val fresh = generateIdentityKeyPair(crypto)
|
||
|
|
storage.saveIdentityKeyPair(fresh)
|
||
|
|
fresh
|
||
|
|
}
|
||
|
|
|
||
|
|
registrationId = storage.getLocalRegistrationId()
|
||
|
|
if (registrationId == 0) {
|
||
|
|
var id = crypto.randomUint32()
|
||
|
|
if (id == 0) id = 1
|
||
|
|
registrationId = id
|
||
|
|
storage.saveLocalRegistrationId(id)
|
||
|
|
}
|
||
|
|
|
||
|
|
val spk = storage.getSignedPreKey(1)
|
||
|
|
if (spk == null) {
|
||
|
|
val fresh = generateSignedPreKey(crypto, identity!!, 1)
|
||
|
|
storage.saveSignedPreKey(fresh)
|
||
|
|
currentSignedPreKeyId = 1
|
||
|
|
} else {
|
||
|
|
currentSignedPreKeyId = spk.keyId
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
fun getPublicIdentity(): Pair<ByteArray, ByteArray> {
|
||
|
|
val id = identity ?: throw IllegalStateException("Not initialized")
|
||
|
|
return id.signingPublicKey to id.dhPublicKey
|
||
|
|
}
|
||
|
|
|
||
|
|
suspend fun getIdentityFingerprint(): String {
|
||
|
|
val id = identity ?: throw IllegalStateException("Not initialized")
|
||
|
|
return computeFingerprint(crypto, id.signingPublicKey, id.dhPublicKey)
|
||
|
|
}
|
||
|
|
|
||
|
|
suspend fun createPreKeyBundle(): PreKeyBundle {
|
||
|
|
val id = identity ?: throw IllegalStateException("Not initialized")
|
||
|
|
val spk = storage.getSignedPreKey(currentSignedPreKeyId)
|
||
|
|
?: throw IllegalStateException("No signed prekey")
|
||
|
|
return createPreKeyBundle(registrationId, id, spk)
|
||
|
|
}
|
||
|
|
|
||
|
|
suspend fun generateOneTimePreKeys(count: Int): List<OneTimePreKey> {
|
||
|
|
val existing = storage.getOneTimePreKeyCount()
|
||
|
|
val startId = existing + 1
|
||
|
|
val keys = generateOneTimePreKeys(crypto, startId, count)
|
||
|
|
for (k in keys) storage.saveOneTimePreKey(k)
|
||
|
|
return keys
|
||
|
|
}
|
||
|
|
|
||
|
|
suspend fun rotateSignedPreKey(): SignedPreKey {
|
||
|
|
val id = identity ?: throw IllegalStateException("Not initialized")
|
||
|
|
val newId = currentSignedPreKeyId + 1
|
||
|
|
val spk = generateSignedPreKey(crypto, id, newId)
|
||
|
|
storage.saveSignedPreKey(spk)
|
||
|
|
currentSignedPreKeyId = newId
|
||
|
|
return spk
|
||
|
|
}
|
||
|
|
|
||
|
|
suspend fun initSessionFromBundle(address: String, bundle: PreKeyBundle) {
|
||
|
|
val id = identity ?: throw IllegalStateException("Not initialized")
|
||
|
|
val x3dhResult = processPreKeyBundle(crypto, id, bundle)
|
||
|
|
val session = initSenderSession(
|
||
|
|
crypto,
|
||
|
|
x3dhResult.rootKey,
|
||
|
|
x3dhResult.remoteIdentityKey,
|
||
|
|
x3dhResult.remoteSignedPreKey,
|
||
|
|
)
|
||
|
|
storage.saveSession(address, session)
|
||
|
|
storage.saveTrustedIdentity(address, x3dhResult.remoteIdentityKey)
|
||
|
|
pendingX3DH[address] = PendingX3DH(
|
||
|
|
ephemeralPublicKey = x3dhResult.ephemeralPublicKey,
|
||
|
|
signedPreKeyId = x3dhResult.signedPreKeyId,
|
||
|
|
preKeyId = x3dhResult.preKeyId,
|
||
|
|
identityDHKey = id.dhPublicKey,
|
||
|
|
registrationId = registrationId,
|
||
|
|
)
|
||
|
|
}
|
||
|
|
|
||
|
|
suspend fun encrypt(address: String, plaintext: ByteArray): ShadeEnvelope {
|
||
|
|
val session = storage.getSession(address)
|
||
|
|
?: throw IllegalStateException("No session for $address")
|
||
|
|
val ratchetMsg = ratchetEncrypt(crypto, session, plaintext)
|
||
|
|
|
||
|
|
val pending = pendingX3DH.remove(address)
|
||
|
|
if (pending != null) {
|
||
|
|
storage.saveSession(address, session)
|
||
|
|
val preKeyMsg = PreKeyMessage(
|
||
|
|
registrationId = pending.registrationId,
|
||
|
|
preKeyId = pending.preKeyId,
|
||
|
|
signedPreKeyId = pending.signedPreKeyId,
|
||
|
|
ephemeralKey = pending.ephemeralPublicKey,
|
||
|
|
identityDHKey = pending.identityDHKey,
|
||
|
|
message = ratchetMsg,
|
||
|
|
)
|
||
|
|
return ShadeEnvelope(
|
||
|
|
type = ShadeEnvelope.EnvelopeType.PREKEY,
|
||
|
|
content = preKeyMsg,
|
||
|
|
timestamp = System.currentTimeMillis(),
|
||
|
|
senderAddress = address,
|
||
|
|
)
|
||
|
|
}
|
||
|
|
|
||
|
|
storage.saveSession(address, session)
|
||
|
|
return ShadeEnvelope(
|
||
|
|
type = ShadeEnvelope.EnvelopeType.RATCHET,
|
||
|
|
content = ratchetMsg,
|
||
|
|
timestamp = System.currentTimeMillis(),
|
||
|
|
senderAddress = address,
|
||
|
|
)
|
||
|
|
}
|
||
|
|
|
||
|
|
suspend fun decrypt(address: String, envelope: ShadeEnvelope): ByteArray {
|
||
|
|
return when (envelope.type) {
|
||
|
|
ShadeEnvelope.EnvelopeType.PREKEY -> decryptPreKeyMessage(address, envelope.content as PreKeyMessage)
|
||
|
|
ShadeEnvelope.EnvelopeType.RATCHET -> decryptRatchetMessage(address, envelope.content as RatchetMessage)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
private suspend fun decryptPreKeyMessage(address: String, message: PreKeyMessage): ByteArray {
|
||
|
|
val id = identity ?: throw IllegalStateException("Not initialized")
|
||
|
|
val spk = storage.getSignedPreKey(message.signedPreKeyId)
|
||
|
|
?: throw IllegalStateException("Signed prekey ${message.signedPreKeyId} not found")
|
||
|
|
|
||
|
|
val oneTimePrivate: ByteArray? = message.preKeyId?.let { keyId ->
|
||
|
|
val otpk = storage.getOneTimePreKey(keyId)
|
||
|
|
?: throw IllegalStateException("One-time prekey $keyId not found")
|
||
|
|
storage.removeOneTimePreKey(keyId)
|
||
|
|
otpk.keyPair.privateKey
|
||
|
|
}
|
||
|
|
|
||
|
|
val x3dhResult = processPreKeyMessage(
|
||
|
|
crypto,
|
||
|
|
id,
|
||
|
|
spk.keyPair.privateKey,
|
||
|
|
oneTimePrivate,
|
||
|
|
message,
|
||
|
|
)
|
||
|
|
|
||
|
|
val session = initReceiverSession(
|
||
|
|
rootKey = x3dhResult.rootKey,
|
||
|
|
remoteIdentityKey = x3dhResult.remoteIdentityKey,
|
||
|
|
localDHKeyPair = spk.keyPair,
|
||
|
|
)
|
||
|
|
|
||
|
|
val plaintext = ratchetDecrypt(crypto, session, message.message)
|
||
|
|
storage.saveSession(address, session)
|
||
|
|
storage.saveTrustedIdentity(address, x3dhResult.remoteIdentityKey)
|
||
|
|
return plaintext
|
||
|
|
}
|
||
|
|
|
||
|
|
private suspend fun decryptRatchetMessage(address: String, message: RatchetMessage): ByteArray {
|
||
|
|
val session = storage.getSession(address)
|
||
|
|
?: throw IllegalStateException("No session for $address")
|
||
|
|
val plaintext = ratchetDecrypt(crypto, session, message)
|
||
|
|
storage.saveSession(address, session)
|
||
|
|
return plaintext
|
||
|
|
}
|
||
|
|
}
|