Add Task Graph

This commit is contained in:
Wojtek Figat
2021-06-12 22:43:37 +02:00
parent 25c00a0d55
commit d7e7dcc823
3 changed files with 215 additions and 11 deletions

View File

@@ -493,17 +493,8 @@ public:
/// Adds the other collection to the collection.
/// </summary>
/// <param name="other">The other collection to add.</param>
FORCE_INLINE void Add(const Array& other)
{
Add(other.Get(), other.Count());
}
/// <summary>
/// Adds the other collection to the collection.
/// </summary>
/// <param name="other">The other collection to add.</param>
template<typename U>
FORCE_INLINE void Add(const Array<U>& other)
template<typename OtherT, typename OtherAllocationType = HeapAllocation>
FORCE_INLINE void Add(const Array<OtherT, OtherAllocationType>& other)
{
Add(other.Get(), other.Count());
}

View File

@@ -0,0 +1,118 @@
// Copyright (c) 2012-2021 Wojciech Figat. All rights reserved.
#include "TaskGraph.h"
#include "JobSystem.h"
#include "Engine/Core/Collections/Sorting.h"
#include "Engine/Profiler/ProfilerCPU.h"
namespace
{
bool SortTaskGraphSystem(TaskGraphSystem* const& a, TaskGraphSystem* const& b)
{
return b->Order < a->Order;
};
}
TaskGraphSystem::TaskGraphSystem(const SpawnParams& params)
: PersistentScriptingObject(params)
{
}
void TaskGraphSystem::AddDependency(TaskGraphSystem* system)
{
_dependencies.Add(system);
}
void TaskGraphSystem::PreExecute(TaskGraph* graph)
{
}
void TaskGraphSystem::Execute(TaskGraph* graph)
{
}
void TaskGraphSystem::PostExecute(TaskGraph* graph)
{
}
TaskGraph::TaskGraph(const SpawnParams& params)
: PersistentScriptingObject(params)
{
}
const Array<TaskGraphSystem*, InlinedAllocation<64>>& TaskGraph::GetSystems() const
{
return _systems;
}
void TaskGraph::AddSystem(TaskGraphSystem* system)
{
_systems.Add(system);
}
void TaskGraph::RemoveSystem(TaskGraphSystem* system)
{
_systems.Remove(system);
}
void TaskGraph::Execute()
{
PROFILE_CPU();
for (auto system : _systems)
system->PreExecute(this);
_queue.Clear();
_remaining.Clear();
_remaining.Add(_systems);
while (_remaining.HasItems())
{
// Find systems without dependencies or with already executed dependencies
for (int32 i = _remaining.Count() - 1; i >= 0; i--)
{
auto e = _remaining[i];
bool hasReadyDependencies = true;
for (auto d : e->_dependencies)
{
if (_remaining.Contains(d))
{
hasReadyDependencies = false;
break;
}
}
if (hasReadyDependencies)
{
_queue.Add(e);
_remaining.RemoveAt(i);
}
}
// End if no systems left
if (_queue.IsEmpty())
break;
// Execute in order
Sorting::QuickSort(_queue.Get(), _queue.Count(), &SortTaskGraphSystem);
_currentLabel = 0;
for (int32 i = 0; i < _queue.Count(); i++)
{
_currentSystem = _queue[i];
_currentSystem->Execute(this);
}
_currentSystem = nullptr;
_queue.Clear();
// Wait for async jobs to finish
JobSystem::Wait(_currentLabel);
}
for (auto system : _systems)
system->PostExecute(this);
}
void TaskGraph::DispatchJob(const Function<void(int32)>& job, int32 jobCount)
{
ASSERT(_currentSystem);
_currentLabel = JobSystem::Dispatch(job, jobCount);
}

View File

@@ -0,0 +1,95 @@
// Copyright (c) 2012-2021 Wojciech Figat. All rights reserved.
#pragma once
#include "Engine/Scripting/ScriptingObject.h"
#include "Engine/Core/Collections/Array.h"
class TaskGraph;
/// <summary>
/// System that can generate work into Task Graph for asynchronous execution.
/// </summary>
API_CLASS(Abstract) class FLAXENGINE_API TaskGraphSystem : public PersistentScriptingObject
{
DECLARE_SCRIPTING_TYPE(TaskGraphSystem);
friend TaskGraph;
private:
Array<TaskGraphSystem*, InlinedAllocation<16>> _dependencies;
public:
/// <summary>
/// The execution order of the system (systems with higher order are executed earlier).
/// </summary>
API_FIELD() int32 Order = 0;
public:
/// <summary>
/// Adds the dependency on the system execution. Before this system can be executed the given dependant system has to be executed first.
/// </summary>
/// <param name="system">The system to depend on.</param>
API_FUNCTION() void AddDependency(TaskGraphSystem* system);
/// <summary>
/// Called before executing any systems of the graph. Can be used to initialize data (synchronous).
/// </summary>
/// <param name="graph">The graph executing the system.</param>
API_FUNCTION() virtual void PreExecute(TaskGraph* graph);
/// <summary>
/// Executes the system logic and schedules the asynchronous work.
/// </summary>
/// <param name="graph">The graph executing the system.</param>
API_FUNCTION() virtual void Execute(TaskGraph* graph);
/// <summary>
/// Called after executing all systems of the graph. Can be used to cleanup data (synchronous).
/// </summary>
/// <param name="graph">The graph executing the system.</param>
API_FUNCTION() virtual void PostExecute(TaskGraph* graph);
};
/// <summary>
/// Graph-based asynchronous tasks scheduler for high-performance computing and processing.
/// </summary>
API_CLASS() class FLAXENGINE_API TaskGraph : public PersistentScriptingObject
{
DECLARE_SCRIPTING_TYPE(TaskGraph);
private:
Array<TaskGraphSystem*, InlinedAllocation<64>> _systems;
Array<TaskGraphSystem*, InlinedAllocation<64>> _remaining;
Array<TaskGraphSystem*, InlinedAllocation<64>> _queue;
TaskGraphSystem* _currentSystem = nullptr;
int64 _currentLabel = 0;
public:
/// <summary>
/// Gets the list of systems.
/// </summary>
API_PROPERTY() const Array<TaskGraphSystem*, InlinedAllocation<64>>& GetSystems() const;
/// <summary>
/// Adds the system to the graph for the execution.
/// </summary>
/// <param name="system">The system to add.</param>
API_FUNCTION() void AddSystem(TaskGraphSystem* system);
/// <summary>
/// Removes the system from the graph.
/// </summary>
/// <param name="system">The system to add.</param>
API_FUNCTION() void RemoveSystem(TaskGraphSystem* system);
/// <summary>
/// Schedules the asynchronous systems execution including ordering and dependencies handling.
/// </summary>
API_FUNCTION() void Execute();
/// <summary>
/// Dispatches the job for the execution.
/// </summary>
/// <remarks>Call only from system's Execute method to properly schedule job.</remarks>
/// <param name="job">The job. Argument is an index of the job execution.</param>
/// <param name="jobCount">The job executions count.</param>
API_FUNCTION() void DispatchJob(const Function<void(int32)>& job, int32 jobCount = 1);
};