Main Page | Modules | Namespace List | Class Hierarchy | Alphabetical List | Class List | File List | Namespace Members | Class Members | File Members | Related Pages

ShTransformer.cpp

00001 // Sh: A GPU metaprogramming language. 00002 // 00003 // Copyright (c) 2003 University of Waterloo Computer Graphics Laboratory 00004 // Project administrator: Michael D. McCool 00005 // Authors: Zheng Qin, Stefanus Du Toit, Kevin Moule, Tiberiu S. Popa, 00006 // Michael D. McCool 00007 // 00008 // This software is provided 'as-is', without any express or implied 00009 // warranty. In no event will the authors be held liable for any damages 00010 // arising from the use of this software. 00011 // 00012 // Permission is granted to anyone to use this software for any purpose, 00013 // including commercial applications, and to alter it and redistribute it 00014 // freely, subject to the following restrictions: 00015 // 00016 // 1. The origin of this software must not be misrepresented; you must 00017 // not claim that you wrote the original software. If you use this 00018 // software in a product, an acknowledgment in the product documentation 00019 // would be appreciated but is not required. 00020 // 00021 // 2. Altered source versions must be plainly marked as such, and must 00022 // not be misrepresented as being the original software. 00023 // 00024 // 3. This notice may not be removed or altered from any source 00025 // distribution. 00027 #include <algorithm> 00028 #include <map> 00029 #include <list> 00030 #include "ShContext.hpp" 00031 #include "ShError.hpp" 00032 #include "ShDebug.hpp" 00033 #include "ShVariableNode.hpp" 00034 #include "ShInternals.hpp" 00035 #include "ShTransformer.hpp" 00036 00037 namespace SH { 00038 00039 ShTransformer::ShTransformer(const ShProgramNodePtr& program) 00040 : m_program(program), m_changed(false) 00041 { 00042 } 00043 00044 ShTransformer::~ShTransformer() 00045 { 00046 } 00047 00048 bool ShTransformer::changed() { return m_changed; } 00049 00050 // Variable splitting, marks statements for which some variable is split 00051 struct VariableSplitter { 00052 00053 VariableSplitter(int maxTuple, ShTransformer::VarSplitMap& splits, bool& changed) 00054 : maxTuple(maxTuple), splits(splits), changed(changed) {} 00055 00056 void operator()(ShCtrlGraphNodePtr node) { 00057 if (!node) return; 00058 ShBasicBlockPtr block = node->block; 00059 if (!block) return; 00060 for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) { 00061 splitVars(*I); 00062 } 00063 } 00064 00065 // this must be called BEFORE running a DFS on the program 00066 // to split temporaries (otherwise the stupid hack marked (#) does not work) 00067 void splitVarList(ShProgramNode::VarList &vars) { 00068 for(ShProgramNode::VarList::iterator I = vars.begin(); 00069 I != vars.end();) { 00070 if(split(*I)) { 00071 // (#) erase the stuff that split added to the end of the var list 00072 vars.resize(vars.size() - splits[*I].size()); 00073 00074 vars.insert(I, splits[*I].begin(), splits[*I].end()); 00075 I = vars.erase(I); 00076 } else ++I; 00077 } 00078 } 00079 00080 void splitVars(ShStatement& stmt) { 00081 stmt.marked = false; 00082 if(stmt.dest.node()) stmt.marked = split(stmt.dest.node()) || stmt.marked; 00083 for(int i = 0; i < 3; ++i) if(stmt.src[i].node()) stmt.marked = split(stmt.src[i].node()) || stmt.marked; 00084 } 00085 00086 // returns true if variable split 00087 bool split(ShVariableNodePtr node) 00088 { 00089 int i, offset; 00090 int n = node->size(); 00091 if(n <= maxTuple ) return false; 00092 else if(splits.count(node) > 0) return true; 00093 if( node->kind() == SH_TEXTURE || node->kind() == SH_STREAM ) { 00094 shError( ShTransformerException( 00095 "Long tuple support is not implemented for textures or streams")); 00096 00097 } 00098 changed = true; 00099 ShTransformer::VarNodeVec &nodeVarNodeVec = splits[node]; 00100 ShVariableNodePtr newNode; 00101 for(offset = 0; n > 0; offset += maxTuple, n -= maxTuple) { 00102 ShProgramNodePtr prev = ShContext::current()->parsing(); 00103 if(node->uniform()) ShContext::current()->exit(); 00104 newNode = new ShVariableNode(node->kind(), n < maxTuple ? n : maxTuple, node->specialType()); 00105 newNode->name(node->name()); 00106 if(node->uniform()) ShContext::current()->enter(prev); 00107 00108 if( node->hasValues() ) { 00109 for(i = 0; i < newNode->size(); ++i){ 00110 newNode->setValue(i, node->getValue(offset + i)); 00111 } 00112 } 00113 nodeVarNodeVec.push_back( newNode ); 00114 } 00115 return true; 00116 } 00117 00118 int maxTuple; 00119 ShTransformer::VarSplitMap &splits; 00120 bool& changed; 00121 }; 00122 00123 struct StatementSplitter { 00124 typedef std::vector<ShVariable> VarVec; 00125 00126 StatementSplitter(int maxTuple, ShTransformer::VarSplitMap &splits, bool& changed) 00127 : maxTuple(maxTuple), splits(splits), changed(changed) {} 00128 00129 void operator()(ShCtrlGraphNodePtr node) { 00130 if (!node) return; 00131 ShBasicBlockPtr block = node->block; 00132 if (!block) return; 00133 for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end();) { 00134 splitStatement(block, I); 00135 } 00136 } 00137 00138 void makeSrcTemps(const ShVariable &v, VarVec &vv, ShBasicBlock::ShStmtList &stmts) { 00139 if( v.size() <= maxTuple && v.node()->size() <= maxTuple ) { 00140 vv.push_back(v); 00141 return; 00142 } 00143 std::size_t i, j, k; 00144 int n; 00145 const ShSwizzle &swiz = v.swizzle(); 00146 00147 // get VarNodeVec for src 00148 ShTransformer::VarNodeVec srcVec; 00149 if(splits.count(v.node()) > 0) { 00150 srcVec = splits[v.node()]; 00151 } else srcVec.push_back(v.node()); 00152 00153 // make and assign to a VarVec for temps 00154 for(i = 0, n = v.size(); n > 0; i += maxTuple, n -= maxTuple) { 00155 std::size_t tsize = (int)n < maxTuple ? n : maxTuple; 00156 //TODO make this smarter so that it reuses variable nodes if it's just reswizlling a src node 00157 // (check that move elimination doesn't do this for us already) 00158 ShVariable tempVar(new ShVariableNode(SH_TEMP, tsize, SH_ATTRIB)); 00159 vv.push_back(tempVar); 00160 00161 int* tempSwiz = new int[tsize]; 00162 int* srcSwiz = new int[tsize]; 00163 int tempSize; 00164 for(j = 0; j < srcVec.size(); ++j) { 00165 tempSize = 0; 00166 for(k = 0; k < tsize; ++k ) { 00167 if(swiz[i + k] / maxTuple == (int)j) { 00168 tempSwiz[tempSize] = k; 00169 srcSwiz[tempSize] = swiz[i + k] % maxTuple; 00170 tempSize++; 00171 } 00172 } 00173 if( tempSize > 0 ) { 00174 ShVariable srcVar(srcVec[j]); 00175 stmts.push_back(ShStatement(tempVar(tempSize, tempSwiz), SH_OP_ASN, srcVar(tempSize, srcSwiz))); 00176 } 00177 } 00178 delete [] tempSwiz; 00179 delete [] srcSwiz; 00180 } 00181 } 00182 00183 // moves the result to the destination based on the destination swizzle 00184 void movToDest(ShTransformer::VarNodeVec &destVec, const ShSwizzle &destSwiz, 00185 const VarVec &resultVec, ShBasicBlock::ShStmtList &stmts) { 00186 std::size_t j; 00187 int k; 00188 int offset = 0; 00189 int* swizd = new int[maxTuple]; 00190 int* swizr = new int[maxTuple]; 00191 int size; 00192 for(VarVec::const_iterator I = resultVec.begin(); I != resultVec.end(); 00193 offset += I->size(), ++I) { 00194 for(j = 0; j < destVec.size(); ++j) { 00195 size = 0; 00196 for(k = 0; k < I->size(); ++k) { 00197 if( destSwiz[k + offset] / maxTuple == (int)j) { 00198 swizd[size] = destSwiz[k + offset] % maxTuple; 00199 swizr[size] = k; 00200 size++; 00201 } 00202 } 00203 if( size > 0 ) { 00204 ShVariable destVar(destVec[j]); 00205 stmts.push_back(ShStatement(destVar(size, swizd), SH_OP_ASN, (*I)(size, swizr))); 00206 } 00207 } 00208 } 00209 delete [] swizd; 00210 delete [] swizr; 00211 } 00212 00213 // works on two assumptions 00214 // 1) special cases for DOT, XPD (and any other future non-componentwise ops) implemented separately 00215 // 2) Everything else is in the form N = [1|N]+ in terms of tuple sizes involved in dest and src 00216 void updateStatement(ShStatement &oldStmt, VarVec srcVec[3], ShBasicBlock::ShStmtList &stmts) { 00217 std::size_t i, j; 00218 ShVariable &dest = oldStmt.dest; 00219 const ShSwizzle &destSwiz = dest.swizzle(); 00220 ShTransformer::VarNodeVec destVec; 00221 VarVec resultVec; 00222 00223 if(splits.count(dest.node()) > 0) { 00224 destVec = splits[dest.node()]; 00225 } else destVec.push_back(dest.node()); 00226 00227 switch(oldStmt.op) { 00228 case SH_OP_DOT: 00229 { 00230 // TODO for large tuples, may want to use another dot to sum up results instead of 00231 // SH_OP_ADD. For now, do naive method 00232 SH_DEBUG_ASSERT(destSwiz.size() == 1); 00233 ShVariable dott = ShVariable(new ShVariableNode(SH_TEMP, 1, SH_ATTRIB)); 00234 ShVariable sumt = ShVariable(new ShVariableNode(SH_TEMP, 1, SH_ATTRIB)); 00235 stmts.push_back(ShStatement(sumt, srcVec[0][0], SH_OP_DOT, srcVec[1][0])); 00236 for(i = 1; i < srcVec[0].size(); ++i) { 00237 stmts.push_back(ShStatement(dott, srcVec[0][i], SH_OP_DOT, srcVec[1][i])); 00238 stmts.push_back(ShStatement(sumt, sumt, SH_OP_ADD, dott)); 00239 } 00240 resultVec.push_back(sumt); 00241 } 00242 break; 00243 case SH_OP_XPD: 00244 { 00245 SH_DEBUG_ASSERT( srcVec[0].size() == 1 && srcVec[0][0].size() == 3 && 00246 srcVec[1].size() == 1 && srcVec[1][0].size() == 3); 00247 ShVariable result = ShVariable(new ShVariableNode(SH_TEMP, 3, SH_ATTRIB)); 00248 stmts.push_back(ShStatement(result, srcVec[0][0], SH_OP_XPD, srcVec[1][0])); 00249 resultVec.push_back(result); 00250 } 00251 break; 00252 00253 default: 00254 { 00255 int maxi = 0; 00256 if( srcVec[1].size() > srcVec[0].size() ) maxi = 1; 00257 if( srcVec[2].size() > srcVec[maxi].size() ) maxi = 2; 00258 for(i = 0; i < srcVec[maxi].size(); ++i) { 00259 ShVariable resultPart(new ShVariableNode(SH_TEMP, srcVec[maxi][i].size(), SH_ATTRIB)); 00260 ShStatement newStmt(resultPart, oldStmt.op); 00261 for(j = 0; j < 3 && !srcVec[j].empty(); ++j) { 00262 newStmt.src[j] = srcVec[j].size() > i ? srcVec[j][i] : srcVec[j][0]; 00263 } 00264 stmts.push_back(newStmt); 00265 resultVec.push_back(resultPart); 00266 } 00267 } 00268 break; 00269 } 00270 movToDest(destVec, destSwiz, resultVec, stmts); 00271 } 00272 00275 void splitStatement(ShBasicBlockPtr block, ShBasicBlock::ShStmtList::iterator &stit) { 00276 ShStatement &stmt = *stit; 00277 int i; 00278 if(!stmt.marked && stmt.dest.size() <= maxTuple) { 00279 for(i = 0; i < 3; ++i) if(stmt.src[i].size() > maxTuple) break; 00280 if(i == 3) { // nothing needs splitting 00281 ++stit; 00282 return; 00283 } 00284 } 00285 changed = true; 00286 ShBasicBlock::ShStmtList newStmts; 00287 VarVec srcVec[3]; 00288 00289 for(i = 0; i < 3; ++i) if(stmt.src[i].node()) makeSrcTemps(stmt.src[i], srcVec[i], newStmts); 00290 updateStatement(stmt, srcVec, newStmts); 00291 00292 // remove old statmeent and splice in new statements 00293 stit = block->erase(stit); 00294 block->splice(stit, newStmts); 00295 } 00296 00297 int maxTuple; 00298 ShTransformer::VarSplitMap &splits; 00299 bool& changed; 00300 }; 00301 00302 void ShTransformer::splitTuples(int maxTuple, ShTransformer::VarSplitMap &splits) { 00303 SH_DEBUG_ASSERT(maxTuple > 0); 00304 00305 VariableSplitter vs(maxTuple, splits, m_changed); 00306 vs.splitVarList(m_program->inputs); 00307 vs.splitVarList(m_program->outputs); 00308 m_program->ctrlGraph->dfs(vs); 00309 00310 00311 StatementSplitter ss(maxTuple, splits, m_changed); 00312 m_program->ctrlGraph->dfs(ss); 00313 } 00314 00315 static int id = 0; 00316 00317 // Output Convertion to temporaries 00318 struct InputOutputConvertor { 00319 InputOutputConvertor(const ShProgramNodePtr& program, 00320 ShVariableReplacer::VarMap &varMap, bool& changed) 00321 : m_program(program), m_varMap( varMap ), m_changed(changed), m_id(++id) 00322 {} 00323 00324 void operator()(ShCtrlGraphNodePtr node) { 00325 if (!node) return; 00326 ShBasicBlockPtr block = node->block; 00327 if (!block) return; 00328 for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) { 00329 convertIO(*I); 00330 } 00331 } 00332 00333 /* Convert all INOUT nodes that appear in a VarList (use std::for_each with this object) 00334 * (currently InOuts are always converted) */ 00335 void operator()(ShVariableNodePtr node) { 00336 if (node->kind() != SH_INOUT || m_varMap.count(node) > 0) return; 00337 m_varMap[node] = dupNode(node); 00338 } 00339 00340 // dup that works only on nodes without values (inputs, outputs fall into this category) 00341 ShVariableNodePtr dupNode(ShVariableNodePtr node, ShBindingType newBinding = SH_TEMP) { 00342 ShVariableNodePtr result( new ShVariableNode(newBinding, 00343 node->size(), node->specialType())); 00344 if (node->has_name()) { 00345 // TODO: should really copy all meta information here. 00346 result->name(node->name()); 00347 } 00348 return result; 00349 } 00350 00351 // Convert inputs, outputs only when they appear in incompatible locations 00352 // (inputs used as dest, outputs used as src) 00353 void convertIO(ShStatement& stmt) 00354 { 00355 if(!stmt.dest.null()) { 00356 const ShVariableNodePtr &oldNode = stmt.dest.node(); 00357 if(oldNode->kind() == SH_INPUT) { 00358 if(m_varMap.count(oldNode) == 0) { 00359 m_varMap[oldNode] = dupNode(oldNode); 00360 } 00361 } 00362 } 00363 for(int i = 0; i < 3; ++i) { 00364 if(!stmt.src[i].null()) { 00365 const ShVariableNodePtr &oldNode = stmt.src[i].node(); 00366 if(oldNode->kind() == SH_OUTPUT) { 00367 if(m_varMap.count(oldNode) == 0) { 00368 m_varMap[oldNode] = dupNode(oldNode); 00369 } 00370 } 00371 } 00372 } 00373 } 00374 00375 void updateGraph() { 00376 if(m_varMap.empty()) return; 00377 m_changed = true; 00378 00379 // create block after exit 00380 ShCtrlGraphNodePtr oldExit = m_program->ctrlGraph->appendExit(); 00381 ShCtrlGraphNodePtr oldEntry = m_program->ctrlGraph->prependEntry(); 00382 00383 for(ShVariableReplacer::VarMap::const_iterator it = m_varMap.begin(); it != m_varMap.end(); ++it) { 00384 // assign temporary to output 00385 ShVariableNodePtr oldNode = it->first; 00386 if(oldNode->kind() == SH_OUTPUT) { 00387 oldExit->block->addStatement(ShStatement( 00388 ShVariable(oldNode), SH_OP_ASN, ShVariable(it->second))); 00389 } else if(oldNode->kind() == SH_INPUT) { 00390 oldEntry->block->addStatement(ShStatement( 00391 ShVariable(it->second), SH_OP_ASN, ShVariable(oldNode))); 00392 } else if(oldNode->kind() == SH_INOUT) { 00393 // replace INOUT nodes in input/output lists with INPUT and OUTPUT nodes 00394 ShVariableNodePtr newInNode(dupNode(oldNode, SH_INPUT)); 00395 ShVariableNodePtr newOutNode(dupNode(oldNode, SH_OUTPUT)); 00396 00397 std::replace(m_program->inputs.begin(), m_program->inputs.end(), 00398 oldNode, newInNode); 00399 m_program->inputs.pop_back(); 00400 00401 std::replace(m_program->outputs.begin(), m_program->outputs.end(), 00402 oldNode, newOutNode); 00403 m_program->outputs.pop_back(); 00404 00405 // add mov statements to/from temporary 00406 oldEntry->block->addStatement(ShStatement( 00407 ShVariable(it->second), SH_OP_ASN, ShVariable(newInNode))); 00408 oldExit->block->addStatement(ShStatement( 00409 ShVariable(newOutNode), SH_OP_ASN, ShVariable(it->second))); 00410 } 00411 } 00412 } 00413 00414 ShProgramNodePtr m_program; 00415 ShVariableReplacer::VarMap &m_varMap; // maps from outputs used as srcs in computation to their temporary variables 00416 bool& m_changed; 00417 int m_id; 00418 }; 00419 00420 void ShTransformer::convertInputOutput() 00421 { 00422 ShVariableReplacer::VarMap varMap; // maps from outputs used as srcs in computation to their temporary variables 00423 00424 InputOutputConvertor ioc(m_program, varMap, m_changed); 00425 std::for_each(m_program->inputs.begin(), m_program->inputs.end(), ioc); 00426 std::for_each(m_program->outputs.begin(), m_program->outputs.end(), ioc); 00427 m_program->ctrlGraph->dfs(ioc); 00428 00429 ShVariableReplacer vr(varMap); 00430 m_program->ctrlGraph->dfs(vr); 00431 00432 ioc.updateGraph(); 00433 } 00434 00435 struct TextureLookupConverter { 00436 TextureLookupConverter() : changed(false) {} 00437 00438 void operator()(const ShCtrlGraphNodePtr& node) 00439 { 00440 if (!node) return; 00441 ShBasicBlockPtr block = node->block; 00442 if (!block) return; 00443 for (ShBasicBlock::ShStmtList::iterator I = block->begin(); I != block->end(); ++I) { 00444 convert(block, I); 00445 } 00446 } 00447 00448 void convert(ShBasicBlockPtr block, ShBasicBlock::ShStmtList::iterator& I) 00449 { 00450 const ShStatement& stmt = *I; 00451 if (stmt.op != SH_OP_TEX && stmt.op != SH_OP_TEXI) return; 00452 ShTextureNodePtr tn = shref_dynamic_cast<ShTextureNode>(stmt.src[0].node()); 00453 00454 ShBasicBlock::ShStmtList newStmts; 00455 00456 if (!tn) { SH_DEBUG_ERROR("TEX Instruction from non-texture"); return; } 00457 if (stmt.op == SH_OP_TEX && tn->dims() == SH_TEXTURE_RECT) { 00458 ShVariable tc(new ShVariableNode(SH_TEMP, tn->texSizeVar().size())); 00459 newStmts.push_back(ShStatement(tc, stmt.src[1], SH_OP_MUL, tn->texSizeVar())); 00460 newStmts.push_back(ShStatement(stmt.dest, stmt.src[0], SH_OP_TEXI, tc)); 00461 } else if (stmt.op == SH_OP_TEXI && tn->dims() != SH_TEXTURE_RECT) { 00462 ShVariable tc(new ShVariableNode(SH_TEMP, tn->texSizeVar().size())); 00463 newStmts.push_back(ShStatement(tc, stmt.src[1], SH_OP_DIV, tn->texSizeVar())); 00464 newStmts.push_back(ShStatement(stmt.dest, stmt.src[0], SH_OP_TEX, tc)); 00465 } else { 00466 return; 00467 } 00468 I = block->erase(I); // I is pointing one past its original value now 00469 block->splice(I, newStmts); 00470 I--; // Make I point to its original value, it will be inc'd later. 00471 changed = true; 00472 return; 00473 } 00474 00475 bool changed; 00476 }; 00477 00478 void ShTransformer::convertTextureLookups() 00479 { 00480 TextureLookupConverter conv; 00481 m_program->ctrlGraph->dfs(conv); 00482 if (conv.changed) m_changed = true; 00483 } 00484 00485 } 00486

Generated on Mon Oct 18 14:17:40 2004 for Sh by doxygen 1.3.7