Main Page | Modules | Namespace List | Class Hierarchy | Alphabetical List | Class List | Directories | File List | Namespace Members | Class Members | File Members | Related Pages

ShConstProp.cpp

00001 // Sh: A GPU metaprogramming language.
00002 //
00003 // Copyright 2003-2005 Serious Hack Inc.
00004 // 
00005 // This software is provided 'as-is', without any express or implied
00006 // warranty. In no event will the authors be held liable for any damages
00007 // arising from the use of this software.
00008 // 
00009 // Permission is granted to anyone to use this software for any purpose,
00010 // including commercial applications, and to alter it and redistribute it
00011 // freely, subject to the following restrictions:
00012 // 
00013 // 1. The origin of this software must not be misrepresented; you must
00014 // not claim that you wrote the original software. If you use this
00015 // software in a product, an acknowledgment in the product documentation
00016 // would be appreciated but is not required.
00017 // 
00018 // 2. Altered source versions must be plainly marked as such, and must
00019 // not be misrepresented as being the original software.
00020 // 
00021 // 3. This notice may not be removed or altered from any source
00022 // distribution.
00024 #include "ShOptimizations.hpp"
00025 #include <map>
00026 #include <set>
00027 #include <utility>
00028 #include <iostream>
00029 #include "ShBitSet.hpp"
00030 #include "ShCtrlGraph.hpp"
00031 #include "ShDebug.hpp"
00032 #include "ShVariant.hpp"
00033 #include "ShEvaluate.hpp"
00034 #include "ShContext.hpp"
00035 #include "ShSyntax.hpp"
00036 #include "ShInfo.hpp"
00037 #include <sstream>
00038 #include <fstream>
00039 
00040 // Uncomment to enable constant/uniform propagation debugging (verbose!)
00041 // #define SH_DEBUG_CONSTPROP
00042 
00043 #ifdef SH_DEBUG_OPTIMIZER
00044 #ifndef SH_DEBUG_CONSTPROP
00045 #define SH_DEBUG_CONSTPROP
00046 #endif
00047 #endif
00048 
00049 namespace {
00050 
00051 using namespace SH;
00052 
00053 typedef std::queue<ValueTracking::Def> ConstWorkList;
00054 
00055 struct ConstProp : public ShInfo {
00056   ConstProp(ShStatement* stmt,
00057             ShProgramNodeCPtr prog,
00058             ConstWorkList& worklist)
00059     : stmt(stmt)
00060   {
00061     for (int i = 0; i < opInfo[stmt->op].arity ; i++) {
00062       for (int j = 0; j < stmt->src[i].size(); j++) {
00063         switch(stmt->src[i].node()->kind()) {
00064         case SH_INPUT:
00065         case SH_OUTPUT:
00066         case SH_INOUT:
00067         case SH_TEXTURE:
00068         case SH_STREAM:
00069         case SH_PALETTE:
00070           src[i].push_back(Cell(Cell::BOTTOM));
00071           break;
00072         case SH_TEMP:
00073           if (stmt->src[i].uniform()) {
00074             // Don't lift computations dependent on uniforms which
00075             // have been marked with "opt:lifting" == "never"
00076             if (stmt->src[i].meta("opt:lifting") != "never") {
00077               src[i].push_back(Cell(Cell::UNIFORM, stmt->src[i], j));
00078             } else {
00079               src[i].push_back(Cell(Cell::BOTTOM));
00080             }
00081           } else {
00082             src[i].push_back(Cell(Cell::TOP));
00083           }
00084           break;
00085         case SH_CONST:
00086           src[i].push_back(Cell(Cell::CONSTANT, stmt->src[i].getVariant(j)));
00087           break;
00088         default:
00089           SH_DEBUG_ASSERT(0 && "Invalid ShBindingType");
00090           return;
00091         }
00092       }
00093     }
00094     updateDest(worklist);
00095   }
00096 
00097   ShInfo* clone() const;
00098 
00099   int idx(int destindex, int source)
00100   {
00101     return (stmt->src[source].size() == 1 ? 0 : destindex);
00102   }
00103 
00104   void updateDest(ConstWorkList& worklist)
00105   {
00106     if(dest.empty()) {
00107       dest.resize(stmt->dest.size(), Cell(Cell::TOP)); 
00108     }
00109 
00110     // Ignore KIL, optbra, etc.
00111     if (opInfo[stmt->op].result_source == ShOperationInfo::IGNORE) return;
00112 
00113     if (stmt->op == SH_OP_ASN) {
00114       for (int i = 0; i < stmt->dest.size(); i++) {
00115         setDest(i, src[0][i], worklist); // assume src[0][i] cannot move up lattice
00116       }
00117     } else if (opInfo[stmt->op].result_source == ShOperationInfo::EXTERNAL) {
00118       // This statement never results in a constant
00119       // E.g. texture fetches, stream fetches.
00120       for (int i = 0; i < stmt->dest.size(); i++) {
00121         setDest(i, Cell(Cell::BOTTOM), worklist);
00122       }
00123     } else if (opInfo[stmt->op].result_source == ShOperationInfo::LINEAR) {
00124       // The strategy here is to ensure that 
00125       // a) whenever one src becomes bottom, dest becomes bottom
00126       // b) uniform only gets set when ALL src are uniform (because value
00127       // tracking requires it)
00128       // c) otherwise, propagate CONST state per element
00129       // @todo range (CONST may move across to UNIFORM, check that this is okay)
00130 
00131       // Consider each tuple element in turn.
00132       // Dest and sources are guaranteed to be of the same length.
00133       // Except that sources might be scalar.
00134       bool all_fields_uniform = true;
00135       bool some_field_bottom = false;
00136       for (int i = 0; i < stmt->dest.size(); i++) {
00137         bool alluniform = true;
00138         bool allconst = true;
00139         bool somebottom = false;
00140         for (int s = 0; s < opInfo[stmt->op].arity; s++) {
00141           if (src[s][idx(i,s)].state == Cell::BOTTOM) {
00142             somebottom = true;
00143           }
00144           if (src[s][idx(i,s)].state != Cell::CONSTANT) {
00145             allconst = false;
00146             if (src[s][idx(i,s)].state != Cell::UNIFORM) {
00147               alluniform = false;
00148             }
00149           }
00150         }
00151         some_field_bottom |= somebottom;
00152         if (!(alluniform && !allconst)) all_fields_uniform = false;
00153         if (allconst) {
00154           ShVariable tmpdest(new ShVariableNode(SH_CONST, 1, stmt->dest.valueType()));
00155           ShStatement eval(*stmt);
00156           eval.dest = tmpdest;
00157           for (int k = 0; k < opInfo[stmt->op].arity; k++) {
00158             ShVariantCPtr srcValue = src[k][idx(i,k)].value;
00159             ShVariable tmpsrc(new ShVariableNode(SH_CONST, 1, srcValue->valueType()));
00160             tmpsrc.setVariant(srcValue, 0);
00161             eval.src[k] = tmpsrc;
00162           }
00163           evaluate(eval);
00164           setDest(i, Cell(Cell::CONSTANT, tmpdest.getVariant(0)), worklist);
00165         } else if (somebottom) {
00166           setDest(i, Cell(Cell::BOTTOM), worklist);
00167         }
00168       } 
00169 
00170 
00171       // Because making a uniform cell based on a ConstProp requires
00172       // generating a value for the entire statement (not just one
00173       // field of the destination), we only push said cells if ALL of
00174       // the indices are uniform for ALL of their corresponding sources.
00175       if (all_fields_uniform) {
00176         for (int i = 0; i < stmt->dest.size(); i++) {
00177           setDest(i, Cell(Cell::UNIFORM, this, i), worklist);
00178         }
00179       }
00180     } else if (opInfo[stmt->op].result_source == ShOperationInfo::ALL) {
00181       // build statement ONLY if ALL elements of ALL sources are constant
00182       bool allconst = true;
00183       bool alluniform = true; // all statements are either uniform or constant
00184       bool somebottom = false;
00185       for (int s = 0; s < opInfo[stmt->op].arity && !somebottom; s++) {
00186         for (unsigned int k = 0; k < src[s].size(); k++) {
00187           if(src[s][k].state == Cell::BOTTOM) {
00188             somebottom = true;
00189           }
00190           if (src[s][k].state != Cell::CONSTANT) {
00191             allconst = false;
00192             if (src[s][k].state != Cell::UNIFORM) {
00193               alluniform = false;
00194             }
00195           }
00196         }
00197       }
00198       if (allconst) { 
00199         ShVariable tmpdest(new ShVariableNode(SH_CONST, stmt->dest.size(), stmt->dest.valueType()));
00200         ShStatement eval(*stmt);
00201         eval.dest = tmpdest;
00202         for (int i = 0; i < opInfo[stmt->op].arity; i++) {
00203           SH_DEBUG_ASSERT(src[i][0].value); // @todo type DEBUGGING
00204           ShValueType srcValueType = src[i][0].value->valueType(); 
00205           ShVariable tmpsrc(new ShVariableNode(SH_CONST, stmt->src[i].size(), srcValueType));
00206           for (int j = 0; j < stmt->src[i].size(); j++) {
00207             tmpsrc.setVariant(src[i][j].value, j);
00208           }
00209           eval.src[i] = tmpsrc;
00210         }
00211         evaluate(eval);
00212         for (int i = 0; i < stmt->dest.size(); i++) {
00213           setDest(i, Cell(Cell::CONSTANT, tmpdest.getVariant(i)), worklist);
00214         }
00215       } else if (alluniform) {
00216         for (int i = 0; i < stmt->dest.size(); i++) {
00217           setDest(i, Cell(Cell::UNIFORM, this, i), worklist);
00218         }
00219       } else if (somebottom) {
00220         for (int i = 0; i < stmt->dest.size(); i++) {
00221           setDest(i, Cell(Cell::BOTTOM), worklist);
00222         }
00223       }
00224     } else {
00225       SH_DEBUG_ASSERT(0 && "Invalid result source type");
00226     }
00227   }
00228 
00229   typedef int ValueNum;
00230 
00231   struct Uniform {
00232     Uniform()
00233       : constant(false),
00234         valuenum(-1)
00235     {
00236     }
00237 
00238     // @todo type...this is my current understanding:
00239     // May be constant or if !constval, value is not
00240     // known to be constant
00241     Uniform(ShVariantCPtr cval)
00242       : constant(true),
00243         constval(cval ? cval->get() : ShVariantPtr(0))
00244     {
00245     }
00246     
00247     Uniform(int valuenum, int index, bool neg)
00248       : constant(false),
00249         valuenum(valuenum), index(index), neg(neg)
00250     {
00251     }
00252 
00253     bool operator==(const Uniform& other) const
00254     {
00255       if (constant != other.constant) return false;
00256 
00257       if (constant) {
00258         // @todo type
00259         if(!constval) return false;
00260         // Check with Stefanus whether this modification is correct
00261     //    SH_DEBUG_ASSERT(constval); // @todo type debugging
00262         return constval->equals(other.constval);
00263       } else {
00264         if (valuenum != other.valuenum) return false;
00265         if (index != other.index) return false;
00266         if (neg != other.neg) return false;
00267         return true;
00268       }
00269     }
00270 
00271     ShValueType valueType() const 
00272     {
00273       if(constant) return constval->valueType();
00274       return Value::get(valuenum)->valueType();
00275     }
00276 
00277     bool operator!=(const Uniform& other) const
00278     {
00279       return !(*this == other);
00280     }
00281 
00282 
00283     bool constant;
00284 
00285     ValueNum valuenum;
00286     int index;
00287     bool neg;
00288 
00289     ShVariantPtr constval;
00290   };
00291 
00292   class Value {
00293   public:
00294     enum Type {
00295       NODE,
00296       STMT
00297     };
00298 
00299     Type type;
00300     
00301     // Only for type == NODE:
00302     ShVariableNodePtr node;
00303 
00304     // Only for type == STMT:
00305     ShOperation op;
00306     int destsize;
00307     ShValueType destValueType;
00308     std::vector<Uniform> src[3];
00309 
00310     static void clear()
00311     {
00312       m_values.clear();
00313     }
00314     
00315     static ValueNum lookup(const ShVariableNodePtr& node)
00316     {
00317       for (std::size_t i = 0; i < m_values.size(); i++) {
00318         if (m_values[i]->type == NODE && m_values[i]->node == node) return i;
00319       }
00320       m_values.push_back(new Value(node));
00321       return m_values.size() - 1;
00322     }
00323 
00324     bool operator==(const Value& other) const
00325     {
00326       if (type != other.type) return false;
00327       if (type == NODE) {
00328         return node == other.node;
00329       } else if (type == STMT) {
00330         if (op != other.op || destsize != other.destsize || destValueType != other.destValueType) return false;
00331         for (int i = 0; i < opInfo[op].arity; i++) {
00332           if (src[i].size() != other.src[i].size()) return false;
00333           for (std::size_t j = 0; j < src[i].size(); j++) {
00334             if (src[i][j] != other.src[i][j]) return false;
00335           }
00336         }
00337         return true;
00338       }
00339       return false;
00340     }
00341 
00342     bool operator!=(const Value& other) const
00343     {
00344       return !(*this == other);
00345     }
00346 
00347     static ValueNum lookup(ConstProp* cp)
00348     {
00349       Value* val = new Value(cp);
00350       
00351       for (std::size_t i = 0; i < m_values.size(); i++) {
00352         if (m_values[i]->type != STMT) continue;
00353         if (m_values[i]->op != cp->stmt->op) continue;
00354 
00355         if (*val == *m_values[i]) {
00356           delete val;
00357           return i;
00358         }
00359       }
00360       m_values.push_back(val);
00361       return m_values.size() - 1;
00362     }
00363 
00364     ShValueType valueType() {
00365       if(type == NODE) {
00366         return node->valueType();
00367       } 
00368       return destValueType;
00369     }
00370 
00371     static Value* get(ValueNum n)
00372     {
00373       return m_values[n];
00374     }
00375 
00376     static void dump(std::ostream& out);
00377 
00378     std::string name() const
00379     {
00380       if(type == NODE) return node->name();
00381       else return ""; 
00382     }
00383 
00384   private:
00385     Value(const ShVariableNodePtr& node)
00386       : type(NODE), node(node)
00387     {
00388     }
00389 
00390     Value(ConstProp* cp)
00391       : type(STMT), node(0), op(cp->stmt->op), destsize(cp->stmt->dest.size()), destValueType(cp->stmt->dest.valueType())
00392     {
00393       for (int i = 0; i < opInfo[cp->stmt->op].arity; i++) {
00394         for (std::size_t j = 0; j < cp->src[i].size(); j++) {
00395           if (cp->src[i][j].state == Cell::UNIFORM) {
00396             src[i].push_back(cp->src[i][j].uniform);
00397           } else {
00398             SH_DEBUG_ASSERT(cp->src[i][j].state == Cell::CONSTANT); // @todo type should be fixed
00399             src[i].push_back(Uniform(cp->src[i][j].value));
00400           }
00401         }
00402       }
00403     }
00404     
00405     static std::vector<Value*> m_values;
00406   };
00407   
00408   struct Cell {
00409     enum State {
00410       BOTTOM,
00411       CONSTANT,
00412       UNIFORM,
00413       TOP
00414     };
00415 
00416     // @todo comments added, but may not be correct
00417    
00418     // Construct a CONSTANT Cell with non-null value
00419     // or a TOP/BOTTOM
00420     Cell(State state, ShVariantPtr value = 0)
00421       : state(state)
00422     {
00423       if(value) this->value = value->get(); 
00424       SH_DEBUG_ASSERT(this->value || (state != CONSTANT));
00425     }
00426 
00427     // Construct a UNIFORM cell from a variable
00428     Cell(State state, const ShVariable& var, int index) 
00429       : state(state), value(0),
00430         uniform(Value::lookup(var.node()), var.swizzle()[index], var.neg())
00431     {
00432     }
00433 
00434     // Construct a UNIFORM cell as a result of some
00435     // statement where all relevant sources are uniforms/constants.
00436     Cell(State state, ConstProp* cp, int index) 
00437       : state(state), value(0),
00438         uniform(Value::lookup(cp), index, false)
00439     {
00440     }
00441 
00442     bool operator==(const Cell& other) const
00443     {
00444       if(state != other.state) return false;
00445       if(value) return value->equals(other.value);
00446       // @todo type remove debug
00447       SH_DEBUG_ASSERT(!value);
00448       return value == other.value; // null
00449     }
00450 
00451     bool operator!=(const Cell& other) const
00452     {
00453       return !(*this == other);
00454     }
00455     
00456     State state;
00457     ShVariantPtr value; // Only for state == CONSTANT
00458 
00459     Uniform uniform; // Only for state == UNIFORM
00460   };
00461 
00462   // If dest[index] != cell, sets dest[index] to cell and updates the worklist 
00463   // Caller must ensure that cell does not move up the lattice. 
00464   void setDest(int index, const Cell &cell, ConstWorkList &worklist)
00465   {
00466     if(dest[index] == cell) return; 
00467     dest[index] = cell;
00468     worklist.push(ValueTracking::Def(stmt, index));
00469   }
00470 
00471   ShStatement* stmt;
00472   std::vector<Cell> dest;
00473   std::vector<Cell> src[3];
00474 };
00475 
00476 ShInfo* ConstProp::clone() const
00477 {
00478   return new ConstProp(*this);
00479 }
00480 
00481 ConstProp::Cell meet(const ConstProp::Cell& a, const ConstProp::Cell& b)
00482 {
00483   if (a.state == ConstProp::Cell::BOTTOM || b.state == ConstProp::Cell::BOTTOM) {
00484     return ConstProp::Cell(ConstProp::Cell::BOTTOM);
00485   }
00486   if (a.state != b.state
00487       && (a.state != ConstProp::Cell::TOP && b.state != ConstProp::Cell::TOP)) {
00488     return ConstProp::Cell(ConstProp::Cell::BOTTOM);
00489   }
00490   // At this point either the cells are the same or one of them is
00491   // top.
00492   if (a.state == b.state) {
00493     if (a.state == ConstProp::Cell::CONSTANT) {
00494       SH_DEBUG_ASSERT(a.value); // @todo type debugging
00495       if (a.value->equals(b.value)) {
00496         return a;
00497       } else {
00498         return ConstProp::Cell(ConstProp::Cell::BOTTOM);
00499       }
00500     }
00501     if (a.state == ConstProp::Cell::UNIFORM) {
00502       if (a.uniform == b.uniform) return a;
00503       return ConstProp::Cell(ConstProp::Cell::BOTTOM);
00504     }
00505   }
00506 
00507   if (a.state != ConstProp::Cell::TOP) return a;
00508   if (b.state != ConstProp::Cell::TOP) return b;
00509     
00510   return ConstProp::Cell(ConstProp::Cell::TOP);
00511 }
00512 
00513 std::vector<ConstProp::Value*> ConstProp::Value::m_values = std::vector<ConstProp::Value*>();
00514 
00515 std::ostream& operator<<(std::ostream& out, const ConstProp::Uniform& uniform)
00516 {
00517   if (uniform.constant) {
00518     SH_DEBUG_ASSERT(uniform.constval); // @todo type DEBUGGING
00519     out << uniform.constval->encode();
00520   } else {
00521     if (uniform.neg) out << '-';
00522     out << "v" << uniform.valuenum << "[" << uniform.index << "]";
00523   }
00524   return out;
00525 }
00526 
00527 void ConstProp::Value::dump(std::ostream& out)
00528 {
00529   out << "--- uniform values ---" << std::endl;
00530   for (std::size_t i = 0; i < m_values.size(); i++) {
00531     out << i << ": ";
00532     if (m_values[i]->type == NODE) {
00533       out << "node " << m_values[i]->node->name() << std::endl;
00534     } else if (m_values[i]->type == STMT) {
00535       out << "stmt [" << m_values[i]->destsize << "] " << opInfo[m_values[i]->op].name << " ";
00536       for (int j = 0; j < opInfo[m_values[i]->op].arity; j++) {
00537         if (j) out << ", ";
00538         for (std::vector<Uniform>::iterator U = m_values[i]->src[j].begin();
00539              U != m_values[i]->src[j].end(); ++U) {
00540           out << *U;
00541         }
00542       }
00543       out << ";" << std::endl;
00544     }
00545   }
00546 }
00547 
00548 std::ostream& operator<<(std::ostream& out, const ConstProp::Cell& cell)
00549 {
00550   switch(cell.state) {
00551   case ConstProp::Cell::BOTTOM:
00552     out << "[bot]";
00553     break;
00554   case ConstProp::Cell::CONSTANT:
00555     out << "[" << cell.value->encode() << "]";
00556     break;
00557   case ConstProp::Cell::TOP:
00558     out << "[top]";
00559     break;
00560   case ConstProp::Cell::UNIFORM:
00561     out << "<" << cell.uniform << ">";
00562     break;
00563   }
00564   return out;
00565 }
00566 
00567 struct InitConstProp {
00568   InitConstProp(ShProgramNodeCPtr prog, ConstWorkList& worklist)
00569     : prog(prog), worklist(worklist)
00570   {
00571   }
00572 
00573   void operator()(const ShCtrlGraphNodePtr& node)
00574   {
00575     if (!node) return;
00576     ShBasicBlockPtr block = node->block;
00577     if (!block) return;
00578     for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00579       I->destroy_info<ConstProp>();
00580       ConstProp* cp = new ConstProp(&(*I), prog, worklist);
00581       I->add_info(cp);
00582     }
00583   }
00584 
00585   ShProgramNodeCPtr prog;
00586   ConstWorkList& worklist;
00587 };
00588 
00589 struct DumpConstProp {
00590   void operator()(const ShCtrlGraphNodePtr& node)
00591   {
00592     if (!node) return;
00593     ShBasicBlockPtr block = node->block;
00594     if (!block) return;
00595     for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00596       std::cerr << "{" << *I << "} --- ";
00597       ConstProp* cp = I->get_info<ConstProp>();
00598 
00599       if (!cp) {
00600         std::cerr << "NO CP INFORMATION" << std::endl;
00601         continue;
00602       }
00603 
00604       std::cerr << "dest = {";
00605       for (std::size_t i = 0; i < cp->dest.size(); i++) {
00606         std::cerr << cp->dest[i];
00607       }
00608       std::cerr << "}; ";
00609       for (int s = 0; s < opInfo[I->op].arity; s++) {
00610         if (s) std::cerr << ", ";
00611         std::cerr << "src" << s << " = {";
00612         for (std::size_t i = 0; i < cp->src[s].size(); i++) {
00613           std::cerr << cp->src[s][i];
00614         }
00615         std::cerr << "}";
00616       }
00617       std::cerr << std::endl;
00618     }
00619     
00620   }
00621 };
00622 
00623 struct FinishConstProp
00624 {
00625   FinishConstProp(bool lift_uniforms)
00626     : lift_uniforms(lift_uniforms)
00627   {
00628   }
00629   
00630   void operator()(const ShCtrlGraphNodePtr& node) {
00631     if (!node) return;
00632     ShBasicBlockPtr block = node->block;
00633     if (!block) return;
00634     for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00635       ConstProp* cp = I->get_info<ConstProp>();
00636 
00637       if (!cp) continue;
00638 
00639       if (!cp->dest.empty()) {
00640         // if all dest fields are constants, replace this with a
00641         // constant assignment
00642 
00643         if (I->op != SH_OP_ASN || I->src[0].node()->kind() != SH_CONST) {
00644           bool allconst = true;
00645           for (int i = 0; i < I->dest.size(); i++) {
00646             if (cp->dest[i].state != ConstProp::Cell::CONSTANT) {
00647               allconst = false;
00648               break;
00649             }
00650           }
00651           if (allconst) {
00652             SH_DEBUG_ASSERT(cp->dest[0].value); // @todo type debugging
00653             ShValueType destValueType = cp->dest[0].value->valueType(); 
00654             ShVariable newconst(new ShVariableNode(SH_CONST, I->dest.size(), destValueType));
00655             for(int i = 0; i < I->dest.size(); ++i) {
00656               newconst.setVariant(cp->dest[i].value, i);
00657             }
00658 #ifdef SH_DEBUG_CONSTPROP
00659             SH_DEBUG_PRINT("Replaced {" << *I << "} with " << newconst);
00660 #endif
00661             *I = ShStatement(I->dest, SH_OP_ASN, newconst);
00662           } else {
00663             // otherwise, do the same for each source field.
00664             for (int s = 0; s < opInfo[I->op].arity; s++) {
00665               if (I->src[s].node()->kind() == SH_CONST) continue;
00666             
00667               ShValueType srcValueType = I->src[s].valueType();
00668               ShVariable newconst(new ShVariableNode(SH_CONST, I->src[s].size(), srcValueType));
00669               bool allconst = true;
00670               for (int i = 0; i < I->src[s].size(); i++) {
00671                 if (cp->src[s][i].state != ConstProp::Cell::CONSTANT) {
00672                   allconst = false;
00673                   break;
00674                 }
00675                 newconst.setVariant(cp->src[s][i].value, i);
00676               }
00677               if (allconst) {
00678 #ifdef SH_DEBUG_CONSTPROP
00679                 SH_DEBUG_PRINT("Replaced {" << *I << "}.src[" << s << "] with " << newconst);
00680 #endif
00681                 I->src[s] = newconst;
00682               }
00683             }
00684           }
00685 
00686           if (!lift_uniforms || allconst) {
00687             //SH_DEBUG_PRINT("Skipping uniform lifting");
00688             continue;
00689           }
00690 
00691           bool alluniform = true;
00692           for (int s = 0; s < opInfo[I->op].arity; s++) {
00693             for (int i = 0; i < I->src[s].size(); i++) {
00694               if (cp->src[s][i].state != ConstProp::Cell::UNIFORM
00695                   && cp->src[s][i].state != ConstProp::Cell::CONSTANT) {
00696                 alluniform = false;
00697                 break;
00698               }
00699             }
00700             if (!alluniform) break;
00701           }
00702           if (!alluniform || I->dest.node()->kind() == SH_OUTPUT
00703               || I->dest.node()->kind() == SH_INOUT) {
00704 #ifdef SH_DEBUG_CONSTPROP
00705             SH_DEBUG_PRINT("Considering " << *I << " for uniform lifting");
00706 #endif          
00707             for (int s = 0; s < opInfo[I->op].arity; s++) {
00708               if (I->src[s].uniform()) {
00709 #ifdef SH_DEBUG_CONSTPROP
00710                 SH_DEBUG_PRINT(*I << ".src[" << s << "] is already a uniform");
00711 #endif
00712                 continue;
00713               }
00714 
00715               bool mixed = false;
00716               bool neg = false;
00717               ConstProp::ValueNum uniform = -1;
00718               std::vector<int> indices;
00719               
00720               for (int i = 0; i < I->src[s].size(); i++) {
00721                 if (cp->src[s][i].state == ConstProp::Cell::UNIFORM) {
00722                   if (uniform < 0) {
00723                     uniform = cp->src[s][i].uniform.valuenum;
00724                     neg = cp->src[s][i].uniform.neg;
00725                     indices.push_back(cp->src[s][i].uniform.index);
00726                   } else {
00727                     if (uniform != cp->src[s][i].uniform.valuenum
00728                         || neg != cp->src[s][i].uniform.neg) {
00729                       mixed = true;
00730                       break;
00731                     }
00732                     indices.push_back(cp->src[s][i].uniform.index);
00733                   }
00734                 } else {
00735                   // Can't lift this, unless we introduce intermediate instructions.
00736                   mixed = true;
00737                 }
00738               }
00739               
00740               if (uniform < 0) {
00741 #ifdef SH_DEBUG_CONSTPROP
00742                 SH_DEBUG_PRINT("{" << *I << "}.src[" << s << "] is not uniform");
00743 #endif
00744                 continue;
00745               }
00746               if (mixed) {
00747 #ifdef SH_DEBUG_CONSTPROP
00748                 SH_DEBUG_PRINT("{" << *I << "}.src[" << s << "] is mixed");
00749 #endif
00750                 continue;
00751               }
00752               ConstProp::Value* value = ConstProp::Value::get(uniform);
00753 
00754 
00755 #ifdef SH_DEBUG_CONSTPROP
00756               SH_DEBUG_PRINT("Lifting {" << *I << "}.src[" << s << "]: " << uniform);
00757 #endif
00758               int srcsize;
00759               if (value->type == ConstProp::Value::NODE) {
00760                 srcsize = value->node->size();
00761               } else {
00762                 srcsize = value->destsize;
00763               }
00764               ShSwizzle swizzle(srcsize, indices.size(), &(*indices.begin()));
00765               // Build a uniform to represent this computation.
00766               ShVariableNodePtr node = build_uniform(value, uniform);
00767               if (node) {
00768                 I->src[s] = ShVariable(node, swizzle, neg);
00769               } else {
00770 #ifdef SH_DEBUG_CONSTPROP
00771                 SH_DEBUG_PRINT("Could not lift " << *I << ".src[" << s << "] for some reason");
00772 #endif
00773               }
00774             }
00775           }
00776         }
00777       }
00778     }
00779 
00780     // Clean up
00781     for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00782       // Remove constant propagation information.
00783       I->destroy_info<ConstProp>();
00784     }
00785   }
00786 
00787   ShVariableNodePtr build_uniform(ConstProp::Value* value,
00788                                   ConstProp::ValueNum valuenum)
00789   {
00790     if (value->type == ConstProp::Value::NODE) {
00791       return value->node;
00792     }
00793     
00794     ShContext::current()->enter(0);
00795     ShVariableNodePtr node = new ShVariableNode(SH_TEMP, value->destsize, value->destValueType);
00796     {
00797     std::ostringstream s;
00798     s << "dep_" << valuenum << "_" << value->name();
00799     node->name(s.str());
00800     }
00801     ShContext::current()->exit();
00802 
00803 #ifdef SH_DEBUG_CONSTPROP
00804     SH_DEBUG_PRINT("Lifting value #" << valuenum);
00805 #endif
00806 
00807     bool broken = false;
00808     
00809     ShProgram prg = SH_BEGIN_PROGRAM("uniform") {
00810       ShStatement stmt(node, value->op);
00811 
00812       for (int i = 0; i < opInfo[value->op].arity; i++) {
00813         stmt.src[i] = compute(value->src[i]);
00814         if (stmt.src[i].null()) broken = true;
00815       }
00816 
00817       ShContext::current()->parsing()->tokenizer.blockList()->addStatement(stmt);
00818     } SH_END;
00819 
00820 #ifdef SH_DEBUG_CONSTPROP
00821     {
00822       std::ostringstream s;
00823       s << "lifted_" << valuenum;
00824       std::string dotfilename(s.str() + ".dot");
00825       std::ofstream dot(dotfilename.c_str());
00826       prg.node()->ctrlGraph->graphvizDump(dot);
00827       dot.close();
00828       std::string cmdline = std::string("dot -Tps -o ") + s.str() + ".ps " + s.str() + ".dot";
00829       system(cmdline.c_str());
00830     }
00831 #endif
00832     
00833     if (broken) return 0;
00834     node->attach(prg.node());
00835 
00836     value->type = ConstProp::Value::NODE;
00837     value->node = node;
00838     
00839     return node;
00840   }
00841 
00842   ShVariable compute(const std::vector<ConstProp::Uniform>& src)
00843   {
00844     ConstProp::ValueNum v = -1;
00845     
00846     bool allsame = true;
00847     bool neg = false;
00848     std::vector<int> indices;
00849     std::vector<ShVariantCPtr> constvals;
00850     for (std::size_t i = 0; i < src.size(); i++) {
00851       if (src[i].constant) {
00852         if (v >= 0) {
00853           allsame = false;
00854           break;
00855         }
00856         constvals.push_back(src[i].constval);
00857       } else {
00858         if (v < 0 && constvals.empty()) {
00859           v = src[i].valuenum;
00860           neg = src[i].neg;
00861         } else {
00862           if (v != src[i].valuenum || neg != src[i].neg || !constvals.empty()) {
00863             allsame = false;
00864           }
00865         }
00866         indices.push_back(src[i].index);
00867       }
00868     }
00869     if (!allsame) {
00870       // Make intermediate variables, combine them together.
00871       ShVariable r = ShVariable(new ShVariableNode(SH_TEMP, src.size(), src[0].valueType()));
00872       
00873       for (std::size_t i = 0; i < src.size(); i++) {
00874         std::vector<ConstProp::Uniform> v;
00875         v.push_back(src[i]);
00876         ShVariable scalar = compute(v);
00877         ShStatement stmt(r(i), SH_OP_ASN, scalar);
00878         ShContext::current()->parsing()->tokenizer.blockList()->addStatement(stmt);
00879       }
00880       return r;
00881     }
00882 
00883     if (!constvals.empty()) {
00884       ShVariable var(new ShVariableNode(SH_CONST, constvals.size(), constvals[0]->valueType()));
00885       for(std::size_t i = 0; i < constvals.size(); ++i) var.setVariant(constvals[i], i);
00886       return var;
00887     }
00888     
00889     ConstProp::Value* value = ConstProp::Value::get(v);
00890     
00891     if (value->type == ConstProp::Value::NODE) {
00892       ShSwizzle swizzle(value->node->size(), indices.size(), &*indices.begin());
00893       return ShVariable(value->node, swizzle, neg);
00894     }
00895     if (value->type == ConstProp::Value::STMT) {
00896       ShVariableNodePtr node = new ShVariableNode(SH_TEMP, value->destsize, value->destValueType);
00897       ShStatement stmt(node, value->op);
00898 
00899       for (int i = 0; i < opInfo[value->op].arity; i++) {
00900         stmt.src[i] = compute(value->src[i]);
00901         if (stmt.src[i].null()) return ShVariable();
00902       }
00903 
00904       ShContext::current()->parsing()->tokenizer.blockList()->addStatement(stmt);
00905       ShSwizzle swizzle(node->size(), indices.size(), &*indices.begin());
00906       return ShVariable(node, swizzle, neg);
00907     }
00908 
00909 #ifdef SH_DEBUG_CONSTPROP
00910     SH_DEBUG_PRINT("Reached invalid point");
00911 #endif
00912 
00913     // Should never reach here.
00914     return ShVariable();
00915   }
00916 
00917   bool lift_uniforms;
00918 };
00919 
00920 }
00921 
00922 namespace SH {
00923 
00924 
00925 void propagate_constants(ShProgram& p)
00926 {
00927   ShCtrlGraphPtr graph = p.node()->ctrlGraph;
00928 
00929   ConstWorkList worklist;
00930 
00931   //ConstProp::Value::clear();
00932   
00933   InitConstProp init(p.node(), worklist);
00934   graph->dfs(init);
00935 
00936 #ifdef SH_DEBUG_CONSTPROP
00937   SH_DEBUG_PRINT("Const Prop Initial Values:");
00938   DumpConstProp dump_pre;
00939   graph->dfs(dump_pre);
00940 #endif
00941 
00942   
00943   while (!worklist.empty()) {
00944     ValueTracking::Def def = worklist.front(); worklist.pop();
00945     ValueTracking* vt = def.stmt->get_info<ValueTracking>();
00946     if (!vt) {
00947 #ifdef SH_DEBUG_CONSTPROP
00948       SH_DEBUG_PRINT(*def.stmt << " on worklist does not have VT information?");
00949 #endif
00950       continue;
00951     }
00952 
00953 
00954     for (ValueTracking::DefUseChain::iterator use = vt->uses[def.index].begin();
00955          use != vt->uses[def.index].end(); ++use) {
00956       if (use->kind != ValueTracking::Use::STMT) continue;
00957       ConstProp* cp = use->stmt->get_info<ConstProp>();
00958       if (!cp) {
00959 #ifdef SH_DEBUG_CONSTPROP
00960         SH_DEBUG_PRINT("Use " << *use->stmt << " does not have const prop information!");
00961 #endif
00962         continue;
00963       }
00964 
00965       ConstProp::Cell cell = cp->src[use->source][use->index];
00966 
00967       ValueTracking* ut = use->stmt->get_info<ValueTracking>();
00968       if (!ut) {
00969         // Should never happen...
00970 #ifdef SH_DEBUG_CONSTPROP
00971         SH_DEBUG_PRINT("Use " << *use->stmt << " on worklist does not have VT information?");
00972 #endif
00973         continue;
00974       }
00975 
00976 #ifdef SH_DEBUG_CONSTPROP
00977       SH_DEBUG_PRINT("Meeting cell for {" << *use->stmt
00978                      << "}.src" << use->source << "[" << use->index << "]");
00979 #endif
00980       
00981       ConstProp::Cell new_cell(ConstProp::Cell::TOP);
00982       
00983       for (ValueTracking::UseDefChain::iterator possdef
00984              = ut->defs[use->source][use->index].begin();
00985            possdef != ut->defs[use->source][use->index].end(); ++possdef) {
00986         ConstProp* dcp = possdef->stmt->get_info<ConstProp>();
00987         if (!dcp) {
00988 #ifdef SH_DEBUG_CONSTPROP
00989           SH_DEBUG_PRINT("Possible def " << *dcp->stmt << " on worklist does not have CP information?");
00990 #endif
00991           continue;
00992         }
00993         
00994         ConstProp::Cell destcell = dcp->dest[possdef->index];
00995 
00996         // If the use is negated, we need to change the cell
00997         if (use->stmt->src[use->source].neg()) {
00998           if (destcell.state == ConstProp::Cell::CONSTANT) {
00999             SH_DEBUG_ASSERT(destcell.value); // @todo type DEBUGGING
01000             destcell.value->negate();
01001           } else if (destcell.state == ConstProp::Cell::UNIFORM) {
01002             destcell.uniform.neg = !destcell.uniform.neg;
01003           }
01004         }
01005 #ifdef SH_DEBUG_CONSTPROP
01006         SH_DEBUG_PRINT("  meet(" << new_cell << ", " << destcell << ") = " <<
01007                        meet(new_cell, destcell));
01008 #endif
01009         new_cell = meet(new_cell, destcell);
01010       }
01011       
01012       if (cell != new_cell) {
01013 #ifdef SH_DEBUG_CONSTPROP
01014         SH_DEBUG_PRINT("  ...replacing cell");
01015 #endif
01016         cp->src[use->source][use->index] = new_cell;
01017         cp->updateDest(worklist);
01018       }
01019     }
01020   }
01021   // Now do something with our glorious newfound information.
01022 
01023   
01024 #ifdef SH_DEBUG_CONSTPROP
01025   ConstProp::Value::dump(std::cerr);
01026 
01027   DumpConstProp dump;
01028   graph->dfs(dump);
01029 #endif
01030   
01031   FinishConstProp finish(p.node()->target().find("gpu:") == 0
01032                          && !ShContext::current()->optimization_disabled("uniform lifting"));
01033   graph->dfs(finish);
01034   
01035 }
01036 
01037 }

Generated on Thu Apr 21 17:32:46 2005 for Sh by  doxygen 1.4.2