Skip to content

Commit

Permalink
Complete deserialization logic
Browse files Browse the repository at this point in the history
  • Loading branch information
stoyicker committed Nov 15, 2023
1 parent 9e32c58 commit 10e7797
Show file tree
Hide file tree
Showing 9 changed files with 219 additions and 146 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package com.tidal.networktime.internal

import kotlin.time.Duration
import kotlin.time.Duration.Companion.milliseconds

internal class FromEpochNtpTimestampFactory {
operator fun invoke(epochTime: Duration) = NtpTimestamp(epochTime.epochTimeAsNtpTime)

private val Duration.epochTimeAsNtpTime: Duration
get() {
val millis = inWholeMilliseconds
val useBase1 = millis < NtpPacket.NTP_TIMESTAMP_BASE_WITH_EPOCH_MSB_0_MILLISECONDS
val baseTimeMillis = millis -
if (useBase1) {
NtpPacket.NTP_TIMESTAMP_BASE_WITH_EPOCH_MSB_1_MILLISECONDS
} else {
NtpPacket.NTP_TIMESTAMP_BASE_WITH_EPOCH_MSB_0_MILLISECONDS
}
var seconds = baseTimeMillis / 1_000
if (useBase1) {
seconds = seconds or 0x80000000L
}
val fraction = baseTimeMillis % 1_000 * 0x100000000L / 1_000
return (seconds shl 32 or fraction).milliseconds
}
}
Original file line number Diff line number Diff line change
@@ -1,27 +1,64 @@
@file:Suppress("DuplicatedCode") // We need the duplicated variable list for performance reasons

package com.tidal.networktime.internal

import kotlin.time.Duration
import kotlin.time.Duration.Companion.milliseconds

internal data class NtpExchangeResult(
val timeMeasured: Duration,
val returnTime: Duration,
val ntpPacket: NtpPacket,
) {
val roundTripDelay: Duration
get() = timeMeasured - ntpPacket.run {
originateEpochTimestamp -
(
transmitEpochTimestamp -
receiveEpochTimestamp
)
get() = ntpPacket.run {
val originEpochMillis = originateEpochTimestamp.epochTime.inWholeMilliseconds
val receiveNtpMillis = receiveEpochTimestamp.ntpTime.inWholeMilliseconds
val receiveEpochMillis = receiveEpochTimestamp.epochTime.inWholeMilliseconds
val transmitNtpMillis = transmitEpochTimestamp.ntpTime.inWholeMilliseconds
val transmitEpochMillis = transmitEpochTimestamp.epochTime.inWholeMilliseconds
val returnTimeMillis = returnTime.inWholeMilliseconds
if (receiveNtpMillis == 0L || transmitNtpMillis == 0L) {
return@run if (returnTimeMillis >= originEpochMillis) {
(returnTimeMillis - originEpochMillis).milliseconds
} else {
Duration.INFINITE
}
}
var delayMillis = returnTimeMillis - originEpochMillis
val deltaMillis = transmitEpochMillis - receiveEpochMillis
if (deltaMillis <= delayMillis) {
delayMillis -= deltaMillis
} else if (deltaMillis - delayMillis == 1L) {
if (delayMillis != 0L) {
delayMillis = 0
}
}
delayMillis.milliseconds
}

val clockOffset: Duration
get() = ntpPacket.run {
(
receiveEpochTimestamp -
originateEpochTimestamp +
transmitEpochTimestamp -
timeMeasured
)
} / 2
val originNtpMillis = originateEpochTimestamp.ntpTime.inWholeMilliseconds
val originEpochMillis = originateEpochTimestamp.epochTime.inWholeMilliseconds
val receiveNtpMillis = receiveEpochTimestamp.ntpTime.inWholeMilliseconds
val receiveEpochMillis = receiveEpochTimestamp.epochTime.inWholeMilliseconds
val transmitNtpMillis = transmitEpochTimestamp.ntpTime.inWholeMilliseconds
val transmitEpochMillis = transmitEpochTimestamp.epochTime.inWholeMilliseconds
val returnTimeMillis = returnTime.inWholeMilliseconds
if (originNtpMillis == 0L) {
if (transmitNtpMillis != 0L) {
return@run (transmitEpochMillis - returnTimeMillis).milliseconds
}
}
if (receiveNtpMillis == 0L || transmitNtpMillis == 0L) {
if (receiveNtpMillis != 0L) {
return@run (receiveEpochMillis - originEpochMillis).milliseconds
}
if (transmitNtpMillis != 0L) {
return@run (transmitEpochMillis - returnTimeMillis).milliseconds
}
}
((receiveEpochMillis - originEpochMillis + transmitEpochMillis - returnTimeMillis) / 2)
.milliseconds
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import kotlin.time.Duration

internal class NtpExchanger(
private val referenceClock: KotlinXDateTimeSystemClock,
private val fromEpochNtpTimestampFactory: FromEpochNtpTimestampFactory,
private val ntpPacketSerializer: NtpPacketSerializer,
private val ntpPacketDeserializer: NtpPacketDeserializer,
) {
Expand All @@ -13,21 +14,19 @@ internal class NtpExchanger(
ntpVersion: UByte,
): NtpExchangeResult? {
val ntpUdpSocketOperations = NtpUdpSocketOperations()
val requestPacket = NtpPacket(
versionNumber = ntpVersion.toInt(),
mode = NTP_MODE_CLIENT,
)
return try {
ntpUdpSocketOperations.prepareSocket(queryTimeout.inWholeMilliseconds)
val ntpPacket = NtpPacket(versionNumber = ntpVersion.toInt(), mode = NTP_MODE_CLIENT)
val requestTime = referenceClock.referenceEpochTime
val buffer = ntpPacketSerializer(requestPacket.copy(transmitEpochTimestamp = requestTime))
ntpPacket.transmitEpochTimestamp = fromEpochNtpTimestampFactory(requestTime)
val buffer = ntpPacketSerializer(ntpPacket)
ntpUdpSocketOperations.exchangePacketInPlace(
buffer,
address,
NTP_PORT_NUMBER,
)
val responseTime = referenceClock.referenceEpochTime - requestTime
NtpExchangeResult(responseTime, ntpPacketDeserializer(buffer))
val returnTime = referenceClock.referenceEpochTime
ntpPacketDeserializer(buffer)?.let { NtpExchangeResult(returnTime, it) }
} catch (_: Throwable) {
null
} finally {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,31 +1,25 @@
package com.tidal.networktime.internal

import kotlin.time.Duration
import kotlin.time.Duration.Companion.days

internal data class NtpPacket(
val leapIndicator: Int = 0,
val versionNumber: Int,
val mode: Int,
val stratum: Byte = 0,
val poll: Byte = 0,
val precision: Byte = 0,
val stratum: Int = 0,
val poll: Duration = Duration.INFINITE,
val precision: Duration = Duration.INFINITE,
val rootDelay: Duration = Duration.INFINITE,
val rootDispersion: Duration = Duration.INFINITE,
val referenceIdentifier: Int = 0,
val referenceEpochTimestamp: Duration = Duration.INFINITE,
val originateEpochTimestamp: Duration = Duration.INFINITE,
val receiveEpochTimestamp: Duration = Duration.INFINITE,
val transmitEpochTimestamp: Duration = Duration.INFINITE,
val referenceIdentifier: String = "",
val referenceEpochTimestamp: NtpTimestamp = NtpTimestamp(Duration.ZERO),
val originateEpochTimestamp: NtpTimestamp = NtpTimestamp(Duration.ZERO),
val receiveEpochTimestamp: NtpTimestamp = NtpTimestamp(Duration.ZERO),
/** Keep this mutable to minimize delay (avoids an allocation) **/
var transmitEpochTimestamp: NtpTimestamp = NtpTimestamp(Duration.ZERO),
) {
init {
// Check sizes of fields whose type does not match their corresponding size in the actual packet
check(leapIndicator <= 0b0011)
check(versionNumber <= 0b0111)
check(mode <= 0b0111)
}

companion object {
val NTP_EPOCH_OFFSET_WITH_EPOCH = (365.days * 70 + 17.days)
const val NTP_TIMESTAMP_BASE_WITH_EPOCH_MSB_0_MILLISECONDS = 2085978496000
const val NTP_TIMESTAMP_BASE_WITH_EPOCH_MSB_1_MILLISECONDS = -2085978496000
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,68 +6,89 @@ import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.Duration.Companion.seconds

internal class NtpPacketDeserializer {
operator fun invoke(bytes: ByteArray): NtpPacket {
operator fun invoke(bytes: ByteArray): NtpPacket? {
var index = 0
val leapIndicator = (bytes[index].toInt() shr 6) and 0b11
if (leapIndicator == LEAP_INDICATOR_CLOCK_UNSYNCHRONIZED) {
return null
}
val versionNumber = (bytes[index].toInt() shr 3) and 0b111
val mode = bytes[index].toInt() and 0b111
if (mode != MODE_SERVER) {
return null
}
++index
val stratum = bytes[index++].asUnsignedInt
if (stratum >= STRATUM_CLOCK_NOT_SYNCHRONIZED) {
return null
}
val poll = bytes[index++].asSignedIntToThePowerOf2.seconds
val precision = bytes[index++].asSignedIntToThePowerOf2.milliseconds
val rootDelay = bytes.sliceArray(index until index + 4).asNtpIntervalToInterval
index += 4
val rootDispersion = bytes.sliceArray(index until index + 4).asNtpIntervalToInterval
index += 4
val referenceIdentifier = bytes.sliceArray(index until index + 4).decodeToString()
index += 4
val reference = bytes.sliceArray(index until index + 8).asNtpTimestamp
index += 8
val originate = bytes.sliceArray(index until index + 8).asNtpTimestamp
index += 8
val receive = bytes.sliceArray(index until index + 8).asNtpTimestamp
index += 8
val transmit = bytes.sliceArray(index until index + 8).asNtpTimestamp
return NtpPacket(
(bytes[index++].toInt() shl 8) + bytes[index++],
(bytes[index++].toInt() shl 16) + (bytes[index++].toInt() shl 24) + bytes[index++],
(bytes[index++].toInt() shl 16) + (bytes[index++].toInt() shl 24) + bytes[index++],
bytes[index++],
bytes[index++],
bytes[index++],
bytes.sliceArray(index until index + 32).asNtpIntervalToInterval.also { index += 32 },
bytes.sliceArray(index until index + 32).asNtpIntervalToInterval.also { index += 32 },
(bytes[index++].toInt() shl 24) +
(bytes[index++].toInt() shl 16) +
(bytes[index++].toInt() shl 8) +
bytes[index++].toInt(),
bytes.sliceArray(index until index + 64).asNtpEpochTimestampToEpochTime
.also { index += 64 },
bytes.sliceArray(index until index + 64).asNtpEpochTimestampToEpochTime
.also { index += 64 },
bytes.sliceArray(index until index + 64).asNtpEpochTimestampToEpochTime
.also { index += 64 },
bytes.sliceArray(index until index + 64).asNtpEpochTimestampToEpochTime,
leapIndicator,
versionNumber,
mode,
stratum,
poll,
precision,
rootDelay,
rootDispersion,
referenceIdentifier,
reference,
originate,
receive,
transmit,
)
}

private val Byte.asSignedIntToThePowerOf2
get() = 2.toDouble().pow(toInt())

private val Byte.asUnsignedInt: Int
get() = toUByte().toInt()

private val ByteArray.asNtpIntervalToInterval: Duration
get() {
var index = 0
val seconds = (this[index++].toUByte().toInt() shl 8) +
this[index++].toUByte().toInt()
val fraction = (
(this[index++].toUByte().toInt() shl 8) +
this[index].toUByte().toInt()
) *
1_000 /
(1 shl 16)
val seconds = (this[index++].asUnsignedInt shl 8) + this[index++].asUnsignedInt
val fraction = ((this[index++].asUnsignedInt shl 8) + this[index].asUnsignedInt)
.toDouble() / (1 shl 16) * 1_000
return seconds.seconds + fraction.milliseconds
}

private val ByteArray.asNtpEpochTimestampToEpochTime: Duration
private val Byte.asUnsignedLong: Long
get() = toUByte().toLong()

private val ByteArray.asNtpTimestamp: NtpTimestamp
get() {
val rollOverAdjustment = if ((this[0].toInt() shr 7) == 0) {
2.toDouble().pow(32).seconds
} else {
Duration.ZERO
}
var index = 0
val seconds = (this[index++].toUByte().toInt() shl 24) +
(this[index++].toUByte().toInt() shl 16) +
(this[index++].toUByte().toInt() shl 8) +
this[index++].toUByte().toInt()
val fraction = (
(this[index++].toUByte().toInt() shl 24) +
(this[index++].toUByte().toInt() shl 16) +
(this[index++].toUByte().toInt() shl 8) +
this[index].toUByte().toInt()
) *
1_000 /
0b100000000000000000000000000000000
return seconds.seconds +
rollOverAdjustment -
NtpPacket.NTP_EPOCH_OFFSET_WITH_EPOCH +
fraction.milliseconds
val ntpMillis = (this[index++].asUnsignedLong shl 56) or
(this[index++].asUnsignedLong shl 48) or
(this[index++].asUnsignedLong shl 40) or
(this[index++].asUnsignedLong shl 32) or
(this[index++].asUnsignedLong shl 24) or
(this[index++].asUnsignedLong shl 16) or
(this[index++].asUnsignedLong shl 8) or
this[index].asUnsignedLong
return NtpTimestamp(ntpMillis.milliseconds)
}

companion object {
private const val LEAP_INDICATOR_CLOCK_UNSYNCHRONIZED = 0b11
private const val MODE_SERVER = 4
private const val STRATUM_CLOCK_NOT_SYNCHRONIZED = 16
}
}
Loading

0 comments on commit 10e7797

Please sign in to comment.