diff --git a/library/src/commonMain/kotlin/com/tidal/networktime/internal/FromEpochNtpTimestampFactory.kt b/library/src/commonMain/kotlin/com/tidal/networktime/internal/FromEpochNtpTimestampFactory.kt new file mode 100644 index 00000000..3aa824d7 --- /dev/null +++ b/library/src/commonMain/kotlin/com/tidal/networktime/internal/FromEpochNtpTimestampFactory.kt @@ -0,0 +1,26 @@ +package com.tidal.networktime.internal + +import kotlin.time.Duration +import kotlin.time.Duration.Companion.milliseconds + +internal class FromEpochNtpTimestampFactory { + operator fun invoke(epochTime: Duration) = NtpTimestamp(epochTime.epochTimeAsNtpTime) + + private val Duration.epochTimeAsNtpTime: Duration + get() { + val millis = inWholeMilliseconds + val useBase1 = millis < NtpPacket.NTP_TIMESTAMP_BASE_WITH_EPOCH_MSB_0_MILLISECONDS + val baseTimeMillis = millis - + if (useBase1) { + NtpPacket.NTP_TIMESTAMP_BASE_WITH_EPOCH_MSB_1_MILLISECONDS + } else { + NtpPacket.NTP_TIMESTAMP_BASE_WITH_EPOCH_MSB_0_MILLISECONDS + } + var seconds = baseTimeMillis / 1_000 + if (useBase1) { + seconds = seconds or 0x80000000L + } + val fraction = baseTimeMillis % 1_000 * 0x100000000L / 1_000 + return (seconds shl 32 or fraction).milliseconds + } +} diff --git a/library/src/commonMain/kotlin/com/tidal/networktime/internal/NtpExchangeResult.kt b/library/src/commonMain/kotlin/com/tidal/networktime/internal/NtpExchangeResult.kt index a8796a09..cd1a6567 100644 --- a/library/src/commonMain/kotlin/com/tidal/networktime/internal/NtpExchangeResult.kt +++ b/library/src/commonMain/kotlin/com/tidal/networktime/internal/NtpExchangeResult.kt @@ -1,27 +1,66 @@ +@file:Suppress("DuplicatedCode") // We need the duplicated variable list for performance reasons + package com.tidal.networktime.internal import kotlin.time.Duration +import kotlin.time.Duration.Companion.milliseconds internal data class NtpExchangeResult( - val timeMeasured: Duration, + val returnTime: Duration, val ntpPacket: NtpPacket, ) { val roundTripDelay: Duration - get() = timeMeasured - ntpPacket.run { - originateEpochTimestamp - - ( - transmitEpochTimestamp - - receiveEpochTimestamp - ) + get() = ntpPacket.run { + val originEpochMillis = originateEpochTimestamp.epochTime.inWholeMilliseconds + val receiveNtpMillis = receiveEpochTimestamp.ntpTime.inWholeMilliseconds + val receiveEpochMillis = receiveEpochTimestamp.epochTime.inWholeMilliseconds + val transmitNtpMillis = transmitEpochTimestamp.ntpTime.inWholeMilliseconds + val transmitEpochMillis = transmitEpochTimestamp.epochTime.inWholeMilliseconds + val returnTimeMillis = returnTime.inWholeMilliseconds + if (receiveNtpMillis == 0L || transmitNtpMillis == 0L) { + return@run if (returnTimeMillis >= originEpochMillis) { + (returnTimeMillis - originEpochMillis).milliseconds + } else { + Duration.INFINITE + } + } + var delayMillis = returnTimeMillis - originEpochMillis + val deltaMillis = transmitEpochMillis - receiveEpochMillis + if (deltaMillis <= delayMillis) { + delayMillis -= deltaMillis + } else if (deltaMillis - delayMillis == 1L) { + if (delayMillis != 0L) { + delayMillis = 0 + } + } + delayMillis.milliseconds } val clockOffset: Duration get() = ntpPacket.run { - ( - receiveEpochTimestamp - - originateEpochTimestamp + - transmitEpochTimestamp - - timeMeasured - ) - } / 2 + val originNtpMillis = originateEpochTimestamp.ntpTime.inWholeMilliseconds + val originEpochMillis = originateEpochTimestamp.epochTime.inWholeMilliseconds + val receiveNtpMillis = receiveEpochTimestamp.ntpTime.inWholeMilliseconds + val receiveEpochMillis = receiveEpochTimestamp.epochTime.inWholeMilliseconds + val transmitNtpMillis = transmitEpochTimestamp.ntpTime.inWholeMilliseconds + val transmitEpochMillis = transmitEpochTimestamp.epochTime.inWholeMilliseconds + val returnTimeMillis = returnTime.inWholeMilliseconds + if (originNtpMillis == 0L) { + if (transmitNtpMillis != 0L) { + return@run (transmitEpochMillis - returnTimeMillis).milliseconds + } + return@run Duration.INFINITE + } + if (receiveNtpMillis == 0L || transmitNtpMillis == 0L) { + if (receiveNtpMillis != 0L) { + return@run (receiveEpochMillis - originEpochMillis).milliseconds + } + if (transmitNtpMillis != 0L) { + return@run (transmitEpochMillis - returnTimeMillis).milliseconds + } + return@run Duration.INFINITE + } + ((receiveEpochMillis - originEpochMillis + transmitEpochMillis - returnTimeMillis) / 2) + .milliseconds + } } diff --git a/library/src/commonMain/kotlin/com/tidal/networktime/internal/NtpExchanger.kt b/library/src/commonMain/kotlin/com/tidal/networktime/internal/NtpExchanger.kt index d767f16f..27fb2c79 100644 --- a/library/src/commonMain/kotlin/com/tidal/networktime/internal/NtpExchanger.kt +++ b/library/src/commonMain/kotlin/com/tidal/networktime/internal/NtpExchanger.kt @@ -4,6 +4,7 @@ import kotlin.time.Duration internal class NtpExchanger( private val referenceClock: KotlinXDateTimeSystemClock, + private val fromEpochNtpTimestampFactory: FromEpochNtpTimestampFactory, private val ntpPacketSerializer: NtpPacketSerializer, private val ntpPacketDeserializer: NtpPacketDeserializer, ) { @@ -13,21 +14,19 @@ internal class NtpExchanger( ntpVersion: UByte, ): NtpExchangeResult? { val ntpUdpSocketOperations = NtpUdpSocketOperations() - val requestPacket = NtpPacket( - versionNumber = ntpVersion.toInt(), - mode = NTP_MODE_CLIENT, - ) return try { ntpUdpSocketOperations.prepareSocket(queryTimeout.inWholeMilliseconds) + val ntpPacket = NtpPacket(versionNumber = ntpVersion.toInt(), mode = NTP_MODE_CLIENT) val requestTime = referenceClock.referenceEpochTime - val buffer = ntpPacketSerializer(requestPacket.copy(transmitEpochTimestamp = requestTime)) + ntpPacket.transmitEpochTimestamp = fromEpochNtpTimestampFactory(requestTime) + val buffer = ntpPacketSerializer(ntpPacket) ntpUdpSocketOperations.exchangePacketInPlace( buffer, address, NTP_PORT_NUMBER, ) - val responseTime = referenceClock.referenceEpochTime - requestTime - NtpExchangeResult(responseTime, ntpPacketDeserializer(buffer)) + val returnTime = referenceClock.referenceEpochTime + ntpPacketDeserializer(buffer)?.let { NtpExchangeResult(returnTime, it) } } catch (_: Throwable) { null } finally { diff --git a/library/src/commonMain/kotlin/com/tidal/networktime/internal/NtpPacket.kt b/library/src/commonMain/kotlin/com/tidal/networktime/internal/NtpPacket.kt index 97ec54da..6d79242c 100644 --- a/library/src/commonMain/kotlin/com/tidal/networktime/internal/NtpPacket.kt +++ b/library/src/commonMain/kotlin/com/tidal/networktime/internal/NtpPacket.kt @@ -1,31 +1,25 @@ package com.tidal.networktime.internal import kotlin.time.Duration -import kotlin.time.Duration.Companion.days internal data class NtpPacket( val leapIndicator: Int = 0, val versionNumber: Int, val mode: Int, - val stratum: Byte = 0, - val poll: Byte = 0, - val precision: Byte = 0, + val stratum: Int = 0, + val poll: Duration = Duration.INFINITE, + val precision: Duration = Duration.INFINITE, val rootDelay: Duration = Duration.INFINITE, val rootDispersion: Duration = Duration.INFINITE, - val referenceIdentifier: Int = 0, - val referenceEpochTimestamp: Duration = Duration.INFINITE, - val originateEpochTimestamp: Duration = Duration.INFINITE, - val receiveEpochTimestamp: Duration = Duration.INFINITE, - val transmitEpochTimestamp: Duration = Duration.INFINITE, + val referenceIdentifier: String = "", + val referenceEpochTimestamp: NtpTimestamp = NtpTimestamp(Duration.ZERO), + val originateEpochTimestamp: NtpTimestamp = NtpTimestamp(Duration.ZERO), + val receiveEpochTimestamp: NtpTimestamp = NtpTimestamp(Duration.ZERO), + /** Keep this mutable to minimize delay (avoids an allocation) **/ + var transmitEpochTimestamp: NtpTimestamp = NtpTimestamp(Duration.ZERO), ) { - init { - // Check sizes of fields whose type does not match their corresponding size in the actual packet - check(leapIndicator <= 0b0011) - check(versionNumber <= 0b0111) - check(mode <= 0b0111) - } - companion object { - val NTP_EPOCH_OFFSET_WITH_EPOCH = (365.days * 70 + 17.days) + const val NTP_TIMESTAMP_BASE_WITH_EPOCH_MSB_0_MILLISECONDS = 2085978496000 + const val NTP_TIMESTAMP_BASE_WITH_EPOCH_MSB_1_MILLISECONDS = -2208988800000 } } diff --git a/library/src/commonMain/kotlin/com/tidal/networktime/internal/NtpPacketDeserializer.kt b/library/src/commonMain/kotlin/com/tidal/networktime/internal/NtpPacketDeserializer.kt index dfa498d0..c00e3158 100644 --- a/library/src/commonMain/kotlin/com/tidal/networktime/internal/NtpPacketDeserializer.kt +++ b/library/src/commonMain/kotlin/com/tidal/networktime/internal/NtpPacketDeserializer.kt @@ -6,68 +6,89 @@ import kotlin.time.Duration.Companion.milliseconds import kotlin.time.Duration.Companion.seconds internal class NtpPacketDeserializer { - operator fun invoke(bytes: ByteArray): NtpPacket { + operator fun invoke(bytes: ByteArray): NtpPacket? { var index = 0 + val leapIndicator = (bytes[index].toInt() shr 6) and 0b11 + if (leapIndicator == LEAP_INDICATOR_CLOCK_UNSYNCHRONIZED) { + return null + } + val versionNumber = (bytes[index].toInt() shr 3) and 0b111 + val mode = bytes[index].toInt() and 0b111 + if (mode != MODE_SERVER) { + return null + } + ++index + val stratum = bytes[index++].asUnsignedInt + if (stratum >= STRATUM_CLOCK_NOT_SYNCHRONIZED) { + return null + } + val poll = bytes[index++].asSignedIntToThePowerOf2.seconds + val precision = bytes[index++].asSignedIntToThePowerOf2.milliseconds + val rootDelay = bytes.sliceArray(index until index + 4).asNtpIntervalToInterval + index += 4 + val rootDispersion = bytes.sliceArray(index until index + 4).asNtpIntervalToInterval + index += 4 + val referenceIdentifier = bytes.sliceArray(index until index + 4).decodeToString() + index += 4 + val reference = bytes.sliceArray(index until index + 8).asNtpTimestamp + index += 8 + val originate = bytes.sliceArray(index until index + 8).asNtpTimestamp + index += 8 + val receive = bytes.sliceArray(index until index + 8).asNtpTimestamp + index += 8 + val transmit = bytes.sliceArray(index until index + 8).asNtpTimestamp return NtpPacket( - (bytes[index++].toInt() shl 8) + bytes[index++], - (bytes[index++].toInt() shl 16) + (bytes[index++].toInt() shl 24) + bytes[index++], - (bytes[index++].toInt() shl 16) + (bytes[index++].toInt() shl 24) + bytes[index++], - bytes[index++], - bytes[index++], - bytes[index++], - bytes.sliceArray(index until index + 32).asNtpIntervalToInterval.also { index += 32 }, - bytes.sliceArray(index until index + 32).asNtpIntervalToInterval.also { index += 32 }, - (bytes[index++].toInt() shl 24) + - (bytes[index++].toInt() shl 16) + - (bytes[index++].toInt() shl 8) + - bytes[index++].toInt(), - bytes.sliceArray(index until index + 64).asNtpEpochTimestampToEpochTime - .also { index += 64 }, - bytes.sliceArray(index until index + 64).asNtpEpochTimestampToEpochTime - .also { index += 64 }, - bytes.sliceArray(index until index + 64).asNtpEpochTimestampToEpochTime - .also { index += 64 }, - bytes.sliceArray(index until index + 64).asNtpEpochTimestampToEpochTime, + leapIndicator, + versionNumber, + mode, + stratum, + poll, + precision, + rootDelay, + rootDispersion, + referenceIdentifier, + reference, + originate, + receive, + transmit, ) } + private val Byte.asSignedIntToThePowerOf2 + get() = 2.toDouble().pow(toInt()) + + private val Byte.asUnsignedInt: Int + get() = toUByte().toInt() + private val ByteArray.asNtpIntervalToInterval: Duration get() { var index = 0 - val seconds = (this[index++].toUByte().toInt() shl 8) + - this[index++].toUByte().toInt() - val fraction = ( - (this[index++].toUByte().toInt() shl 8) + - this[index].toUByte().toInt() - ) * - 1_000 / - (1 shl 16) + val seconds = (this[index++].asUnsignedInt shl 8) + this[index++].asUnsignedInt + val fraction = ((this[index++].asUnsignedInt shl 8) + this[index].asUnsignedInt) + .toDouble() / (1 shl 16) * 1_000 return seconds.seconds + fraction.milliseconds } - private val ByteArray.asNtpEpochTimestampToEpochTime: Duration + private val Byte.asUnsignedLong: Long + get() = toUByte().toLong() + + private val ByteArray.asNtpTimestamp: NtpTimestamp get() { - val rollOverAdjustment = if ((this[0].toInt() shr 7) == 0) { - 2.toDouble().pow(32).seconds - } else { - Duration.ZERO - } var index = 0 - val seconds = (this[index++].toUByte().toInt() shl 24) + - (this[index++].toUByte().toInt() shl 16) + - (this[index++].toUByte().toInt() shl 8) + - this[index++].toUByte().toInt() - val fraction = ( - (this[index++].toUByte().toInt() shl 24) + - (this[index++].toUByte().toInt() shl 16) + - (this[index++].toUByte().toInt() shl 8) + - this[index].toUByte().toInt() - ) * - 1_000 / - 0b100000000000000000000000000000000 - return seconds.seconds + - rollOverAdjustment - - NtpPacket.NTP_EPOCH_OFFSET_WITH_EPOCH + - fraction.milliseconds + val ntpMillis = (this[index++].asUnsignedLong shl 56) or + (this[index++].asUnsignedLong shl 48) or + (this[index++].asUnsignedLong shl 40) or + (this[index++].asUnsignedLong shl 32) or + (this[index++].asUnsignedLong shl 24) or + (this[index++].asUnsignedLong shl 16) or + (this[index++].asUnsignedLong shl 8) or + this[index].asUnsignedLong + return NtpTimestamp(ntpMillis.milliseconds) } + + companion object { + private const val LEAP_INDICATOR_CLOCK_UNSYNCHRONIZED = 0b11 + private const val MODE_SERVER = 4 + private const val STRATUM_CLOCK_NOT_SYNCHRONIZED = 16 + } } diff --git a/library/src/commonMain/kotlin/com/tidal/networktime/internal/NtpPacketSerializer.kt b/library/src/commonMain/kotlin/com/tidal/networktime/internal/NtpPacketSerializer.kt index a0f3890a..e6ec53bb 100644 --- a/library/src/commonMain/kotlin/com/tidal/networktime/internal/NtpPacketSerializer.kt +++ b/library/src/commonMain/kotlin/com/tidal/networktime/internal/NtpPacketSerializer.kt @@ -1,61 +1,30 @@ package com.tidal.networktime.internal -import kotlin.random.Random import kotlin.time.Duration -internal class NtpPacketSerializer(private val random: Random) { - operator fun invoke(ntpPacket: NtpPacket) = with(ntpPacket) { - byteArrayOf( - ((leapIndicator shl 6) or (versionNumber shl 3) or mode).toByte(), - stratum, - poll, - precision, - *rootDelay.asIntervalToNtpInterval, - *rootDispersion.asIntervalToNtpInterval, - (referenceIdentifier shr 24).toByte(), - (referenceIdentifier shr 16).toByte(), - (referenceIdentifier shr 8).toByte(), - referenceIdentifier.toByte(), - *referenceEpochTimestamp.asEpochTimestampToNtpEpochTimestamp, - *originateEpochTimestamp.asEpochTimestampToNtpEpochTimestamp, - *receiveEpochTimestamp.asEpochTimestampToNtpEpochTimestamp, - *transmitEpochTimestamp.asEpochTimestampToNtpEpochTimestamp, - ) - } - - private val Duration.asIntervalToNtpInterval: ByteArray - get() { - if (this == Duration.INFINITE) { - return ByteArray(4) - } - val wholeSeconds = inWholeSeconds - val fraction = (inWholeMilliseconds - wholeSeconds * 1_000) * (1 shl 16) / 1_000 - return byteArrayOf( - (wholeSeconds shr 8).toByte(), - wholeSeconds.toByte(), - (fraction shr 8).toByte(), - fraction.toByte(), - ) +internal class NtpPacketSerializer { + operator fun invoke(ntpPacket: NtpPacket) = ntpPacket.run { + ByteArray(48).apply { + set(0, ((0 shl 6) or (versionNumber shl 3) or mode).toByte()) + transmitEpochTimestamp.ntpTime + .ntpTimestampAsByteArray + .forEachIndexed { i, it -> + set(40 + i, it) + } } + } - private val Duration.asEpochTimestampToNtpEpochTimestamp: ByteArray - get() { - if (this == Duration.INFINITE) { - return ByteArray(8) - } - val seconds = (this + NtpPacket.NTP_EPOCH_OFFSET_WITH_EPOCH).inWholeSeconds - val fraction = (inWholeMilliseconds - inWholeSeconds * 1_000) * - 0b100000000000000000000000000000000 / - 1_000 - return byteArrayOf( - (seconds shr 24).toByte(), - (seconds shr 16).toByte(), - (seconds shr 8).toByte(), - seconds.toByte(), - (fraction shr 24).toByte(), - (fraction shr 16).toByte(), - (fraction shr 8).toByte(), - random.nextBytes(1).single(), + private val Duration.ntpTimestampAsByteArray: ByteArray + get() = inWholeMilliseconds.run { + byteArrayOf( + (this shr 56 and 0xff).toByte(), + (this shr 48 and 0xff).toByte(), + (this shr 40 and 0xff).toByte(), + (this shr 32 and 0xff).toByte(), + (this shr 24 and 0xff).toByte(), + (this shr 16 and 0xff).toByte(), + (this shr 8 and 0xff).toByte(), + (this and 0xff).toByte(), ) } } diff --git a/library/src/commonMain/kotlin/com/tidal/networktime/internal/NtpTimestamp.kt b/library/src/commonMain/kotlin/com/tidal/networktime/internal/NtpTimestamp.kt new file mode 100644 index 00000000..2766915f --- /dev/null +++ b/library/src/commonMain/kotlin/com/tidal/networktime/internal/NtpTimestamp.kt @@ -0,0 +1,26 @@ +package com.tidal.networktime.internal + +import kotlin.time.Duration +import kotlin.time.Duration.Companion.milliseconds + +/** + * NTP timestamps have more precision than epochs represented with Kotlin's Long, so use them as the + * non-computed property. + */ +// TODO Convert this and FromEpochNtpTimestampFactory into two inline classes wrapping Duration +internal data class NtpTimestamp(val ntpTime: Duration) { + val epochTime: Duration + get() { + val ntpTimeValue = ntpTime.inWholeMilliseconds + val seconds = ntpTimeValue ushr 32 and 0xffffffff + val fraction = (1000.0 * (ntpTimeValue and 0xffffffff) / 0x100000000).toLong() + val mostSignificantBit = seconds and 0x80000000L + return ( + if (mostSignificantBit == 0L) { + NtpPacket.NTP_TIMESTAMP_BASE_WITH_EPOCH_MSB_0_MILLISECONDS + } else { + NtpPacket.NTP_TIMESTAMP_BASE_WITH_EPOCH_MSB_1_MILLISECONDS + } + seconds * 1000 + fraction + ).milliseconds + } +} diff --git a/library/src/commonMain/kotlin/com/tidal/networktime/internal/SyncPeriodic.kt b/library/src/commonMain/kotlin/com/tidal/networktime/internal/SyncPeriodic.kt index 6f3d0e2e..2b2a82b3 100644 --- a/library/src/commonMain/kotlin/com/tidal/networktime/internal/SyncPeriodic.kt +++ b/library/src/commonMain/kotlin/com/tidal/networktime/internal/SyncPeriodic.kt @@ -2,7 +2,6 @@ package com.tidal.networktime.internal import com.tidal.networktime.NTPServer import kotlinx.coroutines.delay -import kotlin.random.Random import kotlin.time.Duration internal class SyncPeriodic( @@ -10,10 +9,10 @@ internal class SyncPeriodic( private val syncInterval: Duration, private val referenceClock: KotlinXDateTimeSystemClock, private val mutableState: MutableState, - random: Random = Random.Default, private val ntpExchanger: NtpExchanger = NtpExchanger( referenceClock, - NtpPacketSerializer(random), + FromEpochNtpTimestampFactory(), + NtpPacketSerializer(), NtpPacketDeserializer(), ), ) { diff --git a/library/src/commonMain/kotlin/com/tidal/networktime/internal/SyncSingular.kt b/library/src/commonMain/kotlin/com/tidal/networktime/internal/SyncSingular.kt index 3c4e9015..881e3d60 100644 --- a/library/src/commonMain/kotlin/com/tidal/networktime/internal/SyncSingular.kt +++ b/library/src/commonMain/kotlin/com/tidal/networktime/internal/SyncSingular.kt @@ -31,7 +31,7 @@ internal class SyncSingular( } } mutableState.synchronizationResult = SynchronizationResult( - selectedResult.run { timeMeasured + clockOffset }, + selectedResult.run { returnTime + clockOffset }, referenceClock.referenceEpochTime, ) } diff --git a/samples/shared/src/commonMain/kotlin/root/MainScreen.kt b/samples/shared/src/commonMain/kotlin/root/MainScreen.kt index 1a4feed4..f83524c3 100644 --- a/samples/shared/src/commonMain/kotlin/root/MainScreen.kt +++ b/samples/shared/src/commonMain/kotlin/root/MainScreen.kt @@ -41,7 +41,21 @@ fun MainScreen(mainViewModel: MainViewModel) { horizontalArrangement = Arrangement.Start, ) { Text(state.localEpoch.epochToString, style = textStyle) - Text(state.synchronizedEpoch?.epochToString ?: "None", style = textStyle) + val synchronizedEpoch = state.synchronizedEpoch ?: return@FlowColumn + val diff = synchronizedEpoch - state.localEpoch + Text( + "${synchronizedEpoch.epochToString} " + + "(${ + if (diff > Duration.ZERO) { + "ahead " + } else if (diff < Duration.ZERO) { + "behind " + } else { + "" + } + }$diff)", + style = textStyle, + ) } } } diff --git a/samples/shared/src/commonMain/kotlin/root/MainViewModel.kt b/samples/shared/src/commonMain/kotlin/root/MainViewModel.kt index 88d25d59..0cc3b6b4 100644 --- a/samples/shared/src/commonMain/kotlin/root/MainViewModel.kt +++ b/samples/shared/src/commonMain/kotlin/root/MainViewModel.kt @@ -18,6 +18,7 @@ class MainViewModel { queriesPerResolvedAddress = 1, waitBetweenResolvedAddressQueries = 1.seconds, ), + synchronizationInterval = 5.seconds, ) private val stateCalculator = StateCalculator(sntpClient) private val _uiState = MutableStateFlow(stateCalculator())