Attachment 'automatic_dr1.py'

Download

   1 def quick(l):
   2     l = list(l)
   3     exchanges = []
   4     n = len(l)
   5     i = 0
   6     while i < n:
   7         if (i>0) and (l[i]<l[i-1]):
   8             new_place = min([u for u in range(i) if l[u]>l[i]  ])
   9             tt = l[new_place]
  10             l[new_place] = l[i]
  11             l[i] = tt
  12             exchanges.append( ( new_place,i ) )
  13         else:
  14             i += 1
  15     return [l,exchanges]
  16 
  17 class Sum(list):
  18     '''A sum is a list of terms with integer coefficients.  Terms are arbitrary object with the only restriction that they must be hashable.'''
  19 
  20     def __init__(self, elts, coeffs=None):
  21         if coeffs == None:
  22             coeffs = [1 for e in elts]
  23         else:
  24             coeffs = coeffs
  25         self.coeffs = []
  26         for i,e in enumerate(elts):
  27             if isinstance(e,Sum):
  28                 for j,p in enumerate(e):
  29                     self.append(p)
  30                     self.coeffs.append( coeffs[i] * e.coeffs[j] )
  31             else:
  32                 self.append(e)
  33                 self.coeffs.append(coeffs[i])
  34 
  35 
  36     def diff(self,i,n=1):
  37         if n == 1:
  38             return Sum(  [ e.diff(i) for e in self ] , self.coeffs ).collect()
  39         else:
  40             return self.diff(i,n-1).diff(i)
  41 
  42     def collect(self):
  43         new_elements = []
  44         new_coeffs = []
  45         for i,e in enumerate(self):
  46             if e in new_elements:
  47                 k = new_elements.index(e)
  48                 new_coeffs[k] += self.coeffs[i]
  49             else:
  50                 new_elements.append(e)
  51                 new_coeffs.append(self.coeffs[i])
  52         return Sum(new_elements,new_coeffs)
  53 
  54 
  55 
  56     def _latex_(self):
  57         terms = []
  58         for i,e in enumerate(self):
  59             if self.coeffs[i] == 1:
  60                 terms.append( e._latex_() )
  61             else:
  62                 terms.append( str(self.coeffs[i]) + ' ' + e._latex_() )
  63         s = str.join(' + ', terms)
  64         return s
  65 
  66 class Bracket(list):
  67     '''This class represent a multilinear function. Its content is stored as a list of tuples.
  68     For instance, "[ ('a','e'),('e',) ]" would represent "f^{2}[V_ae,V_e]".
  69     Derivatives with respect to any variable 'x' is computed by added successively 'x' to each of the tuples, and to add a new tuple containing only 'x'. Hence the result is a list of brackets. Each bracket of the resulting sum is also reordered in lexical order w.r.t. to its tuples.
  70     '''
  71 
  72     def diff(self,i):
  73         new_brackets = [ ]
  74         for k in range(len(self)):
  75             n = Bracket(self)
  76             n[k] = n[k] + (i,)
  77             n.reorder()
  78             new_brackets.append(n)
  79         b = Bracket( tuple( self + [(i,)] ) )
  80         b.reorder()
  81         new_brackets.append( b )
  82         return Sum(new_brackets).collect()
  83 
  84     def reorder(self):
  85         for i,e in enumerate(self):
  86             self[i] = tuple( sorted(e) )
  87         self.sort()
  88         return None
  89 
  90     def is_ascending(self):
  91         l = []
  92         [l.extend(list(a)) for a in self]
  93         for i in range(1,len(l)):
  94             if l[i] < l[i-1]:
  95                 return False
  96         return True
  97 
  98     def order_swaps(self):
  99         l = []
 100         [l.extend(list(a)) for a in self]
 101         [junk,swaps] = quick(l)
 102         return swaps
 103 
 104 
 105 
 106     def __hash__(self):
 107         return tuple(self).__hash__()
 108 
 109     def flat_indices(self):
 110         l = []
 111         for e in self:
 112             l.extend( list(e) )
 113         return l
 114 
 115     def _latex_(self):
 116         if True:
 117             n = len(self)
 118             s = "f^{{({n})}}[{content}]"
 119             content = []
 120             for el in self:
 121                 index = str.join(' ',  [ str(e) for e in el ]   )
 122                 content.append( 'V_{{{index}}}'.format(index=index) )
 123             content = str.join(',',content)
 124             s = s.format(n=n,content=content)
 125             s = s.replace('s','\\sigma')
 126             return s
 127 
 128 class BracketG(Bracket):
 129 
 130     def diff(self,i):
 131         new_brackets = [ ]
 132         for k in range(len(self)):
 133             if self[k] not in ( ('t',), ('y',), ('z',) ):
 134                 n = BracketG(self)
 135                 n[k] = n[k] + (i,)
 136                 n.reorder()
 137                 new_brackets.append(n)
 138         if i not in ('p','s'):
 139             b = BracketG( tuple( self + [(i,)] ) )
 140             b.reorder()
 141             new_brackets.append( b )
 142         elif i == 'p':
 143             b = BracketG( tuple( self + [(i,)] ) )
 144             b1 = BracketG( tuple( self + [('t',)] ) )
 145             b.reorder()
 146             b1.reorder()
 147             new_brackets.extend( [b,b1] )
 148         elif i == 's':
 149             b = BracketG( tuple( self + [(i,)] ) )
 150             b1 = BracketG( tuple( self + [('y',)] ) )
 151             b2 = BracketG( tuple( self + [('z',)] ) )
 152             b.reorder()
 153             b1.reorder()
 154             b2.reorder()
 155             new_brackets.append( b )
 156             new_brackets.append( b1 )
 157             new_brackets.append( b2 )
 158         return Sum(new_brackets)
 159 
 160     def _latex_(self,type='f'):
 161         if True:
 162             n = len(self)
 163             s = 'g^{{( {n} )}}_{{ {ind} }}[ {content} ]'
 164             ind = []
 165             for i in self:
 166                 if i == ('t',):
 167                     ind.append('p')
 168                 elif i == ('y',):
 169                     ind.append('u')
 170                 elif i == ('z',):
 171                     ind.append('\\sigma')
 172                 else:
 173                     ind.append('a')
 174             ind = str.join('',ind)
 175             content = []
 176             for i in self:
 177                 if i == ('t',):
 178                     content.append('I_p')
 179                 elif i == ('y',):
 180                     content.append('q')
 181                 elif i == ('z',):
 182                     content.append('I_s')
 183                 else:
 184                     content.append( 'g_{{ {0} }}'.format( str.join('', i) ) )
 185             content = str.join(',',content)
 186             content = content.replace('s','\\sigma ').replace('q','\\epsilon')
 187             return s.format(n=n,ind=ind,content=content)
 188         
 189     def stoch_order(self):
 190         '''Counts the number of epsilons'''
 191         return sum( [ el.count('y') for el in self] )
 192 
 193 def kill_s(sum_expr):
 194     '''Simplifies expressions using g_s=0.'''
 195     bras = []
 196     coeffs = []
 197     for i,expr in enumerate(sum_expr):
 198         coeff = sum_expr.coeffs[i]
 199         valid = True
 200         for e in expr:
 201             if e.count('s') == 1:
 202                 valid = False
 203                 pass
 204         if expr.count(('z',)) == 1:
 205             valid = False
 206         if valid:
 207             bras.append(expr)
 208             coeffs.append(coeff)
 209     return Sum(bras,coeffs)
 210 
 211 def get_nondecreasing_indices(args,order):
 212     indices = [[(i,) for i in args]]
 213     for o in range(2,order+1):
 214         pindices = indices[-1]
 215         cindices = []
 216         for e in args:
 217             for p in pindices:
 218                 if p[-1] <= e:
 219                     cindices += [ p + (e,) ]
 220         indices.append(cindices)
 221     return indices
 222 
 223 def get_K(expr):
 224     resp =  Sum([ h for h in expr if len(h) > 1])
 225     other = Sum([ h for h in expr if len(h) == 1])
 226     return [resp,other]
 227 
 228 def get_L(expr):
 229     n = sum([len(el) for el in expr[0]])
 230     special = BracketG( [('a',)] ) * n
 231     resp =  Sum([ h for h in expr if (len(h)>1) and (h!=special) ])
 232     other = Sum([ h for h in expr if (len(h) == 1) or h==special])
 233     return [resp,other]
 234 
 235 def compute_derivatives(n,vars,):
 236     eqs_f = [ Sum([Bracket([])]) ]
 237     eqs_f.append( [ eqs_f[-1].diff(v) for v in vars ] )
 238     eqs_g = [ Sum([BracketG([])]) ]
 239     eqs_g.append( [ eqs_g[-1].diff(v) for v in vars ] )
 240     ndt = [[],[(v,) for v in vars]]
 241     for i in range(n-1):
 242         new_eqs_f = []
 243         new_eqs_g = []
 244         new_ndt = []
 245         o_eqs_f = eqs_f[-1]
 246         o_eqs_g = eqs_g[-1]
 247         o_ndt = ndt[-1]
 248         for i,e in enumerate(o_eqs_f):
 249             ind = o_ndt[i]
 250             for v in vars:
 251                 if v>= ind[-1]:
 252                     new_eqs_f.append( e.diff(v).collect() )
 253         for i,e in enumerate(o_eqs_g):
 254             ind = o_ndt[i]
 255             for v in vars:
 256                 if v>= ind[-1]:
 257                     new_eqs_g.append( e.diff(v).collect() )
 258                     new_ndt.append( ind + (v,) )
 259         eqs_f.append(new_eqs_f)
 260         eqs_g.append(new_eqs_g)
 261         ndt.append(new_ndt)
 262     return [eqs_f,eqs_g,ndt]
 263 
 264 ###############################################
 265 ###############################################
 266 #                                             #
 267 # this part of the code is specific to python #
 268 #                                             #
 269 ###############################################
 270 ###############################################
 271 
 272 def compile_bracket(bra):
 273     n = len(bra.flat_indices())    
 274     if isinstance(bra,BracketG):
 275         content = []
 276         for el in bra:
 277             if el == ('t',):
 278                 content += ['I_p']
 279             elif el == ('y',):
 280                 content  += ['I_e']
 281             elif el == ('z',):
 282                 content += ['I_s']
 283             else:
 284                 content +=  ["g_{0}".format( str.join('',el) )]
 285         indices = []
 286         for el in bra:
 287             if el == ('t',):
 288                 indices += 'p'
 289             elif el == ('y',):
 290                 indices += 'e'
 291             elif el == ('z',):
 292                 indices += 's'
 293             else:
 294                 indices += 'a'
 295         indices = str.join('',indices)
 296         txt = 'mdot(g_{0},[{1}])'.format( indices, str.join(",",content))
 297     elif isinstance(bra,Bracket):
 298         content = [ "V_{0}".format( str.join("",e) ) for e in bra]
 299         txt = 'mdot(f_{0},[{1}])'.format(n, str.join(",",content))
 300     if not bra.is_ascending():
 301         l = []
 302         for t in bra:
 303             for e in t:
 304                 l.append(e)
 305         [nl, exchanges] = quick(l)
 306         swp = [".swapaxes({0},{1})".format(e[0],e[1]) for e in exchanges]
 307         txt += str.join('',swp)
 308     return txt
 309 
 310 def compile_expectation(expr):
 311     if isinstance(expr, BracketG):
 312         bra = expr
 313         ns = bra.stoch_order()
 314         if ns == 0:
 315             return compile_bracket(bra)
 316         elif ns == 1:
 317             return 'np.zeros((n_v,{0}))'.format(str.join(',', ['n_'+i for i in bra.flat_indices() if i != 'y'] ))
 318         inds = bra.flat_indices()
 319         eps_inds = [str(n) for n,i in enumerate(inds) if inds[n] == 'y' ]
 320         sig_indices = [str(i) for i in tuple(range(ns))]
 321         resp = compile_bracket(bra)        
 322         resp = 'np.tensordot( {0}, Sigma_{1}, axes=( ({2}), ({3}) ) )'.format(resp, n+1, str.join(',', eps_inds), str.join(',',sig_indices) )
 323         return resp
 324     if isinstance(expr,Sum):
 325         resp = ''
 326         for n,term in enumerate(expr):
 327             c = expr.coeffs[n]
 328             if c == 1:
 329                 resp += ' + {0}'.format( compile_expectation(term) )
 330             elif c >= 1:
 331                 resp += ' + {0}*{1}'.format(c,compile_expectation(term) )
 332         resp = resp.lstrip(' +')
 333 
 334         return resp
 335         
 336 def compile_sum(s):
 337 
 338     resp = ''
 339     for n,bra in enumerate(s):
 340         c = s.coeffs[n]
 341         if c == 1:
 342             resp += ' + {0}'.format(compile_bracket(bra))
 343         elif c >= 1:
 344             resp += ' + {0}*{1}'.format(c,compile_bracket(bra))
 345     resp = resp.lstrip(' +')
 346     return resp
 347 
 348 def write_order(order,vars,eqs_f,eqs_g,ndt):
 349     code = "{br}Computing order {order}\n\n".format(br='\n#' + 10*'-' ,order=order)
 350     code += 'order = {0}\n'.format(order)
 351     for k in range(len(eqs_f[order])):
 352         eq_f = eqs_f[order][k]
 353         eq_g = eqs_g[order][k]
 354         inds = ndt[order][k]
 355         code += '\n#--- Computing derivatives {0}\n\n'.format(inds)
 356         [K_eq,K_other] = get_K(eq_f)
 357         K_name = 'K_'+str.join('',inds)
 358         lhs = K_name
 359         code += '{lhs} = {rhs}\n'.format(lhs=lhs,rhs=compile_sum(K_eq))
 360         [L_eq,L_other] = get_L(eq_g)
 361         L_name = 'L_'+str.join('',inds)
 362         lhs = L_name
 363         code += '{lhs} = {rhs}\n\n'.format(lhs=lhs,rhs=compile_sum(L_eq))
 364         if k == 0:
 365             code += '#We need to solve the infamous sylvester equation\n'
 366             code += '#A = f_a + sdot(f_d,g_a)\n'
 367             code += '#B = f_d\n'
 368             code += '#C = g_a\n'
 369             code += 'D = {K} + sdot(f_d,{L})\n'.format(K=K_name,L=L_name)
 370             code += 'g_{sub} = solve_sylvester(A,B,C,D)\n'.format(sub = str.join('',inds))
 371         else:
 372             code += '#We solve A*X + const = 0\n'
 373             code += 'const = sdot(f_d,{L}) + {K}\n'.format(L=L_name,K=K_name)
 374             code += 'g_{sub} = - sdot(A_inv, const)\n'.format(sub=str.join('',inds))
 375         if True:
 376             sizes = str.join(',', ['n_'+v for v in vars] )
 377             code += '\nif order < max_order:\n'
 378             code += '    Y = {L}{other}\n'.format(L = L_name, other = compile_sum(L_other) )
 379             code += '    Z = g_{sub}\n'.format(sub=str.join('',inds))
 380             code += '    V_{sub} = build_V(Y,Z,({sizes}))\n'.format(sub = str.join('',inds), sizes=sizes)
 381 #            code += 'end\n'
 382     return code
 383 
 384 def gen_code(max_order,vars):
 385     [eqs_f,eqs_g,ndt] = compute_derivatives(max_order,vars)
 386 
 387     code = ''
 388 #    code = 'max_order = {0}\n'.format(max_order)
 389     for o in range(2,max_order+1):
 390         code += write_order(o,vars,eqs_f,eqs_g,ndt)
 391     return code
 392 
 393 def sort_order(expr):
 394     orders = [term.stoch_order() for term in expr]
 395     print orders
 396     if len(orders) == 0:
 397         return []
 398     resp = []
 399     for ord in range(0,max(orders)+1):
 400         resp.append( [term for n,term in enumerate(expr) if orders[n] == ord] )
 401     return resp
 402 
 403 
 404 if __name__ == '__main__':
 405     #code = gen_code(4,('a','e'))
 406 
 407     [eqs_f,eqs_g,ndt] = compute_derivatives(5,('a','e','s'))
 408 
 409     for order in range(len(eqs_g)):
 410         for eq in eqs_g[order]:
 411             print ''
 412             print eq._latex_()
 413             print compile_expectation( kill_s(eq) )
 414             #for a in ( sort_order( eq ) ):
 415             #    print [compile_expectation( el ) for el in a]
 416     
 417     #print code

Attached Files

To refer to attachments on a page, use attachment:filename, as shown below in the list of files. Do NOT use the URL of the [get] link, since this is subject to change and can break easily.
  • [get | view] (2010-12-09 17:35:32, 200.0 KB) [[attachment:Bifurcations_portfolios_appendix.pdf]]
  • [get | view] (2010-12-09 18:05:41, 87.2 KB) [[attachment:automatic_dr1.pdf]]
  • [get | view] (2010-12-10 09:20:27, 13.6 KB) [[attachment:automatic_dr1.py]]
  • [get | view] (2024-01-09 16:49:55, 0.4 KB) [[attachment:latex_10d4a8966448758f76ecefe0a19ecfa944eface7_p1.png]]
  • [get | view] (2024-01-09 16:49:55, 1.0 KB) [[attachment:latex_1b2264f92c07685c87478c785567d7f482cf7108_p1.png]]
  • [get | view] (2024-01-09 16:49:56, 0.2 KB) [[attachment:latex_2568a85634b9bdefb53c9fae0cd2a74d39cae4f5_p1.png]]
  • [get | view] (2024-01-09 16:49:56, 0.2 KB) [[attachment:latex_5b7accd67b8fddfab7d40de986a4f4b84d862d14_p1.png]]
  • [get | view] (2024-01-09 16:49:56, 0.5 KB) [[attachment:latex_9d3b996515a7d78756053519d04a98ea6798ca5c_p1.png]]
  • [get | view] (2024-01-09 16:49:56, 0.2 KB) [[attachment:latex_a6afa70ad0dd291879a156956894c8d75a5fdbc7_p1.png]]
  • [get | view] (2024-01-09 16:49:56, 0.7 KB) [[attachment:latex_e32ac87e8922bca40d542f64fcbb94d7ed6d433f_p1.png]]
  • [get | view] (2024-01-09 16:49:55, 0.5 KB) [[attachment:latex_fa81e0e825b58f68a34ef3272b44e94a0b18a624_p1.png]]
 All files | Selected Files: delete move to page copy to page

You are not allowed to attach a file to this page.