diff --git a/src/Transformations/SwapOperators/swap_operators.cpp b/src/Transformations/SwapOperators/swap_operators.cpp index 47565fa..3f5d28a 100644 --- a/src/Transformations/SwapOperators/swap_operators.cpp +++ b/src/Transformations/SwapOperators/swap_operators.cpp @@ -17,7 +17,6 @@ using namespace std; string getNameByArg(SAPFOR::Argument* arg); -SgSymbol* getSybolByArg(SAPFOR::Argument* arg); static vector findInstructionsFromOperator(SgStatement* st, vector Blocks) { vector result; @@ -36,9 +35,9 @@ static vector findInstructionsFromOperator(SgStatement* st, v return result; } -unordered_set loop_tags = {FOR_NODE}; -unordered_set control_tags = {IF_NODE, ELSEIF_NODE, DO_WHILE_NODE, WHILE_NODE}; -unordered_set control_end_tags = {CONTROL_END}; +unordered_set loop_tags = {FOR_NODE}; // Loop statements +unordered_set control_tags = {IF_NODE, ELSEIF_NODE, DO_WHILE_NODE, WHILE_NODE, LOGIF_NODE}; // Control structures that cannot be moved +unordered_set control_end_tags = {CONTROL_END}; // End marker struct OperatorInfo { SgStatement* stmt; @@ -46,21 +45,170 @@ struct OperatorInfo { set definedVars; int lineNumber; bool isMovable; + OperatorInfo(SgStatement* s) : stmt(s), lineNumber(s->lineNumber()), isMovable(true) {} }; +static bool isStatementEmbedded(SgStatement* stmt, SgStatement* parent) { + if (!stmt || !parent || stmt == parent) return false; + + if (parent->variant() == LOGIF_NODE) { + if (stmt->lineNumber() == parent->lineNumber()) { + return true; + } + + SgStatement* current = parent; + SgStatement* lastNode = parent->lastNodeOfStmt(); + + while (current && current != lastNode) { + if (current == stmt) { + return true; + } + if (current->isIncludedInStmt(*stmt)) { + return true; + } + current = current->lexNext(); + } + } + + if (parent->isIncludedInStmt(*stmt)) { + return true; + } + + return false; +} + +static bool isLoopBoundary(SgStatement* stmt) { + if (!stmt) return false; + + if (stmt->variant() == FOR_NODE || stmt->variant() == CONTROL_END) { + return true; + } + + return false; +} + +static bool isPartOfNestedLoop(SgStatement* stmt, SgForStmt* loop) { + if (!stmt || !loop) return false; + + SgStatement* loopStart = loop->lexNext(); + SgStatement* loopEnd = loop->lastNodeOfStmt(); + if (!loopStart || !loopEnd) return false; + + if (stmt->lineNumber() < loopStart->lineNumber() || stmt->lineNumber() > loopEnd->lineNumber()) { + return false; + } + + SgStatement* current = loopStart; + + while (current && current != loopEnd) { + + if (current->variant() == FOR_NODE && current != loop) { + SgForStmt* nestedLoop = (SgForStmt*)current; + SgStatement* nestedStart = nestedLoop->lexNext(); + SgStatement* nestedEnd = nestedLoop->lastNodeOfStmt(); + + if (nestedStart && nestedEnd && + stmt->lineNumber() >= nestedStart->lineNumber() && + stmt->lineNumber() <= nestedEnd->lineNumber()) { + return true; + } + } + + current = current->lexNext(); + } + + return false; +} + +static bool canSafelyExtract(SgStatement* stmt, SgForStmt* loop) { + if (!stmt || !loop) return false; + + if (isLoopBoundary(stmt)) { + return false; + } + + if (control_tags.find(stmt->variant()) != control_tags.end()) { + return false; + } + + if (isPartOfNestedLoop(stmt, loop)) { + return false; + } + + SgStatement* loopStart = loop->lexNext(); + SgStatement* loopEnd = loop->lastNodeOfStmt(); + if (!loopStart || !loopEnd) return false; + + SgStatement* current = loopStart; + + while (current && current != loopEnd) { + if (current->variant() == LOGIF_NODE && current->lineNumber() == stmt->lineNumber()) { + return false; + } + + if (control_tags.find(current->variant()) != control_tags.end()) { + if (isStatementEmbedded(stmt, current)) { + return false; + } + } + if (current == stmt) break; + current = current->lexNext(); + } + + return true; +} + static vector analyzeOperatorsInLoop(SgForStmt* loop, vector blocks, map>& FullIR) { vector operators; SgStatement* loopStart = loop->lexNext(); SgStatement* loopEnd = loop->lastNodeOfStmt(); + if (!loopStart || !loopEnd) { + return operators; + } + SgStatement* current = loopStart; + unordered_set visited; + while (current && current != loopEnd) { + + if (visited.find(current) != visited.end()) { + break; + } + visited.insert(current); + + if (isLoopBoundary(current)) { + current = current->lexNext(); + continue; + } + + if (current->variant() == FOR_NODE && current != loop) { + SgStatement* nestedEnd = current->lastNodeOfStmt(); + if (nestedEnd) { + current = nestedEnd->lexNext(); + } else { + current = current->lexNext(); + } + continue; + } + if (isSgExecutableStatement(current)) { + if (control_tags.find(current->variant()) != control_tags.end()) { + current = current->lexNext(); + continue; + } + if (current->variant() != ASSIGN_STAT) { + current = current->lexNext(); + continue; + } + OperatorInfo opInfo(current); vector irBlocks = findInstructionsFromOperator(current, blocks); for (auto irBlock : irBlocks) { + if (!irBlock || !irBlock->getInstruction()) continue; + SAPFOR::Instruction* instr = irBlock->getInstruction(); if (instr->getArg1()) { @@ -83,10 +231,6 @@ static vector analyzeOperatorsInLoop(SgForStmt* loop, vectorvariant()) != control_tags.end()) { - opInfo.isMovable = false; - } - operators.push_back(opInfo); } current = current->lexNext(); @@ -97,13 +241,11 @@ static vector analyzeOperatorsInLoop(SgForStmt* loop, vector> findVariableDefinitions(SgForStmt* loop, vector& operators) { map> varDefinitions; - for (auto& op : operators) { for (const string& var : op.definedVars) { varDefinitions[var].push_back(op.stmt); } } - return varDefinitions; } @@ -121,7 +263,9 @@ static SgStatement* findBestPosition(SgStatement* operatorStmt, vectorisMovable) return nullptr; + if (!opInfo || !opInfo->isMovable) { + return nullptr; + } SgStatement* bestPos = nullptr; int minDistance = INT_MAX; @@ -147,12 +291,22 @@ static bool canMoveTo(SgStatement* from, SgStatement* to, SgForStmt* loop) { SgStatement* loopStart = loop->lexNext(); SgStatement* loopEnd = loop->lastNodeOfStmt(); + if (!loopStart || !loopEnd) return false; + if (to->lineNumber() < loopStart->lineNumber() || to->lineNumber() > loopEnd->lineNumber()) { return false; } SgStatement* current = from; + unordered_set visited; + while (current && current != loopEnd) { + + if (visited.find(current) != visited.end()) { + return false; + } + visited.insert(current); + if (control_tags.find(current->variant()) != control_tags.end()) { return false; } @@ -203,29 +357,98 @@ static bool applyOperatorReordering(SgForStmt* loop, vector& newOr SgStatement* loopStart = loop->lexNext(); SgStatement* loopEnd = loop->lastNodeOfStmt(); + if (!loopStart || !loopEnd) return false; + + vector originalOrder; + SgStatement* current = loopStart; + while (current && current != loopEnd) { + if (isSgExecutableStatement(current) && current->variant() == ASSIGN_STAT) { + originalOrder.push_back(current); + } + current = current->lexNext(); + } + + bool orderChanged = false; + if (originalOrder.size() == newOrder.size()) { + for (size_t i = 0; i < originalOrder.size(); i++) { + if (originalOrder[i] != newOrder[i]) { + orderChanged = true; + break; + } + } + } else { + orderChanged = true; + } + + if (!orderChanged) { + return false; + } + vector extractedStatements; vector savedComments; + unordered_set extractedSet; + map originalLineNumbers; for (SgStatement* stmt : newOrder) { - if (stmt && stmt != loop && stmt != loopEnd) { + if (stmt && stmt != loop && stmt != loopEnd && extractedSet.find(stmt) == extractedSet.end()) { + if (control_tags.find(stmt->variant()) != control_tags.end()) { + continue; + } + + if (!canSafelyExtract(stmt, loop)) { + continue; + } + + bool isMoving = false; + for (size_t i = 0; i < originalOrder.size(); i++) { + if (originalOrder[i] == stmt) { + for (size_t j = 0; j < newOrder.size(); j++) { + if (newOrder[j] == stmt && i != j) { + isMoving = true; + break; + } + } + break; + } + } + + if (!isMoving) { + continue; + } + + originalLineNumbers[stmt] = stmt->lineNumber(); savedComments.push_back(stmt->comments() ? strdup(stmt->comments()) : nullptr); SgStatement* extracted = stmt->extractStmt(); if (extracted) { extractedStatements.push_back(extracted); + extractedSet.insert(stmt); } } } SgStatement* currentPos = loop; - int lineCounter = loop->lineNumber() + 1; for (size_t i = 0; i < extractedStatements.size(); i++) { SgStatement* stmt = extractedStatements[i]; if (stmt) { + SgStatement* nextPos = currentPos->lexNext(); + if (nextPos && nextPos != loopEnd) { + if (nextPos->variant() == FOR_NODE && nextPos != loop) { + continue; + } + if (nextPos->variant() == CONTROL_END) { + continue; + } + } + if (i < savedComments.size() && savedComments[i]) { stmt->setComments(savedComments[i]); } - stmt->setlineNumber(lineCounter++); + + if (originalLineNumbers.find(stmt) != originalLineNumbers.end()) { + stmt->setlineNumber(originalLineNumbers[stmt]); + } + currentPos->insertStmtAfter(*stmt, *loop); currentPos = stmt; } @@ -258,17 +481,24 @@ vector findFuncBlocksByFuncStatement(SgStatement *st, map> findAndAnalyzeLoops(SgStatement *st, vector blocks) { map> result; SgStatement *lastNode = st->lastNodeOfStmt(); + while (st && st != lastNode) { if (loop_tags.find(st -> variant()) != loop_tags.end()) { SgForStmt *forSt = (SgForStmt*)st; SgStatement *loopBody = forSt -> body(); SgStatement *lastLoopNode = st->lastNodeOfStmt(); unordered_set blocks_nums; + while (loopBody && loopBody != lastLoopNode) { - SAPFOR::IR_Block* IR = findInstructionsFromOperator(loopBody, blocks).front(); - if (blocks_nums.find(IR -> getBasicBlock() -> getNumber()) == blocks_nums.end()) { - result[forSt].push_back(IR -> getBasicBlock()); - blocks_nums.insert(IR -> getBasicBlock() -> getNumber()); + vector irBlocks = findInstructionsFromOperator(loopBody, blocks); + if (!irBlocks.empty()) { + SAPFOR::IR_Block* IR = irBlocks.front(); + if (IR && IR->getBasicBlock()) { + if (blocks_nums.find(IR -> getBasicBlock() -> getNumber()) == blocks_nums.end()) { + result[forSt].push_back(IR -> getBasicBlock()); + blocks_nums.insert(IR -> getBasicBlock() -> getNumber()); + } + } } loopBody = loopBody -> lexNext(); } @@ -281,19 +511,25 @@ map> findAndAnalyzeLoops(SgStatement *st void runSwapOperators(SgFile *file, std::map>& loopGraph, std::map>& FullIR, int& countOfTransform) { countOfTransform += 1; + + std::cout << "SWAP_OPERATORS Pass Started" << std::endl; const int funcNum = file -> numberOfFunctions(); + for (int i = 0; i < funcNum; ++i) { SgStatement *st = file -> functions(i); + vector blocks = findFuncBlocksByFuncStatement(st, FullIR); + map> loopsMapping = findAndAnalyzeLoops(st, blocks); for (pair> loopForAnalyze: loopsMapping) { vector operators = analyzeOperatorsInLoop(loopForAnalyze.first, loopForAnalyze.second, FullIR); map> varDefinitions = findVariableDefinitions(loopForAnalyze.first, operators); - vector newOrder = optimizeOperatorOrder(loopForAnalyze.first, operators, varDefinitions); applyOperatorReordering(loopForAnalyze.first, newOrder); } } + + std::cout << "SWAP_OPERATORS Pass Completed" << std::endl; } \ No newline at end of file