ShConstProp.cpp

00001 // Sh: A GPU metaprogramming language.
00002 //
00003 // Copyright 2003-2006 Serious Hack Inc.
00004 // 
00005 // This library is free software; you can redistribute it and/or
00006 // modify it under the terms of the GNU Lesser General Public
00007 // License as published by the Free Software Foundation; either
00008 // version 2.1 of the License, or (at your option) any later version.
00009 //
00010 // This library is distributed in the hope that it will be useful,
00011 // but WITHOUT ANY WARRANTY; without even the implied warranty of
00012 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00013 // Lesser General Public License for more details.
00014 //
00015 // You should have received a copy of the GNU Lesser General Public
00016 // License along with this library; if not, write to the Free Software
00017 // Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, 
00018 // MA  02110-1301, USA
00020 #include "ShOptimizations.hpp"
00021 #include <map>
00022 #include <set>
00023 #include <utility>
00024 #include <iostream>
00025 #include "ShBitSet.hpp"
00026 #include "ShCtrlGraph.hpp"
00027 #include "ShDebug.hpp"
00028 #include "ShVariant.hpp"
00029 #include "ShEvaluate.hpp"
00030 #include "ShContext.hpp"
00031 #include "ShSyntax.hpp"
00032 #include "ShInfo.hpp"
00033 #include <sstream>
00034 #include <fstream>
00035 
00036 // Uncomment to enable constant/uniform propagation debugging (verbose!)
00037 // #define SH_DEBUG_CONSTPROP
00038 
00039 #ifdef SH_DEBUG_OPTIMIZER
00040 #ifndef SH_DEBUG_CONSTPROP
00041 #define SH_DEBUG_CONSTPROP
00042 #endif
00043 #endif
00044 
00045 namespace {
00046 
00047 using namespace SH;
00048 
00049 typedef std::queue<ValueTracking::Def> ConstWorkList;
00050 
00051 struct ConstProp : public ShInfo {
00052   ConstProp(ShStatement* stmt,
00053             ShProgramNodeCPtr prog,
00054             ConstWorkList& worklist)
00055     : stmt(stmt)
00056   {
00057     for (int i = 0; i < opInfo[stmt->op].arity ; i++) {
00058       for (int j = 0; j < stmt->src[i].size(); j++) {
00059         switch(stmt->src[i].node()->kind()) {
00060         case SH_INPUT:
00061         case SH_OUTPUT:
00062         case SH_INOUT:
00063         case SH_TEXTURE:
00064         case SH_STREAM:
00065         case SH_PALETTE:
00066           src[i].push_back(Cell(Cell::BOTTOM));
00067           break;
00068         case SH_TEMP:
00069           if (stmt->src[i].uniform()) {
00070             // Don't lift computations dependent on uniforms which
00071             // have been marked with "opt:lifting" == "never"
00072             if (stmt->src[i].meta("opt:lifting") != "never") {
00073               src[i].push_back(Cell(Cell::UNIFORM, stmt->src[i], j));
00074             } else {
00075               src[i].push_back(Cell(Cell::BOTTOM));
00076             }
00077           } else {
00078             src[i].push_back(Cell(Cell::TOP));
00079           }
00080           break;
00081         case SH_CONST:
00082           src[i].push_back(Cell(Cell::CONSTANT, stmt->src[i].getVariant(j)));
00083           break;
00084         default:
00085           SH_DEBUG_ASSERT(0 && "Invalid ShBindingType");
00086           return;
00087         }
00088       }
00089     }
00090     updateDest(worklist);
00091   }
00092 
00093   ShInfo* clone() const;
00094 
00095   int idx(int destindex, int source)
00096   {
00097     return (stmt->src[source].size() == 1 ? 0 : destindex);
00098   }
00099 
00100   void updateDest(ConstWorkList& worklist)
00101   {
00102     if(dest.empty()) {
00103       dest.resize(stmt->dest.size(), Cell(Cell::TOP)); 
00104     }
00105 
00106     // Ignore KIL, optbra, etc.
00107     if (opInfo[stmt->op].result_source == ShOperationInfo::IGNORE) return;
00108 
00109     if (stmt->op == SH_OP_ASN) {
00110       for (int i = 0; i < stmt->dest.size(); i++) {
00111         setDest(i, src[0][i], worklist); // assume src[0][i] cannot move up lattice
00112       }
00113     } else if (opInfo[stmt->op].result_source == ShOperationInfo::EXTERNAL) {
00114       // This statement never results in a constant
00115       // E.g. texture fetches, stream fetches.
00116       for (int i = 0; i < stmt->dest.size(); i++) {
00117         setDest(i, Cell(Cell::BOTTOM), worklist);
00118       }
00119     } else if (opInfo[stmt->op].result_source == ShOperationInfo::LINEAR) {
00120       // The strategy here is to ensure that 
00121       // a) whenever one src becomes bottom, dest becomes bottom
00122       // b) uniform only gets set when ALL src are uniform (because value
00123       // tracking requires it)
00124       // c) otherwise, propagate CONST state per element
00125       // @todo range (CONST may move across to UNIFORM, check that this is okay)
00126 
00127       // Consider each tuple element in turn.
00128       // Dest and sources are guaranteed to be of the same length.
00129       // Except that sources might be scalar.
00130       bool all_fields_uniform = true;
00131       bool some_field_bottom = false;
00132       for (int i = 0; i < stmt->dest.size(); i++) {
00133         bool alluniform = true;
00134         bool allconst = true;
00135         bool somebottom = false;
00136         for (int s = 0; s < opInfo[stmt->op].arity; s++) {
00137           if (src[s][idx(i,s)].state == Cell::BOTTOM) {
00138             somebottom = true;
00139           }
00140           if (src[s][idx(i,s)].state != Cell::CONSTANT) {
00141             allconst = false;
00142             if (src[s][idx(i,s)].state != Cell::UNIFORM) {
00143               alluniform = false;
00144             }
00145           }
00146         }
00147         some_field_bottom |= somebottom;
00148         if (!(alluniform && !allconst)) all_fields_uniform = false;
00149         if (allconst) {
00150           ShVariable tmpdest(new ShVariableNode(SH_CONST, 1, stmt->dest.valueType()));
00151           ShStatement eval(*stmt);
00152           eval.dest = tmpdest;
00153           for (int k = 0; k < opInfo[stmt->op].arity; k++) {
00154             ShVariantCPtr srcValue = src[k][idx(i,k)].value;
00155             ShVariable tmpsrc(new ShVariableNode(SH_CONST, 1, srcValue->valueType()));
00156             tmpsrc.setVariant(srcValue, 0);
00157             eval.src[k] = tmpsrc;
00158           }
00159           evaluate(eval);
00160           setDest(i, Cell(Cell::CONSTANT, tmpdest.getVariant(0)), worklist);
00161         } else if (somebottom) {
00162           setDest(i, Cell(Cell::BOTTOM), worklist);
00163         }
00164       } 
00165 
00166 
00167       // Because making a uniform cell based on a ConstProp requires
00168       // generating a value for the entire statement (not just one
00169       // field of the destination), we only push said cells if ALL of
00170       // the indices are uniform for ALL of their corresponding sources.
00171       if (all_fields_uniform) {
00172         for (int i = 0; i < stmt->dest.size(); i++) {
00173           setDest(i, Cell(Cell::UNIFORM, this, i), worklist);
00174         }
00175       }
00176     } else if (opInfo[stmt->op].result_source == ShOperationInfo::ALL) {
00177       // build statement ONLY if ALL elements of ALL sources are constant
00178       bool allconst = true;
00179       bool alluniform = true; // all statements are either uniform or constant
00180       bool somebottom = false;
00181       for (int s = 0; s < opInfo[stmt->op].arity && !somebottom; s++) {
00182         for (unsigned int k = 0; k < src[s].size(); k++) {
00183           if(src[s][k].state == Cell::BOTTOM) {
00184             somebottom = true;
00185           }
00186           if (src[s][k].state != Cell::CONSTANT) {
00187             allconst = false;
00188             if (src[s][k].state != Cell::UNIFORM) {
00189               alluniform = false;
00190             }
00191           }
00192         }
00193       }
00194       if (allconst) { 
00195         ShVariable tmpdest(new ShVariableNode(SH_CONST, stmt->dest.size(), stmt->dest.valueType()));
00196         ShStatement eval(*stmt);
00197         eval.dest = tmpdest;
00198         for (int i = 0; i < opInfo[stmt->op].arity; i++) {
00199           SH_DEBUG_ASSERT(src[i][0].value); // @todo type DEBUGGING
00200           ShValueType srcValueType = src[i][0].value->valueType(); 
00201           ShVariable tmpsrc(new ShVariableNode(SH_CONST, stmt->src[i].size(), srcValueType));
00202           for (int j = 0; j < stmt->src[i].size(); j++) {
00203             tmpsrc.setVariant(src[i][j].value, j);
00204           }
00205           eval.src[i] = tmpsrc;
00206         }
00207         evaluate(eval);
00208         for (int i = 0; i < stmt->dest.size(); i++) {
00209           setDest(i, Cell(Cell::CONSTANT, tmpdest.getVariant(i)), worklist);
00210         }
00211       } else if (alluniform) {
00212         for (int i = 0; i < stmt->dest.size(); i++) {
00213           setDest(i, Cell(Cell::UNIFORM, this, i), worklist);
00214         }
00215       } else if (somebottom) {
00216         for (int i = 0; i < stmt->dest.size(); i++) {
00217           setDest(i, Cell(Cell::BOTTOM), worklist);
00218         }
00219       }
00220     } else {
00221       SH_DEBUG_ASSERT(0 && "Invalid result source type");
00222     }
00223   }
00224 
00225   typedef int ValueNum;
00226 
00227   struct Uniform {
00228     Uniform()
00229       : constant(false),
00230         valuenum(-1)
00231     {
00232     }
00233 
00234     // @todo type...this is my current understanding:
00235     // May be constant or if !constval, value is not
00236     // known to be constant
00237     Uniform(const ShVariantCPtr& cval)
00238       : constant(true),
00239         constval(cval ? cval->get() : ShVariantPtr(0))
00240     {
00241     }
00242     
00243     Uniform(int valuenum, int index, bool neg)
00244       : constant(false),
00245         valuenum(valuenum), index(index), neg(neg)
00246     {
00247     }
00248 
00249     bool operator==(const Uniform& other) const
00250     {
00251       if (constant != other.constant) return false;
00252 
00253       if (constant) {
00254         // @todo type
00255         if(!constval) return false;
00256         // Check with Stefanus whether this modification is correct
00257     //    SH_DEBUG_ASSERT(constval); // @todo type debugging
00258         return constval->equals(other.constval);
00259       } else {
00260         if (valuenum != other.valuenum) return false;
00261         if (index != other.index) return false;
00262         if (neg != other.neg) return false;
00263         return true;
00264       }
00265     }
00266 
00267     ShValueType valueType() const 
00268     {
00269       if(constant) return constval->valueType();
00270       return Value::get(valuenum)->valueType();
00271     }
00272 
00273     bool operator!=(const Uniform& other) const
00274     {
00275       return !(*this == other);
00276     }
00277 
00278 
00279     bool constant;
00280 
00281     ValueNum valuenum;
00282     int index;
00283     bool neg;
00284 
00285     ShVariantPtr constval;
00286   };
00287 
00288   class Value {
00289   public:
00290     enum Type {
00291       NODE,
00292       STMT
00293     };
00294 
00295     Type type;
00296     
00297     // Only for type == NODE:
00298     ShVariableNodePtr node;
00299 
00300     // Only for type == STMT:
00301     ShOperation op;
00302     int destsize;
00303     ShValueType destValueType;
00304     std::vector<Uniform> src[3];
00305 
00306     static void clear()
00307     {
00308       m_values.clear();
00309     }
00310     
00311     static ValueNum lookup(const ShVariableNodePtr& node)
00312     {
00313       for (std::size_t i = 0; i < m_values.size(); i++) {
00314         if (m_values[i]->type == NODE && m_values[i]->node == node) return i;
00315       }
00316       m_values.push_back(new Value(node)); // FIXME: Never gets deleted
00317       return m_values.size() - 1;
00318     }
00319 
00320     bool operator==(const Value& other) const
00321     {
00322       if (type != other.type) return false;
00323       if (type == NODE) {
00324         return node == other.node;
00325       } else if (type == STMT) {
00326         if (op != other.op || destsize != other.destsize || destValueType != other.destValueType) return false;
00327         for (int i = 0; i < opInfo[op].arity; i++) {
00328           if (src[i].size() != other.src[i].size()) return false;
00329           for (std::size_t j = 0; j < src[i].size(); j++) {
00330             if (src[i][j] != other.src[i][j]) return false;
00331           }
00332         }
00333         return true;
00334       }
00335       return false;
00336     }
00337 
00338     bool operator!=(const Value& other) const
00339     {
00340       return !(*this == other);
00341     }
00342 
00343     static ValueNum lookup(ConstProp* cp)
00344     {
00345       Value* val = new Value(cp); // FIXME: Never gets deleted
00346       
00347       for (std::size_t i = 0; i < m_values.size(); i++) {
00348         if (m_values[i]->type != STMT) continue;
00349         if (m_values[i]->op != cp->stmt->op) continue;
00350 
00351         if (*val == *m_values[i]) {
00352           delete val;
00353           return i;
00354         }
00355       }
00356       m_values.push_back(val);
00357       return m_values.size() - 1;
00358     }
00359 
00360     ShValueType valueType() {
00361       if(type == NODE) {
00362         return node->valueType();
00363       } 
00364       return destValueType;
00365     }
00366 
00367     static Value* get(ValueNum n)
00368     {
00369       return m_values[n];
00370     }
00371 
00372     static void dump(std::ostream& out);
00373 
00374     std::string name() const
00375     {
00376       if(type == NODE) return node->name();
00377       else return ""; 
00378     }
00379 
00380   private:
00381     Value(const ShVariableNodePtr& node)
00382       : type(NODE), node(node)
00383     {
00384     }
00385 
00386     Value(ConstProp* cp)
00387       : type(STMT), node(0), op(cp->stmt->op), destsize(cp->stmt->dest.size()), destValueType(cp->stmt->dest.valueType())
00388     {
00389       for (int i = 0; i < opInfo[cp->stmt->op].arity; i++) {
00390         for (std::size_t j = 0; j < cp->src[i].size(); j++) {
00391           if (cp->src[i][j].state == Cell::UNIFORM) {
00392             src[i].push_back(cp->src[i][j].uniform);
00393           } else {
00394             SH_DEBUG_ASSERT(cp->src[i][j].state == Cell::CONSTANT); // @todo type should be fixed
00395             src[i].push_back(Uniform(cp->src[i][j].value));
00396           }
00397         }
00398       }
00399     }
00400     
00401     static std::vector<Value*> m_values;
00402   };
00403   
00404   struct Cell {
00405     enum State {
00406       BOTTOM,
00407       CONSTANT,
00408       UNIFORM,
00409       TOP
00410     };
00411 
00412     // @todo comments added, but may not be correct
00413    
00414     // Construct a CONSTANT Cell with non-null value
00415     // or a TOP/BOTTOM
00416     Cell(State state, ShVariantPtr value = 0)
00417       : state(state)
00418     {
00419       if(value) this->value = value->get(); 
00420       SH_DEBUG_ASSERT(this->value || (state != CONSTANT));
00421     }
00422 
00423     // Construct a UNIFORM cell from a variable
00424     Cell(State state, const ShVariable& var, int index) 
00425       : state(state), value(0),
00426         uniform(Value::lookup(var.node()), var.swizzle()[index], var.neg())
00427     {
00428     }
00429 
00430     // Construct a UNIFORM cell as a result of some
00431     // statement where all relevant sources are uniforms/constants.
00432     Cell(State state, ConstProp* cp, int index) 
00433       : state(state), value(0),
00434         uniform(Value::lookup(cp), index, false)
00435     {
00436     }
00437 
00438     bool operator==(const Cell& other) const
00439     {
00440       if(state != other.state) return false;
00441       if(value) return value->equals(other.value);
00442       // @todo type remove debug
00443       SH_DEBUG_ASSERT(!value);
00444       return value == other.value; // null
00445     }
00446 
00447     bool operator!=(const Cell& other) const
00448     {
00449       return !(*this == other);
00450     }
00451     
00452     State state;
00453     ShVariantPtr value; // Only for state == CONSTANT
00454 
00455     Uniform uniform; // Only for state == UNIFORM
00456   };
00457 
00458   // If dest[index] != cell, sets dest[index] to cell and updates the worklist 
00459   // Caller must ensure that cell does not move up the lattice. 
00460   void setDest(int index, const Cell &cell, ConstWorkList &worklist)
00461   {
00462     if(dest[index] == cell) return; 
00463     dest[index] = cell;
00464     worklist.push(ValueTracking::Def(stmt, index));
00465   }
00466 
00467   ShStatement* stmt;
00468   std::vector<Cell> dest;
00469   std::vector<Cell> src[3];
00470 };
00471 
00472 ShInfo* ConstProp::clone() const
00473 {
00474   return new ConstProp(*this);
00475 }
00476 
00477 ConstProp::Cell meet(const ConstProp::Cell& a, const ConstProp::Cell& b)
00478 {
00479   if (a.state == ConstProp::Cell::BOTTOM || b.state == ConstProp::Cell::BOTTOM) {
00480     return ConstProp::Cell(ConstProp::Cell::BOTTOM);
00481   }
00482   if (a.state != b.state
00483       && (a.state != ConstProp::Cell::TOP && b.state != ConstProp::Cell::TOP)) {
00484     return ConstProp::Cell(ConstProp::Cell::BOTTOM);
00485   }
00486   // At this point either the cells are the same or one of them is
00487   // top.
00488   if (a.state == b.state) {
00489     if (a.state == ConstProp::Cell::CONSTANT) {
00490       SH_DEBUG_ASSERT(a.value); // @todo type debugging
00491       if (a.value->equals(b.value)) {
00492         return a;
00493       } else {
00494         return ConstProp::Cell(ConstProp::Cell::BOTTOM);
00495       }
00496     }
00497     if (a.state == ConstProp::Cell::UNIFORM) {
00498       if (a.uniform == b.uniform) return a;
00499       return ConstProp::Cell(ConstProp::Cell::BOTTOM);
00500     }
00501   }
00502 
00503   if (a.state != ConstProp::Cell::TOP) return a;
00504   if (b.state != ConstProp::Cell::TOP) return b;
00505     
00506   return ConstProp::Cell(ConstProp::Cell::TOP);
00507 }
00508 
00509 std::vector<ConstProp::Value*> ConstProp::Value::m_values = std::vector<ConstProp::Value*>();
00510 
00511 std::ostream& operator<<(std::ostream& out, const ConstProp::Uniform& uniform)
00512 {
00513   if (uniform.constant) {
00514     SH_DEBUG_ASSERT(uniform.constval); // @todo type DEBUGGING
00515     out << uniform.constval->encode();
00516   } else {
00517     if (uniform.neg) out << '-';
00518     out << "v" << uniform.valuenum << "[" << uniform.index << "]";
00519   }
00520   return out;
00521 }
00522 
00523 void ConstProp::Value::dump(std::ostream& out)
00524 {
00525   out << "--- uniform values ---" << std::endl;
00526   for (std::size_t i = 0; i < m_values.size(); i++) {
00527     out << i << ": ";
00528     if (m_values[i]->type == NODE) {
00529       out << "node " << m_values[i]->node->name() << std::endl;
00530     } else if (m_values[i]->type == STMT) {
00531       out << "stmt [" << m_values[i]->destsize << "] " << opInfo[m_values[i]->op].name << " ";
00532       for (int j = 0; j < opInfo[m_values[i]->op].arity; j++) {
00533         if (j) out << ", ";
00534         for (std::vector<Uniform>::iterator U = m_values[i]->src[j].begin();
00535              U != m_values[i]->src[j].end(); ++U) {
00536           out << *U;
00537         }
00538       }
00539       out << ";" << std::endl;
00540     }
00541   }
00542 }
00543 
00544 std::ostream& operator<<(std::ostream& out, const ConstProp::Cell& cell)
00545 {
00546   switch(cell.state) {
00547   case ConstProp::Cell::BOTTOM:
00548     out << "[bot]";
00549     break;
00550   case ConstProp::Cell::CONSTANT:
00551     out << "[" << cell.value->encode() << "]";
00552     break;
00553   case ConstProp::Cell::TOP:
00554     out << "[top]";
00555     break;
00556   case ConstProp::Cell::UNIFORM:
00557     out << "<" << cell.uniform << ">";
00558     break;
00559   }
00560   return out;
00561 }
00562 
00563 struct InitConstProp {
00564   InitConstProp(const ShProgramNodeCPtr& prog, ConstWorkList& worklist)
00565     : prog(prog), worklist(worklist)
00566   {
00567   }
00568 
00569   // assignment operator could not be generated: declaration only
00570   InitConstProp& operator=(InitConstProp const&);
00571 
00572   void operator()(const ShCtrlGraphNodePtr& node)
00573   {
00574     if (!node) return;
00575     ShBasicBlockPtr block = node->block;
00576     if (!block) return;
00577     for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00578       I->destroy_info<ConstProp>();
00579       ConstProp* cp = new ConstProp(&(*I), prog, worklist);
00580       I->add_info(cp);
00581     }
00582   }
00583 
00584   ShProgramNodeCPtr prog;
00585   ConstWorkList& worklist;
00586 };
00587 
00588 struct DumpConstProp {
00589   void operator()(const ShCtrlGraphNodePtr& node)
00590   {
00591     if (!node) return;
00592     ShBasicBlockPtr block = node->block;
00593     if (!block) return;
00594     for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00595       std::cerr << "{" << *I << "} --- ";
00596       ConstProp* cp = I->get_info<ConstProp>();
00597 
00598       if (!cp) {
00599         std::cerr << "NO CP INFORMATION" << std::endl;
00600         continue;
00601       }
00602 
00603       std::cerr << "dest = {";
00604       for (std::size_t i = 0; i < cp->dest.size(); i++) {
00605         std::cerr << cp->dest[i];
00606       }
00607       std::cerr << "}; ";
00608       for (int s = 0; s < opInfo[I->op].arity; s++) {
00609         if (s) std::cerr << ", ";
00610         std::cerr << "src" << s << " = {";
00611         for (std::size_t i = 0; i < cp->src[s].size(); i++) {
00612           std::cerr << cp->src[s][i];
00613         }
00614         std::cerr << "}";
00615       }
00616       std::cerr << std::endl;
00617     }
00618     
00619   }
00620 };
00621 
00622 struct FinishConstProp
00623 {
00624   FinishConstProp(bool lift_uniforms)
00625     : lift_uniforms(lift_uniforms)
00626   {
00627   }
00628   
00629   void operator()(const ShCtrlGraphNodePtr& node) {
00630     if (!node) return;
00631     ShBasicBlockPtr block = node->block;
00632     if (!block) return;
00633     for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00634       ConstProp* cp = I->get_info<ConstProp>();
00635 
00636       if (!cp) continue;
00637 
00638       if (!cp->dest.empty()) {
00639         // if all dest fields are constants, replace this with a
00640         // constant assignment
00641 
00642         if (I->op != SH_OP_ASN || I->src[0].node()->kind() != SH_CONST) {
00643           bool allconst = true;
00644           for (int i = 0; i < I->dest.size(); i++) {
00645             if (cp->dest[i].state != ConstProp::Cell::CONSTANT) {
00646               allconst = false;
00647               break;
00648             }
00649           }
00650           if (allconst) {
00651             SH_DEBUG_ASSERT(cp->dest[0].value); // @todo type debugging
00652             ShValueType destValueType = cp->dest[0].value->valueType(); 
00653             ShVariable newconst(new ShVariableNode(SH_CONST, I->dest.size(), destValueType));
00654             for(int i = 0; i < I->dest.size(); ++i) {
00655               newconst.setVariant(cp->dest[i].value, i);
00656             }
00657 #ifdef SH_DEBUG_CONSTPROP
00658             SH_DEBUG_PRINT("Replaced {" << *I << "} with " << newconst);
00659 #endif
00660             *I = ShStatement(I->dest, SH_OP_ASN, newconst);
00661           } else {
00662             // otherwise, do the same for each source field.
00663             for (int s = 0; s < opInfo[I->op].arity; s++) {
00664               if (I->src[s].node()->kind() == SH_CONST) continue;
00665             
00666               ShValueType srcValueType = I->src[s].valueType();
00667               ShVariable newconst(new ShVariableNode(SH_CONST, I->src[s].size(), srcValueType));
00668               bool allconst = true;
00669               for (int i = 0; i < I->src[s].size(); i++) {
00670                 if (cp->src[s][i].state != ConstProp::Cell::CONSTANT) {
00671                   allconst = false;
00672                   break;
00673                 }
00674                 newconst.setVariant(cp->src[s][i].value, i);
00675               }
00676               if (allconst) {
00677 #ifdef SH_DEBUG_CONSTPROP
00678                 SH_DEBUG_PRINT("Replaced {" << *I << "}.src[" << s << "] with " << newconst);
00679 #endif
00680                 I->src[s] = newconst;
00681               }
00682             }
00683           }
00684 
00685           if (!lift_uniforms || allconst) {
00686             //SH_DEBUG_PRINT("Skipping uniform lifting");
00687             continue;
00688           }
00689 
00690           bool alluniform = true;
00691           for (int s = 0; s < opInfo[I->op].arity; s++) {
00692             for (int i = 0; i < I->src[s].size(); i++) {
00693               if (cp->src[s][i].state != ConstProp::Cell::UNIFORM
00694                   && cp->src[s][i].state != ConstProp::Cell::CONSTANT) {
00695                 alluniform = false;
00696                 break;
00697               }
00698             }
00699             if (!alluniform) break;
00700           }
00701           if (!alluniform || I->dest.node()->kind() == SH_OUTPUT
00702               || I->dest.node()->kind() == SH_INOUT) {
00703 #ifdef SH_DEBUG_CONSTPROP
00704             SH_DEBUG_PRINT("Considering " << *I << " for uniform lifting");
00705 #endif          
00706             for (int s = 0; s < opInfo[I->op].arity; s++) {
00707               if (I->src[s].uniform()) {
00708 #ifdef SH_DEBUG_CONSTPROP
00709                 SH_DEBUG_PRINT(*I << ".src[" << s << "] is already a uniform");
00710 #endif
00711                 continue;
00712               }
00713 
00714               bool mixed = false;
00715               bool neg = false;
00716               ConstProp::ValueNum uniform = -1;
00717               std::vector<int> indices;
00718               
00719               for (int i = 0; i < I->src[s].size(); i++) {
00720                 if (cp->src[s][i].state == ConstProp::Cell::UNIFORM) {
00721                   if (uniform < 0) {
00722                     uniform = cp->src[s][i].uniform.valuenum;
00723                     neg = cp->src[s][i].uniform.neg;
00724                     indices.push_back(cp->src[s][i].uniform.index);
00725                   } else {
00726                     if (uniform != cp->src[s][i].uniform.valuenum
00727                         || neg != cp->src[s][i].uniform.neg) {
00728                       mixed = true;
00729                       break;
00730                     }
00731                     indices.push_back(cp->src[s][i].uniform.index);
00732                   }
00733                 } else {
00734                   // Can't lift this, unless we introduce intermediate instructions.
00735                   mixed = true;
00736                 }
00737               }
00738               
00739               if (uniform < 0) {
00740 #ifdef SH_DEBUG_CONSTPROP
00741                 SH_DEBUG_PRINT("{" << *I << "}.src[" << s << "] is not uniform");
00742 #endif
00743                 continue;
00744               }
00745               if (mixed) {
00746 #ifdef SH_DEBUG_CONSTPROP
00747                 SH_DEBUG_PRINT("{" << *I << "}.src[" << s << "] is mixed");
00748 #endif
00749                 continue;
00750               }
00751               ConstProp::Value* value = ConstProp::Value::get(uniform);
00752 
00753 
00754 #ifdef SH_DEBUG_CONSTPROP
00755               SH_DEBUG_PRINT("Lifting {" << *I << "}.src[" << s << "]: " << uniform);
00756 #endif
00757               int srcsize;
00758               if (value->type == ConstProp::Value::NODE) {
00759                 srcsize = value->node->size();
00760               } else {
00761                 srcsize = value->destsize;
00762               }
00763               ShSwizzle swizzle(srcsize, indices.size(), &(*indices.begin()));
00764               // Build a uniform to represent this computation.
00765               ShVariableNodePtr node = build_uniform(value, uniform);
00766               if (node) {
00767                 I->src[s] = ShVariable(node, swizzle, neg);
00768               } else {
00769 #ifdef SH_DEBUG_CONSTPROP
00770                 SH_DEBUG_PRINT("Could not lift " << *I << ".src[" << s << "] for some reason");
00771 #endif
00772               }
00773             }
00774           }
00775         }
00776       }
00777     }
00778 
00779     // Clean up
00780     for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) {
00781       // Remove constant propagation information.
00782       I->destroy_info<ConstProp>();
00783     }
00784   }
00785 
00786   ShVariableNodePtr build_uniform(ConstProp::Value* value,
00787                                   ConstProp::ValueNum valuenum)
00788   {
00789     if (value->type == ConstProp::Value::NODE) {
00790       return value->node;
00791     }
00792     
00793     ShContext::current()->enter(0);
00794     ShVariableNodePtr node = new ShVariableNode(SH_TEMP, value->destsize, value->destValueType);
00795     {
00796     std::ostringstream s;
00797     s << "dep_" << valuenum << "_" << value->name();
00798     node->name(s.str());
00799     }
00800     ShContext::current()->exit();
00801 
00802 #ifdef SH_DEBUG_CONSTPROP
00803     SH_DEBUG_PRINT("Lifting value #" << valuenum);
00804 #endif
00805 
00806     bool broken = false;
00807     
00808     ShProgram prg = SH_BEGIN_PROGRAM("uniform") {
00809       ShStatement stmt(node, value->op);
00810 
00811       for (int i = 0; i < opInfo[value->op].arity; i++) {
00812         stmt.src[i] = compute(value->src[i]);
00813         if (stmt.src[i].null()) broken = true;
00814       }
00815 
00816       ShContext::current()->parsing()->tokenizer.blockList()->addStatement(stmt);
00817     } SH_END;
00818 
00819 #ifdef SH_DEBUG_CONSTPROP
00820     {
00821       std::ostringstream s;
00822       s << "lifted_" << valuenum;
00823       std::string dotfilename(s.str() + ".dot");
00824       std::ofstream dot(dotfilename.c_str());
00825       prg.node()->ctrlGraph->graphvizDump(dot);
00826       dot.close();
00827       std::string cmdline = std::string("dot -Tps -o ") + s.str() + ".ps " + s.str() + ".dot";
00828       system(cmdline.c_str());
00829     }
00830 #endif
00831     
00832     if (broken) return 0;
00833     node->attach(prg.node());
00834 
00835     value->type = ConstProp::Value::NODE;
00836     value->node = node;
00837     
00838     return node;
00839   }
00840 
00841   ShVariable compute(const std::vector<ConstProp::Uniform>& src)
00842   {
00843     ConstProp::ValueNum v = -1;
00844     
00845     bool allsame = true;
00846     bool neg = false;
00847     std::vector<int> indices;
00848     std::vector<ShVariantCPtr> constvals;
00849     for (std::size_t i = 0; i < src.size(); i++) {
00850       if (src[i].constant) {
00851         if (v >= 0) {
00852           allsame = false;
00853           break;
00854         }
00855         constvals.push_back(src[i].constval);
00856       } else {
00857         if (v < 0 && constvals.empty()) {
00858           v = src[i].valuenum;
00859           neg = src[i].neg;
00860         } else {
00861           if (v != src[i].valuenum || neg != src[i].neg || !constvals.empty()) {
00862             allsame = false;
00863           }
00864         }
00865         indices.push_back(src[i].index);
00866       }
00867     }
00868     if (!allsame) {
00869       // Make intermediate variables, combine them together.
00870       ShVariable r = ShVariable(new ShVariableNode(SH_TEMP, src.size(), src[0].valueType()));
00871       
00872       for (std::size_t i = 0; i < src.size(); i++) {
00873         std::vector<ConstProp::Uniform> v;
00874         v.push_back(src[i]);
00875         ShVariable scalar = compute(v);
00876         ShStatement stmt(r(i), SH_OP_ASN, scalar);
00877         ShContext::current()->parsing()->tokenizer.blockList()->addStatement(stmt);
00878       }
00879       return r;
00880     }
00881 
00882     if (!constvals.empty()) {
00883       ShVariable var(new ShVariableNode(SH_CONST, constvals.size(), constvals[0]->valueType()));
00884       for(std::size_t i = 0; i < constvals.size(); ++i) var.setVariant(constvals[i], i);
00885       return var;
00886     }
00887     
00888     ConstProp::Value* value = ConstProp::Value::get(v);
00889     
00890     if (value->type == ConstProp::Value::NODE) {
00891       ShSwizzle swizzle(value->node->size(), indices.size(), &*indices.begin());
00892       return ShVariable(value->node, swizzle, neg);
00893     }
00894     if (value->type == ConstProp::Value::STMT) {
00895       ShVariableNodePtr node = new ShVariableNode(SH_TEMP, value->destsize, value->destValueType);
00896       ShStatement stmt(node, value->op);
00897 
00898       for (int i = 0; i < opInfo[value->op].arity; i++) {
00899         stmt.src[i] = compute(value->src[i]);
00900         if (stmt.src[i].null()) return ShVariable();
00901       }
00902 
00903       ShContext::current()->parsing()->tokenizer.blockList()->addStatement(stmt);
00904       ShSwizzle swizzle(node->size(), indices.size(), &*indices.begin());
00905       return ShVariable(node, swizzle, neg);
00906     }
00907 
00908 #ifdef SH_DEBUG_CONSTPROP
00909     SH_DEBUG_PRINT("Reached invalid point");
00910 #endif
00911 
00912     // Should never reach here.
00913     return ShVariable();
00914   }
00915 
00916   bool lift_uniforms;
00917 };
00918 
00919 }
00920 
00921 namespace SH {
00922 
00923 
00924 void propagate_constants(ShProgram& p)
00925 {
00926   ShCtrlGraphPtr graph = p.node()->ctrlGraph;
00927 
00928   ConstWorkList worklist;
00929 
00930   //ConstProp::Value::clear();
00931   
00932   InitConstProp init(p.node(), worklist);
00933   graph->dfs(init);
00934 
00935 #ifdef SH_DEBUG_CONSTPROP
00936   SH_DEBUG_PRINT("Const Prop Initial Values:");
00937   DumpConstProp dump_pre;
00938   graph->dfs(dump_pre);
00939 #endif
00940 
00941   
00942   while (!worklist.empty()) {
00943     ValueTracking::Def def = worklist.front(); worklist.pop();
00944     ValueTracking* vt = def.stmt->get_info<ValueTracking>();
00945     if (!vt) {
00946 #ifdef SH_DEBUG_CONSTPROP
00947       SH_DEBUG_PRINT(*def.stmt << " on worklist does not have VT information?");
00948 #endif
00949       continue;
00950     }
00951 
00952 
00953     for (ValueTracking::DefUseChain::iterator use = vt->uses[def.index].begin();
00954          use != vt->uses[def.index].end(); ++use) {
00955       if (use->kind != ValueTracking::Use::STMT) continue;
00956       ConstProp* cp = use->stmt->get_info<ConstProp>();
00957       if (!cp) {
00958 #ifdef SH_DEBUG_CONSTPROP
00959         SH_DEBUG_PRINT("Use " << *use->stmt << " does not have const prop information!");
00960 #endif
00961         continue;
00962       }
00963 
00964       ConstProp::Cell cell = cp->src[use->source][use->index];
00965 
00966       ValueTracking* ut = use->stmt->get_info<ValueTracking>();
00967       if (!ut) {
00968         // Should never happen...
00969 #ifdef SH_DEBUG_CONSTPROP
00970         SH_DEBUG_PRINT("Use " << *use->stmt << " on worklist does not have VT information?");
00971 #endif
00972         continue;
00973       }
00974 
00975 #ifdef SH_DEBUG_CONSTPROP
00976       SH_DEBUG_PRINT("Meeting cell for {" << *use->stmt
00977                      << "}.src" << use->source << "[" << use->index << "]");
00978 #endif
00979       
00980       ConstProp::Cell new_cell(ConstProp::Cell::TOP);
00981       
00982       for (ValueTracking::UseDefChain::iterator possdef
00983              = ut->defs[use->source][use->index].begin();
00984            possdef != ut->defs[use->source][use->index].end(); ++possdef) {
00985         ConstProp* dcp = possdef->stmt->get_info<ConstProp>();
00986         if (!dcp) {
00987 #ifdef SH_DEBUG_CONSTPROP
00988           SH_DEBUG_PRINT("Possible def " << *dcp->stmt << " on worklist does not have CP information?");
00989 #endif
00990           continue;
00991         }
00992         
00993         ConstProp::Cell destcell = dcp->dest[possdef->index];
00994 
00995         // If the use is negated, we need to change the cell
00996         if (use->stmt->src[use->source].neg()) {
00997           if (destcell.state == ConstProp::Cell::CONSTANT) {
00998             SH_DEBUG_ASSERT(destcell.value); // @todo type DEBUGGING
00999             destcell.value->negate();
01000           } else if (destcell.state == ConstProp::Cell::UNIFORM) {
01001             destcell.uniform.neg = !destcell.uniform.neg;
01002           }
01003         }
01004 #ifdef SH_DEBUG_CONSTPROP
01005         SH_DEBUG_PRINT("  meet(" << new_cell << ", " << destcell << ") = " <<
01006                        meet(new_cell, destcell));
01007 #endif
01008         new_cell = meet(new_cell, destcell);
01009       }
01010       
01011       if (cell != new_cell) {
01012 #ifdef SH_DEBUG_CONSTPROP
01013         SH_DEBUG_PRINT("  ...replacing cell");
01014 #endif
01015         cp->src[use->source][use->index] = new_cell;
01016         cp->updateDest(worklist);
01017       }
01018     }
01019   }
01020   // Now do something with our glorious newfound information.
01021 
01022   
01023 #ifdef SH_DEBUG_CONSTPROP
01024   ConstProp::Value::dump(std::cerr);
01025 
01026   DumpConstProp dump;
01027   graph->dfs(dump);
01028 #endif
01029   
01030   FinishConstProp finish(p.node()->target().find("gpu:") == 0
01031                          && !ShContext::current()->optimization_disabled("uniform lifting"));
01032   graph->dfs(finish);
01033   
01034 }
01035 
01036 }

Generated on Thu Feb 16 14:51:32 2006 for Sh by  doxygen 1.4.6