add restore pass
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <unordered_set>
|
||||
#include <unordered_map>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <numeric>
|
||||
#include <iostream>
|
||||
@@ -23,9 +23,9 @@ static bool isParentStmt(SgStatement* stmt, SgStatement* parent)
|
||||
}
|
||||
|
||||
/*returns head block and loop*/
|
||||
static pair<SAPFOR::BasicBlock*, unordered_set<SAPFOR::BasicBlock*>> GetBasicBlocksForLoop(const LoopGraph* loop, const vector<SAPFOR::BasicBlock*> blocks)
|
||||
pair<SAPFOR::BasicBlock*, set<SAPFOR::BasicBlock*>> GetBasicBlocksForLoop(const LoopGraph* loop, const vector<SAPFOR::BasicBlock*> blocks)
|
||||
{
|
||||
unordered_set<SAPFOR::BasicBlock*> block_loop;
|
||||
set<SAPFOR::BasicBlock*> block_loop;
|
||||
SAPFOR::BasicBlock* head_block = nullptr;
|
||||
auto loop_operator = loop->loop->GetOriginal();
|
||||
for (const auto& block : blocks)
|
||||
@@ -51,16 +51,16 @@ static pair<SAPFOR::BasicBlock*, unordered_set<SAPFOR::BasicBlock*>> GetBasicBlo
|
||||
return { head_block, block_loop };
|
||||
}
|
||||
|
||||
static void BuildLoopIndex(map<string, LoopGraph*>& loopForIndex, LoopGraph* loop) {
|
||||
static void BuildLoopIndex(map<SgStatement*, LoopGraph*>& loopForIndex, LoopGraph* loop) {
|
||||
string index = loop->loopSymbol();
|
||||
loopForIndex[index] = loop;
|
||||
loopForIndex[loop->loop->GetOriginal()] = loop;
|
||||
|
||||
for (const auto& childLoop : loop->children)
|
||||
BuildLoopIndex(loopForIndex, childLoop);
|
||||
}
|
||||
|
||||
static string FindIndexName(int pos, SAPFOR::BasicBlock* block, map<string, LoopGraph*>& loopForIndex) {
|
||||
unordered_set<SAPFOR::Argument*> args = { block->getInstructions()[pos]->getInstruction()->getArg1() };
|
||||
set<SAPFOR::Argument*> args = { block->getInstructions()[pos]->getInstruction()->getArg1() };
|
||||
|
||||
for (int i = pos - 1; i >= 0; i--)
|
||||
{
|
||||
@@ -95,7 +95,7 @@ static string FindIndexName(int pos, SAPFOR::BasicBlock* block, map<string, Loop
|
||||
|
||||
static int GetDefUseArray(SAPFOR::BasicBlock* block, LoopGraph* loop, ArrayAccessingIndexes& def, ArrayAccessingIndexes& use, Region* region) {
|
||||
auto instructions = block->getInstructions();
|
||||
map<string, LoopGraph*> loopForIndex;
|
||||
map<SgStatement*, LoopGraph*> loopForIndex;
|
||||
BuildLoopIndex(loopForIndex, loop);
|
||||
for (int i = 0; i < instructions.size(); i++)
|
||||
{
|
||||
@@ -136,7 +136,6 @@ static int GetDefUseArray(SAPFOR::BasicBlock* block, LoopGraph* loop, ArrayAcces
|
||||
vector<SAPFOR::Argument*> index_vars;
|
||||
vector<int> refPos;
|
||||
string array_name = instruction->getInstruction()->getArg1()->getValue();
|
||||
|
||||
int j = i - 1;
|
||||
while (j >= 0 && instructions[j]->getInstruction()->getOperation() == SAPFOR::CFG_OP::REF)
|
||||
{
|
||||
@@ -180,25 +179,16 @@ static int GetDefUseArray(SAPFOR::BasicBlock* block, LoopGraph* loop, ArrayAcces
|
||||
string name, full_name = var->getValue();
|
||||
int pos = full_name.find('%');
|
||||
LoopGraph* currentLoop;
|
||||
if (pos != -1)
|
||||
{
|
||||
name = full_name.substr(pos + 1);
|
||||
if (loopForIndex.find(name) != loopForIndex.end())
|
||||
currentLoop = loopForIndex[name];
|
||||
else
|
||||
return -1;
|
||||
}
|
||||
else
|
||||
{
|
||||
name = FindIndexName(currentVarPos, block, loopForIndex);
|
||||
if (name == "")
|
||||
return -1;
|
||||
|
||||
if (loopForIndex.find(name) != loopForIndex.end())
|
||||
currentLoop = loopForIndex[name];
|
||||
else
|
||||
return -1;
|
||||
}
|
||||
auto serachInstr = instruction->getInstruction()->getOperator();
|
||||
while (serachInstr && serachInstr->variant() != FOR_NODE)
|
||||
serachInstr = serachInstr->controlParent();
|
||||
|
||||
name = full_name.substr(pos + 1);
|
||||
if (loopForIndex.find(serachInstr) != loopForIndex.end())
|
||||
currentLoop = loopForIndex[serachInstr];
|
||||
else
|
||||
return -1;
|
||||
|
||||
uint64_t start = coeffsForDims.back().second * currentLoop->startVal + coeffsForDims.back().first;
|
||||
uint64_t step = currentLoop->stepVal;
|
||||
@@ -243,7 +233,7 @@ static int GetDefUseArray(SAPFOR::BasicBlock* block, LoopGraph* loop, ArrayAcces
|
||||
|
||||
}
|
||||
|
||||
static void RemoveHeaderConnection(SAPFOR::BasicBlock* header, const unordered_set<SAPFOR::BasicBlock*>& blockSet, unordered_map<SAPFOR::BasicBlock*, Region*>& bbToRegion)
|
||||
static void RemoveHeaderConnection(SAPFOR::BasicBlock* header, const set<SAPFOR::BasicBlock*>& blockSet, map<SAPFOR::BasicBlock*, Region*>& bbToRegion)
|
||||
{
|
||||
for (SAPFOR::BasicBlock* block : blockSet)
|
||||
{
|
||||
@@ -259,18 +249,35 @@ static void RemoveHeaderConnection(SAPFOR::BasicBlock* header, const unordered_s
|
||||
}
|
||||
}
|
||||
|
||||
static void DFS(Region* block, vector<Region*>& result, unordered_set<Region*> cycleBlocks)
|
||||
static bool DFS(Region* block,
|
||||
vector<Region*>& result,
|
||||
const set<Region*>& cycleBlocks,
|
||||
map<Region*, int>& color)
|
||||
{
|
||||
auto it = color.find(block);
|
||||
if (it != color.end())
|
||||
{
|
||||
if (it->second == 0)
|
||||
return false;
|
||||
if (it->second == 1)
|
||||
return true;
|
||||
}
|
||||
color[block] = 0;
|
||||
for (Region* nextBlock : block->getNextRegions())
|
||||
{
|
||||
if (cycleBlocks.find(nextBlock) != cycleBlocks.end())
|
||||
DFS(nextBlock, result, cycleBlocks);
|
||||
if (cycleBlocks.find(nextBlock) == cycleBlocks.end())
|
||||
continue;
|
||||
if (!DFS(nextBlock, result, cycleBlocks, color))
|
||||
return false;
|
||||
}
|
||||
color[block] = 1;
|
||||
result.push_back(block);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HasCycle(Region* block, const std::unordered_set<Region*>& cycleBlocks, std::unordered_set<Region*>& visitedBlocks)
|
||||
bool HasCycle(Region* block, const std::set<Region*>& cycleBlocks, std::set<Region*>& visitedBlocks)
|
||||
{
|
||||
return false;
|
||||
if (visitedBlocks.find(block) != visitedBlocks.end())
|
||||
return true;
|
||||
visitedBlocks.insert(block);
|
||||
@@ -284,18 +291,17 @@ bool HasCycle(Region* block, const std::unordered_set<Region*>& cycleBlocks, std
|
||||
|
||||
bool TopologySort(std::vector<Region*>& basikBlocks, Region* header)
|
||||
{
|
||||
unordered_set<Region*> cycleBlocks(basikBlocks.begin(), basikBlocks.end());
|
||||
unordered_set<Region*> visitedBlocks;
|
||||
if (HasCycle(header, cycleBlocks, visitedBlocks))
|
||||
return false;
|
||||
set<Region*> cycleBlocks(basikBlocks.begin(), basikBlocks.end());
|
||||
vector<Region*> result;
|
||||
DFS(header, result, cycleBlocks);
|
||||
map<Region*, int> color;
|
||||
if (!DFS(header, result, cycleBlocks, color))
|
||||
return false;
|
||||
reverse(result.begin(), result.end());
|
||||
basikBlocks = move(result);
|
||||
return true;
|
||||
}
|
||||
|
||||
static void SetConnections(unordered_map<SAPFOR::BasicBlock*, Region*>& bbToRegion, const unordered_set<SAPFOR::BasicBlock*>& blockSet)
|
||||
static void SetConnections(map<SAPFOR::BasicBlock*, Region*>& bbToRegion, const set<SAPFOR::BasicBlock*>& blockSet)
|
||||
{
|
||||
for (SAPFOR::BasicBlock* block : blockSet)
|
||||
{
|
||||
@@ -309,7 +315,7 @@ static void SetConnections(unordered_map<SAPFOR::BasicBlock*, Region*>& bbToRegi
|
||||
}
|
||||
}
|
||||
|
||||
static Region* CreateSubRegion(LoopGraph* loop, const vector<SAPFOR::BasicBlock*>& Blocks, unordered_map<SAPFOR::BasicBlock*, Region*>& bbToRegion)
|
||||
static Region* CreateSubRegion(LoopGraph* loop, const vector<SAPFOR::BasicBlock*>& Blocks, map<SAPFOR::BasicBlock*, Region*>& bbToRegion)
|
||||
{
|
||||
Region* region = new Region;
|
||||
auto [header, blockSet] = GetBasicBlocksForLoop(loop, Blocks);
|
||||
@@ -340,7 +346,7 @@ static Region* CreateSubRegion(LoopGraph* loop, const vector<SAPFOR::BasicBlock*
|
||||
Region::Region(LoopGraph* loop, const vector<SAPFOR::BasicBlock*>& Blocks)
|
||||
{
|
||||
auto [header, blockSet] = GetBasicBlocksForLoop(loop, Blocks);
|
||||
unordered_map<SAPFOR::BasicBlock*, Region*> bbToRegion;
|
||||
map<SAPFOR::BasicBlock*, Region*> bbToRegion;
|
||||
for (auto poiner : blockSet)
|
||||
{
|
||||
bbToRegion[poiner] = new Region(*poiner);
|
||||
|
||||
Reference in New Issue
Block a user