Skip to content

Commit

Permalink
Fix start time logic for file loading (#195)
Browse files Browse the repository at this point in the history
* Fix start time logic for file loading and resampling

* Add test file
  • Loading branch information
ZachNagengast authored Aug 10, 2024
1 parent 1f40552 commit c268c8d
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 22 deletions.
23 changes: 23 additions & 0 deletions .swiftpm/configuration/Package.resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"pins" : [
{
"identity" : "swift-argument-parser",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-argument-parser.git",
"state" : {
"revision" : "c8ed701b513cf5177118a175d85fbbbcd707ab41",
"version" : "1.3.0"
}
},
{
"identity" : "swift-transformers",
"kind" : "remoteSourceControl",
"location" : "https://github.com/huggingface/swift-transformers.git",
"state" : {
"revision" : "74b94211bdc741694ed7e700a1104c72e5ba68fe",
"version" : "0.1.7"
}
}
],
"version" : 2
}
32 changes: 25 additions & 7 deletions Sources/WhisperKit/Core/AudioProcessor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public protocol AudioProcessing {
/// - startTime: Optional start time in seconds to read from
/// - endTime: Optional end time in seconds to read until
/// - Returns: `AVAudioPCMBuffer` containing the audio data.
static func loadAudio(fromPath audioFilePath: String, startTime: Double?, endTime: Double?) throws -> AVAudioPCMBuffer
static func loadAudio(fromPath audioFilePath: String, startTime: Double?, endTime: Double?, maxReadFrameSize: AVAudioFrameCount?) throws -> AVAudioPCMBuffer

/// Loads and converts audio data from a specified file paths.
/// - Parameter audioPaths: The file paths of the audio files.
Expand Down Expand Up @@ -185,7 +185,12 @@ public class AudioProcessor: NSObject, AudioProcessing {

// MARK: - Loading and conversion

public static func loadAudio(fromPath audioFilePath: String, startTime: Double? = 0, endTime: Double? = nil) throws -> AVAudioPCMBuffer {
public static func loadAudio(
fromPath audioFilePath: String,
startTime: Double? = 0,
endTime: Double? = nil,
maxReadFrameSize: AVAudioFrameCount? = nil
) throws -> AVAudioPCMBuffer {
guard FileManager.default.fileExists(atPath: audioFilePath) else {
throw WhisperError.loadAudioFailed("Resource path does not exist \(audioFilePath)")
}
Expand Down Expand Up @@ -222,7 +227,8 @@ public class AudioProcessor: NSObject, AudioProcessing {
outputBuffer = buffer
} else {
// Audio needs resampling to 16khz
outputBuffer = resampleAudio(fromFile: audioFile, toSampleRate: 16000, channelCount: 1, frameCount: frameCount)
let maxReadFrameSize = maxReadFrameSize ?? Constants.defaultAudioReadFrameSize
outputBuffer = resampleAudio(fromFile: audioFile, toSampleRate: 16000, channelCount: 1, frameCount: frameCount, maxReadFrameSize: maxReadFrameSize)
}

if let outputBuffer = outputBuffer {
Expand Down Expand Up @@ -272,11 +278,13 @@ public class AudioProcessor: NSObject, AudioProcessing {
toSampleRate sampleRate: Double,
channelCount: AVAudioChannelCount,
frameCount: AVAudioFrameCount? = nil,
maxReadFrameSize: AVAudioFrameCount = 1_323_000 // 30s of audio at commonly found 44.1khz sample rate
maxReadFrameSize: AVAudioFrameCount = Constants.defaultAudioReadFrameSize
) -> AVAudioPCMBuffer? {
let inputFormat = audioFile.fileFormat
let inputStartFrame = audioFile.framePosition
let inputFrameCount = frameCount ?? AVAudioFrameCount(audioFile.length)
let inputDuration = Double(inputFrameCount) / inputFormat.sampleRate
let endFramePosition = min(inputStartFrame + AVAudioFramePosition(inputFrameCount), audioFile.length + 1)

guard let outputFormat = AVAudioFormat(standardFormatWithSampleRate: sampleRate, channels: channelCount) else {
Logging.error("Failed to create output audio format")
Expand All @@ -293,8 +301,8 @@ public class AudioProcessor: NSObject, AudioProcessing {

let inputBuffer = AVAudioPCMBuffer(pcmFormat: audioFile.processingFormat, frameCapacity: maxReadFrameSize)!

while audioFile.framePosition < inputFrameCount {
let remainingFrames = inputFrameCount - AVAudioFrameCount(audioFile.framePosition)
while audioFile.framePosition < endFramePosition {
let remainingFrames = AVAudioFrameCount(endFramePosition - audioFile.framePosition)
let framesToRead = min(remainingFrames, maxReadFrameSize)

let currentPositionInSeconds = Double(audioFile.framePosition) / inputFormat.sampleRate
Expand Down Expand Up @@ -357,9 +365,19 @@ public class AudioProcessor: NSObject, AudioProcessing {
/// - Returns: Resampled audio as an AVAudioPCMBuffer.
/// - Throws: WhisperError if resampling fails.
public static func resampleBuffer(_ buffer: AVAudioPCMBuffer, with converter: AVAudioConverter) throws -> AVAudioPCMBuffer {
var capacity = converter.outputFormat.sampleRate * Double(buffer.frameLength) / converter.inputFormat.sampleRate

// Check if the capacity is a whole number
if capacity.truncatingRemainder(dividingBy: 1) != 0 {
// Round to the nearest whole number
let roundedCapacity = capacity.rounded(.toNearestOrEven)
Logging.debug("Rounding buffer frame capacity from \(capacity) to \(roundedCapacity) to better fit new sample rate")
capacity = roundedCapacity
}

guard let convertedBuffer = AVAudioPCMBuffer(
pcmFormat: converter.outputFormat,
frameCapacity: AVAudioFrameCount(converter.outputFormat.sampleRate * Double(buffer.frameLength) / converter.inputFormat.sampleRate)
frameCapacity: AVAudioFrameCount(capacity)
) else {
throw WhisperError.audioProcessingFailed("Failed to create converted buffer")
}
Expand Down
3 changes: 3 additions & 0 deletions Sources/WhisperKit/Core/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Copyright © 2024 Argmax, Inc. All rights reserved.

import Accelerate
import AVFAudio
import CoreML
import Hub
import NaturalLanguage
Expand Down Expand Up @@ -1460,4 +1461,6 @@ public enum Constants {
public static let languageCodes: Set<String> = Set(languages.values)

public static let defaultLanguageCode: String = "en"

public static let defaultAudioReadFrameSize: AVAudioFrameCount = 1_323_000 // 30s of audio at commonly found 44.1khz sample rate
}
17 changes: 9 additions & 8 deletions Sources/WhisperKit/Core/Utils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -201,27 +201,28 @@ extension AVAudioPCMBuffer {
}

guard startingFrame + AVAudioFramePosition(frameCount) <= AVAudioFramePosition(buffer.frameLength) else {
Logging.debug("Insufficient audio in buffer")
Logging.error("Insufficient audio in buffer")
return false
}

guard frameLength + frameCount <= frameCapacity else {
Logging.debug("Insufficient space in buffer")
guard let destination = floatChannelData, let source = buffer.floatChannelData else {
Logging.error("Failed to access float channel data")
return false
}

guard let destination = floatChannelData, let source = buffer.floatChannelData else {
Logging.debug("Failed to access float channel data")
return false
var calculatedFrameCount = frameCount
if frameLength + frameCount > frameCapacity {
Logging.debug("Insufficient space in buffer, reducing frame count to fit")
calculatedFrameCount = frameCapacity - frameLength
}

let calculatedStride = stride
let destinationPointer = destination.pointee.advanced(by: calculatedStride * Int(frameLength))
let sourcePointer = source.pointee.advanced(by: calculatedStride * Int(startingFrame))

memcpy(destinationPointer, sourcePointer, Int(frameCount) * calculatedStride * MemoryLayout<Float>.size)
memcpy(destinationPointer, sourcePointer, Int(calculatedFrameCount) * calculatedStride * MemoryLayout<Float>.size)

frameLength += frameCount
frameLength += calculatedFrameCount
return true
}

Expand Down
Binary file added Tests/WhisperKitTests/Resources/jfk_441khz.m4a
Binary file not shown.
46 changes: 39 additions & 7 deletions Tests/WhisperKitTests/UnitTests.swift
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// For licensing see accompanying LICENSE.md file.
// Copyright © 2024 Argmax, Inc. All rights reserved.

import Combine
import AVFoundation
import Combine
import CoreML
import Hub
import NaturalLanguage
Expand All @@ -12,6 +12,10 @@ import XCTest

@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
final class UnitTests: XCTestCase {
override func setUp() async throws {
Logging.shared.logLevel = .debug
}

// MARK: - Model Loading Test

func testInit() async throws {
Expand Down Expand Up @@ -39,18 +43,49 @@ final class UnitTests: XCTestCase {
XCTAssertNotNil(audioBuffer, "Failed to load audio file at path: \(audioFilePath)")
XCTAssertEqual(audioBuffer.format.sampleRate, 16000)
XCTAssertEqual(audioBuffer.format.channelCount, 1)
XCTAssertEqual(audioBuffer.frameLength, 176000)
XCTAssertEqual(audioBuffer.frameLength, 176_000)
XCTAssertEqual(audioBuffer.frameLength, 11 * 16000)

let audioBufferWithStartTime = try AudioProcessor.loadAudio(fromPath: audioFilePath, startTime: 1.2)
XCTAssertEqual(audioBufferWithStartTime.frameLength, AVAudioFrameCount(156800))
XCTAssertEqual(audioBufferWithStartTime.frameLength, AVAudioFrameCount(156_800))
XCTAssertEqual(audioBufferWithStartTime.frameLength, AVAudioFrameCount(16000 * (11 - 1.2)))

let audioBufferWithStartTimeAndEndTime = try AudioProcessor.loadAudio(fromPath: audioFilePath, startTime: 1.2, endTime: 3.4)
XCTAssertEqual(audioBufferWithStartTimeAndEndTime.frameLength, AVAudioFrameCount(35200))
XCTAssertEqual(audioBufferWithStartTimeAndEndTime.frameLength, AVAudioFrameCount(16000 * (3.4 - 1.2)))
}

func testAudioFileLoadingWithResampling() throws {
let audioFilePath = try XCTUnwrap(
Bundle.module.path(forResource: "jfk_441khz", ofType: "m4a"),
"Audio file not found"
)
let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioFilePath)
XCTAssertNotNil(audioBuffer, "Failed to load audio file at path: \(audioFilePath)")
XCTAssertEqual(audioBuffer.format.sampleRate, 16000)
XCTAssertEqual(audioBuffer.format.channelCount, 1)
XCTAssertEqual(audioBuffer.frameLength, 176_000)

// Test start time and end time with varying max frame sizes
let audioBufferWithStartTime1 = try AudioProcessor.loadAudio(fromPath: audioFilePath, startTime: 1.2)
XCTAssertEqual(audioBufferWithStartTime1.frameLength, AVAudioFrameCount(156_800))
XCTAssertEqual(audioBufferWithStartTime1.frameLength, AVAudioFrameCount(16000 * (11 - 1.2)))

let audioBufferWithStartTimeAndEndTime1 = try AudioProcessor.loadAudio(fromPath: audioFilePath, startTime: 1.2, endTime: 3.4)
XCTAssertEqual(audioBufferWithStartTimeAndEndTime1.frameLength, AVAudioFrameCount(35200))
XCTAssertEqual(audioBufferWithStartTimeAndEndTime1.frameLength, AVAudioFrameCount(16000 * (3.4 - 1.2)))

// NOTE: depending on frameSize, the final frame lengths will match due to integer division between sample rates
let frameSize = AVAudioFrameCount(10024)
let audioBufferWithStartTime2 = try AudioProcessor.loadAudio(fromPath: audioFilePath, startTime: 1.2, maxReadFrameSize: frameSize)
XCTAssertEqual(audioBufferWithStartTime2.frameLength, AVAudioFrameCount(156_800))
XCTAssertEqual(audioBufferWithStartTime2.frameLength, AVAudioFrameCount(16000 * (11 - 1.2)))

let audioBufferWithStartTimeAndEndTime2 = try AudioProcessor.loadAudio(fromPath: audioFilePath, startTime: 1.2, endTime: 3.4, maxReadFrameSize: frameSize)
XCTAssertEqual(audioBufferWithStartTimeAndEndTime2.frameLength, AVAudioFrameCount(35200))
XCTAssertEqual(audioBufferWithStartTimeAndEndTime2.frameLength, AVAudioFrameCount(16000 * (3.4 - 1.2)))
}

func testAudioPad() {
let audioSamples = [Float](repeating: 0.0, count: 1000)
let paddedSamples = AudioProcessor.padOrTrimAudio(fromArray: audioSamples, startAt: 0, toLength: 1600)
Expand Down Expand Up @@ -85,8 +120,6 @@ final class UnitTests: XCTestCase {
}

func testAudioResampleFromFile() throws {
Logging.shared.logLevel = .debug

let audioFileURL = try XCTUnwrap(
Bundle.module.url(forResource: "jfk", withExtension: "wav"),
"Audio file not found"
Expand All @@ -95,7 +128,7 @@ final class UnitTests: XCTestCase {

let targetSampleRate = 16000.0
let targetChannelCount: AVAudioChannelCount = 1
let smallMaxReadFrameSize: AVAudioFrameCount = 10_000 // Small chunk size to test chunking logic
let smallMaxReadFrameSize: AVAudioFrameCount = 10000 // Small chunk size to test chunking logic

let resampledAudio = AudioProcessor.resampleAudio(
fromFile: audioFile,
Expand Down Expand Up @@ -1187,7 +1220,6 @@ final class UnitTests: XCTestCase {

func testVADAudioChunker() async throws {
let chunker = VADAudioChunker()
Logging.shared.logLevel = .debug

let singleChunkPath = try XCTUnwrap(
Bundle.module.path(forResource: "jfk", ofType: "wav"),
Expand Down

0 comments on commit c268c8d

Please sign in to comment.