diff --git a/include/PinPatternManager.h b/include/PinPatternManager.h index 75fbd221..29dad47f 100644 --- a/include/PinPatternManager.h +++ b/include/PinPatternManager.h @@ -1,10 +1,10 @@ #pragma once #include "Common.h" +#include "SimpleMutex.h" #include -#include #include #include @@ -41,6 +41,6 @@ namespace OpenShock { gpio_num_t m_gpioPin; std::vector m_pattern; TaskHandle_t m_taskHandle; - SemaphoreHandle_t m_taskMutex; + SimpleMutex m_taskMutex; }; } // namespace OpenShock diff --git a/include/RGBPatternManager.h b/include/RGBPatternManager.h index 6f8cc822..e98e1760 100644 --- a/include/RGBPatternManager.h +++ b/include/RGBPatternManager.h @@ -1,10 +1,10 @@ #pragma once #include "Common.h" +#include "SimpleMutex.h" #include -#include #include #include @@ -32,7 +32,8 @@ namespace OpenShock { void SetPattern(const RGBState* pattern, std::size_t patternLength); template - inline void SetPattern(const RGBState (&pattern)[N]) { + inline void SetPattern(const RGBState (&pattern)[N]) + { SetPattern(pattern, N); } void SetBrightness(uint8_t brightness); @@ -47,6 +48,6 @@ namespace OpenShock { std::vector m_pattern; rmt_obj_t* m_rmtHandle; TaskHandle_t m_taskHandle; - SemaphoreHandle_t m_taskMutex; + SimpleMutex m_taskMutex; }; } // namespace OpenShock diff --git a/src/CommandHandler.cpp b/src/CommandHandler.cpp index 9b798d6f..cb261ea0 100644 --- a/src/CommandHandler.cpp +++ b/src/CommandHandler.cpp @@ -11,11 +11,11 @@ const char* const TAG = "CommandHandler"; #include "Logging.h" #include "radio/RFTransmitter.h" #include "ReadWriteMutex.h" +#include "SimpleMutex.h" #include "Time.h" #include "util/TaskUtils.h" #include -#include #include #include @@ -23,8 +23,6 @@ const char* const TAG = "CommandHandler"; const int64_t KEEP_ALIVE_INTERVAL = 60'000; const uint16_t KEEP_ALIVE_DURATION = 300; -using namespace OpenShock; - uint32_t calculateEepyTime(int64_t timeToKeepAlive) { int64_t now = OpenShock::millis(); @@ -33,19 +31,21 @@ uint32_t calculateEepyTime(int64_t timeToKeepAlive) struct KnownShocker { bool killTask; - ShockerModelType model; + OpenShock::ShockerModelType model; uint16_t shockerId; int64_t lastActivityTimestamp; }; -static ReadWriteMutex s_rfTransmitterMutex = {}; -static std::unique_ptr s_rfTransmitter = nullptr; +static OpenShock::ReadWriteMutex s_rfTransmitterMutex = {}; +static std::unique_ptr s_rfTransmitter = nullptr; -static SemaphoreHandle_t s_estopManagerMutex = nullptr; +static OpenShock::SimpleMutex s_estopManagerMutex = {}; -static ReadWriteMutex s_keepAliveMutex = {}; -static QueueHandle_t s_keepAliveQueue = nullptr; -static TaskHandle_t s_keepAliveTaskHandle = nullptr; +static OpenShock::ReadWriteMutex s_keepAliveMutex = {}; +static QueueHandle_t s_keepAliveQueue = nullptr; +static TaskHandle_t s_keepAliveTaskHandle = nullptr; + +using namespace OpenShock; void _keepAliveTask(void* arg) { @@ -111,7 +111,7 @@ bool _internalSetKeepAliveEnabled(bool enabled) return true; } - ScopedWriteLock keepAliveLock(&s_keepAliveMutex); + ScopedWriteLock lock__(&s_keepAliveMutex); if (enabled) { OS_LOGV(TAG, "Enabling keep-alive task"); @@ -160,8 +160,6 @@ bool CommandHandler::Init() } initialized = true; - s_estopManagerMutex = xSemaphoreCreateMutex(); - Config::RFConfig rfConfig; if (!Config::GetRFConfig(rfConfig)) { OS_LOGE(TAG, "Failed to get RF config"); @@ -218,7 +216,7 @@ SetGPIOResultCode CommandHandler::SetRfTxPin(gpio_num_t txPin) return SetGPIOResultCode::InvalidPin; } - ScopedWriteLock rftxLock(&s_rfTransmitterMutex); + ScopedWriteLock lock__(&s_rfTransmitterMutex); if (s_rfTransmitter != nullptr) { OS_LOGV(TAG, "Destroying existing RF transmitter"); @@ -245,23 +243,20 @@ SetGPIOResultCode CommandHandler::SetRfTxPin(gpio_num_t txPin) SetGPIOResultCode CommandHandler::SetEStopPin(gpio_num_t estopPin) { if (OpenShock::IsValidInputPin(static_cast(estopPin))) { - xSemaphoreTake(s_estopManagerMutex, portMAX_DELAY); + ScopedLock lock__(&s_estopManagerMutex); if (!EStopManager::SetEStopPin(estopPin)) { OS_LOGE(TAG, "Failed to set EStop pin"); - xSemaphoreGive(s_estopManagerMutex); return SetGPIOResultCode::InternalError; } if (!Config::SetEStopGpioPin(estopPin)) { OS_LOGE(TAG, "Failed to set EStop pin in config"); - xSemaphoreGive(s_estopManagerMutex); return SetGPIOResultCode::InternalError; } - xSemaphoreGive(s_estopManagerMutex); return SetGPIOResultCode::Success; } else { return SetGPIOResultCode::InvalidPin; @@ -318,7 +313,7 @@ gpio_num_t CommandHandler::GetRfTxPin() bool CommandHandler::HandleCommand(ShockerModelType model, uint16_t shockerId, ShockerCommandType type, uint8_t intensity, uint16_t durationMs) { - ScopedReadLock rftxLock(&s_rfTransmitterMutex); + ScopedReadLock lock__rf(&s_rfTransmitterMutex); if (s_rfTransmitter == nullptr) { OS_LOGW(TAG, "RF Transmitter is not initialized, ignoring command"); @@ -340,8 +335,8 @@ bool CommandHandler::HandleCommand(ShockerModelType model, uint16_t shockerId, S bool ok = s_rfTransmitter->SendCommand(model, shockerId, type, intensity, durationMs); - rftxLock.unlock(); - ScopedReadLock keepAliveLock(&s_keepAliveMutex); + lock__rf.unlock(); + ScopedReadLock lock__ka(&s_keepAliveMutex); if (ok && s_keepAliveQueue != nullptr) { KnownShocker cmd {.model = model, .shockerId = shockerId, .lastActivityTimestamp = OpenShock::millis() + durationMs}; diff --git a/src/EStopManager.cpp b/src/EStopManager.cpp index 1558d854..3fb6e534 100644 --- a/src/EStopManager.cpp +++ b/src/EStopManager.cpp @@ -213,7 +213,7 @@ bool EStopManager::Init() return false; } - OpenShock::ScopedLock lock(&s_estopMutex); + OpenShock::ScopedLock lock__(&s_estopMutex); if (!_setEStopPinImpl(cfg.gpioPin)) { OS_LOGE(TAG, "Failed to set EStop pin"); @@ -230,7 +230,7 @@ bool EStopManager::Init() bool EStopManager::SetEStopEnabled(bool enabled) { - OpenShock::ScopedLock lock(&s_estopMutex); + OpenShock::ScopedLock lock__(&s_estopMutex); if (s_estopPin == GPIO_NUM_NC) { gpio_num_t pin; @@ -251,7 +251,7 @@ bool EStopManager::SetEStopEnabled(bool enabled) bool EStopManager::SetEStopPin(gpio_num_t pin) { - OpenShock::ScopedLock lock(&s_estopMutex); + OpenShock::ScopedLock lock__(&s_estopMutex); return _setEStopPinImpl(pin); } diff --git a/src/OtaUpdateManager.cpp b/src/OtaUpdateManager.cpp index 907bda9a..33247351 100644 --- a/src/OtaUpdateManager.cpp +++ b/src/OtaUpdateManager.cpp @@ -11,6 +11,7 @@ const char* const TAG = "OtaUpdateManager"; #include "Logging.h" #include "SemVer.h" #include "serialization/WSGateway.h" +#include "SimpleMutex.h" #include "Time.h" #include "util/HexUtils.h" #include "util/PartitionUtils.h" @@ -49,12 +50,11 @@ using namespace std::string_view_literals; /// /// @see .platformio/packages/framework-arduinoespressif32/cores/esp32/esp32-hal-misc.c /// @return true -bool verifyRollbackLater() { +bool verifyRollbackLater() +{ return true; } -using namespace OpenShock; - enum OtaTaskEventFlag : uint32_t { OTA_TASK_EVENT_UPDATE_REQUESTED = 1 << 0, OTA_TASK_EVENT_WIFI_DISCONNECTED = 1 << 1, // If both connected and disconnected are set, disconnected takes priority. @@ -65,46 +65,53 @@ static esp_ota_img_states_t _otaImageState; static OpenShock::FirmwareBootType _bootType; static TaskHandle_t _taskHandle; static OpenShock::SemVer _requestedVersion; -static SemaphoreHandle_t _requestedVersionMutex = xSemaphoreCreateMutex(); +static OpenShock::SimpleMutex _requestedVersionMutex = {}; + +using namespace OpenShock; -bool _tryQueueUpdateRequest(const OpenShock::SemVer& version) { - if (xSemaphoreTake(_requestedVersionMutex, pdMS_TO_TICKS(1000)) != pdTRUE) { +bool _tryQueueUpdateRequest(const OpenShock::SemVer& version) +{ + if (!_requestedVersionMutex.lock(pdMS_TO_TICKS(1000))) { OS_LOGE(TAG, "Failed to take requested version mutex"); return false; } _requestedVersion = version; - xSemaphoreGive(_requestedVersionMutex); + _requestedVersionMutex.unlock(); xTaskNotify(_taskHandle, OTA_TASK_EVENT_UPDATE_REQUESTED, eSetBits); return true; } -bool _tryGetRequestedVersion(OpenShock::SemVer& version) { - if (xSemaphoreTake(_requestedVersionMutex, pdMS_TO_TICKS(1000)) != pdTRUE) { +bool _tryGetRequestedVersion(OpenShock::SemVer& version) +{ + if (!_requestedVersionMutex.lock(pdMS_TO_TICKS(1000))) { OS_LOGE(TAG, "Failed to take requested version mutex"); return false; } version = _requestedVersion; - xSemaphoreGive(_requestedVersionMutex); + _requestedVersionMutex.unlock(); return true; } -void _otaEvGotIPHandler(arduino_event_t* event) { +void _otaEvGotIPHandler(arduino_event_t* event) +{ (void)event; xTaskNotify(_taskHandle, OTA_TASK_EVENT_WIFI_CONNECTED, eSetBits); } -void _otaEvWiFiDisconnectedHandler(arduino_event_t* event) { +void _otaEvWiFiDisconnectedHandler(arduino_event_t* event) +{ (void)event; xTaskNotify(_taskHandle, OTA_TASK_EVENT_WIFI_DISCONNECTED, eSetBits); } -bool _sendProgressMessage(Serialization::Gateway::OtaInstallProgressTask task, float progress) { +bool _sendProgressMessage(Serialization::Gateway::OtaInstallProgressTask task, float progress) +{ int32_t updateId; if (!Config::GetOtaUpdateId(updateId)) { OS_LOGE(TAG, "Failed to get OTA update ID"); @@ -118,7 +125,8 @@ bool _sendProgressMessage(Serialization::Gateway::OtaInstallProgressTask task, f return true; } -bool _sendFailureMessage(std::string_view message, bool fatal = false) { +bool _sendFailureMessage(std::string_view message, bool fatal = false) +{ int32_t updateId; if (!Config::GetOtaUpdateId(updateId)) { OS_LOGE(TAG, "Failed to get OTA update ID"); @@ -133,7 +141,8 @@ bool _sendFailureMessage(std::string_view message, bool fatal = false) { return true; } -bool _flashAppPartition(const esp_partition_t* partition, std::string_view remoteUrl, const uint8_t (&remoteHash)[32]) { +bool _flashAppPartition(const esp_partition_t* partition, std::string_view remoteUrl, const uint8_t (&remoteHash)[32]) +{ OS_LOGD(TAG, "Flashing app partition"); if (!_sendProgressMessage(Serialization::Gateway::OtaInstallProgressTask::FlashingApplication, 0.0f)) { @@ -168,7 +177,8 @@ bool _flashAppPartition(const esp_partition_t* partition, std::string_view remot return true; } -bool _flashFilesystemPartition(const esp_partition_t* parition, std::string_view remoteUrl, const uint8_t (&remoteHash)[32]) { +bool _flashFilesystemPartition(const esp_partition_t* parition, std::string_view remoteUrl, const uint8_t (&remoteHash)[32]) +{ if (!_sendProgressMessage(Serialization::Gateway::OtaInstallProgressTask::PreparingForInstall, 0.0f)) { return false; } @@ -218,13 +228,14 @@ bool _flashFilesystemPartition(const esp_partition_t* parition, std::string_view return true; } -void _otaUpdateTask(void* arg) { +void _otaUpdateTask(void* arg) +{ (void)arg; OS_LOGD(TAG, "OTA update task started"); - bool connected = false; - bool updateRequested = false; + bool connected = false; + bool updateRequested = false; int64_t lastUpdateCheck = 0; // Update task loop. @@ -264,7 +275,7 @@ void _otaUpdateTask(void* arg) { continue; } - bool firstCheck = lastUpdateCheck == 0; + bool firstCheck = lastUpdateCheck == 0; int64_t diff = now - lastUpdateCheck; int64_t diffMins = diff / 60'000LL; @@ -395,7 +406,8 @@ void _otaUpdateTask(void* arg) { esp_restart(); } -bool _tryGetStringList(std::string_view url, std::vector& list) { +bool _tryGetStringList(std::string_view url, std::vector& list) +{ auto response = OpenShock::HTTP::GetString( url, { @@ -428,7 +440,8 @@ bool _tryGetStringList(std::string_view url, std::vector& list) { return true; } -bool OtaUpdateManager::Init() { +bool OtaUpdateManager::Init() +{ OS_LOGN(TAG, "Fetching current partition"); // Fetch current partition info. @@ -484,7 +497,8 @@ bool OtaUpdateManager::Init() { return true; } -bool OtaUpdateManager::TryGetFirmwareVersion(OtaUpdateChannel channel, OpenShock::SemVer& version) { +bool OtaUpdateManager::TryGetFirmwareVersion(OtaUpdateChannel channel, OpenShock::SemVer& version) +{ std::string_view channelIndexUrl; switch (channel) { case OtaUpdateChannel::Stable: @@ -523,7 +537,8 @@ bool OtaUpdateManager::TryGetFirmwareVersion(OtaUpdateChannel channel, OpenShock return true; } -bool OtaUpdateManager::TryGetFirmwareBoards(const OpenShock::SemVer& version, std::vector& boards) { +bool OtaUpdateManager::TryGetFirmwareBoards(const OpenShock::SemVer& version, std::vector& boards) +{ std::string channelIndexUrl; if (!FormatToString(channelIndexUrl, OPENSHOCK_FW_CDN_BOARDS_INDEX_URL_FORMAT, version.toString().c_str())) { // TODO: This is abusing the SemVer::toString() method causing alot of string copies, fix this OS_LOGE(TAG, "Failed to format URL"); @@ -540,7 +555,8 @@ bool OtaUpdateManager::TryGetFirmwareBoards(const OpenShock::SemVer& version, st return true; } -bool _tryParseIntoHash(std::string_view hash, uint8_t (&hashBytes)[32]) { +bool _tryParseIntoHash(std::string_view hash, uint8_t (&hashBytes)[32]) +{ if (!HexUtils::TryParseHex(hash.data(), hash.size(), hashBytes, 32)) { OS_LOGE(TAG, "Failed to parse hash: %.*s", hash.size(), hash.data()); return false; @@ -549,7 +565,8 @@ bool _tryParseIntoHash(std::string_view hash, uint8_t (&hashBytes)[32]) { return true; } -bool OtaUpdateManager::TryGetFirmwareRelease(const OpenShock::SemVer& version, FirmwareRelease& release) { +bool OtaUpdateManager::TryGetFirmwareRelease(const OpenShock::SemVer& version, FirmwareRelease& release) +{ auto versionStr = version.toString(); // TODO: This is abusing the SemVer::toString() method causing alot of string copies, fix this if (!FormatToString(release.appBinaryUrl, OPENSHOCK_FW_CDN_APP_URL_FORMAT, versionStr.c_str())) { @@ -633,21 +650,25 @@ bool OtaUpdateManager::TryGetFirmwareRelease(const OpenShock::SemVer& version, F return true; } -bool OtaUpdateManager::TryStartFirmwareInstallation(const OpenShock::SemVer& version) { +bool OtaUpdateManager::TryStartFirmwareInstallation(const OpenShock::SemVer& version) +{ OS_LOGD(TAG, "Requesting firmware version %s", version.toString().c_str()); // TODO: This is abusing the SemVer::toString() method causing alot of string copies, fix this return _tryQueueUpdateRequest(version); } -FirmwareBootType OtaUpdateManager::GetFirmwareBootType() { +FirmwareBootType OtaUpdateManager::GetFirmwareBootType() +{ return _bootType; } -bool OtaUpdateManager::IsValidatingApp() { +bool OtaUpdateManager::IsValidatingApp() +{ return _otaImageState == ESP_OTA_IMG_PENDING_VERIFY; } -void OtaUpdateManager::InvalidateAndRollback() { +void OtaUpdateManager::InvalidateAndRollback() +{ // Set OTA boot type in config. if (!Config::SetOtaUpdateStep(OpenShock::OtaUpdateStep::RollingBack)) { OS_PANIC(TAG, "Failed to set OTA firmware boot type in critical section"); // TODO: THIS IS A CRITICAL SECTION, WHAT DO WE DO? @@ -674,7 +695,8 @@ void OtaUpdateManager::InvalidateAndRollback() { esp_restart(); } -void OtaUpdateManager::ValidateApp() { +void OtaUpdateManager::ValidateApp() +{ if (esp_ota_mark_app_valid_cancel_rollback() != ESP_OK) { OS_PANIC(TAG, "Unable to mark app as valid, WTF?"); // TODO: Wtf do we do here? } diff --git a/src/PinPatternManager.cpp b/src/PinPatternManager.cpp index b56d928b..573e0a14 100644 --- a/src/PinPatternManager.cpp +++ b/src/PinPatternManager.cpp @@ -15,7 +15,7 @@ PinPatternManager::PinPatternManager(gpio_num_t gpioPin) : m_gpioPin(GPIO_NUM_NC) , m_pattern() , m_taskHandle(nullptr) - , m_taskMutex(xSemaphoreCreateMutex()) + , m_taskMutex() { if (gpioPin == GPIO_NUM_NC) { OS_LOGE(TAG, "Pin is not set"); @@ -45,8 +45,6 @@ PinPatternManager::~PinPatternManager() { ClearPattern(); - vSemaphoreDelete(m_taskMutex); - if (m_gpioPin != GPIO_NUM_NC) { gpio_reset_pin(m_gpioPin); } @@ -54,6 +52,8 @@ PinPatternManager::~PinPatternManager() void PinPatternManager::SetPattern(const State* pattern, std::size_t patternLength) { + m_taskMutex.lock(portMAX_DELAY); + ClearPatternInternal(); // Set new values @@ -72,20 +72,20 @@ void PinPatternManager::SetPattern(const State* pattern, std::size_t patternLeng m_pattern.clear(); } - // Give the semaphore back - xSemaphoreGive(m_taskMutex); + m_taskMutex.unlock(); } void PinPatternManager::ClearPattern() { + m_taskMutex.lock(portMAX_DELAY); + ClearPatternInternal(); - xSemaphoreGive(m_taskMutex); + + m_taskMutex.unlock(); } void PinPatternManager::ClearPatternInternal() { - xSemaphoreTake(m_taskMutex, portMAX_DELAY); - if (m_taskHandle != nullptr) { vTaskDelete(m_taskHandle); m_taskHandle = nullptr; diff --git a/src/RGBPatternManager.cpp b/src/RGBPatternManager.cpp index 7fdae6b3..ee509538 100644 --- a/src/RGBPatternManager.cpp +++ b/src/RGBPatternManager.cpp @@ -22,7 +22,7 @@ RGBPatternManager::RGBPatternManager(gpio_num_t gpioPin) , m_pattern() , m_rmtHandle(nullptr) , m_taskHandle(nullptr) - , m_taskMutex(xSemaphoreCreateMutex()) + , m_taskMutex() { if (gpioPin == GPIO_NUM_NC) { OS_LOGE(TAG, "Pin is not set"); @@ -52,11 +52,13 @@ RGBPatternManager::~RGBPatternManager() { ClearPattern(); - vSemaphoreDelete(m_taskMutex); + rmtDeinit(m_rmtHandle); } void RGBPatternManager::SetPattern(const RGBState* pattern, std::size_t patternLength) { + m_taskMutex.lock(portMAX_DELAY); + ClearPatternInternal(); // Set new values @@ -72,20 +74,20 @@ void RGBPatternManager::SetPattern(const RGBState* pattern, std::size_t patternL m_pattern.clear(); } - // Give the semaphore back - xSemaphoreGive(m_taskMutex); + m_taskMutex.unlock(); } void RGBPatternManager::ClearPattern() { + m_taskMutex.lock(portMAX_DELAY); + ClearPatternInternal(); - xSemaphoreGive(m_taskMutex); + + m_taskMutex.unlock(); } void RGBPatternManager::ClearPatternInternal() { - xSemaphoreTake(m_taskMutex, portMAX_DELAY); - if (m_taskHandle != nullptr) { vTaskDelete(m_taskHandle); m_taskHandle = nullptr; diff --git a/src/http/HTTPRequestManager.cpp b/src/http/HTTPRequestManager.cpp index 751a36fc..9ed9d704 100644 --- a/src/http/HTTPRequestManager.cpp +++ b/src/http/HTTPRequestManager.cpp @@ -4,6 +4,7 @@ const char* const TAG = "HTTPRequestManager"; #include "Common.h" #include "Logging.h" +#include "SimpleMutex.h" #include "Time.h" #include "util/StringUtils.h" @@ -22,31 +23,40 @@ const std::size_t HTTP_BUFFER_SIZE = 4096LLU; const int HTTP_DOWNLOAD_SIZE_LIMIT = 200 * 1024 * 1024; // 200 MB struct RateLimit { - RateLimit() : m_mutex(xSemaphoreCreateMutex()), m_blockUntilMs(0), m_limits(), m_requests() { } + RateLimit() + : m_mutex() + , m_blockUntilMs(0) + , m_limits() + , m_requests() + { + } - void addLimit(uint32_t durationMs, uint16_t count) { - xSemaphoreTake(m_mutex, portMAX_DELAY); + void addLimit(uint32_t durationMs, uint16_t count) + { + m_mutex.lock(portMAX_DELAY); // Insert sorted m_limits.insert(std::upper_bound(m_limits.begin(), m_limits.end(), durationMs, [](int64_t durationMs, const Limit& limit) { return durationMs > limit.durationMs; }), {durationMs, count}); - xSemaphoreGive(m_mutex); + m_mutex.unlock(); } - void clearLimits() { - xSemaphoreTake(m_mutex, portMAX_DELAY); + + void clearLimits() + { + m_mutex.lock(portMAX_DELAY); m_limits.clear(); - xSemaphoreGive(m_mutex); + m_mutex.unlock(); } - bool tryRequest() { + bool tryRequest() + { int64_t now = OpenShock::millis(); - xSemaphoreTake(m_mutex, portMAX_DELAY); + OpenShock::ScopedLock lock__(&m_mutex); if (m_blockUntilMs > now) { - xSemaphoreGive(m_mutex); return false; } @@ -59,34 +69,37 @@ struct RateLimit { auto it = std::find_if(m_limits.begin(), m_limits.end(), [this](const RateLimit::Limit& limit) { return m_requests.size() >= limit.count; }); if (it != m_limits.end()) { m_blockUntilMs = now + it->durationMs; - xSemaphoreGive(m_mutex); return false; } // Add the request m_requests.push_back(now); - xSemaphoreGive(m_mutex); - return true; } - void clearRequests() { - xSemaphoreTake(m_mutex, portMAX_DELAY); + void clearRequests() + { + m_mutex.lock(portMAX_DELAY); + m_requests.clear(); - xSemaphoreGive(m_mutex); + + m_mutex.unlock(); } - void blockUntil(int64_t blockUntilMs) { - xSemaphoreTake(m_mutex, portMAX_DELAY); + void blockUntil(int64_t blockUntilMs) + { + m_mutex.lock(portMAX_DELAY); + m_blockUntilMs = blockUntilMs; - xSemaphoreGive(m_mutex); + + m_mutex.unlock(); } - uint32_t requestsSince(int64_t sinceMs) { - xSemaphoreTake(m_mutex, portMAX_DELAY); - uint32_t result = std::count_if(m_requests.begin(), m_requests.end(), [sinceMs](int64_t requestMs) { return requestMs >= sinceMs; }); - xSemaphoreGive(m_mutex); - return result; + uint32_t requestsSince(int64_t sinceMs) + { + OpenShock::ScopedLock lock__(&m_mutex); + + return std::count_if(m_requests.begin(), m_requests.end(), [sinceMs](int64_t requestMs) { return requestMs >= sinceMs; }); } private: @@ -95,18 +108,19 @@ struct RateLimit { uint16_t count; }; - SemaphoreHandle_t m_mutex; + OpenShock::SimpleMutex m_mutex; int64_t m_blockUntilMs; std::vector m_limits; std::vector m_requests; }; -SemaphoreHandle_t s_rateLimitsMutex = xSemaphoreCreateMutex(); -std::unordered_map> s_rateLimits; +static OpenShock::SimpleMutex s_rateLimitsMutex = {}; +static std::unordered_map> s_rateLimits = {}; using namespace OpenShock; -std::string_view _getDomain(std::string_view url) { +std::string_view _getDomain(std::string_view url) +{ if (url.empty()) { return {}; } @@ -142,7 +156,8 @@ std::string_view _getDomain(std::string_view url) { return url; } -std::shared_ptr _rateLimitFactory(std::string_view domain) { +std::shared_ptr _rateLimitFactory(std::string_view domain) +{ auto rateLimit = std::make_shared(); // Add default limits @@ -158,13 +173,14 @@ std::shared_ptr _rateLimitFactory(std::string_view domain) { return rateLimit; } -std::shared_ptr _getRateLimiter(std::string_view url) { +std::shared_ptr _getRateLimiter(std::string_view url) +{ auto domain = std::string(_getDomain(url)); if (domain.empty()) { return nullptr; } - xSemaphoreTake(s_rateLimitsMutex, portMAX_DELAY); + s_rateLimitsMutex.lock(portMAX_DELAY); auto it = s_rateLimits.find(domain); if (it == s_rateLimits.end()) { @@ -172,12 +188,13 @@ std::shared_ptr _getRateLimiter(std::string_view url) { it = s_rateLimits.find(domain); } - xSemaphoreGive(s_rateLimitsMutex); + s_rateLimitsMutex.unlock(); return it->second; } -void _setupClient(HTTPClient& client) { +void _setupClient(HTTPClient& client) +{ client.setUserAgent(OpenShock::Constants::FW_USERAGENT); } @@ -186,10 +203,12 @@ struct StreamReaderResult { std::size_t nWritten; }; -constexpr bool _isCRLF(const uint8_t* buffer) { +constexpr bool _isCRLF(const uint8_t* buffer) +{ return buffer[0] == '\r' && buffer[1] == '\n'; } -constexpr bool _tryFindCRLF(std::size_t& pos, const uint8_t* buffer, std::size_t len) { +constexpr bool _tryFindCRLF(std::size_t& pos, const uint8_t* buffer, std::size_t len) +{ const uint8_t* cur = buffer; const uint8_t* end = buffer + len - 1; @@ -204,7 +223,8 @@ constexpr bool _tryFindCRLF(std::size_t& pos, const uint8_t* buffer, std::size_t return false; } -constexpr bool _tryParseHexSizeT(std::size_t& result, std::string_view str) { +constexpr bool _tryParseHexSizeT(std::size_t& result, std::string_view str) +{ if (str.empty() || str.size() > sizeof(std::size_t) * 2) { return false; } @@ -232,7 +252,8 @@ enum ParserState : uint8_t { Invalid, }; -ParserState _parseChunkHeader(const uint8_t* buffer, std::size_t bufferLen, std::size_t& headerLen, std::size_t& payloadLen) { +ParserState _parseChunkHeader(const uint8_t* buffer, std::size_t bufferLen, std::size_t& headerLen, std::size_t& payloadLen) +{ if (bufferLen < 5) { // Bare minimum: "0\r\n\r\n" return ParserState::NeedMoreData; } @@ -282,7 +303,8 @@ ParserState _parseChunkHeader(const uint8_t* buffer, std::size_t bufferLen, std: return ParserState::Ok; } -ParserState _parseChunk(const uint8_t* buffer, std::size_t bufferLen, std::size_t& payloadPos, std::size_t& payloadLen) { +ParserState _parseChunk(const uint8_t* buffer, std::size_t bufferLen, std::size_t& payloadPos, std::size_t& payloadLen) +{ if (payloadPos == 0) { ParserState state = _parseChunkHeader(buffer, bufferLen, payloadPos, payloadLen); if (state != ParserState::Ok) { @@ -304,7 +326,8 @@ ParserState _parseChunk(const uint8_t* buffer, std::size_t bufferLen, std::size_ return ParserState::Ok; } -void _alignChunk(uint8_t* buffer, std::size_t& bufferCursor, std::size_t payloadPos, std::size_t payloadLen) { +void _alignChunk(uint8_t* buffer, std::size_t& bufferCursor, std::size_t payloadPos, std::size_t payloadLen) +{ std::size_t totalLen = payloadPos + payloadLen + 2; // +2 for CRLF std::size_t remaining = bufferCursor - totalLen; if (remaining > 0) { @@ -315,7 +338,8 @@ void _alignChunk(uint8_t* buffer, std::size_t& bufferCursor, std::size_t payload } } -StreamReaderResult _readStreamDataChunked(HTTPClient& client, WiFiClient* stream, HTTP::DownloadCallback downloadCallback, int64_t begin, uint32_t timeoutMs) { +StreamReaderResult _readStreamDataChunked(HTTPClient& client, WiFiClient* stream, HTTP::DownloadCallback downloadCallback, int64_t begin, uint32_t timeoutMs) +{ std::size_t totalWritten = 0; HTTP::RequestResult result = HTTP::RequestResult::Success; @@ -395,7 +419,8 @@ StreamReaderResult _readStreamDataChunked(HTTPClient& client, WiFiClient* stream return {result, totalWritten}; } -StreamReaderResult _readStreamData(HTTPClient& client, WiFiClient* stream, std::size_t contentLength, HTTP::DownloadCallback downloadCallback, int64_t begin, uint32_t timeoutMs) { +StreamReaderResult _readStreamData(HTTPClient& client, WiFiClient* stream, std::size_t contentLength, HTTP::DownloadCallback downloadCallback, int64_t begin, uint32_t timeoutMs) +{ std::size_t nWritten = 0; HTTP::RequestResult result = HTTP::RequestResult::Success; @@ -448,7 +473,8 @@ HTTP::Response _doGetStream( HTTP::GotContentLengthCallback contentLengthCallback, HTTP::DownloadCallback downloadCallback, uint32_t timeoutMs -) { +) +{ int64_t begin = OpenShock::millis(); if (!client.begin(OpenShock::StringToArduinoString(url))) { OS_LOGE(TAG, "Failed to begin HTTP request"); @@ -535,7 +561,8 @@ HTTP::Response _doGetStream( } HTTP::Response - HTTP::Download(std::string_view url, const std::map& headers, HTTP::GotContentLengthCallback contentLengthCallback, HTTP::DownloadCallback downloadCallback, const std::vector& acceptedCodes, uint32_t timeoutMs) { + HTTP::Download(std::string_view url, const std::map& headers, HTTP::GotContentLengthCallback contentLengthCallback, HTTP::DownloadCallback downloadCallback, const std::vector& acceptedCodes, uint32_t timeoutMs) +{ std::shared_ptr rateLimiter = _getRateLimiter(url); if (rateLimiter == nullptr) { return {RequestResult::InvalidURL, 0, 0}; @@ -551,7 +578,8 @@ HTTP::Response return _doGetStream(client, url, headers, acceptedCodes, rateLimiter, contentLengthCallback, downloadCallback, timeoutMs); } -HTTP::Response HTTP::GetString(std::string_view url, const std::map& headers, const std::vector& acceptedCodes, uint32_t timeoutMs) { +HTTP::Response HTTP::GetString(std::string_view url, const std::map& headers, const std::vector& acceptedCodes, uint32_t timeoutMs) +{ std::string result; auto allocator = [&result](std::size_t contentLength) { diff --git a/src/wifi/WiFiScanManager.cpp b/src/wifi/WiFiScanManager.cpp index 23074e06..ed2ea260 100644 --- a/src/wifi/WiFiScanManager.cpp +++ b/src/wifi/WiFiScanManager.cpp @@ -5,6 +5,7 @@ const char* const TAG = "WiFiScanManager"; #include "Logging.h" +#include "SimpleMutex.h" #include "util/TaskUtils.h" #include @@ -22,40 +23,40 @@ enum WiFiScanTaskNotificationFlags { CLEAR_FLAGS = CHANNEL_DONE | ERROR }; -using namespace OpenShock; - -static bool s_initialized = false; -static TaskHandle_t s_scanTaskHandle = nullptr; -static SemaphoreHandle_t s_scanTaskMutex = xSemaphoreCreateMutex(); -static uint8_t s_currentChannel = 0; -static std::map s_statusChangedHandlers; -static std::map s_networksDiscoveredHandlers; +static bool s_initialized = false; +static TaskHandle_t s_scanTaskHandle = nullptr; +static OpenShock::SimpleMutex s_scanTaskMutex = {}; +static uint8_t s_currentChannel = 0; +static std::map s_statusChangedHandlers; +static std::map s_networksDiscoveredHandlers; -bool _notifyTask(WiFiScanTaskNotificationFlags flags) { - xSemaphoreTake(s_scanTaskMutex, portMAX_DELAY); +using namespace OpenShock; - bool success = false; +bool _notifyTask(WiFiScanTaskNotificationFlags flags) +{ + ScopedLock lock__(&s_scanTaskMutex); - if (s_scanTaskHandle != nullptr) { - success = xTaskNotify(s_scanTaskHandle, flags, eSetBits) == pdPASS; + if (s_scanTaskHandle == nullptr) { + return false; } - xSemaphoreGive(s_scanTaskMutex); - - return success; + return xTaskNotify(s_scanTaskHandle, flags, eSetBits) == pdPASS; } -void _notifyStatusChangedHandlers(OpenShock::WiFiScanStatus status) { +void _notifyStatusChangedHandlers(OpenShock::WiFiScanStatus status) +{ for (auto& it : s_statusChangedHandlers) { it.second(status); } } -bool _isScanError(int16_t retval) { +bool _isScanError(int16_t retval) +{ return retval < 0 && retval != WIFI_SCAN_RUNNING; } -void _handleScanError(int16_t retval) { +void _handleScanError(int16_t retval) +{ if (retval >= 0) return; _notifyTask(WiFiScanTaskNotificationFlags::ERROR); @@ -73,7 +74,8 @@ void _handleScanError(int16_t retval) { OS_LOGE(TAG, "Scan returned an unknown error"); } -int16_t _scanChannel(uint8_t channel) { +int16_t _scanChannel(uint8_t channel) +{ int16_t retval = WiFi.scanNetworks(true, true, false, OPENSHOCK_WIFI_SCAN_MAX_MS_PER_CHANNEL, channel); if (!_isScanError(retval)) { return retval; @@ -84,7 +86,8 @@ int16_t _scanChannel(uint8_t channel) { return retval; } -WiFiScanStatus _scanningTaskImpl() { +WiFiScanStatus _scanningTaskImpl() +{ // Start the scan on the highest channel and work our way down uint8_t channel = OPENSHOCK_WIFI_SCAN_MAX_CHANNEL; @@ -138,25 +141,29 @@ WiFiScanStatus _scanningTaskImpl() { return WiFiScanStatus::Completed; } -void _scanningTask(void* arg) { +void _scanningTask(void* arg) +{ (void)arg; - + // Start the scan WiFiScanStatus status = _scanningTaskImpl(); // Notify the status changed handlers of the scan result _notifyStatusChangedHandlers(status); + s_scanTaskMutex.lock(portMAX_DELAY); + // Clear the task handle - xSemaphoreTake(s_scanTaskMutex, portMAX_DELAY); s_scanTaskHandle = nullptr; - xSemaphoreGive(s_scanTaskMutex); + + s_scanTaskMutex.unlock(); // Kill this task vTaskDelete(nullptr); } -void _evScanCompleted(arduino_event_id_t event, arduino_event_info_t info) { +void _evScanCompleted(arduino_event_id_t event, arduino_event_info_t info) +{ (void)event; (void)info; @@ -192,14 +199,16 @@ void _evScanCompleted(arduino_event_id_t event, arduino_event_info_t info) { // Notify the scan task that we're done _notifyTask(WiFiScanTaskNotificationFlags::CHANNEL_DONE); } -void _evSTAStopped(arduino_event_id_t event, arduino_event_info_t info) { +void _evSTAStopped(arduino_event_id_t event, arduino_event_info_t info) +{ (void)event; (void)info; _notifyTask(WiFiScanTaskNotificationFlags::WIFI_DISABLED); } -bool WiFiScanManager::Init() { +bool WiFiScanManager::Init() +{ if (s_initialized) { OS_LOGW(TAG, "WiFiScanManager is already initialized"); return true; @@ -213,40 +222,36 @@ bool WiFiScanManager::Init() { return true; } -bool WiFiScanManager::IsScanning() { +bool WiFiScanManager::IsScanning() +{ return s_scanTaskHandle != nullptr; } -bool WiFiScanManager::StartScan() { - xSemaphoreTake(s_scanTaskMutex, portMAX_DELAY); +bool WiFiScanManager::StartScan() +{ + ScopedLock lock__(&s_scanTaskMutex); // Check if a scan is already in progress if (s_scanTaskHandle != nullptr && eTaskGetState(s_scanTaskHandle) != eDeleted) { OS_LOGW(TAG, "Cannot start scan: scan task is already running"); - - xSemaphoreGive(s_scanTaskMutex); return false; } // Start the scan task if (TaskUtils::TaskCreateExpensive(_scanningTask, "WiFiScanManager", 4096, nullptr, 1, &s_scanTaskHandle) != pdPASS) { // PROFILED: 1.8KB stack usage OS_LOGE(TAG, "Failed to create scan task"); - - xSemaphoreGive(s_scanTaskMutex); return false; } - xSemaphoreGive(s_scanTaskMutex); return true; } -bool WiFiScanManager::AbortScan() { - xSemaphoreTake(s_scanTaskMutex, portMAX_DELAY); +bool WiFiScanManager::AbortScan() +{ + ScopedLock lock__(&s_scanTaskMutex); // Check if a scan is in progress if (s_scanTaskHandle == nullptr || eTaskGetState(s_scanTaskHandle) == eDeleted) { OS_LOGW(TAG, "Cannot abort scan: no scan is in progress"); - - xSemaphoreGive(s_scanTaskMutex); return false; } @@ -259,17 +264,18 @@ bool WiFiScanManager::AbortScan() { it.second(WiFiScanStatus::Aborted); } - xSemaphoreGive(s_scanTaskMutex); return true; } -uint64_t WiFiScanManager::RegisterStatusChangedHandler(const WiFiScanManager::StatusChangedHandler& handler) { - static uint64_t nextHandle = 0; - uint64_t handle = nextHandle++; +uint64_t WiFiScanManager::RegisterStatusChangedHandler(const WiFiScanManager::StatusChangedHandler& handler) +{ + static uint64_t nextHandle = 0; + uint64_t handle = nextHandle++; s_statusChangedHandlers[handle] = handler; return handle; } -void WiFiScanManager::UnregisterStatusChangedHandler(uint64_t handle) { +void WiFiScanManager::UnregisterStatusChangedHandler(uint64_t handle) +{ auto it = s_statusChangedHandlers.find(handle); if (it != s_statusChangedHandlers.end()) { @@ -277,13 +283,15 @@ void WiFiScanManager::UnregisterStatusChangedHandler(uint64_t handle) { } } -uint64_t WiFiScanManager::RegisterNetworksDiscoveredHandler(const WiFiScanManager::NetworksDiscoveredHandler& handler) { - static uint64_t nextHandle = 0; - uint64_t handle = nextHandle++; +uint64_t WiFiScanManager::RegisterNetworksDiscoveredHandler(const WiFiScanManager::NetworksDiscoveredHandler& handler) +{ + static uint64_t nextHandle = 0; + uint64_t handle = nextHandle++; s_networksDiscoveredHandlers[handle] = handler; return handle; } -void WiFiScanManager::UnregisterNetworksDiscoveredHandler(uint64_t handle) { +void WiFiScanManager::UnregisterNetworksDiscoveredHandler(uint64_t handle) +{ auto it = s_networksDiscoveredHandlers.find(handle); if (it != s_networksDiscoveredHandlers.end()) {