#include #include #include #include #include #include #include "../../Utils/errors.h" #include "../../Utils/SgUtils.h" #include "../../GraphCall/graph_calls.h" #include "../../GraphCall/graph_calls_func.h" #include "../../CFGraph/CFGraph.h" #include "../../CFGraph/IR.h" #include "../../GraphLoop/graph_loops.h" #include "swap_operators.h" using namespace std; string getNameByArg(SAPFOR::Argument* arg); static vector findInstructionsFromOperator(SgStatement* st, vector Blocks) { vector result; string filename = st->fileName(); for (auto& block: Blocks) { vector instructionsInBlock = block->getInstructions(); for (auto& instruction: instructionsInBlock) { SgStatement* curOperator = instruction->getInstruction()->getOperator(); if (curOperator->lineNumber() == st->lineNumber()) { result.push_back(instruction); } } } return result; } unordered_set loop_tags = {FOR_NODE}; unordered_set control_tags = {IF_NODE, ELSEIF_NODE, DO_WHILE_NODE, WHILE_NODE, LOGIF_NODE}; unordered_set control_end_tags = {CONTROL_END}; struct OperatorInfo { SgStatement* stmt; set usedVars; set definedVars; int lineNumber; bool isMovable; OperatorInfo(SgStatement* s) : stmt(s), lineNumber(s->lineNumber()), isMovable(true) {} }; static bool isStatementEmbedded(SgStatement* stmt, SgStatement* parent) { if (!stmt || !parent || stmt == parent) return false; if (parent->variant() == LOGIF_NODE) { if (stmt->lineNumber() == parent->lineNumber()) { return true; } SgStatement* current = parent; SgStatement* lastNode = parent->lastNodeOfStmt(); while (current && current != lastNode) { if (current == stmt) { return true; } if (current->isIncludedInStmt(*stmt)) { return true; } current = current->lexNext(); } } if (parent->isIncludedInStmt(*stmt)) { return true; } return false; } static bool isLoopBoundary(SgStatement* stmt) { if (!stmt) return false; if (stmt->variant() == FOR_NODE || stmt->variant() == CONTROL_END) { return true; } return false; } static bool isPartOfNestedLoop(SgStatement* stmt, SgForStmt* loop) { if (!stmt || !loop) return false; SgStatement* loopStart = loop->lexNext(); SgStatement* loopEnd = loop->lastNodeOfStmt(); if (!loopStart || !loopEnd) return false; if (stmt->lineNumber() < loopStart->lineNumber() || stmt->lineNumber() > loopEnd->lineNumber()) { return false; } SgStatement* current = loopStart; while (current && current != loopEnd) { if (current->variant() == FOR_NODE && current != loop) { SgForStmt* nestedLoop = (SgForStmt*)current; SgStatement* nestedStart = nestedLoop->lexNext(); SgStatement* nestedEnd = nestedLoop->lastNodeOfStmt(); if (nestedStart && nestedEnd && stmt->lineNumber() >= nestedStart->lineNumber() && stmt->lineNumber() <= nestedEnd->lineNumber()) { return true; } } current = current->lexNext(); } return false; } static bool canSafelyExtract(SgStatement* stmt, SgForStmt* loop) { if (!stmt || !loop) return false; if (isLoopBoundary(stmt)) { return false; } if (control_tags.find(stmt->variant()) != control_tags.end()) { return false; } if (isPartOfNestedLoop(stmt, loop)) { return false; } SgStatement* loopStart = loop->lexNext(); SgStatement* loopEnd = loop->lastNodeOfStmt(); if (!loopStart || !loopEnd) return false; SgStatement* current = loopStart; while (current && current != loopEnd) { if (current->variant() == LOGIF_NODE && current->lineNumber() == stmt->lineNumber()) { return false; } if (control_tags.find(current->variant()) != control_tags.end()) { if (isStatementEmbedded(stmt, current)) { return false; } } if (current == stmt) break; current = current->lexNext(); } return true; } static vector analyzeOperatorsInLoop(SgForStmt* loop, vector blocks, map>& FullIR) { vector operators; SgStatement* loopStart = loop->lexNext(); SgStatement* loopEnd = loop->lastNodeOfStmt(); if (!loopStart || !loopEnd) { return operators; } SgStatement* current = loopStart; unordered_set visited; while (current && current != loopEnd) { if (visited.find(current) != visited.end()) { break; } visited.insert(current); if (isLoopBoundary(current)) { current = current->lexNext(); continue; } if (current->variant() == FOR_NODE && current != loop) { SgStatement* nestedEnd = current->lastNodeOfStmt(); if (nestedEnd) { current = nestedEnd->lexNext(); } else { current = current->lexNext(); } continue; } if (isSgExecutableStatement(current)) { if (control_tags.find(current->variant()) != control_tags.end()) { current = current->lexNext(); continue; } if (current->variant() != ASSIGN_STAT) { current = current->lexNext(); continue; } OperatorInfo opInfo(current); vector irBlocks = findInstructionsFromOperator(current, blocks); for (auto irBlock : irBlocks) { if (!irBlock || !irBlock->getInstruction()) continue; SAPFOR::Instruction* instr = irBlock->getInstruction(); if (instr->getArg1()) { string varName = getNameByArg(instr->getArg1()); if (!varName.empty()) { opInfo.usedVars.insert(varName); } } if (instr->getArg2()) { string varName = getNameByArg(instr->getArg2()); if (!varName.empty()) { opInfo.usedVars.insert(varName); } } if (instr->getResult()) { string varName = getNameByArg(instr->getResult()); if (!varName.empty()) { opInfo.definedVars.insert(varName); } } } operators.push_back(opInfo); } current = current->lexNext(); } return operators; } static map> findVariableDefinitions(SgForStmt* loop, vector& operators) { map> varDefinitions; for (auto& op : operators) { for (const string& var : op.definedVars) { varDefinitions[var].push_back(op.stmt); } } return varDefinitions; } static int calculateDistance(SgStatement* from, SgStatement* to) { if (!from || !to) return INT_MAX; return abs(to->lineNumber() - from->lineNumber()); } static SgStatement* findBestPosition(SgStatement* operatorStmt, vector& operators, map>& varDefinitions, SgForStmt* loop) { OperatorInfo* opInfo = nullptr; for (auto& op : operators) { if (op.stmt == operatorStmt) { opInfo = &op; break; } } if (!opInfo || !opInfo->isMovable) { return nullptr; } SgStatement* bestPos = nullptr; int bestLine = -1; for (const string& usedVar : opInfo->usedVars) { if (varDefinitions.find(usedVar) != varDefinitions.end()) { for (SgStatement* defStmt : varDefinitions[usedVar]) { if (defStmt->lineNumber() < operatorStmt->lineNumber()) { if (defStmt->controlParent() == operatorStmt->controlParent()) { if (defStmt->lineNumber() > bestLine) { bestLine = defStmt->lineNumber(); bestPos = defStmt; } } } } } } if (!bestPos) { bool allLoopCarried = true; bool hasAnyDefinition = false; for (const string& usedVar : opInfo->usedVars) { if (varDefinitions.find(usedVar) != varDefinitions.end()) { for (SgStatement* defStmt : varDefinitions[usedVar]) { if (defStmt == operatorStmt) { continue; } hasAnyDefinition = true; if (defStmt->lineNumber() < operatorStmt->lineNumber() && defStmt->controlParent() == operatorStmt->controlParent()) { allLoopCarried = false; break; } } } if (!allLoopCarried) break; } if (allLoopCarried || (!hasAnyDefinition && !opInfo->usedVars.empty())) { SgStatement* loopStart = loop->lexNext(); return loopStart; } } return bestPos; } static bool canMoveTo(SgStatement* from, SgStatement* to, SgForStmt* loop) { if (!from || !to || from == to) return false; SgStatement* loopStart = loop->lexNext(); SgStatement* loopEnd = loop->lastNodeOfStmt(); if (!loopStart || !loopEnd) return false; if (to == loopStart) { SgStatement* fromControlParent = from->controlParent(); if (!fromControlParent) fromControlParent = loop; return fromControlParent == loop || fromControlParent == loopStart->controlParent(); } if (from->lineNumber() < loopStart->lineNumber() || from->lineNumber() > loopEnd->lineNumber()) { return false; } if (to->lineNumber() < loopStart->lineNumber() || to->lineNumber() > loopEnd->lineNumber()) { return false; } if (to->lineNumber() >= from->lineNumber()) { return false; } if (from->controlParent() != to->controlParent()) { return false; } SgStatement* current = to->lexNext(); while (current && current != from && current != loopEnd) { if (control_tags.find(current->variant()) != control_tags.end()) { SgStatement* controlEnd = current->lastNodeOfStmt(); if (controlEnd && from->lineNumber() <= controlEnd->lineNumber()) { if (from->controlParent() == current && to->controlParent() != current) { return false; } } } current = current->lexNext(); } return true; } static vector optimizeOperatorOrder(SgForStmt* loop, vector& operators, map>& varDefinitions) { vector newOrder; for (auto& op : operators) { newOrder.push_back(op.stmt); } map stmtToOpInfo; for (auto& op : operators) { stmtToOpInfo[op.stmt] = &op; } bool changed = true; int iterations = 0; const int MAX_ITERATIONS = 50; while (changed && iterations < MAX_ITERATIONS) { changed = false; iterations++; for (int i = operators.size() - 1; i >= 0; i--) { if (!operators[i].isMovable) continue; SgStatement* stmt = operators[i].stmt; OperatorInfo* opInfo = stmtToOpInfo[stmt]; if (!opInfo) continue; size_t currentPos = 0; for (size_t j = 0; j < newOrder.size(); j++) { if (newOrder[j] == stmt) { currentPos = j; break; } } SgStatement* bestPos = findBestPosition(stmt, operators, varDefinitions, loop); if (!bestPos) { bool hasDependents = false; for (size_t j = currentPos + 1; j < newOrder.size(); j++) { SgStatement* candidate = newOrder[j]; OperatorInfo* candidateOpInfo = stmtToOpInfo[candidate]; if (candidateOpInfo) { for (const string& definedVar : opInfo->definedVars) { if (candidateOpInfo->usedVars.find(definedVar) != candidateOpInfo->usedVars.end()) { hasDependents = true; break; } } if (hasDependents) break; } } continue; } size_t targetPos = 0; bool foundTarget = false; if (bestPos == loop->lexNext()) { targetPos = 0; for (size_t j = 0; j < currentPos && j < newOrder.size(); j++) { SgStatement* candidate = newOrder[j]; OperatorInfo* candidateOpInfo = stmtToOpInfo[candidate]; if (candidateOpInfo) { bool usesDefinedVar = false; for (const string& definedVar : opInfo->definedVars) { if (candidateOpInfo->usedVars.find(definedVar) != candidateOpInfo->usedVars.end()) { usesDefinedVar = true; break; } } if (usesDefinedVar) { targetPos = j; break; } } } foundTarget = true; if (currentPos != targetPos && canMoveTo(stmt, bestPos, loop)) { newOrder.erase(newOrder.begin() + currentPos); newOrder.insert(newOrder.begin() + targetPos, stmt); changed = true; } } else { size_t bestPosIdx = 0; bool foundBestPos = false; for (size_t j = 0; j < newOrder.size(); j++) { if (newOrder[j] == bestPos) { bestPosIdx = j; foundBestPos = true; break; } } if (foundBestPos) { targetPos = bestPosIdx + 1; for (size_t j = bestPosIdx + 1; j < currentPos && j < newOrder.size(); j++) { SgStatement* candidate = newOrder[j]; OperatorInfo* candidateOpInfo = stmtToOpInfo[candidate]; if (candidateOpInfo) { bool definesUsedVar = false; for (const string& usedVar : opInfo->usedVars) { if (candidateOpInfo->definedVars.find(usedVar) != candidateOpInfo->definedVars.end()) { definesUsedVar = true; break; } } if (definesUsedVar) { targetPos = j + 1; } } } bool wouldBreakDependency = false; for (size_t j = targetPos; j < currentPos && j < newOrder.size(); j++) { SgStatement* candidate = newOrder[j]; OperatorInfo* candidateOpInfo = stmtToOpInfo[candidate]; if (candidateOpInfo) { for (const string& definedVar : opInfo->definedVars) { if (candidateOpInfo->usedVars.find(definedVar) != candidateOpInfo->usedVars.end()) { wouldBreakDependency = true; break; } } if (wouldBreakDependency) break; } } if (!wouldBreakDependency && currentPos > targetPos && canMoveTo(stmt, bestPos, loop)) { newOrder.erase(newOrder.begin() + currentPos); newOrder.insert(newOrder.begin() + targetPos, stmt); changed = true; } } } } } bool dependencyViolation = true; set> triedPairs; while (dependencyViolation) { dependencyViolation = false; triedPairs.clear(); for (size_t i = 0; i < newOrder.size(); i++) { SgStatement* stmt = newOrder[i]; OperatorInfo* opInfo = stmtToOpInfo[stmt]; if (!opInfo) continue; for (size_t j = 0; j < i; j++) { SgStatement* prevStmt = newOrder[j]; OperatorInfo* prevOpInfo = stmtToOpInfo[prevStmt]; if (!prevOpInfo) continue; pair key = make_pair(stmt, prevStmt); if (triedPairs.find(key) != triedPairs.end()) { continue; } bool violation = false; for (const string& definedVar : opInfo->definedVars) { if (prevOpInfo->usedVars.find(definedVar) != prevOpInfo->usedVars.end()) { violation = true; break; } } if (violation) { triedPairs.insert(key); bool wouldCreateViolation = false; for (size_t k = j; k < i; k++) { SgStatement* betweenStmt = newOrder[k]; OperatorInfo* betweenOpInfo = stmtToOpInfo[betweenStmt]; if (!betweenOpInfo) continue; for (const string& usedVar : opInfo->usedVars) { if (betweenOpInfo->definedVars.find(usedVar) != betweenOpInfo->definedVars.end()) { wouldCreateViolation = true; break; } } if (wouldCreateViolation) break; } if (!wouldCreateViolation) { newOrder.erase(newOrder.begin() + i); newOrder.insert(newOrder.begin() + j, stmt); dependencyViolation = true; break; } } } if (dependencyViolation) break; } } return newOrder; } static bool applyOperatorReordering(SgForStmt* loop, vector& newOrder) { if (!loop || newOrder.empty()) return false; SgStatement* loopStart = loop->lexNext(); SgStatement* loopEnd = loop->lastNodeOfStmt(); if (!loopStart || !loopEnd) return false; vector originalOrder; SgStatement* current = loopStart; while (current && current != loopEnd) { if (isSgExecutableStatement(current) && current->variant() == ASSIGN_STAT) { originalOrder.push_back(current); } current = current->lexNext(); } bool orderChanged = false; if (originalOrder.size() == newOrder.size()) { for (size_t i = 0; i < originalOrder.size(); i++) { if (originalOrder[i] != newOrder[i]) { orderChanged = true; break; } } } else { orderChanged = true; } if (!orderChanged) { return false; } vector extractedStatements; vector savedComments; unordered_set extractedSet; map originalLineNumbers; map stmtToExtracted; for (SgStatement* stmt : newOrder) { if (stmt && stmt != loop && stmt != loopEnd && extractedSet.find(stmt) == extractedSet.end()) { if (control_tags.find(stmt->variant()) != control_tags.end()) { continue; } if (!canSafelyExtract(stmt, loop)) { continue; } bool isMoving = false; for (size_t i = 0; i < originalOrder.size(); i++) { if (originalOrder[i] == stmt) { for (size_t j = 0; j < newOrder.size(); j++) { if (newOrder[j] == stmt && i != j) { isMoving = true; break; } } break; } } if (!isMoving) { continue; } originalLineNumbers[stmt] = stmt->lineNumber(); savedComments.push_back(stmt->comments() ? strdup(stmt->comments()) : nullptr); SgStatement* extracted = stmt->extractStmt(); if (extracted) { extractedStatements.push_back(extracted); extractedSet.insert(stmt); stmtToExtracted[stmt] = extracted; } } } map insertedStatements; for (size_t idx = 0; idx < newOrder.size(); idx++) { SgStatement* stmt = newOrder[idx]; if (extractedSet.find(stmt) != extractedSet.end()) { SgStatement* stmtToInsert = stmtToExtracted[stmt]; if (!stmtToInsert) continue; SgStatement* insertAfter = loop; for (int i = idx - 1; i >= 0; i--) { SgStatement* prevStmt = newOrder[i]; if (extractedSet.find(prevStmt) != extractedSet.end()) { if (insertedStatements.find(prevStmt) != insertedStatements.end()) { insertAfter = insertedStatements[prevStmt]; break; } } else { SgStatement* search = loop->lexNext(); while (search && search != loopEnd) { bool skip = false; for (size_t j = idx; j < newOrder.size(); j++) { if (extractedSet.find(newOrder[j]) != extractedSet.end() && search == newOrder[j]) { skip = true; break; } } if (skip) { search = search->lexNext(); continue; } if (search == prevStmt) { insertAfter = search; break; } search = search->lexNext(); } if (insertAfter != loop) break; } } size_t commentIdx = 0; for (size_t i = 0; i < extractedStatements.size(); i++) { if (extractedStatements[i] == stmtToInsert) { commentIdx = i; break; } } if (commentIdx < savedComments.size() && savedComments[commentIdx]) { stmtToInsert->setComments(savedComments[commentIdx]); } if (originalLineNumbers.find(stmt) != originalLineNumbers.end()) { stmtToInsert->setlineNumber(originalLineNumbers[stmt]); } SgStatement* controlParent = stmt->controlParent(); if (!controlParent) controlParent = loop; insertAfter->insertStmtAfter(*stmtToInsert, *controlParent); insertedStatements[stmt] = stmtToInsert; } } for (char* comment : savedComments) { if (comment) { free(comment); } } return true; } vector findFuncBlocksByFuncStatement(SgStatement *st, map>& FullIR) { vector result; Statement* forSt = (Statement*)st; for (auto& func: FullIR) { if (func.first -> funcPointer -> getCurrProcessFile() == forSt -> getCurrProcessFile() && func.first -> funcPointer -> lineNumber() == forSt -> lineNumber()) result = func.second; } return result; } map> findAndAnalyzeLoops(SgStatement *st, vector blocks) { map> result; SgStatement *lastNode = st->lastNodeOfStmt(); while (st && st != lastNode) { if (loop_tags.find(st -> variant()) != loop_tags.end()) { SgForStmt *forSt = (SgForStmt*)st; SgStatement *loopBody = forSt -> body(); SgStatement *lastLoopNode = st->lastNodeOfStmt(); unordered_set blocks_nums; while (loopBody && loopBody != lastLoopNode) { vector irBlocks = findInstructionsFromOperator(loopBody, blocks); if (!irBlocks.empty()) { SAPFOR::IR_Block* IR = irBlocks.front(); if (IR && IR->getBasicBlock()) { if (blocks_nums.find(IR -> getBasicBlock() -> getNumber()) == blocks_nums.end()) { result[forSt].push_back(IR -> getBasicBlock()); blocks_nums.insert(IR -> getBasicBlock() -> getNumber()); } } } loopBody = loopBody -> lexNext(); } std::sort(result[forSt].begin(), result[forSt].end()); } st = st -> lexNext(); } return result; } static void processLoopRecursively(SgForStmt* loop, vector blocks, map>& FullIR) { if (!loop) return; SgStatement* loopStart = loop->lexNext(); SgStatement* loopEnd = loop->lastNodeOfStmt(); if (loopStart && loopEnd) { SgStatement* current = loopStart; while (current && current != loopEnd) { if (current->variant() == FOR_NODE && current != loop) { SgForStmt* nestedLoop = (SgForStmt*)current; processLoopRecursively(nestedLoop, blocks, FullIR); SgStatement* nestedEnd = nestedLoop->lastNodeOfStmt(); if (nestedEnd) { current = nestedEnd->lexNext(); } else { current = current->lexNext(); } } else { current = current->lexNext(); } } } vector operators = analyzeOperatorsInLoop(loop, blocks, FullIR); if (!operators.empty()) { map> varDefinitions = findVariableDefinitions(loop, operators); vector newOrder = optimizeOperatorOrder(loop, operators, varDefinitions); applyOperatorReordering(loop, newOrder); } } void runSwapOperators(SgFile *file, std::map>& loopGraph, std::map>& FullIR, int& countOfTransform) { countOfTransform += 1; std::cout << "SWAP_OPERATORS Pass Started" << std::endl; const int funcNum = file -> numberOfFunctions(); for (int i = 0; i < funcNum; ++i) { SgStatement *st = file -> functions(i); vector blocks = findFuncBlocksByFuncStatement(st, FullIR); map> loopsMapping = findAndAnalyzeLoops(st, blocks); for (pair> loopForAnalyze: loopsMapping) { processLoopRecursively(loopForAnalyze.first, loopForAnalyze.second, FullIR); } } std::cout << "SWAP_OPERATORS Pass Completed" << std::endl; }