Make pass correct

This commit is contained in:
Egor Mayorov
2025-10-23 14:54:43 +03:00
parent 2746d072d4
commit e9d5a2ee70

View File

@@ -3,6 +3,7 @@
#include <vector> #include <vector>
#include <queue> #include <queue>
#include <iostream> #include <iostream>
#include <algorithm>
#include "../../Utils/errors.h" #include "../../Utils/errors.h"
#include "../../Utils/SgUtils.h" #include "../../Utils/SgUtils.h"
@@ -15,319 +16,206 @@
using namespace std; using namespace std;
string getNameByArg(SAPFOR::Argument* arg);
SgSymbol* getSybolByArg(SAPFOR::Argument* arg);
unordered_set<int> loop_tags = {FOR_NODE/*, FORALL_NODE, WHILE_NODE, DO_WHILE_NODE*/}; static vector<SAPFOR::IR_Block*> findInstructionsFromOperator(SgStatement* st, vector<SAPFOR::BasicBlock*> Blocks) {
unordered_set<int> importantDepsTags = {FOR_NODE, IF_NODE};
unordered_set<int> importantUpdDepsTags = {ELSEIF_NODE};
unordered_set<int> importantEndTags = {CONTROL_END};
vector<SAPFOR::IR_Block*> findInstructionsFromOperator(SgStatement* st, vector<SAPFOR::BasicBlock*> Blocks)
{
vector<SAPFOR::IR_Block*> result; vector<SAPFOR::IR_Block*> result;
string filename = st -> fileName(); string filename = st->fileName();
for (auto& block: Blocks)
{ for (auto& block: Blocks) {
vector<SAPFOR::IR_Block*> instructionsInBlock = block -> getInstructions(); vector<SAPFOR::IR_Block*> instructionsInBlock = block->getInstructions();
for (auto& instruction: instructionsInBlock) for (auto& instruction: instructionsInBlock) {
{ SgStatement* curOperator = instruction->getInstruction()->getOperator();
SgStatement* curOperator = instruction -> getInstruction() -> getOperator(); // Match by line number to find corresponding IR instruction
if (curOperator -> lineNumber() == st -> lineNumber()) if (curOperator->lineNumber() == st->lineNumber()) {
result.push_back(instruction); result.push_back(instruction);
}
} }
} }
return result; return result;
} }
vector<SAPFOR::BasicBlock*> findFuncBlocksByFuncStatement(SgStatement *st, map<FuncInfo*, vector<SAPFOR::BasicBlock*>>& FullIR) unordered_set<int> loop_tags = {FOR_NODE};
{ unordered_set<int> control_tags = {IF_NODE, ELSEIF_NODE, DO_WHILE_NODE, WHILE_NODE};
vector<SAPFOR::BasicBlock*> result; unordered_set<int> control_end_tags = {CONTROL_END};
Statement* forSt = (Statement*)st;
for (auto& func: FullIR)
{
if (func.first -> funcPointer -> getCurrProcessFile() == forSt -> getCurrProcessFile()
&& func.first -> funcPointer -> lineNumber() == forSt -> lineNumber())
result = func.second;
}
return result;
}
map<SgForStmt*, vector<SAPFOR::BasicBlock*>> findAndAnalyzeLoops(SgStatement *st, vector<SAPFOR::BasicBlock*> blocks) struct OperatorInfo {
{
map<SgForStmt*, vector<SAPFOR::BasicBlock*>> result;
SgStatement *lastNode = st->lastNodeOfStmt();
while (st && st != lastNode)
{
if (loop_tags.find(st -> variant()) != loop_tags.end())
{
// part with find statements of loop
SgForStmt *forSt = (SgForStmt*)st;
SgStatement *loopBody = forSt -> body();
SgStatement *lastLoopNode = st->lastNodeOfStmt();
// part with find blocks and instructions of loops
unordered_set<int> blocks_nums;
while (loopBody && loopBody != lastLoopNode)
{
SAPFOR::IR_Block* IR = findInstructionsFromOperator(loopBody, blocks).front();
if (blocks_nums.find(IR -> getBasicBlock() -> getNumber()) == blocks_nums.end())
{
result[forSt].push_back(IR -> getBasicBlock());
blocks_nums.insert(IR -> getBasicBlock() -> getNumber());
}
loopBody = loopBody -> lexNext();
}
std::sort(result[forSt].begin(), result[forSt].end());
}
st = st -> lexNext();
}
return result;
}
map<SgStatement*, set<SgStatement*>> AnalyzeLoopAndFindDeps(SgForStmt* forStatement, vector<SAPFOR::BasicBlock*> loopBlocks, map<FuncInfo*, vector<SAPFOR::BasicBlock*>>& FullIR)
{
map<SgStatement*, set<SgStatement*>> result;
for (SAPFOR::BasicBlock* bb: loopBlocks)
{
map<SAPFOR::Argument*, set<int>> blockReachingDefinitions = bb -> getRD_In();
vector<SAPFOR::IR_Block*> instructions = bb -> getInstructions();
for (SAPFOR::IR_Block* irBlock: instructions)
{
// TODO: Think about what to do with function calls and array references. Because there are also dependencies there that are not reflected in RD, but they must be taken into account
SAPFOR::Instruction* instr = irBlock -> getInstruction();
result[instr -> getOperator()];
// take Argument 1 and it's RD and push operators to final set
if (instr -> getArg1() != NULL)
{
SAPFOR::Argument* arg = instr -> getArg1();
set<int> prevInstructionsNumbers = blockReachingDefinitions[arg];
for (int i: prevInstructionsNumbers)
{
SAPFOR::Instruction* foundInstruction = getInstructionAndBlockByNumber(FullIR, i).first;
if (foundInstruction != NULL)
{
SgStatement* prevOp = foundInstruction -> getOperator();
if (prevOp != forStatement && instr -> getOperator() != forStatement && instr -> getOperator() -> lineNumber() > prevOp -> lineNumber()
&& prevOp -> lineNumber() > forStatement -> lineNumber())
result[instr -> getOperator()].insert(prevOp);
}
}
}
// take Argument 2 (if exists) and it's RD and push operators to final set
if (instr -> getArg2() != NULL)
{
SAPFOR::Argument* arg = instr -> getArg2();
set<int> prevInstructionsNumbers = blockReachingDefinitions[arg];
for (int i: prevInstructionsNumbers)
{
SAPFOR::Instruction* foundInstruction = getInstructionAndBlockByNumber(FullIR, i).first;
if (foundInstruction != NULL)
{
SgStatement* prevOp = foundInstruction -> getOperator();
if (prevOp != forStatement && instr -> getOperator() != forStatement&& instr -> getOperator() -> lineNumber() > prevOp -> lineNumber()
&& prevOp -> lineNumber() > forStatement -> lineNumber())
result[instr -> getOperator()].insert(prevOp);
}
}
}
// update RD
if (instr -> getResult() != NULL)
blockReachingDefinitions[instr -> getResult()] = {instr -> getNumber()};
}
}
return result;
}
void buildAdditionalDeps(SgForStmt* forStatement, map<SgStatement*, set<SgStatement*>>& dependencies)
{
SgStatement* lastNode = forStatement->lastNodeOfStmt();
vector<SgStatement*> importantDeps;
SgStatement* st = (SgStatement*) forStatement;
st = st -> lexNext();
SgStatement* logIfOp = NULL;
while (st && st != lastNode)
{
if(importantDeps.size() != 0)
{
if (st != importantDeps.back())
{
dependencies[st].insert(importantDeps.back());
}
}
if (logIfOp != NULL)
{
dependencies[st].insert(logIfOp);
logIfOp = NULL;
}
if (st -> variant() == LOGIF_NODE)
{
logIfOp = st;
}
if (importantDepsTags.find(st -> variant()) != importantDepsTags.end())
{
importantDeps.push_back(st);
}
if (importantUpdDepsTags.find(st -> variant()) != importantUpdDepsTags.end())
{
importantDeps.pop_back();
importantDeps.push_back(st);
}
if (importantEndTags.find(st -> variant()) != importantEndTags.end())
{
if(importantDeps.size() != 0)
{
importantDeps.pop_back();
}
}
st = st -> lexNext();
}
}
struct ReadyOp {
SgStatement* stmt; SgStatement* stmt;
int degree; set<string> usedVars;
size_t arrival; set<string> definedVars;
ReadyOp(SgStatement* s, int d, size_t a): stmt(s), degree(d), arrival(a) {} int lineNumber;
bool isMovable;
OperatorInfo(SgStatement* s) : stmt(s), lineNumber(s->lineNumber()), isMovable(true) {}
}; };
struct ReadyOpCompare { static vector<OperatorInfo> analyzeOperatorsInLoop(SgForStmt* loop, vector<SAPFOR::BasicBlock*> blocks, map<FuncInfo*, vector<SAPFOR::BasicBlock*>>& FullIR) {
bool operator()(const ReadyOp& a, const ReadyOp& b) const { vector<OperatorInfo> operators;
if (a.degree != b.degree)
return a.degree > b.degree;
else
return a.arrival > b.arrival;
}
};
vector<SgStatement*> scheduleOperations(const map<SgStatement*, set<SgStatement*>>& dependencies)
{
// get all statements
unordered_set<SgStatement*> allStmtsSet;
for (const auto& pair : dependencies)
{
allStmtsSet.insert(pair.first);
for (SgStatement* dep : pair.second)
{
allStmtsSet.insert(dep);
}
}
vector<SgStatement*> allStmts(allStmtsSet.begin(), allStmtsSet.end());
// count deps and build reversed graph
unordered_map<SgStatement*, vector<SgStatement*>> graph;
unordered_map<SgStatement*, int> inDegree;
unordered_map<SgStatement*, int> degree;
for (auto op : allStmts)
inDegree[op] = 0;
// find and remember initial dependencies
unordered_set<SgStatement*> dependentStmts;
for (const auto& pair : dependencies)
{
SgStatement* op = pair.first;
const auto& deps = pair.second;
degree[op] = deps.size();
inDegree[op] = deps.size();
if (!deps.empty())
dependentStmts.insert(op);
for (auto dep : deps)
graph[dep].push_back(op);
}
for (SgStatement* op : allStmts)
{
if (!degree.count(op))
{
degree[op] = 0;
}
}
// build queues
using PQ = priority_queue<ReadyOp, vector<ReadyOp>, ReadyOpCompare>;
PQ readyDependent;
queue<SgStatement*> readyIndependent;
size_t arrivalCounter = 0;
for (auto op : allStmts)
{
if (inDegree[op] == 0)
{
if (dependentStmts.count(op))
{
readyDependent.emplace(op, degree[op], arrivalCounter++);
}
else
{
readyIndependent.push(op);
}
}
}
// main sort algorythm
vector<SgStatement*> executionOrder;
while (!readyDependent.empty() || !readyIndependent.empty())
{
SgStatement* current = nullptr;
if (!readyDependent.empty())
{
current = readyDependent.top().stmt;
readyDependent.pop();
}
else
{
current = readyIndependent.front();
readyIndependent.pop();
}
executionOrder.push_back(current);
for (SgStatement* neighbor : graph[current])
{
inDegree[neighbor]--;
if (inDegree[neighbor] == 0) {
if (dependentStmts.count(neighbor))
{
readyDependent.emplace(neighbor, degree[neighbor], arrivalCounter++);
}
else
{
readyIndependent.push(neighbor);
}
}
}
}
return executionOrder;
}
static bool buildNewAST(SgStatement* loop, vector<SgStatement*>& newBody)
{
if (!loop) {return false;}
if (newBody.empty()) {return true;}
if (loop->variant() != FOR_NODE) {return false;}
SgStatement* loopStart = loop->lexNext(); SgStatement* loopStart = loop->lexNext();
SgStatement* loopEnd = loop->lastNodeOfStmt(); SgStatement* loopEnd = loop->lastNodeOfStmt();
if (!loopStart || !loopEnd) {return false;}
SgStatement* current = loopStart;
while (current && current != loopEnd) {
if (isSgExecutableStatement(current)) {
OperatorInfo opInfo(current);
vector<SAPFOR::IR_Block*> irBlocks = findInstructionsFromOperator(current, blocks);
for (auto irBlock : irBlocks) {
SAPFOR::Instruction* instr = irBlock->getInstruction();
if (instr->getArg1()) {
string varName = getNameByArg(instr->getArg1());
if (!varName.empty()) {
opInfo.usedVars.insert(varName);
}
}
if (instr->getArg2()) {
string varName = getNameByArg(instr->getArg2());
if (!varName.empty()) {
opInfo.usedVars.insert(varName);
}
}
if (instr->getResult()) {
string varName = getNameByArg(instr->getResult());
if (!varName.empty()) {
opInfo.definedVars.insert(varName);
}
}
}
if (control_tags.find(current->variant()) != control_tags.end()) {
opInfo.isMovable = false;
}
operators.push_back(opInfo);
}
current = current->lexNext();
}
return operators;
}
for (SgStatement* stmt : newBody) { static map<string, vector<SgStatement*>> findVariableDefinitions(SgForStmt* loop, vector<OperatorInfo>& operators) {
if (stmt && stmt != loop && stmt != loopEnd) { map<string, vector<SgStatement*>> varDefinitions;
SgStatement* current = loopStart;
bool found = false; for (auto& op : operators) {
while (current && current != loopEnd->lexNext()) { for (const string& var : op.definedVars) {
if (current == stmt) { varDefinitions[var].push_back(op.stmt);
found = true; }
}
return varDefinitions;
}
static int calculateDistance(SgStatement* from, SgStatement* to) {
if (!from || !to) return INT_MAX;
return abs(to->lineNumber() - from->lineNumber());
}
static SgStatement* findBestPosition(SgStatement* operatorStmt, vector<OperatorInfo>& operators, map<string, vector<SgStatement*>>& varDefinitions) {
OperatorInfo* opInfo = nullptr;
for (auto& op : operators) {
if (op.stmt == operatorStmt) {
opInfo = &op;
break;
}
}
if (!opInfo || !opInfo->isMovable) return nullptr;
SgStatement* bestPos = nullptr;
int minDistance = INT_MAX;
for (const string& usedVar : opInfo->usedVars) {
if (varDefinitions.find(usedVar) != varDefinitions.end()) {
for (SgStatement* defStmt : varDefinitions[usedVar]) {
int distance = calculateDistance(operatorStmt, defStmt);
if (distance < minDistance) {
minDistance = distance;
bestPos = defStmt;
}
}
}
}
return bestPos;
}
static bool canMoveTo(SgStatement* from, SgStatement* to, SgForStmt* loop) {
if (!from || !to || from == to) return false;
SgStatement* loopStart = loop->lexNext();
SgStatement* loopEnd = loop->lastNodeOfStmt();
if (to->lineNumber() < loopStart->lineNumber() || to->lineNumber() > loopEnd->lineNumber()) {
return false;
}
SgStatement* current = from;
while (current && current != loopEnd) {
if (control_tags.find(current->variant()) != control_tags.end()) {
return false;
}
if (current == to) break;
current = current->lexNext();
}
return true;
}
static vector<SgStatement*> optimizeOperatorOrder(SgForStmt* loop, vector<OperatorInfo>& operators, map<string, vector<SgStatement*>>& varDefinitions) {
vector<SgStatement*> newOrder;
vector<bool> moved(operators.size(), false);
for (size_t i = 0; i < operators.size(); i++) {
if (moved[i] || !operators[i].isMovable) {
newOrder.push_back(operators[i].stmt);
moved[i] = true;
continue;
}
SgStatement* bestPos = findBestPosition(operators[i].stmt, operators, varDefinitions);
if (bestPos && canMoveTo(operators[i].stmt, bestPos, loop)) {
bool inserted = false;
for (size_t j = 0; j < newOrder.size(); j++) {
if (newOrder[j] == bestPos) {
newOrder.insert(newOrder.begin() + j + 1, operators[i].stmt);
inserted = true;
break; break;
} }
current = current->lexNext();
} }
if (!found) {return false;} if (!inserted) {
newOrder.push_back(operators[i].stmt);
}
} else {
newOrder.push_back(operators[i].stmt);
} }
moved[i] = true;
} }
return newOrder;
}
static bool applyOperatorReordering(SgForStmt* loop, vector<SgStatement*>& newOrder) {
if (!loop || newOrder.empty()) return false;
SgStatement* loopStart = loop->lexNext();
SgStatement* loopEnd = loop->lastNodeOfStmt();
vector<SgStatement*> extractedStatements; vector<SgStatement*> extractedStatements;
vector<char*> savedComments; vector<char*> savedComments;
vector<int> savedLineNumbers;
for (SgStatement* stmt : newOrder) {
for (SgStatement* stmt : newBody) {
if (stmt && stmt != loop && stmt != loopEnd) { if (stmt && stmt != loop && stmt != loopEnd) {
savedComments.push_back(stmt->comments() ? strdup(stmt->comments()) : nullptr); savedComments.push_back(stmt->comments() ? strdup(stmt->comments()) : nullptr);
savedLineNumbers.push_back(stmt->lineNumber());
SgStatement* extracted = stmt->extractStmt(); SgStatement* extracted = stmt->extractStmt();
if (extracted) {extractedStatements.push_back(extracted);} if (extracted) {
extractedStatements.push_back(extracted);
}
} }
} }
SgStatement* currentPos = loop; SgStatement* currentPos = loop;
int lineCounter = loop->lineNumber() + 1; int lineCounter = loop->lineNumber() + 1;
@@ -342,13 +230,13 @@ static bool buildNewAST(SgStatement* loop, vector<SgStatement*>& newBody)
currentPos = stmt; currentPos = stmt;
} }
} }
for (char* comment : savedComments) { for (char* comment : savedComments) {
if (comment) { if (comment) {
free(comment); free(comment);
} }
} }
if (currentPos && currentPos->lexNext() != loopEnd) { if (currentPos && currentPos->lexNext() != loopEnd) {
currentPos->setLexNext(*loopEnd); currentPos->setLexNext(*loopEnd);
} }
@@ -356,67 +244,56 @@ static bool buildNewAST(SgStatement* loop, vector<SgStatement*>& newBody)
return true; return true;
} }
static bool validateNewOrder(SgStatement* loop, const vector<SgStatement*>& newOrder) vector<SAPFOR::BasicBlock*> findFuncBlocksByFuncStatement(SgStatement *st, map<FuncInfo*, vector<SAPFOR::BasicBlock*>>& FullIR) {
{ vector<SAPFOR::BasicBlock*> result;
if (!loop || newOrder.empty()) { Statement* forSt = (Statement*)st;
return true; for (auto& func: FullIR) {
if (func.first -> funcPointer -> getCurrProcessFile() == forSt -> getCurrProcessFile()
&& func.first -> funcPointer -> lineNumber() == forSt -> lineNumber())
result = func.second;
} }
unordered_set<SgStatement*> seen; return result;
for (SgStatement* stmt : newOrder) {
if (stmt && stmt != loop && stmt != loop->lastNodeOfStmt()) {
if (seen.count(stmt)) {
return false;
}
seen.insert(stmt);
}
}
return true;
} }
void runSwapOperators(SgFile *file, std::map<std::string, std::vector<LoopGraph*>>& loopGraph, std::map<FuncInfo*, std::vector<SAPFOR::BasicBlock*>>& FullIR, int& countOfTransform) map<SgForStmt*, vector<SAPFOR::BasicBlock*>> findAndAnalyzeLoops(SgStatement *st, vector<SAPFOR::BasicBlock*> blocks) {
{ map<SgForStmt*, vector<SAPFOR::BasicBlock*>> result;
std::cout << "SWAP_OPERATORS Pass" << std::endl; // to remove SgStatement *lastNode = st->lastNodeOfStmt();
countOfTransform += 1; // to remove while (st && st != lastNode) {
if (loop_tags.find(st -> variant()) != loop_tags.end()) {
SgForStmt *forSt = (SgForStmt*)st;
SgStatement *loopBody = forSt -> body();
SgStatement *lastLoopNode = st->lastNodeOfStmt();
unordered_set<int> blocks_nums;
while (loopBody && loopBody != lastLoopNode) {
SAPFOR::IR_Block* IR = findInstructionsFromOperator(loopBody, blocks).front();
if (blocks_nums.find(IR -> getBasicBlock() -> getNumber()) == blocks_nums.end()) {
result[forSt].push_back(IR -> getBasicBlock());
blocks_nums.insert(IR -> getBasicBlock() -> getNumber());
}
loopBody = loopBody -> lexNext();
}
std::sort(result[forSt].begin(), result[forSt].end());
}
st = st -> lexNext();
}
return result;
}
void runSwapOperators(SgFile *file, std::map<std::string, std::vector<LoopGraph*>>& loopGraph, std::map<FuncInfo*, std::vector<SAPFOR::BasicBlock*>>& FullIR, int& countOfTransform) {
countOfTransform += 1;
const int funcNum = file -> numberOfFunctions(); const int funcNum = file -> numberOfFunctions();
for (int i = 0; i < funcNum; ++i) for (int i = 0; i < funcNum; ++i) {
{
SgStatement *st = file -> functions(i); SgStatement *st = file -> functions(i);
vector<SAPFOR::BasicBlock*> blocks = findFuncBlocksByFuncStatement(st, FullIR); vector<SAPFOR::BasicBlock*> blocks = findFuncBlocksByFuncStatement(st, FullIR);
map<SgForStmt*, vector<SAPFOR::BasicBlock*>> loopsMapping = findAndAnalyzeLoops(st, blocks); map<SgForStmt*, vector<SAPFOR::BasicBlock*>> loopsMapping = findAndAnalyzeLoops(st, blocks);
for (pair<SgForStmt*, vector<SAPFOR::BasicBlock*>> loopForAnalyze: loopsMapping)
{ for (pair<SgForStmt*, vector<SAPFOR::BasicBlock*>> loopForAnalyze: loopsMapping) {
map<SgStatement*, set<SgStatement*>> dependencyGraph = AnalyzeLoopAndFindDeps(loopForAnalyze.first, loopForAnalyze.second, FullIR); vector<OperatorInfo> operators = analyzeOperatorsInLoop(loopForAnalyze.first, loopForAnalyze.second, FullIR);
// TODO: Write a function that will go through the operators and update all dependencies so that there are no mix-ups and splits inside the semantic blocks (for if, do and may be some other cases) map<string, vector<SgStatement*>> varDefinitions = findVariableDefinitions(loopForAnalyze.first, operators);
buildAdditionalDeps(loopForAnalyze.first, dependencyGraph);
cout << endl; vector<SgStatement*> newOrder = optimizeOperatorOrder(loopForAnalyze.first, operators, varDefinitions);
int firstLine = loopForAnalyze.first -> lineNumber(); applyOperatorReordering(loopForAnalyze.first, newOrder);
int lastLine = loopForAnalyze.first -> lastNodeOfStmt() -> lineNumber();
cout << "LOOP ANALYZE FROM " << firstLine << " TO " << lastLine << " RES" << endl;
// for (auto &v: dependencyGraph) {
// cout << "OPERATOR: " << v.first -> lineNumber() << " " << v.first -> variant() << "\nDEPENDS ON:" << endl;
// if (v.second.size() != 0)
// for (auto vv: v.second)
// cout << vv -> lineNumber() << " ";
// cout << endl;
// }
vector<SgStatement*> new_order = scheduleOperations(dependencyGraph);
cout << "RESULT ORDER:" << endl;
for (auto v: new_order)
if (v -> lineNumber() > firstLine)
cout << v -> lineNumber() << endl;
if (validateNewOrder(loopForAnalyze.first, new_order)) {
buildNewAST(loopForAnalyze.first, new_order);
}
st = loopForAnalyze.first -> lexNext();
while (st != loopForAnalyze.first -> lastNodeOfStmt())
{
cout << st -> lineNumber() << " " << st -> sunparse() << endl;
st = st -> lexNext();
}
} }
} }
}
return;
};