Files
SAPFOR/src/SwapOperators/swapOperators.cpp
2025-10-22 17:00:27 +03:00

280 lines
11 KiB
C++

#include <map>
#include <unordered_set>
#include <vector>
#include <queue>
#include <iostream>
#include "../Utils/errors.h"
#include "../Utils/SgUtils.h"
#include "../GraphCall/graph_calls.h"
#include "../GraphCall/graph_calls_func.h"
#include "../CFGraph/CFGraph.h"
#include "../CFGraph/IR.h"
#include "../GraphLoop/graph_loops.h"
#include "swapOperators.h"
using namespace std;
unordered_set<int> loop_tags = {FOR_NODE/*, FORALL_NODE, WHILE_NODE, DO_WHILE_NODE*/};
vector<SAPFOR::IR_Block*> findInstructionsFromOperator(SgStatement* st, vector<SAPFOR::BasicBlock*> Blocks)
{
vector<SAPFOR::IR_Block*> result;
string filename = st -> fileName();
for (auto& block: Blocks)
{
vector<SAPFOR::IR_Block*> instructionsInBlock = block -> getInstructions();
for (auto& instruction: instructionsInBlock)
{
SgStatement* curOperator = instruction -> getInstruction() -> getOperator();
if (curOperator -> lineNumber() == st -> lineNumber())
result.push_back(instruction);
}
}
return result;
}
vector<SAPFOR::BasicBlock*> findFuncBlocksByFuncStatement(SgStatement *st, std::map<FuncInfo*, std::vector<SAPFOR::BasicBlock*>>& FullIR)
{
vector<SAPFOR::BasicBlock*> result;
Statement* forSt = (Statement*)st;
for (auto& func : FullIR)
{
if (func.first -> funcPointer -> getCurrProcessFile() == forSt -> getCurrProcessFile()
&& func.first -> funcPointer -> lineNumber() == forSt -> lineNumber())
result = func.second;
}
return result;
}
map<SgForStmt*, vector<SAPFOR::BasicBlock*>> findAndAnalyzeLoops(SgStatement *st, vector<SAPFOR::BasicBlock*> blocks)
{
map<SgForStmt*, vector<SAPFOR::BasicBlock*>> result;
SgStatement *lastNode = st->lastNodeOfStmt();
while (st && st != lastNode)
{
if (loop_tags.find(st -> variant()) != loop_tags.end())
{
// part with find statements of loop
SgForStmt *forSt = (SgForStmt*)st;
SgStatement *loopBody = forSt -> body();
SgStatement *lastLoopNode = st->lastNodeOfStmt();
// part with find blocks and instructions of loops
unordered_set<int> blocks_nums;
while (loopBody && loopBody != lastLoopNode)
{
SAPFOR::IR_Block* IR = findInstructionsFromOperator(loopBody, blocks).front();
if (blocks_nums.find(IR -> getBasicBlock() -> getNumber()) == blocks_nums.end())
{
result[forSt].push_back(IR -> getBasicBlock());
blocks_nums.insert(IR -> getBasicBlock() -> getNumber());
}
loopBody = loopBody -> lexNext();
}
std::sort(result[forSt].begin(), result[forSt].end());
}
st = st -> lexNext();
}
return result;
}
map<SgStatement*, set<SgStatement*>> AnalyzeLoopAndFindDeps(SgForStmt* forStatement, vector<SAPFOR::BasicBlock*> loopBlocks, map<FuncInfo*, vector<SAPFOR::BasicBlock*>>& FullIR)
{
map<SgStatement*, set<SgStatement*>> result;
for (SAPFOR::BasicBlock* bb: loopBlocks) {
std::map<SAPFOR::Argument*, std::set<int>> blockReachingDefinitions = bb -> getRD_In();
vector<SAPFOR::IR_Block*> instructions = bb -> getInstructions();
for (SAPFOR::IR_Block* irBlock: instructions) {
// TODO: Think about what to do with function calls and array references. Because there are also dependencies there that are not reflected in RD, but they must be taken into account
SAPFOR::Instruction* instr = irBlock -> getInstruction();
result[instr -> getOperator()];
// take Argument 1 and it's RD and push operators to final set
if (instr -> getArg1() != NULL) {
SAPFOR::Argument* arg = instr -> getArg1();
set<int> prevInstructionsNumbers = blockReachingDefinitions[arg];
for (int i: prevInstructionsNumbers) {
SAPFOR::Instruction* foundInstruction = getInstructionAndBlockByNumber(FullIR, i).first;
if (foundInstruction != NULL) {
SgStatement* prevOp = foundInstruction -> getOperator();
if (prevOp != forStatement && instr -> getOperator() != forStatement && instr -> getOperator() -> lineNumber() > prevOp -> lineNumber())
result[instr -> getOperator()].insert(prevOp);
}
}
}
// take Argument 2 (if exists) and it's RD and push operators to final set
if (instr -> getArg2() != NULL) {
SAPFOR::Argument* arg = instr -> getArg2();
set<int> prevInstructionsNumbers = blockReachingDefinitions[arg];
for (int i: prevInstructionsNumbers) {
SAPFOR::Instruction* foundInstruction = getInstructionAndBlockByNumber(FullIR, i).first;
if (foundInstruction != NULL) {
SgStatement* prevOp = foundInstruction -> getOperator();
if (prevOp != forStatement && instr -> getOperator() != forStatement && instr -> getOperator() -> lineNumber() > prevOp -> lineNumber())
result[instr -> getOperator()].insert(prevOp);
}
}
}
// update RD
if (instr -> getResult() != NULL)
blockReachingDefinitions[instr -> getResult()] = {instr -> getNumber()};
}
}
return result;
}
// int PrintSmthFromLoop(int firstLine, int lastLine, map<SgStatement*, unordered_set<SgStatement*>> moveRules) {
// // only cout done yet ((
// cout << "LOOP ANALYZE FROM " << firstLine << " TO " << lastLine << " RES\n" << endl;
// for (auto r: moveRules) {
// cout << "OPERATOR: " << endl;
// cout << r.first -> lineNumber() << r.first -> sunparse();
// cout << "DEPENDS FROM NEXT: " << endl;
// for (SgStatement* st: r.second)
// cout << st -> lineNumber() << endl;
// }
// cout << "\n\n\n";
// return 0;
// }
void GenNodesOfGraph(
const map<SgStatement*, set<SgStatement*>>& dependencies,
set<SgStatement*>& allNodes,
map<SgStatement*, set<SgStatement*>>& outEdges,
map<SgStatement*, set<SgStatement*>>& inEdges)
{
for (const auto& node: dependencies) {
SgStatement* u = node.first;
allNodes.insert(u);
for (SgStatement* v: node.second) {
allNodes.insert(v);
outEdges[v].insert(u);
inEdges[u].insert(v);
outEdges[u];
inEdges[v];
}
outEdges[u];
inEdges[u];
}
}
vector<set<SgStatement*>> FindLinksInGraph(
const set<SgStatement*>& allNodes,
const map<SgStatement*, set<SgStatement*>>& outEdges,
const map<SgStatement*, set<SgStatement*>>& inEdges)
{
set<SgStatement*> visited;
vector<std::set<SgStatement*>> components;
for (SgStatement* v: allNodes) {
if (visited.count(v)) {
continue;
}
set<SgStatement*> component;
queue<SgStatement*> q;
q.push(v);
visited.insert(v);
while (!q.empty()) {
SgStatement* curr = q.front();
q.pop();
component.insert(curr);
for (SgStatement* neighbour: outEdges.at(curr)) {
if (!visited.count(neighbour)) {
q.push(neighbour); visited.insert(neighbour);
}
}
for (SgStatement* neighbour: inEdges.at(curr)) {
if (!visited.count(neighbour)) {
q.push(neighbour); visited.insert(neighbour);
}
}
}
components.push_back(component);
}
return components;
}
vector<SgStatement*> SortComponent(
const set<SgStatement*>& component,
const map<SgStatement*, set<SgStatement*>>& outEdges,
const map<SgStatement*, set<SgStatement*>>& inEdges)
{
map<SgStatement*, int> inDegree;
for (auto v: component) {
inDegree[v] = inEdges.at(v).size();
}
queue<SgStatement*> q;
for (auto v : component) {
if (inDegree[v] == 0) q.push(v);
}
vector<SgStatement*> result;
while (!q.empty()) {
auto curr = q.front();
q.pop();
result.push_back(curr);
for (SgStatement* neighbour: outEdges.at(curr)) {
if (component.count(neighbour)) {
inDegree[neighbour]--;
if (inDegree[neighbour] == 0) {
q.push(neighbour);
}
}
}
}
return result;
}
vector<SgStatement*> SortNoInterleaving(const map<SgStatement*, set<SgStatement*>>& dependencies)
{
set<SgStatement*> allNodes;
map<SgStatement*, set<SgStatement*>> outEdges, inEdges;
GenNodesOfGraph(dependencies, allNodes, outEdges, inEdges);
auto components = FindLinksInGraph(allNodes, outEdges, inEdges);
vector<SgStatement*> totalOrder;
for (auto& comp : components) {
auto part = SortComponent(comp, outEdges, inEdges);
totalOrder.insert(totalOrder.end(), part.begin(), part.end());
}
return totalOrder;
}
void runSwapOperators(SgFile *file, std::map<std::string, std::vector<LoopGraph*>>& loopGraph, std::map<FuncInfo*, std::vector<SAPFOR::BasicBlock*>>& FullIR, int& countOfTransform)
{
std::cout << "SWAP_OPERATORS Pass" << std::endl; // to remove
countOfTransform += 1; // to remove
const int funcNum = file->numberOfFunctions();
for (int i = 0; i < funcNum; ++i)
{
SgStatement *st = file->functions(i);
vector<SAPFOR::BasicBlock*> blocks = findFuncBlocksByFuncStatement(st, FullIR);
map<SgForStmt*, vector<SAPFOR::BasicBlock*>> loopsMapping = findAndAnalyzeLoops(st, blocks);
for (pair<SgForStmt*, vector<SAPFOR::BasicBlock*>> loopForAnalyze: loopsMapping)
{
map<SgStatement*, set<SgStatement*>> dependencyGraph = AnalyzeLoopAndFindDeps(loopForAnalyze.first, loopForAnalyze.second, FullIR);
// TODO: Write a function that will go through the operators and update all dependencies so that there are no mix-ups and splits inside the semantic blocks (for if, do and may be some other cases)
cout << "\n\n";
for (auto v: dependencyGraph) {
cout << "OPERATOR: " << v.first -> lineNumber() << "\nDEPENDS ON:" << endl;
for (auto vv: v.second) {
cout << vv -> lineNumber() << " ";
}
cout << endl;
}
if (dependencyGraph.size() != 0) {
int firstLine = loopForAnalyze.first -> lineNumber();
int lastLine = loopForAnalyze.first -> lastNodeOfStmt() -> lineNumber();
// countOfTransform += PrintSmthFromLoop(firstLine, lastLine, dependencyGraph);
vector<SgStatement*> new_order = SortNoInterleaving(dependencyGraph);
cout << "\n\nLOOP ANALYZE FROM " << firstLine << " TO " << lastLine << " RES\n" << endl;
for (auto v: new_order)
if (v -> lineNumber() > firstLine)
cout << v -> lineNumber() << " ";
}
}
}
return;
};