00001 #include "ShOptimizations.hpp"
00002 #include <map>
00003 #include <set>
00004 #include <utility>
00005 #include "ShBitSet.hpp"
00006 #include "ShCtrlGraph.hpp"
00007 #include "ShDebug.hpp"
00008 #include "ShEvaluate.hpp"
00009 #include "ShContext.hpp"
00010 #include "ShSyntax.hpp"
00011 #include <sstream>
00012 #include <fstream>
00013
00014 namespace {
00015 using namespace SH;
00016
00017
00018
00019 struct ReachingDefs {
00020 ReachingDefs()
00021 : defsize(0)
00022 {
00023 }
00024
00025 struct Definition {
00026 Definition(ShStatement* stmt,
00027 const ShCtrlGraphNodePtr& node,
00028 int offset,
00029 ShBitSet disable_mask)
00030 : stmt(stmt), node(node), offset(offset),
00031 disable_mask(disable_mask)
00032 {
00033 }
00034
00035 ShStatement* stmt;
00036 ShCtrlGraphNodePtr node;
00037 int offset;
00038
00039
00040
00041
00042 ShBitSet disable_mask;
00043 };
00044
00045 typedef std::map<ShCtrlGraphNodePtr, int> SizeMap;
00046 typedef std::map<ShCtrlGraphNodePtr, ShBitSet> ReachingMap;
00047
00048 std::vector<Definition> defs;
00049 int defsize;
00050
00051 ReachingMap gen, prsv;
00052 ReachingMap rchin;
00053 };
00054
00055 struct DefFinder {
00056 DefFinder(ReachingDefs& r)
00057 : r(r), offset(0)
00058 {
00059 }
00060
00061 void operator()(ShCtrlGraphNodePtr node)
00062 {
00063 if (!node) return;
00064 ShBasicBlockPtr block = node->block;
00065 if (!block) return;
00066
00067
00068
00069 std::map<ShVariableNodePtr, ShBitSet> disable_map;
00070
00071
00072 ShBasicBlock::ShStmtList::iterator I = block->end();
00073 while (1) {
00074 if (I == block->begin()) break;
00075 --I;
00076 if (I->op != SH_OP_KIL && I->op != SH_OP_OPTBRA && I->dest.node()->kind() == SH_TEMP) {
00077
00078
00079 if (disable_map.find(I->dest.node()) == disable_map.end()) {
00080 disable_map[I->dest.node()] = ShBitSet(I->dest.node()->size());
00081 }
00082
00083
00084
00085 ShBitSet defn_map(I->dest.node()->size());
00086 for (int i = 0; i < I->dest.size(); i++) defn_map[I->dest.swizzle()[i]] = true;
00087
00088
00089
00090 if ((defn_map & disable_map[I->dest.node()]) == defn_map) continue;
00091
00092
00093 r.defs.push_back(ReachingDefs::Definition(&(*I), node, offset,
00094 disable_map[I->dest.node()]));
00095 offset += I->dest.size();
00096 r.defsize += I->dest.size();
00097
00098
00099 disable_map[I->dest.node()] |= defn_map;
00100 }
00101 }
00102 }
00103
00104 ReachingDefs& r;
00105 int offset;
00106 };
00107
00108 struct InitRch {
00109 InitRch(ReachingDefs& r)
00110 : r(r)
00111 {
00112 }
00113
00114 void operator()(ShCtrlGraphNodePtr node)
00115 {
00116 if (!node) return;
00117
00118 r.rchin[node] = ShBitSet(r.defsize);
00119 r.gen[node] = ShBitSet(r.defsize);
00120 r.prsv[node] = ~ShBitSet(r.defsize);
00121
00122 ShBasicBlockPtr block = node->block;
00123 if (!block) return;
00124
00125
00126 for (unsigned int i = 0; i < r.defs.size(); i++) {
00127 if (r.defs[i].node == node) {
00128 for (int j = 0; j < r.defs[i].stmt->dest.size(); j++) {
00129 if (!r.defs[i].disable_mask[j]) {
00130 r.gen[node][r.defs[i].offset + j] = true;
00131 }
00132 }
00133 }
00134 }
00135
00136
00137 for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00138 if (I->op == SH_OP_KIL
00139 || I->op == SH_OP_OPTBRA
00140 || I->dest.node()->kind() != SH_TEMP) continue;
00141 for (unsigned int i = 0; i < r.defs.size(); i++) {
00142 if (r.defs[i].stmt->dest.node() != I->dest.node()) continue;
00143
00144 for (int j = 0; j < I->dest.size(); ++j) {
00145 for (int k = 0; k < r.defs[i].stmt->dest.size(); ++k) {
00146 if (r.defs[i].stmt->dest.swizzle()[k] == I->dest.swizzle()[j]) {
00147 r.prsv[node][r.defs[i].offset + k] = false;
00148 }
00149 }
00150 }
00151 }
00152 }
00153 }
00154
00155 ReachingDefs& r;
00156 };
00157
00158 struct IterateRch {
00159 IterateRch(ReachingDefs& r, bool& changed)
00160 : r(r), changed(changed)
00161 {
00162 }
00163
00164 void operator()(const ShCtrlGraphNodePtr& node)
00165 {
00166 if (!node) return;
00167 SH_DEBUG_ASSERT(r.rchin.find(node) != r.rchin.end());
00168 ShBitSet newRchIn(r.defsize);
00169
00170 for (ShCtrlGraphNode::ShPredList::iterator I = node->predecessors.begin();
00171 I != node->predecessors.end(); ++I) {
00172 SH_DEBUG_ASSERT(r.gen.find(*I) != r.gen.end());
00173 SH_DEBUG_ASSERT(r.prsv.find(*I) != r.prsv.end());
00174 SH_DEBUG_ASSERT(r.rchin.find(*I) != r.rchin.end());
00175
00176 newRchIn |= (r.gen[*I] | (r.rchin[*I] & r.prsv[*I]));
00177 }
00178 if (newRchIn != r.rchin[node]) {
00179 r.rchin[node] = newRchIn;
00180 changed = true;
00181 }
00182 }
00183
00184 ReachingDefs& r;
00185 bool& changed;
00186 };
00187
00188 struct UdDuBuilder {
00189 UdDuBuilder(ReachingDefs& r)
00190 : r(r)
00191 {
00192 }
00193
00194 struct TupleElement {
00195 TupleElement(const ShVariableNodePtr& node, int index)
00196 : node(node), index(index)
00197 {
00198 }
00199
00200 bool operator<(const TupleElement& other) const
00201 {
00202 if (node < other.node) return true;
00203 if (node == other.node) return index < other.index;
00204 return false;
00205 }
00206
00207 ShVariableNodePtr node;
00208 int index;
00209 };
00210
00211 void operator()(ShCtrlGraphNodePtr node) {
00212 typedef std::set<ValueTracking::Def> DefSet;
00213 typedef std::map<TupleElement, DefSet> DefMap;
00214
00215
00216
00217 DefMap defs;
00218
00219 if (!node) return;
00220 ShBasicBlockPtr block = node->block;
00221 if (!block) return;
00222
00223
00224
00225
00226
00227 for (std::size_t i = 0; i < r.defs.size(); i++) {
00228 for (int j = 0; j < r.defs[i].stmt->dest().size(); j++) {
00229 if (r.rchin[node][r.defs[i].offset + j]) {
00230 ValueTracking::Def def(r.defs[i].stmt, j);
00231 defs[TupleElement(r.defs[i].stmt->dest.node(),
00232 r.defs[i].stmt->dest.swizzle()[j])].insert(def);
00233 }
00234 }
00235 }
00236
00237
00238 for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00239
00240
00241 for (int j = 0; j < opInfo[I->op].arity; j++) {
00242 if (I->src[j].node()->kind() == SH_TEMP) {
00243 ValueTracking* vt = I->get_info<ValueTracking>();
00244 if (!vt) {
00245 vt = new ValueTracking(&(*I));
00246 I->add_info(vt);
00247 }
00248 for (int i = 0; i < I->src[j].size(); i++) {
00249 const DefSet& ds = defs[TupleElement(I->src[j].node(), I->src[j].swizzle()[i])];
00250
00251 vt->defs[j][i] = ds;
00252 for (DefSet::const_iterator J = ds.begin(); J != ds.end(); J++) {
00253 ValueTracking* ut = J->stmt->get_info<ValueTracking>();
00254 if (!ut) {
00255 ut = new ValueTracking(J->stmt);
00256 J->stmt->add_info(ut);
00257 }
00258 ut->uses[J->index].insert(ValueTracking::Use(&(*I), j, i));
00259 }
00260 }
00261 }
00262 }
00263
00264
00265 for (int i = 0; i < I->dest.size(); ++i) {
00266 defs[TupleElement(I->dest.node(), I->dest.swizzle()[i])].clear();
00267 ValueTracking::Def def(&(*I), i);
00268 defs[TupleElement(I->dest.node(), I->dest.swizzle()[i])].insert(def);
00269 }
00270 }
00271
00272 }
00273
00274 ReachingDefs& r;
00275 };
00276
00277 struct UdDuClearer {
00278 void operator()(ShCtrlGraphNodePtr node) {
00279 if (!node) return;
00280 ShBasicBlockPtr block = node->block;
00281 if (!block) return;
00282
00283 for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00284 I->destroy_info<ValueTracking>();
00285 }
00286 }
00287 };
00288
00289 struct UdDuDumper {
00290 void operator()(ShCtrlGraphNodePtr node) {
00291 if (!node) return;
00292 ShBasicBlockPtr block = node->block;
00293 if (!block) return;
00294
00295 for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00296 ValueTracking* vt = I->get_info<ValueTracking>();
00297 if (!vt) {
00298 SH_DEBUG_PRINT(*I << " HAS NO VALUE TRACKING");
00299 continue;
00300 }
00301 for (int i = 0; i < opInfo[I->op].arity; i++) {
00302
00303 int e = 0;
00304 for (ValueTracking::TupleUseDefChain::iterator E = vt->defs[i].begin();
00305 E != vt->defs[i].end(); ++E, ++e) {
00306 for (ValueTracking::UseDefChain::iterator J = E->begin(); J != E->end(); ++J) {
00307 SH_DEBUG_PRINT("{" << *I << "}.src" << i << "[" << e << "] comes from {" << *J->stmt << "}.dst[" << J->index << "]");
00308 }
00309 }
00310 }
00311 int e = 0;
00312 for (ValueTracking::TupleDefUseChain::iterator E = vt->uses.begin();
00313 E != vt->uses.end(); ++E, ++e) {
00314 for (ValueTracking::DefUseChain::iterator J = E->begin(); J != E->end(); ++J) {
00315 SH_DEBUG_PRINT("{" << *I << "}.dst[" << e << "] contributes to {" << *J->stmt << "}.src" << J->source << "[" << J->index << "]");
00316 }
00317 }
00318 }
00319 }
00320 };
00321
00322 }
00323
00324 namespace SH {
00325
00326 ValueTracking::ValueTracking(ShStatement* stmt)
00327 : uses(stmt->dest.node() ? stmt->dest.size() : 0)
00328 {
00329 for (int i = 0; i < opInfo[stmt->op].arity; i++) {
00330 for (int j = 0; j < (stmt->src[i].node() ? stmt->src[i].size() : 0); j++) {
00331 defs[i].push_back(std::set<Def>());
00332 }
00333 }
00334 }
00335
00336 ShStatementInfo* ValueTracking::clone() const
00337 {
00338 return new ValueTracking(*this);
00339 }
00340
00341 void add_value_tracking(ShProgram& p)
00342 {
00343 ReachingDefs r;
00344
00345 ShCtrlGraphPtr graph = p.node()->ctrlGraph;
00346
00347 DefFinder finder(r);
00348 graph->dfs(finder);
00349
00350 InitRch init(r);
00351 graph->dfs(init);
00352
00353 bool changed;
00354 IterateRch iter(r, changed);
00355 do {
00356 changed = false;
00357 graph->dfs(iter);
00358 } while (changed);
00359
00360 #ifdef SH_DEBUG_OPTIMIZER
00361 SH_DEBUG_PRINT("Dumping Reaching Defs");
00362 SH_DEBUG_PRINT("defsize = " << r.defsize);
00363 SH_DEBUG_PRINT("defs.size() = " << r.defs.size());
00364 for(unsigned int i = 0; i < r.defs.size(); ++i) {
00365 SH_DEBUG_PRINT(" stmt[" << i << "]: " << *(r.defs[i].stmt));
00366 SH_DEBUG_PRINT(" node[" << i << "]: " << r.defs[i].node.object());
00367 SH_DEBUG_PRINT("offset[" << i << "]: " << r.defs[i].offset);
00368 }
00369 std::cerr << std::endl;
00370
00371 for (ReachingDefs::ReachingMap::const_iterator I = r.rchin.begin(); I != r.rchin.end(); ++I) {
00372 ShCtrlGraphNodePtr node = I->first;
00373 SH_DEBUG_PRINT(" rchin[" << node.object() << "]: " << I->second);
00374 SH_DEBUG_PRINT(" gen[" << node.object() << "]: " << r.gen[I->first]);
00375 SH_DEBUG_PRINT(" prsv[" << node.object() << "]: " << r.prsv[I->first]);
00376 std::cerr << std::endl;
00377 }
00378 #endif
00379
00380 UdDuClearer clearer;
00381 graph->dfs(clearer);
00382
00383 UdDuBuilder builder(r);
00384 graph->dfs(builder);
00385
00386 #ifdef SH_DEBUG_OPTIMIZER
00387 UdDuDumper dumper;
00388 graph->dfs(dumper);
00389 #endif
00390 }
00391
00392
00393
00394 }