add restore pass

This commit is contained in:
2026-04-27 16:18:43 +03:00
parent fe7e3449e8
commit 5f25567a14
13 changed files with 470 additions and 278 deletions

View File

@@ -1,8 +1,8 @@
#include <algorithm>
#include <vector>
#include <map>
#include <unordered_set>
#include <unordered_map>
#include <set>
#include <map>
#include <string>
#include <numeric>
#include <iostream>
@@ -23,9 +23,9 @@ static bool isParentStmt(SgStatement* stmt, SgStatement* parent)
}
/*returns head block and loop*/
static pair<SAPFOR::BasicBlock*, unordered_set<SAPFOR::BasicBlock*>> GetBasicBlocksForLoop(const LoopGraph* loop, const vector<SAPFOR::BasicBlock*> blocks)
pair<SAPFOR::BasicBlock*, set<SAPFOR::BasicBlock*>> GetBasicBlocksForLoop(const LoopGraph* loop, const vector<SAPFOR::BasicBlock*> blocks)
{
unordered_set<SAPFOR::BasicBlock*> block_loop;
set<SAPFOR::BasicBlock*> block_loop;
SAPFOR::BasicBlock* head_block = nullptr;
auto loop_operator = loop->loop->GetOriginal();
for (const auto& block : blocks)
@@ -51,16 +51,16 @@ static pair<SAPFOR::BasicBlock*, unordered_set<SAPFOR::BasicBlock*>> GetBasicBlo
return { head_block, block_loop };
}
static void BuildLoopIndex(map<string, LoopGraph*>& loopForIndex, LoopGraph* loop) {
static void BuildLoopIndex(map<SgStatement*, LoopGraph*>& loopForIndex, LoopGraph* loop) {
string index = loop->loopSymbol();
loopForIndex[index] = loop;
loopForIndex[loop->loop->GetOriginal()] = loop;
for (const auto& childLoop : loop->children)
BuildLoopIndex(loopForIndex, childLoop);
}
static string FindIndexName(int pos, SAPFOR::BasicBlock* block, map<string, LoopGraph*>& loopForIndex) {
unordered_set<SAPFOR::Argument*> args = { block->getInstructions()[pos]->getInstruction()->getArg1() };
set<SAPFOR::Argument*> args = { block->getInstructions()[pos]->getInstruction()->getArg1() };
for (int i = pos - 1; i >= 0; i--)
{
@@ -95,7 +95,7 @@ static string FindIndexName(int pos, SAPFOR::BasicBlock* block, map<string, Loop
static int GetDefUseArray(SAPFOR::BasicBlock* block, LoopGraph* loop, ArrayAccessingIndexes& def, ArrayAccessingIndexes& use, Region* region) {
auto instructions = block->getInstructions();
map<string, LoopGraph*> loopForIndex;
map<SgStatement*, LoopGraph*> loopForIndex;
BuildLoopIndex(loopForIndex, loop);
for (int i = 0; i < instructions.size(); i++)
{
@@ -136,7 +136,6 @@ static int GetDefUseArray(SAPFOR::BasicBlock* block, LoopGraph* loop, ArrayAcces
vector<SAPFOR::Argument*> index_vars;
vector<int> refPos;
string array_name = instruction->getInstruction()->getArg1()->getValue();
int j = i - 1;
while (j >= 0 && instructions[j]->getInstruction()->getOperation() == SAPFOR::CFG_OP::REF)
{
@@ -180,25 +179,16 @@ static int GetDefUseArray(SAPFOR::BasicBlock* block, LoopGraph* loop, ArrayAcces
string name, full_name = var->getValue();
int pos = full_name.find('%');
LoopGraph* currentLoop;
if (pos != -1)
{
name = full_name.substr(pos + 1);
if (loopForIndex.find(name) != loopForIndex.end())
currentLoop = loopForIndex[name];
else
return -1;
}
else
{
name = FindIndexName(currentVarPos, block, loopForIndex);
if (name == "")
return -1;
if (loopForIndex.find(name) != loopForIndex.end())
currentLoop = loopForIndex[name];
else
return -1;
}
auto serachInstr = instruction->getInstruction()->getOperator();
while (serachInstr && serachInstr->variant() != FOR_NODE)
serachInstr = serachInstr->controlParent();
name = full_name.substr(pos + 1);
if (loopForIndex.find(serachInstr) != loopForIndex.end())
currentLoop = loopForIndex[serachInstr];
else
return -1;
uint64_t start = coeffsForDims.back().second * currentLoop->startVal + coeffsForDims.back().first;
uint64_t step = currentLoop->stepVal;
@@ -243,7 +233,7 @@ static int GetDefUseArray(SAPFOR::BasicBlock* block, LoopGraph* loop, ArrayAcces
}
static void RemoveHeaderConnection(SAPFOR::BasicBlock* header, const unordered_set<SAPFOR::BasicBlock*>& blockSet, unordered_map<SAPFOR::BasicBlock*, Region*>& bbToRegion)
static void RemoveHeaderConnection(SAPFOR::BasicBlock* header, const set<SAPFOR::BasicBlock*>& blockSet, map<SAPFOR::BasicBlock*, Region*>& bbToRegion)
{
for (SAPFOR::BasicBlock* block : blockSet)
{
@@ -259,18 +249,35 @@ static void RemoveHeaderConnection(SAPFOR::BasicBlock* header, const unordered_s
}
}
static void DFS(Region* block, vector<Region*>& result, unordered_set<Region*> cycleBlocks)
static bool DFS(Region* block,
vector<Region*>& result,
const set<Region*>& cycleBlocks,
map<Region*, int>& color)
{
auto it = color.find(block);
if (it != color.end())
{
if (it->second == 0)
return false;
if (it->second == 1)
return true;
}
color[block] = 0;
for (Region* nextBlock : block->getNextRegions())
{
if (cycleBlocks.find(nextBlock) != cycleBlocks.end())
DFS(nextBlock, result, cycleBlocks);
if (cycleBlocks.find(nextBlock) == cycleBlocks.end())
continue;
if (!DFS(nextBlock, result, cycleBlocks, color))
return false;
}
color[block] = 1;
result.push_back(block);
return true;
}
bool HasCycle(Region* block, const std::unordered_set<Region*>& cycleBlocks, std::unordered_set<Region*>& visitedBlocks)
bool HasCycle(Region* block, const std::set<Region*>& cycleBlocks, std::set<Region*>& visitedBlocks)
{
return false;
if (visitedBlocks.find(block) != visitedBlocks.end())
return true;
visitedBlocks.insert(block);
@@ -284,18 +291,17 @@ bool HasCycle(Region* block, const std::unordered_set<Region*>& cycleBlocks, std
bool TopologySort(std::vector<Region*>& basikBlocks, Region* header)
{
unordered_set<Region*> cycleBlocks(basikBlocks.begin(), basikBlocks.end());
unordered_set<Region*> visitedBlocks;
if (HasCycle(header, cycleBlocks, visitedBlocks))
return false;
set<Region*> cycleBlocks(basikBlocks.begin(), basikBlocks.end());
vector<Region*> result;
DFS(header, result, cycleBlocks);
map<Region*, int> color;
if (!DFS(header, result, cycleBlocks, color))
return false;
reverse(result.begin(), result.end());
basikBlocks = move(result);
return true;
}
static void SetConnections(unordered_map<SAPFOR::BasicBlock*, Region*>& bbToRegion, const unordered_set<SAPFOR::BasicBlock*>& blockSet)
static void SetConnections(map<SAPFOR::BasicBlock*, Region*>& bbToRegion, const set<SAPFOR::BasicBlock*>& blockSet)
{
for (SAPFOR::BasicBlock* block : blockSet)
{
@@ -309,7 +315,7 @@ static void SetConnections(unordered_map<SAPFOR::BasicBlock*, Region*>& bbToRegi
}
}
static Region* CreateSubRegion(LoopGraph* loop, const vector<SAPFOR::BasicBlock*>& Blocks, unordered_map<SAPFOR::BasicBlock*, Region*>& bbToRegion)
static Region* CreateSubRegion(LoopGraph* loop, const vector<SAPFOR::BasicBlock*>& Blocks, map<SAPFOR::BasicBlock*, Region*>& bbToRegion)
{
Region* region = new Region;
auto [header, blockSet] = GetBasicBlocksForLoop(loop, Blocks);
@@ -340,7 +346,7 @@ static Region* CreateSubRegion(LoopGraph* loop, const vector<SAPFOR::BasicBlock*
Region::Region(LoopGraph* loop, const vector<SAPFOR::BasicBlock*>& Blocks)
{
auto [header, blockSet] = GetBasicBlocksForLoop(loop, Blocks);
unordered_map<SAPFOR::BasicBlock*, Region*> bbToRegion;
map<SAPFOR::BasicBlock*, Region*> bbToRegion;
for (auto poiner : blockSet)
{
bbToRegion[poiner] = new Region(*poiner);