Adding handing of nested loops and conditional statements
This commit is contained in:
@@ -17,7 +17,6 @@
|
||||
using namespace std;
|
||||
|
||||
string getNameByArg(SAPFOR::Argument* arg);
|
||||
SgSymbol* getSybolByArg(SAPFOR::Argument* arg);
|
||||
|
||||
static vector<SAPFOR::IR_Block*> findInstructionsFromOperator(SgStatement* st, vector<SAPFOR::BasicBlock*> Blocks) {
|
||||
vector<SAPFOR::IR_Block*> result;
|
||||
@@ -36,9 +35,9 @@ static vector<SAPFOR::IR_Block*> findInstructionsFromOperator(SgStatement* st, v
|
||||
return result;
|
||||
}
|
||||
|
||||
unordered_set<int> loop_tags = {FOR_NODE};
|
||||
unordered_set<int> control_tags = {IF_NODE, ELSEIF_NODE, DO_WHILE_NODE, WHILE_NODE};
|
||||
unordered_set<int> control_end_tags = {CONTROL_END};
|
||||
unordered_set<int> loop_tags = {FOR_NODE}; // Loop statements
|
||||
unordered_set<int> control_tags = {IF_NODE, ELSEIF_NODE, DO_WHILE_NODE, WHILE_NODE, LOGIF_NODE}; // Control structures that cannot be moved
|
||||
unordered_set<int> control_end_tags = {CONTROL_END}; // End marker
|
||||
|
||||
struct OperatorInfo {
|
||||
SgStatement* stmt;
|
||||
@@ -46,21 +45,170 @@ struct OperatorInfo {
|
||||
set<string> definedVars;
|
||||
int lineNumber;
|
||||
bool isMovable;
|
||||
|
||||
OperatorInfo(SgStatement* s) : stmt(s), lineNumber(s->lineNumber()), isMovable(true) {}
|
||||
};
|
||||
|
||||
static bool isStatementEmbedded(SgStatement* stmt, SgStatement* parent) {
|
||||
if (!stmt || !parent || stmt == parent) return false;
|
||||
|
||||
if (parent->variant() == LOGIF_NODE) {
|
||||
if (stmt->lineNumber() == parent->lineNumber()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
SgStatement* current = parent;
|
||||
SgStatement* lastNode = parent->lastNodeOfStmt();
|
||||
|
||||
while (current && current != lastNode) {
|
||||
if (current == stmt) {
|
||||
return true;
|
||||
}
|
||||
if (current->isIncludedInStmt(*stmt)) {
|
||||
return true;
|
||||
}
|
||||
current = current->lexNext();
|
||||
}
|
||||
}
|
||||
|
||||
if (parent->isIncludedInStmt(*stmt)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool isLoopBoundary(SgStatement* stmt) {
|
||||
if (!stmt) return false;
|
||||
|
||||
if (stmt->variant() == FOR_NODE || stmt->variant() == CONTROL_END) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool isPartOfNestedLoop(SgStatement* stmt, SgForStmt* loop) {
|
||||
if (!stmt || !loop) return false;
|
||||
|
||||
SgStatement* loopStart = loop->lexNext();
|
||||
SgStatement* loopEnd = loop->lastNodeOfStmt();
|
||||
if (!loopStart || !loopEnd) return false;
|
||||
|
||||
if (stmt->lineNumber() < loopStart->lineNumber() || stmt->lineNumber() > loopEnd->lineNumber()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
SgStatement* current = loopStart;
|
||||
|
||||
while (current && current != loopEnd) {
|
||||
|
||||
if (current->variant() == FOR_NODE && current != loop) {
|
||||
SgForStmt* nestedLoop = (SgForStmt*)current;
|
||||
SgStatement* nestedStart = nestedLoop->lexNext();
|
||||
SgStatement* nestedEnd = nestedLoop->lastNodeOfStmt();
|
||||
|
||||
if (nestedStart && nestedEnd &&
|
||||
stmt->lineNumber() >= nestedStart->lineNumber() &&
|
||||
stmt->lineNumber() <= nestedEnd->lineNumber()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
current = current->lexNext();
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool canSafelyExtract(SgStatement* stmt, SgForStmt* loop) {
|
||||
if (!stmt || !loop) return false;
|
||||
|
||||
if (isLoopBoundary(stmt)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (control_tags.find(stmt->variant()) != control_tags.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (isPartOfNestedLoop(stmt, loop)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
SgStatement* loopStart = loop->lexNext();
|
||||
SgStatement* loopEnd = loop->lastNodeOfStmt();
|
||||
if (!loopStart || !loopEnd) return false;
|
||||
|
||||
SgStatement* current = loopStart;
|
||||
|
||||
while (current && current != loopEnd) {
|
||||
if (current->variant() == LOGIF_NODE && current->lineNumber() == stmt->lineNumber()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (control_tags.find(current->variant()) != control_tags.end()) {
|
||||
if (isStatementEmbedded(stmt, current)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (current == stmt) break;
|
||||
current = current->lexNext();
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static vector<OperatorInfo> analyzeOperatorsInLoop(SgForStmt* loop, vector<SAPFOR::BasicBlock*> blocks, map<FuncInfo*, vector<SAPFOR::BasicBlock*>>& FullIR) {
|
||||
vector<OperatorInfo> operators;
|
||||
SgStatement* loopStart = loop->lexNext();
|
||||
SgStatement* loopEnd = loop->lastNodeOfStmt();
|
||||
|
||||
if (!loopStart || !loopEnd) {
|
||||
return operators;
|
||||
}
|
||||
|
||||
SgStatement* current = loopStart;
|
||||
unordered_set<SgStatement*> visited;
|
||||
|
||||
while (current && current != loopEnd) {
|
||||
|
||||
if (visited.find(current) != visited.end()) {
|
||||
break;
|
||||
}
|
||||
visited.insert(current);
|
||||
|
||||
if (isLoopBoundary(current)) {
|
||||
current = current->lexNext();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (current->variant() == FOR_NODE && current != loop) {
|
||||
SgStatement* nestedEnd = current->lastNodeOfStmt();
|
||||
if (nestedEnd) {
|
||||
current = nestedEnd->lexNext();
|
||||
} else {
|
||||
current = current->lexNext();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isSgExecutableStatement(current)) {
|
||||
if (control_tags.find(current->variant()) != control_tags.end()) {
|
||||
current = current->lexNext();
|
||||
continue;
|
||||
}
|
||||
if (current->variant() != ASSIGN_STAT) {
|
||||
current = current->lexNext();
|
||||
continue;
|
||||
}
|
||||
|
||||
OperatorInfo opInfo(current);
|
||||
|
||||
vector<SAPFOR::IR_Block*> irBlocks = findInstructionsFromOperator(current, blocks);
|
||||
for (auto irBlock : irBlocks) {
|
||||
if (!irBlock || !irBlock->getInstruction()) continue;
|
||||
|
||||
SAPFOR::Instruction* instr = irBlock->getInstruction();
|
||||
|
||||
if (instr->getArg1()) {
|
||||
@@ -83,10 +231,6 @@ static vector<OperatorInfo> analyzeOperatorsInLoop(SgForStmt* loop, vector<SAPFO
|
||||
}
|
||||
}
|
||||
|
||||
if (control_tags.find(current->variant()) != control_tags.end()) {
|
||||
opInfo.isMovable = false;
|
||||
}
|
||||
|
||||
operators.push_back(opInfo);
|
||||
}
|
||||
current = current->lexNext();
|
||||
@@ -97,13 +241,11 @@ static vector<OperatorInfo> analyzeOperatorsInLoop(SgForStmt* loop, vector<SAPFO
|
||||
|
||||
static map<string, vector<SgStatement*>> findVariableDefinitions(SgForStmt* loop, vector<OperatorInfo>& operators) {
|
||||
map<string, vector<SgStatement*>> varDefinitions;
|
||||
|
||||
for (auto& op : operators) {
|
||||
for (const string& var : op.definedVars) {
|
||||
varDefinitions[var].push_back(op.stmt);
|
||||
}
|
||||
}
|
||||
|
||||
return varDefinitions;
|
||||
}
|
||||
|
||||
@@ -121,7 +263,9 @@ static SgStatement* findBestPosition(SgStatement* operatorStmt, vector<OperatorI
|
||||
}
|
||||
}
|
||||
|
||||
if (!opInfo || !opInfo->isMovable) return nullptr;
|
||||
if (!opInfo || !opInfo->isMovable) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
SgStatement* bestPos = nullptr;
|
||||
int minDistance = INT_MAX;
|
||||
@@ -147,12 +291,22 @@ static bool canMoveTo(SgStatement* from, SgStatement* to, SgForStmt* loop) {
|
||||
SgStatement* loopStart = loop->lexNext();
|
||||
SgStatement* loopEnd = loop->lastNodeOfStmt();
|
||||
|
||||
if (!loopStart || !loopEnd) return false;
|
||||
|
||||
if (to->lineNumber() < loopStart->lineNumber() || to->lineNumber() > loopEnd->lineNumber()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
SgStatement* current = from;
|
||||
unordered_set<SgStatement*> visited;
|
||||
|
||||
while (current && current != loopEnd) {
|
||||
|
||||
if (visited.find(current) != visited.end()) {
|
||||
return false;
|
||||
}
|
||||
visited.insert(current);
|
||||
|
||||
if (control_tags.find(current->variant()) != control_tags.end()) {
|
||||
return false;
|
||||
}
|
||||
@@ -203,29 +357,98 @@ static bool applyOperatorReordering(SgForStmt* loop, vector<SgStatement*>& newOr
|
||||
SgStatement* loopStart = loop->lexNext();
|
||||
SgStatement* loopEnd = loop->lastNodeOfStmt();
|
||||
|
||||
if (!loopStart || !loopEnd) return false;
|
||||
|
||||
vector<SgStatement*> originalOrder;
|
||||
SgStatement* current = loopStart;
|
||||
while (current && current != loopEnd) {
|
||||
if (isSgExecutableStatement(current) && current->variant() == ASSIGN_STAT) {
|
||||
originalOrder.push_back(current);
|
||||
}
|
||||
current = current->lexNext();
|
||||
}
|
||||
|
||||
bool orderChanged = false;
|
||||
if (originalOrder.size() == newOrder.size()) {
|
||||
for (size_t i = 0; i < originalOrder.size(); i++) {
|
||||
if (originalOrder[i] != newOrder[i]) {
|
||||
orderChanged = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
orderChanged = true;
|
||||
}
|
||||
|
||||
if (!orderChanged) {
|
||||
return false;
|
||||
}
|
||||
|
||||
vector<SgStatement*> extractedStatements;
|
||||
vector<char*> savedComments;
|
||||
unordered_set<SgStatement*> extractedSet;
|
||||
map<SgStatement*, int> originalLineNumbers;
|
||||
|
||||
for (SgStatement* stmt : newOrder) {
|
||||
if (stmt && stmt != loop && stmt != loopEnd) {
|
||||
if (stmt && stmt != loop && stmt != loopEnd && extractedSet.find(stmt) == extractedSet.end()) {
|
||||
if (control_tags.find(stmt->variant()) != control_tags.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!canSafelyExtract(stmt, loop)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool isMoving = false;
|
||||
for (size_t i = 0; i < originalOrder.size(); i++) {
|
||||
if (originalOrder[i] == stmt) {
|
||||
for (size_t j = 0; j < newOrder.size(); j++) {
|
||||
if (newOrder[j] == stmt && i != j) {
|
||||
isMoving = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!isMoving) {
|
||||
continue;
|
||||
}
|
||||
|
||||
originalLineNumbers[stmt] = stmt->lineNumber();
|
||||
savedComments.push_back(stmt->comments() ? strdup(stmt->comments()) : nullptr);
|
||||
SgStatement* extracted = stmt->extractStmt();
|
||||
if (extracted) {
|
||||
extractedStatements.push_back(extracted);
|
||||
extractedSet.insert(stmt);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SgStatement* currentPos = loop;
|
||||
int lineCounter = loop->lineNumber() + 1;
|
||||
|
||||
for (size_t i = 0; i < extractedStatements.size(); i++) {
|
||||
SgStatement* stmt = extractedStatements[i];
|
||||
if (stmt) {
|
||||
SgStatement* nextPos = currentPos->lexNext();
|
||||
if (nextPos && nextPos != loopEnd) {
|
||||
if (nextPos->variant() == FOR_NODE && nextPos != loop) {
|
||||
continue;
|
||||
}
|
||||
if (nextPos->variant() == CONTROL_END) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (i < savedComments.size() && savedComments[i]) {
|
||||
stmt->setComments(savedComments[i]);
|
||||
}
|
||||
stmt->setlineNumber(lineCounter++);
|
||||
|
||||
if (originalLineNumbers.find(stmt) != originalLineNumbers.end()) {
|
||||
stmt->setlineNumber(originalLineNumbers[stmt]);
|
||||
}
|
||||
|
||||
currentPos->insertStmtAfter(*stmt, *loop);
|
||||
currentPos = stmt;
|
||||
}
|
||||
@@ -258,17 +481,24 @@ vector<SAPFOR::BasicBlock*> findFuncBlocksByFuncStatement(SgStatement *st, map<F
|
||||
map<SgForStmt*, vector<SAPFOR::BasicBlock*>> findAndAnalyzeLoops(SgStatement *st, vector<SAPFOR::BasicBlock*> blocks) {
|
||||
map<SgForStmt*, vector<SAPFOR::BasicBlock*>> result;
|
||||
SgStatement *lastNode = st->lastNodeOfStmt();
|
||||
|
||||
while (st && st != lastNode) {
|
||||
if (loop_tags.find(st -> variant()) != loop_tags.end()) {
|
||||
SgForStmt *forSt = (SgForStmt*)st;
|
||||
SgStatement *loopBody = forSt -> body();
|
||||
SgStatement *lastLoopNode = st->lastNodeOfStmt();
|
||||
unordered_set<int> blocks_nums;
|
||||
|
||||
while (loopBody && loopBody != lastLoopNode) {
|
||||
SAPFOR::IR_Block* IR = findInstructionsFromOperator(loopBody, blocks).front();
|
||||
if (blocks_nums.find(IR -> getBasicBlock() -> getNumber()) == blocks_nums.end()) {
|
||||
result[forSt].push_back(IR -> getBasicBlock());
|
||||
blocks_nums.insert(IR -> getBasicBlock() -> getNumber());
|
||||
vector<SAPFOR::IR_Block*> irBlocks = findInstructionsFromOperator(loopBody, blocks);
|
||||
if (!irBlocks.empty()) {
|
||||
SAPFOR::IR_Block* IR = irBlocks.front();
|
||||
if (IR && IR->getBasicBlock()) {
|
||||
if (blocks_nums.find(IR -> getBasicBlock() -> getNumber()) == blocks_nums.end()) {
|
||||
result[forSt].push_back(IR -> getBasicBlock());
|
||||
blocks_nums.insert(IR -> getBasicBlock() -> getNumber());
|
||||
}
|
||||
}
|
||||
}
|
||||
loopBody = loopBody -> lexNext();
|
||||
}
|
||||
@@ -281,19 +511,25 @@ map<SgForStmt*, vector<SAPFOR::BasicBlock*>> findAndAnalyzeLoops(SgStatement *st
|
||||
|
||||
void runSwapOperators(SgFile *file, std::map<std::string, std::vector<LoopGraph*>>& loopGraph, std::map<FuncInfo*, std::vector<SAPFOR::BasicBlock*>>& FullIR, int& countOfTransform) {
|
||||
countOfTransform += 1;
|
||||
|
||||
std::cout << "SWAP_OPERATORS Pass Started" << std::endl;
|
||||
|
||||
const int funcNum = file -> numberOfFunctions();
|
||||
|
||||
for (int i = 0; i < funcNum; ++i) {
|
||||
SgStatement *st = file -> functions(i);
|
||||
|
||||
vector<SAPFOR::BasicBlock*> blocks = findFuncBlocksByFuncStatement(st, FullIR);
|
||||
|
||||
map<SgForStmt*, vector<SAPFOR::BasicBlock*>> loopsMapping = findAndAnalyzeLoops(st, blocks);
|
||||
|
||||
for (pair<SgForStmt*, vector<SAPFOR::BasicBlock*>> loopForAnalyze: loopsMapping) {
|
||||
vector<OperatorInfo> operators = analyzeOperatorsInLoop(loopForAnalyze.first, loopForAnalyze.second, FullIR);
|
||||
map<string, vector<SgStatement*>> varDefinitions = findVariableDefinitions(loopForAnalyze.first, operators);
|
||||
|
||||
vector<SgStatement*> newOrder = optimizeOperatorOrder(loopForAnalyze.first, operators, varDefinitions);
|
||||
applyOperatorReordering(loopForAnalyze.first, newOrder);
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "SWAP_OPERATORS Pass Completed" << std::endl;
|
||||
}
|
||||
Reference in New Issue
Block a user