Skip to content

Commit

Permalink
Look for MLPackages and constituent .mlmodel protobufs for associated…
Browse files Browse the repository at this point in the history
… models (#193)

* Add initial mlpackage loading (if .mlmodelc not present)

-- Does not modify model loading in OS WK.  This is a hook to modify
load path URLs.

* Always load audio encoder last

* Adjust timings to account for decoder<>encoder order swap

* Add helper for mlpackage detection

---------

Co-authored-by: ZachNagengast <[email protected]>
  • Loading branch information
bpkeene and ZachNagengast authored Aug 10, 2024
1 parent 37007ef commit 1f40552
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 22 deletions.
16 changes: 16 additions & 0 deletions Sources/WhisperKit/Core/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,10 @@ public struct TranscriptionResult: Codable {
Decoding Full Loop: \(decodingLoopInfo)
-------------------------------
Model Load Time: \(String(format: "%.2f", timings.modelLoading)) seconds
- Prewarm: \(String(format: "%.2f", timings.prewarmLoadTime)) seconds
- Encoder: \(String(format: "%.2f", timings.encoderLoadTime)) seconds
- Decoder: \(String(format: "%.2f", timings.decoderLoadTime)) seconds
- Tokenizer: \(String(format: "%.2f", timings.tokenizerLoadTime)) seconds
Inference Duration (Global): \(String(format: "%.2f", timings.fullPipeline)) seconds
- Decoding Loop (Avg/window): \(String(format: "%.2f", decodeTimePerWindow)) seconds
- Audio Windows: \(String(format: "%.2f", timings.totalAudioProcessingRuns))
Expand Down Expand Up @@ -650,6 +654,10 @@ public struct TranscriptionTimings: Codable {
public var firstTokenTime: CFAbsoluteTime
public var inputAudioSeconds: TimeInterval
public var modelLoading: TimeInterval
public var prewarmLoadTime: TimeInterval
public var encoderLoadTime: TimeInterval
public var decoderLoadTime: TimeInterval
public var tokenizerLoadTime: TimeInterval
public var audioLoading: TimeInterval
public var audioProcessing: TimeInterval
public var logmels: TimeInterval
Expand Down Expand Up @@ -690,6 +698,10 @@ public struct TranscriptionTimings: Codable {

/// Initialize with all time intervals set to zero.
public init(modelLoading: TimeInterval = 0,
prewarmLoadTime: TimeInterval = 0,
encoderLoadTime: TimeInterval = 0,
decoderLoadTime: TimeInterval = 0,
tokenizerLoadTime: TimeInterval = 0,
audioLoading: TimeInterval = 0,
audioProcessing: TimeInterval = 0,
logmels: TimeInterval = 0,
Expand Down Expand Up @@ -719,6 +731,10 @@ public struct TranscriptionTimings: Codable {
self.firstTokenTime = Double.greatestFiniteMagnitude
self.inputAudioSeconds = 0.001
self.modelLoading = modelLoading
self.prewarmLoadTime = prewarmLoadTime
self.encoderLoadTime = encoderLoadTime
self.decoderLoadTime = decoderLoadTime
self.tokenizerLoadTime = tokenizerLoadTime
self.audioLoading = audioLoading
self.audioProcessing = audioProcessing
self.logmels = logmels
Expand Down
20 changes: 20 additions & 0 deletions Sources/WhisperKit/Core/Utils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,22 @@ public func modelSupport(for deviceName: String) -> (default: String, disabled:
return ("openai_whisper-base", [""])
}

public func detectModelURL(inFolder path: URL, named modelName: String) -> URL {
let compiledUrl = path.appending(path: "\(modelName).mlmodelc")
let packageUrl = path.appending(path: "\(modelName).mlpackage/Data/com.apple.CoreML/model.mlmodel")

let compiledModelExists: Bool = FileManager.default.fileExists(atPath: compiledUrl.path)
let packageModelExists: Bool = FileManager.default.fileExists(atPath: packageUrl.path)

// Swap to mlpackage only if the following is true: we found the mlmodel within the mlpackage, and we did not find a .mlmodelc
var modelURL = compiledUrl
if (packageModelExists && !compiledModelExists) {
modelURL = packageUrl
}

return modelURL
}

public func resolveAbsolutePath(_ inputPath: String) -> String {
let fileManager = FileManager.default

Expand Down Expand Up @@ -621,6 +637,10 @@ public func mergeTranscriptionResults(_ results: [TranscriptionResult?], confirm
// Update the merged timings with non-overlapping time values
var mergedTimings = TranscriptionTimings(
modelLoading: validResults.map { $0.timings.modelLoading }.max() ?? 0,
prewarmLoadTime: validResults.map { $0.timings.prewarmLoadTime }.max() ?? 0,
encoderLoadTime: validResults.map { $0.timings.encoderLoadTime }.max() ?? 0,
decoderLoadTime: validResults.map { $0.timings.decoderLoadTime }.max() ?? 0,
tokenizerLoadTime: validResults.map { $0.timings.tokenizerLoadTime }.max() ?? 0,
audioLoading: validResults.map { $0.timings.audioLoading }.reduce(0, +),
audioProcessing: validResults.map { $0.timings.audioProcessing }.reduce(0, +),
logmels: validResults.map { $0.timings.logmels }.reduce(0, +),
Expand Down
56 changes: 34 additions & 22 deletions Sources/WhisperKit/Core/WhisperKit.swift
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,11 @@ open class WhisperKit {

Logging.debug("Loading models from \(path.path) with prewarmMode: \(prewarmMode)")

let logmelUrl = path.appending(path: "MelSpectrogram.mlmodelc")
let encoderUrl = path.appending(path: "AudioEncoder.mlmodelc")
let decoderUrl = path.appending(path: "TextDecoder.mlmodelc")
let decoderPrefillUrl = path.appending(path: "TextDecoderContextPrefill.mlmodelc")
// Find either mlmodelc or mlpackage models
let logmelUrl = detectModelURL(inFolder: path, named: "MelSpectrogram")
let encoderUrl = detectModelURL(inFolder: path, named: "AudioEncoder")
let decoderUrl = detectModelURL(inFolder: path, named: "TextDecoder")
let decoderPrefillUrl = detectModelURL(inFolder: path, named: "TextDecoderContextPrefill")

for item in [logmelUrl, encoderUrl, decoderUrl] {
if !FileManager.default.fileExists(atPath: item.path) {
Expand All @@ -282,40 +283,47 @@ open class WhisperKit {
Logging.debug("Loaded feature extractor")
}

if let audioEncoder = audioEncoder as? WhisperMLModel {
Logging.debug("Loading audio encoder")
try await audioEncoder.loadModel(
at: encoderUrl,
computeUnits: modelCompute.audioEncoderCompute,
if FileManager.default.fileExists(atPath: decoderPrefillUrl.path) {
Logging.debug("Loading text decoder prefill data")
textDecoder.prefillData = TextDecoderContextPrefill()
try await textDecoder.prefillData?.loadModel(
at: decoderPrefillUrl,
computeUnits: modelCompute.prefillCompute,
prewarmMode: prewarmMode
)
Logging.debug("Loaded audio encoder")
Logging.debug("Loaded text decoder prefill data")
}

if let textDecoder = textDecoder as? WhisperMLModel {
Logging.debug("Loading text decoder")
let decoderLoadStart = CFAbsoluteTimeGetCurrent()
try await textDecoder.loadModel(
at: decoderUrl,
computeUnits: modelCompute.textDecoderCompute,
prewarmMode: prewarmMode
)
Logging.debug("Loaded text decoder")
currentTimings.decoderLoadTime = CFAbsoluteTimeGetCurrent() - decoderLoadStart

Logging.debug("Loaded text decoder in \(String(format: "%.2f", currentTimings.decoderLoadTime))s")
}

if FileManager.default.fileExists(atPath: decoderPrefillUrl.path) {
Logging.debug("Loading text decoder prefill data")
textDecoder.prefillData = TextDecoderContextPrefill()
try await textDecoder.prefillData?.loadModel(
at: decoderPrefillUrl,
computeUnits: modelCompute.prefillCompute,
if let audioEncoder = audioEncoder as? WhisperMLModel {
Logging.debug("Loading audio encoder")
let encoderLoadStart = CFAbsoluteTimeGetCurrent()

try await audioEncoder.loadModel(
at: encoderUrl,
computeUnits: modelCompute.audioEncoderCompute,
prewarmMode: prewarmMode
)
Logging.debug("Loaded text decoder prefill data")
currentTimings.encoderLoadTime = CFAbsoluteTimeGetCurrent() - encoderLoadStart

Logging.debug("Loaded audio encoder in \(String(format: "%.2f", currentTimings.encoderLoadTime))s")
}

if prewarmMode {
modelState = .prewarmed
currentTimings.modelLoading = CFAbsoluteTimeGetCurrent() - modelLoadStart
currentTimings.prewarmLoadTime = CFAbsoluteTimeGetCurrent() - modelLoadStart
return
}

Expand All @@ -326,20 +334,24 @@ open class WhisperKit {
textDecoder.isModelMultilingual = isModelMultilingual(logitsDim: logitsDim)
modelVariant = detectVariant(logitsDim: logitsDim, encoderDim: encoderDim)
Logging.debug("Loading tokenizer for \(modelVariant)")
let tokenizerLoadStart = CFAbsoluteTimeGetCurrent()

let tokenizer = try await loadTokenizer(
for: modelVariant,
tokenizerFolder: tokenizerFolder,
useBackgroundSession: useBackgroundDownloadSession
)
currentTimings.tokenizerLoadTime = CFAbsoluteTimeGetCurrent() - tokenizerLoadStart

self.tokenizer = tokenizer
textDecoder.tokenizer = tokenizer
Logging.debug("Loaded tokenizer")
Logging.debug("Loaded tokenizer in \(String(format: "%.2f", currentTimings.tokenizerLoadTime))s")

modelState = .loaded

currentTimings.modelLoading = CFAbsoluteTimeGetCurrent() - modelLoadStart
currentTimings.modelLoading = CFAbsoluteTimeGetCurrent() - modelLoadStart + currentTimings.prewarmLoadTime

Logging.info("Loaded models for whisper size: \(modelVariant)")
Logging.info("Loaded models for whisper size: \(modelVariant) in \(String(format: "%.2f", currentTimings.modelLoading))s")
}

public func unloadModels() async {
Expand Down
1 change: 1 addition & 0 deletions Sources/WhisperKitCLI/TranscribeCLI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ struct TranscribeCLI: AsyncParsableCommand {
computeOptions: computeOptions,
verbose: cliArguments.verbose,
logLevel: .debug,
prewarm: false,
load: true,
useBackgroundDownloadSession: false
)
Expand Down

0 comments on commit 1f40552

Please sign in to comment.