ShLibMiscImpl.hpp

00001 // Sh: A GPU metaprogramming language.
00002 //
00003 // Copyright 2003-2006 Serious Hack Inc.
00004 // 
00005 // This library is free software; you can redistribute it and/or
00006 // modify it under the terms of the GNU Lesser General Public
00007 // License as published by the Free Software Foundation; either
00008 // version 2.1 of the License, or (at your option) any later version.
00009 //
00010 // This library is distributed in the hope that it will be useful,
00011 // but WITHOUT ANY WARRANTY; without even the implied warranty of
00012 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00013 // Lesser General Public License for more details.
00014 //
00015 // You should have received a copy of the GNU Lesser General Public
00016 // License along with this library; if not, write to the Free Software
00017 // Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, 
00018 // MA  02110-1301, USA
00020 #ifndef SHLIBMISCIMPL_HPP
00021 #define SHLIBMISCIMPL_HPP
00022 
00023 #include "ShLibMisc.hpp"
00024 #include "ShInstructions.hpp"
00025 #include "ShProgram.hpp"
00026 
00027 namespace SH {
00028 
00029 template<int M, int N, typename T> 
00030 ShGeneric<M, T> cast(const ShGeneric<N, T>& a)
00031 {
00032   int copySize = std::min(M, N);
00033   ShAttrib<M, SH_TEMP, T> result;
00034 
00035   int* indices = new int[copySize];
00036   for(int i = 0; i < copySize; ++i) indices[i] = i;
00037   if(M < N) {
00038     result = a.template swiz<M>(indices);
00039   } else if( M > N ) {
00040     for (int i=0; i < (M - N); i++) {
00041       result[M - 1 - i] = static_cast<T>(0);
00042     }
00043     result.template swiz<N>(indices) = a;
00044   } else { // M == N
00045     shASN(result, a);
00046   }
00047   delete [] indices;
00048   return result;
00049 }
00050 
00051 template<int M> 
00052 inline
00053 ShGeneric<M, double> cast(double a)
00054 {
00055   return cast<M>(ShAttrib<1, SH_CONST, double>(a));
00056 }
00057 
00058 template<int M, int N, typename T> 
00059 ShGeneric<M, T> fillcast(const ShGeneric<N, T>& a)
00060 {
00061   if( M <= N ) return cast<M>(a);
00062   int indices[M];
00063   for(int i = 0; i < M; ++i) indices[i] = i >= N ? N - 1 : i;
00064   return a.template swiz<M>(indices);
00065 }
00066 
00067 template<int M> 
00068 inline
00069 ShGeneric<M, double> fillcast(double a)
00070 {
00071   return fillcast<M>(ShAttrib<1, SH_CONST, double>(a));
00072 }
00073 
00074 template<int M, int N, typename T1, typename T2> 
00075 ShGeneric<M+N, CT1T2> join(const ShGeneric<M, T1>& a, const ShGeneric<N, T2>& b)
00076 {
00077   int indices[M+N];
00078   for(int i = 0; i < M+N; ++i) indices[i] = i; 
00079   ShAttrib<M+N, SH_TEMP, CT1T2> result;
00080   result.template swiz<M>(indices) = a;
00081   result.template swiz<N>(indices + M) = b;
00082   return result;
00083 }
00084 
00085 template<int M, typename T> 
00086 ShGeneric<M+1, T> join(const T& a, const ShGeneric<M, T>& b)
00087 {
00088   return join(ShAttrib<1, SH_CONST, T>(a), b);
00089 }
00090 
00091 template<int M, typename T> 
00092 ShGeneric<M+1, T> join(const ShGeneric<M, T>& a, const T& b)
00093 {
00094   return join(a, ShAttrib<1, SH_CONST, T>(b));
00095 }
00096 
00097 template<int M, int N, int O, typename T1, typename T2, typename T3> 
00098 ShGeneric<M+N+O, CT1T2T3> join(const ShGeneric<M, T1>& a, 
00099                                const ShGeneric<N, T2> &b, 
00100                                const ShGeneric<O, T3> &c)
00101 {
00102   int indices[M+N+O];
00103   for(int i = 0; i < M+N+O; ++i) indices[i] = i; 
00104   ShAttrib<M+N+O, SH_TEMP, CT1T2T3> result;
00105   result.template swiz<M>(indices) = a;
00106   result.template swiz<N>(indices + M) = b;
00107   result.template swiz<N>(indices + M + N) = c;
00108   return result;
00109 }
00110 
00111 template<int M, int N, int O, int P, typename T1, typename T2, typename T3, typename T4> 
00112 ShGeneric<M+N+O+P, CT1T2T3T4> join(const ShGeneric<M, T1>& a, 
00113                                    const ShGeneric<N, T2> &b, 
00114                                    const ShGeneric<O, T3> &c, 
00115                                    const ShGeneric<P, T4> &d)
00116 {
00117   int indices[M+N+O+P];
00118   for(int i = 0; i < M+N+O+P; ++i) indices[i] = i; 
00119   ShAttrib<M+N+O+P, SH_TEMP, CT1T2T3T4> result;
00120   result.template swiz<M>(indices) = a;
00121   result.template swiz<N>(indices + M) = b;
00122   result.template swiz<N>(indices + M + N) = c;
00123   result.template swiz<N>(indices + M + N + O) = d;
00124   return result;
00125 }
00126 
00127 template<int N, typename T>
00128 inline
00129 void discard(const ShGeneric<N, T>& c)
00130 {
00131   shKIL(c);
00132 }
00133 
00134 template<int N, typename T>
00135 inline
00136 void kill(const ShGeneric<N, T>& c)
00137 {
00138   discard(c);
00139 }
00140 
00141 template<int S, typename VarType>
00142 void groupsort(VarType v[])
00143 {
00144   const int N = VarType::typesize;
00145   typedef typename VarType::storage_type T;
00146 
00147   const int NE = (N + 1) / 2; // number of even elements
00148   const int NO = N / 2; // number of odd elements
00149   const int NU = NO; // number of components to compare for (2i, 2i+1) comparisons
00150   const int ND = NE - 1; // number of componnets to compare for (2i, 2i-1) comparisons
00151 
00152   int i, j;
00153   // hold even/odd temps and condition code for (2i, 2i+1) "up" and (2i, 2i-1) "down" comparisons 
00154 
00155   ShAttrib<NU, SH_TEMP, T> eu, ou, ccu; 
00156   ShAttrib<ND, SH_TEMP, T> ed, od, ccd; 
00157 
00158   // even and odd swizzle (elms 0..NE-1 are the "even" subsequence, NE..N-1 "odd")
00159   int eswiz[NE], oswiz[NO]; 
00160   for(i = 0; i < NE; ++i) eswiz[i] = i;
00161   for(i = 0; i < NO; ++i) oswiz[i] = NE + i;
00162 
00163   for(i = 0; i < NE; ++i) { 
00164     // compare 2i, 2i+1
00165     eu = v[0].template swiz<NU>(eswiz);
00166     ou = v[0].template swiz<NU>(oswiz);
00167     if (S > 1) ccu = eu < ou; 
00168     v[0].template swiz<NU>(eswiz) = min(eu, ou); 
00169     v[0].template swiz<NU>(oswiz) = max(eu, ou); 
00170 
00171     for(j = 1; j < S; ++j) {
00172       eu = v[j].template swiz<NU>(eswiz);
00173       ou = v[j].template swiz<NU>(oswiz);
00174       v[j].template swiz<NU>(eswiz) = cond(ccu, eu, ou); 
00175       v[j].template swiz<NU>(oswiz) = cond(ccu, ou, eu); 
00176     }
00177 
00178     // compare 2i, 2i-1
00179     ed = v[0].template swiz<ND>(eswiz + 1);
00180     od = v[0].template swiz<ND>(oswiz);
00181     if (S > 1) ccd = ed > od; 
00182     v[0].template swiz<ND>(eswiz + 1) = max(ed, od);
00183     v[0].template swiz<ND>(oswiz) = min(ed, od);
00184 
00185     for(j = 1; j < S; ++j) {
00186       ed = v[j].template swiz<ND>(eswiz + 1);
00187       od = v[j].template swiz<ND>(oswiz);
00188       v[j].template swiz<ND>(eswiz + 1) = cond(ccd, ed, od); 
00189       v[j].template swiz<ND>(oswiz) = cond(ccd, od, ed); 
00190     }
00191   }
00192 
00193   // reswizzle "even" to 0, 2, 4,... "odd" to 1, 3, 5, ..
00194   int resultEswiz[NE], resultOswiz[NO]; 
00195   for(i = 0; i < NE; ++i) resultEswiz[i] = i * 2;
00196   for(i = 0; i < NO; ++i) resultOswiz[i] = i * 2 + 1; 
00197   for(i = 0; i < S; ++i) {
00198     ShAttrib<NE, SH_TEMP, T> evens = v[i].template swiz<NE>(eswiz);
00199     v[i].template swiz<NO>(resultOswiz) = v[i].template swiz<NO>(oswiz);
00200     v[i].template swiz<NE>(resultEswiz) = evens;
00201   }
00202 }
00203 
00204 template<int N, typename T> 
00205 ShGeneric<N, T> sort(const ShGeneric<N, T>& a)
00206 {
00207   ShAttrib<N, SH_TEMP, T> result(a);
00208   groupsort<1>(&result);
00209   return result;
00210 }
00211 
00212 template<typename T>
00213 inline
00214 ShProgram freeze(const ShProgram& p,
00215                  const T& uniform)
00216 {
00217   return (p >> uniform) << (T::ConstType)(uniform);
00218 }
00219 
00220 template<int N, int M, typename T1, typename T2>
00221 ShGeneric<N, CT1T2> poly(const ShGeneric<N, T1>& a, const ShGeneric<M, T2>& b)
00222 {
00223   ShAttrib<N, SH_TEMP, CT1T2> t;
00224   for (int i=0; i < N; i++) {
00225     // Uses Horner's rule
00226     t[i] = b[M - 1];
00227     for (int j = M - 1; j > 0; j--) {
00228       t[i] = mad(a[i], t[i], b[j-1]);
00229     }
00230   }
00231   return t;
00232 }
00233 
00234 }
00235 
00236 #endif

Generated on Thu Feb 16 14:51:34 2006 for Sh by  doxygen 1.4.6