Use more complex algorythm for building new order of statements

This commit is contained in:
Egor Mayorov
2025-05-27 01:41:50 +03:00
committed by Alexander
parent bc46a7f239
commit 8399c9d591

View File

@@ -17,7 +17,8 @@ using namespace std;
unordered_set<int> loop_tags = {FOR_NODE/*, FORALL_NODE, WHILE_NODE, DO_WHILE_NODE*/};
unordered_set<int> importantDepsTags = {FOR_NODE, IF_NODE, ELSEIF_NODE};
unordered_set<int> importantDepsTags = {FOR_NODE, IF_NODE};
unordered_set<int> importantUpdDepsTags = {ELSEIF_NODE};
unordered_set<int> importantEndTags = {CONTROL_END};
@@ -38,7 +39,7 @@ vector<SAPFOR::IR_Block*> findInstructionsFromOperator(SgStatement* st, vector<S
return result;
}
vector<SAPFOR::BasicBlock*> findFuncBlocksByFuncStatement(SgStatement *st, std::map<FuncInfo*, std::vector<SAPFOR::BasicBlock*>>& FullIR)
vector<SAPFOR::BasicBlock*> findFuncBlocksByFuncStatement(SgStatement *st, map<FuncInfo*, vector<SAPFOR::BasicBlock*>>& FullIR)
{
vector<SAPFOR::BasicBlock*> result;
Statement* forSt = (Statement*)st;
@@ -86,7 +87,7 @@ map<SgStatement*, set<SgStatement*>> AnalyzeLoopAndFindDeps(SgForStmt* forStatem
{
map<SgStatement*, set<SgStatement*>> result;
for (SAPFOR::BasicBlock* bb: loopBlocks) {
std::map<SAPFOR::Argument*, std::set<int>> blockReachingDefinitions = bb -> getRD_In();
map<SAPFOR::Argument*, set<int>> blockReachingDefinitions = bb -> getRD_In();
vector<SAPFOR::IR_Block*> instructions = bb -> getInstructions();
for (SAPFOR::IR_Block* irBlock: instructions) {
// TODO: Think about what to do with function calls and array references. Because there are also dependencies there that are not reflected in RD, but they must be taken into account
@@ -100,7 +101,8 @@ map<SgStatement*, set<SgStatement*>> AnalyzeLoopAndFindDeps(SgForStmt* forStatem
SAPFOR::Instruction* foundInstruction = getInstructionAndBlockByNumber(FullIR, i).first;
if (foundInstruction != NULL) {
SgStatement* prevOp = foundInstruction -> getOperator();
if (prevOp != forStatement && instr -> getOperator() != forStatement && instr -> getOperator() -> lineNumber() > prevOp -> lineNumber())
if (prevOp != forStatement && instr -> getOperator() != forStatement && instr -> getOperator() -> lineNumber() > prevOp -> lineNumber()
&& prevOp -> lineNumber() > forStatement -> lineNumber())
result[instr -> getOperator()].insert(prevOp);
}
}
@@ -113,7 +115,8 @@ map<SgStatement*, set<SgStatement*>> AnalyzeLoopAndFindDeps(SgForStmt* forStatem
SAPFOR::Instruction* foundInstruction = getInstructionAndBlockByNumber(FullIR, i).first;
if (foundInstruction != NULL) {
SgStatement* prevOp = foundInstruction -> getOperator();
if (prevOp != forStatement && instr -> getOperator() != forStatement && instr -> getOperator() -> lineNumber() > prevOp -> lineNumber())
if (prevOp != forStatement && instr -> getOperator() != forStatement&& instr -> getOperator() -> lineNumber() > prevOp -> lineNumber()
&& prevOp -> lineNumber() > forStatement -> lineNumber())
result[instr -> getOperator()].insert(prevOp);
}
}
@@ -131,10 +134,11 @@ void buildAdditionalDeps(SgForStmt* forStatement, map<SgStatement*, set<SgStatem
SgStatement* lastNode = forStatement->lastNodeOfStmt();
vector<SgStatement*> importantDeps;
SgStatement* st = (SgStatement*) forStatement;
st = st -> lexNext();
SgStatement* logIfOp = NULL;
importantDeps.push_back(st);
while (st && st != lastNode)
{
if(importantDeps.size() != 0)
if (st != importantDeps.back()) {
dependencies[st].insert(importantDeps.back());
}
@@ -148,123 +152,117 @@ void buildAdditionalDeps(SgForStmt* forStatement, map<SgStatement*, set<SgStatem
if (importantDepsTags.find(st -> variant()) != importantDepsTags.end()) {
importantDeps.push_back(st);
}
if (importantUpdDepsTags.find(st -> variant()) != importantUpdDepsTags.end()) {
importantDeps.pop_back();
importantDeps.push_back(st);
}
if (importantEndTags.find(st -> variant()) != importantEndTags.end()) {
if(importantDeps.size() != 0)
importantDeps.pop_back();
}
st = st -> lexNext();
}
}
void GenNodesOfGraph(
const map<SgStatement*, set<SgStatement*>>& dependencies,
set<SgStatement*>& allNodes,
map<SgStatement*, set<SgStatement*>>& outEdges,
map<SgStatement*, set<SgStatement*>>& inEdges)
{
for (const auto& node: dependencies) {
SgStatement* u = node.first;
allNodes.insert(u);
for (SgStatement* v: node.second) {
allNodes.insert(v);
outEdges[v].insert(u);
inEdges[u].insert(v);
outEdges[u];
inEdges[v];
struct ReadyOp {
SgStatement* stmt;
int degree;
size_t arrival;
ReadyOp(SgStatement* s, int d, size_t a): stmt(s), degree(d), arrival(a) {}
};
struct ReadyOpCompare {
bool operator()(const ReadyOp& a, const ReadyOp& b) const {
if (a.degree != b.degree)
return a.degree > b.degree;
else
return a.arrival > b.arrival;
}
outEdges[u];
inEdges[u];
};
vector<SgStatement*> scheduleOperations(
const map<SgStatement*, set<SgStatement*>>& dependencies
) {
// get all statements
unordered_set<SgStatement*> allStmtsSet;
for (const auto& pair : dependencies) {
allStmtsSet.insert(pair.first);
for (SgStatement* dep : pair.second)
allStmtsSet.insert(dep);
}
vector<SgStatement*> allStmts(allStmtsSet.begin(), allStmtsSet.end());
// count deps and build reversed graph
unordered_map<SgStatement*, vector<SgStatement*>> graph;
unordered_map<SgStatement*, int> inDegree;
unordered_map<SgStatement*, int> degree;
for (auto op : allStmts)
inDegree[op] = 0;
// find and remember initial dependencies
unordered_set<SgStatement*> dependentStmts;
for (const auto& pair : dependencies) {
SgStatement* op = pair.first;
const auto& deps = pair.second;
degree[op] = deps.size();
inDegree[op] = deps.size();
if (!deps.empty())
dependentStmts.insert(op);
for (auto dep : deps)
graph[dep].push_back(op);
}
for (SgStatement* op : allStmts)
if (!degree.count(op))
degree[op] = 0;
// build queues
using PQ = priority_queue<ReadyOp, vector<ReadyOp>, ReadyOpCompare>;
PQ readyDependent;
queue<SgStatement*> readyIndependent;
size_t arrivalCounter = 0;
for (auto op : allStmts) {
if (inDegree[op] == 0) {
if (dependentStmts.count(op)) {
readyDependent.emplace(op, degree[op], arrivalCounter++);
} else {
readyIndependent.push(op);
}
}
}
// main sort algorythm
vector<SgStatement*> executionOrder;
while (!readyDependent.empty() || !readyIndependent.empty()) {
SgStatement* current = nullptr;
if (!readyDependent.empty()) {
current = readyDependent.top().stmt;
readyDependent.pop();
} else {
current = readyIndependent.front();
readyIndependent.pop();
}
executionOrder.push_back(current);
for (SgStatement* neighbor : graph[current]) {
inDegree[neighbor]--;
if (inDegree[neighbor] == 0) {
if (dependentStmts.count(neighbor)) {
readyDependent.emplace(neighbor, degree[neighbor], arrivalCounter++);
} else {
readyIndependent.push(neighbor);
}
}
}
}
return executionOrder;
}
vector<set<SgStatement*>> FindLinksInGraph(
const set<SgStatement*>& allNodes,
const map<SgStatement*, set<SgStatement*>>& outEdges,
const map<SgStatement*, set<SgStatement*>>& inEdges)
{
set<SgStatement*> visited;
vector<std::set<SgStatement*>> components;
for (SgStatement* v: allNodes) {
if (visited.count(v)) {
continue;
}
set<SgStatement*> component;
queue<SgStatement*> q;
q.push(v);
visited.insert(v);
while (!q.empty()) {
SgStatement* curr = q.front();
q.pop();
component.insert(curr);
for (SgStatement* neighbour: outEdges.at(curr)) {
if (!visited.count(neighbour)) {
q.push(neighbour); visited.insert(neighbour);
}
}
for (SgStatement* neighbour: inEdges.at(curr)) {
if (!visited.count(neighbour)) {
q.push(neighbour); visited.insert(neighbour);
}
}
}
components.push_back(component);
}
return components;
}
vector<SgStatement*> SortComponent(
const set<SgStatement*>& component,
const map<SgStatement*, set<SgStatement*>>& outEdges,
const map<SgStatement*, set<SgStatement*>>& inEdges)
{
map<SgStatement*, int> inDegree;
for (auto v: component) {
inDegree[v] = inEdges.at(v).size();
}
queue<SgStatement*> q;
for (auto v : component) {
if (inDegree[v] == 0) q.push(v);
}
vector<SgStatement*> result;
while (!q.empty()) {
auto curr = q.front();
q.pop();
result.push_back(curr);
for (SgStatement* neighbour: outEdges.at(curr)) {
if (component.count(neighbour)) {
inDegree[neighbour]--;
if (inDegree[neighbour] == 0) {
q.push(neighbour);
}
}
}
}
return result;
}
vector<SgStatement*> SortNoInterleaving(const map<SgStatement*, set<SgStatement*>>& dependencies)
{
set<SgStatement*> allNodes;
map<SgStatement*, set<SgStatement*>> outEdges, inEdges;
GenNodesOfGraph(dependencies, allNodes, outEdges, inEdges);
auto components = FindLinksInGraph(allNodes, outEdges, inEdges);
vector<SgStatement*> totalOrder;
for (auto& comp : components) {
auto part = SortComponent(comp, outEdges, inEdges);
totalOrder.insert(totalOrder.end(), part.begin(), part.end());
}
return totalOrder;
}
void runSwapOperators(SgFile *file, std::map<std::string, std::vector<LoopGraph*>>& loopGraph, std::map<FuncInfo*, std::vector<SAPFOR::BasicBlock*>>& FullIR, int& countOfTransform)
{
std::cout << "SWAP_OPERATORS Pass" << std::endl; // to remove
countOfTransform += 1; // to remove
const int funcNum = file->numberOfFunctions();
const int funcNum = file -> numberOfFunctions();
for (int i = 0; i < funcNum; ++i)
{
SgStatement *st = file->functions(i);
SgStatement *st = file -> functions(i);
vector<SAPFOR::BasicBlock*> blocks = findFuncBlocksByFuncStatement(st, FullIR);
map<SgForStmt*, vector<SAPFOR::BasicBlock*>> loopsMapping = findAndAnalyzeLoops(st, blocks);
for (pair<SgForStmt*, vector<SAPFOR::BasicBlock*>> loopForAnalyze: loopsMapping)
@@ -272,27 +270,24 @@ void runSwapOperators(SgFile *file, std::map<std::string, std::vector<LoopGraph*
map<SgStatement*, set<SgStatement*>> dependencyGraph = AnalyzeLoopAndFindDeps(loopForAnalyze.first, loopForAnalyze.second, FullIR);
// TODO: Write a function that will go through the operators and update all dependencies so that there are no mix-ups and splits inside the semantic blocks (for if, do and may be some other cases)
buildAdditionalDeps(loopForAnalyze.first, dependencyGraph);
cout << "\n\n";
cout << endl;
int firstLine = loopForAnalyze.first -> lineNumber();
int lastLine = loopForAnalyze.first -> lastNodeOfStmt() -> lineNumber();
// for (auto v: dependencyGraph) {
// cout << "OPERATOR: " << v.first -> lineNumber() << "\nDEPENDS ON:" << endl;
cout << "LOOP ANALYZE FROM " << firstLine << " TO " << lastLine << " RES" << endl;
// for (auto &v: dependencyGraph) {
// cout << "OPERATOR: " << v.first -> lineNumber() << " " << v.first -> variant() << "\nDEPENDS ON:" << endl;
// if (v.second.size() != 0)
// for (auto vv: v.second) {
// if (vv -> lineNumber() > firstLine)
// for (auto vv: v.second)
// cout << vv -> lineNumber() << " ";
// }
// cout << endl;
// }
if (dependencyGraph.size() != 0) {
vector<SgStatement*> new_order = SortNoInterleaving(dependencyGraph);
cout << "\n\nLOOP ANALYZE FROM " << firstLine << " TO " << lastLine << " RES\n" << endl;
vector<SgStatement*> new_order = scheduleOperations(dependencyGraph);
cout << "RESULT ORDER:" << endl;
for (auto v: new_order)
if (v -> lineNumber() > firstLine)
cout << v -> lineNumber() << endl;
}
}
}
return;
};