diff --git a/include/lockSet.h b/include/lockSet.h index 7d6d6a7..fce7212 100644 --- a/include/lockSet.h +++ b/include/lockSet.h @@ -30,9 +30,16 @@ public: * LockerAndInvoker that this LockSet has registered into that Qutex's * queue. */ - typedef std::pair< - std::reference_wrapper, - typename LockerAndInvokerBase::List::iterator> LockUsageDesc; + struct LockUsageDesc + { + std::reference_wrapper qutex; + typename LockerAndInvokerBase::List::iterator iterator; + bool hasBeenReleased = false; + + LockUsageDesc(std::reference_wrapper qutexRef, + typename LockerAndInvokerBase::List::iterator iter) + : qutex(qutexRef), iterator(iter), hasBeenReleased(false) {} + }; typedef std::vector> Set; @@ -90,7 +97,7 @@ public: */ for (auto& lockUsageDesc : locks) { - lockUsageDesc.second = lockUsageDesc.first.get().registerInQueue( + lockUsageDesc.iterator = lockUsageDesc.qutex.get().registerInQueue( lockvoker); } @@ -110,8 +117,8 @@ public: // Unregister from all qutex queues for (auto& lockUsageDesc : locks) { - auto it = lockUsageDesc.second; - lockUsageDesc.first.get().unregisterFromQueue(it); + auto it = lockUsageDesc.iterator; + lockUsageDesc.qutex.get().unregisterFromQueue(it); } } @@ -149,11 +156,11 @@ public: const int nRequiredLocks = static_cast(locks.size()); for (auto& lockUsageDesc : locks) { - if (!lockUsageDesc.first.get().tryAcquire( + if (!lockUsageDesc.qutex.get().tryAcquire( lockvoker, nRequiredLocks)) { // Set the first failed qutex for debugging - firstFailedQutex = std::ref(lockUsageDesc.first.get()); + firstFailedQutex = std::ref(lockUsageDesc.qutex.get()); break; } @@ -164,7 +171,7 @@ public: { // Release any locks we managed to acquire for (int i = 0; i < nAcquired; i++) { - locks[i].first.get().backoff(lockvoker, nRequiredLocks); + locks[i].qutex.get().backoff(lockvoker, nRequiredLocks); } return false; @@ -192,8 +199,11 @@ public: ": LockSet::release() called but allLocksAcquired is false"); } - for (auto& lockUsageDesc : locks) { - lockUsageDesc.first.get().release(); + for (auto& lockUsageDesc : locks) + { + if (lockUsageDesc.hasBeenReleased) { continue; } + + lockUsageDesc.qutex.get().release(); } allLocksAcquired = false; @@ -203,7 +213,7 @@ public: { for (auto& lockUsageDesc : locks) { - if (&lockUsageDesc.first.get() == &criterionLock) { + if (&lockUsageDesc.qutex.get() == &criterionLock) { return lockUsageDesc; } } @@ -214,6 +224,31 @@ public: ": Qutex not found in this LockSet"); } + /** + * @brief Release a specific qutex early and mark it as released + * @param qutex The qutex to release early + */ + void releaseQutexEarly(Qutex &qutex) + { + if (!allLocksAcquired) + { + throw std::runtime_error( + std::string(__func__) + + ": LockSet::releaseQutexEarly() called but allLocksAcquired is false"); + } + + auto& lockUsageDesc = const_cast( + getLockUsageDesc(qutex)); + + if (!lockUsageDesc.hasBeenReleased) + { + lockUsageDesc.qutex.get().release(); + lockUsageDesc.hasBeenReleased = true; + } + + return; + } + public: std::vector locks; diff --git a/include/serializedAsynchronousContinuation.h b/include/serializedAsynchronousContinuation.h index f8c4cb3..39b84b3 100644 --- a/include/serializedAsynchronousContinuation.h +++ b/include/serializedAsynchronousContinuation.h @@ -43,6 +43,13 @@ public: std::unique_ptr>> getAcquiredQutexHistory() const; + /** + * @brief Release a specific qutex early + * @param qutex The qutex to release early + */ + void releaseQutexEarly(Qutex &qutex) + { requiredLocks.releaseQutexEarly(qutex); } + public: LockSet requiredLocks; std::atomic isAwakeOrBeingAwakened{false}; @@ -109,7 +116,7 @@ public: getLockvokerIteratorForQutex(Qutex& qutex) const override { return serializedContinuation.requiredLocks.getLockUsageDesc( - qutex).second; + qutex).iterator; } /** @@ -133,7 +140,7 @@ public: Qutex& getLockAt(size_t index) const override { return serializedContinuation.requiredLocks.locks[index] - .first.get(); + .qutex.get(); } private: @@ -211,9 +218,9 @@ public: : serializedContinuation.requiredLocks.locks) { if (traceContinuationHistoryForDeadlockOn( - lockUsageDesc.first.get())) + lockUsageDesc.qutex.get())) { - return std::ref(lockUsageDesc.first.get()); + return std::ref(lockUsageDesc.qutex.get()); } } return std::nullopt; @@ -289,7 +296,7 @@ const // Add this continuation's locks to the held locks list for (size_t i = 0; i < serializedCont->requiredLocks.locks.size(); ++i) { - heldLocks->push_front(serializedCont->requiredLocks.locks[i].first); + heldLocks->push_front(serializedCont->requiredLocks.locks[i].qutex); } }