yuzu-tx-update/src/core/hle/kernel/address_arbiter.cpp

178 lines
6.1 KiB
C++
Raw Normal View History

// Copyright 2018 yuzu emulator team
// Licensed under GPLv2 or any later version
// Refer to the license.txt file included.
#include <algorithm>
#include <vector>
#include "common/assert.h"
#include "common/common_types.h"
#include "core/core.h"
#include "core/core_cpu.h"
#include "core/hle/kernel/address_arbiter.h"
#include "core/hle/kernel/errors.h"
#include "core/hle/kernel/object.h"
#include "core/hle/kernel/process.h"
#include "core/hle/kernel/scheduler.h"
#include "core/hle/kernel/thread.h"
#include "core/hle/result.h"
#include "core/memory.h"
namespace Kernel {
namespace {
2018-06-21 23:05:34 -04:00
// Wake up num_to_wake (or all) threads in a vector.
void WakeThreads(std::vector<SharedPtr<Thread>>& waiting_threads, s32 num_to_wake) {
2018-06-21 23:05:34 -04:00
// Only process up to 'target' threads, unless 'target' is <= 0, in which case process
// them all.
std::size_t last = waiting_threads.size();
2018-06-21 23:05:34 -04:00
if (num_to_wake > 0)
last = num_to_wake;
// Signal the waiting threads.
for (std::size_t i = 0; i < last; i++) {
ASSERT(waiting_threads[i]->GetStatus() == ThreadStatus::WaitArb);
2018-06-21 23:05:34 -04:00
waiting_threads[i]->SetWaitSynchronizationResult(RESULT_SUCCESS);
waiting_threads[i]->SetArbiterWaitAddress(0);
2018-06-21 23:05:34 -04:00
waiting_threads[i]->ResumeFromWait();
}
}
} // Anonymous namespace
2018-06-21 23:05:34 -04:00
AddressArbiter::AddressArbiter() = default;
AddressArbiter::~AddressArbiter() = default;
ResultCode AddressArbiter::SignalToAddress(VAddr address, s32 num_to_wake) {
std::vector<SharedPtr<Thread>> waiting_threads = GetThreadsWaitingOnAddress(address);
2018-06-21 23:05:34 -04:00
WakeThreads(waiting_threads, num_to_wake);
return RESULT_SUCCESS;
}
ResultCode AddressArbiter::IncrementAndSignalToAddressIfEqual(VAddr address, s32 value,
s32 num_to_wake) {
2018-06-21 23:05:34 -04:00
// Ensure that we can write to the address.
if (!Memory::IsValidVirtualAddress(address)) {
return ERR_INVALID_ADDRESS_STATE;
}
if (static_cast<s32>(Memory::Read32(address)) == value) {
Memory::Write32(address, static_cast<u32>(value + 1));
2018-06-21 23:05:34 -04:00
} else {
return ERR_INVALID_STATE;
}
return SignalToAddress(address, num_to_wake);
}
ResultCode AddressArbiter::ModifyByWaitingCountAndSignalToAddressIfEqual(VAddr address, s32 value,
s32 num_to_wake) {
2018-06-21 23:05:34 -04:00
// Ensure that we can write to the address.
if (!Memory::IsValidVirtualAddress(address)) {
return ERR_INVALID_ADDRESS_STATE;
}
// Get threads waiting on the address.
std::vector<SharedPtr<Thread>> waiting_threads = GetThreadsWaitingOnAddress(address);
2018-06-21 23:05:34 -04:00
// Determine the modified value depending on the waiting count.
s32 updated_value;
if (waiting_threads.empty()) {
2018-06-21 23:05:34 -04:00
updated_value = value - 1;
} else if (num_to_wake <= 0 || waiting_threads.size() <= static_cast<u32>(num_to_wake)) {
2018-06-21 23:05:34 -04:00
updated_value = value + 1;
} else {
updated_value = value;
}
if (static_cast<s32>(Memory::Read32(address)) == value) {
Memory::Write32(address, static_cast<u32>(updated_value));
2018-06-21 23:05:34 -04:00
} else {
return ERR_INVALID_STATE;
}
WakeThreads(waiting_threads, num_to_wake);
return RESULT_SUCCESS;
}
ResultCode AddressArbiter::WaitForAddressIfLessThan(VAddr address, s32 value, s64 timeout,
bool should_decrement) {
2018-06-21 23:05:34 -04:00
// Ensure that we can read the address.
if (!Memory::IsValidVirtualAddress(address)) {
return ERR_INVALID_ADDRESS_STATE;
}
s32 cur_value = static_cast<s32>(Memory::Read32(address));
2018-06-21 23:05:34 -04:00
if (cur_value < value) {
if (should_decrement) {
Memory::Write32(address, static_cast<u32>(cur_value - 1));
}
2018-06-21 23:05:34 -04:00
} else {
return ERR_INVALID_STATE;
}
// Short-circuit without rescheduling, if timeout is zero.
if (timeout == 0) {
return RESULT_TIMEOUT;
}
return WaitForAddress(address, timeout);
}
ResultCode AddressArbiter::WaitForAddressIfEqual(VAddr address, s32 value, s64 timeout) {
2018-06-21 23:05:34 -04:00
// Ensure that we can read the address.
if (!Memory::IsValidVirtualAddress(address)) {
return ERR_INVALID_ADDRESS_STATE;
}
// Only wait for the address if equal.
if (static_cast<s32>(Memory::Read32(address)) != value) {
2018-06-21 23:05:34 -04:00
return ERR_INVALID_STATE;
}
// Short-circuit without rescheduling, if timeout is zero.
if (timeout == 0) {
return RESULT_TIMEOUT;
}
return WaitForAddress(address, timeout);
}
ResultCode AddressArbiter::WaitForAddress(VAddr address, s64 timeout) {
SharedPtr<Thread> current_thread = GetCurrentThread();
current_thread->SetArbiterWaitAddress(address);
current_thread->SetStatus(ThreadStatus::WaitArb);
current_thread->InvalidateWakeupCallback();
current_thread->WakeAfterDelay(timeout);
Core::System::GetInstance().CpuCore(current_thread->GetProcessorID()).PrepareReschedule();
return RESULT_TIMEOUT;
}
std::vector<SharedPtr<Thread>> AddressArbiter::GetThreadsWaitingOnAddress(VAddr address) const {
const auto RetrieveWaitingThreads = [](std::size_t core_index,
std::vector<SharedPtr<Thread>>& waiting_threads,
VAddr arb_addr) {
const auto& scheduler = Core::System::GetInstance().Scheduler(core_index);
const auto& thread_list = scheduler.GetThreadList();
for (const auto& thread : thread_list) {
if (thread->GetArbiterWaitAddress() == arb_addr)
waiting_threads.push_back(thread);
}
};
// Retrieve all threads that are waiting for this address.
std::vector<SharedPtr<Thread>> threads;
RetrieveWaitingThreads(0, threads, address);
RetrieveWaitingThreads(1, threads, address);
RetrieveWaitingThreads(2, threads, address);
RetrieveWaitingThreads(3, threads, address);
// Sort them by priority, such that the highest priority ones come first.
std::sort(threads.begin(), threads.end(),
[](const SharedPtr<Thread>& lhs, const SharedPtr<Thread>& rhs) {
return lhs->GetPriority() < rhs->GetPriority();
});
return threads;
}
} // namespace Kernel