Attachment 'automatic_dr1.py'
Download 1 def quick(l):
2 l = list(l)
3 exchanges = []
4 n = len(l)
5 i = 0
6 while i < n:
7 if (i>0) and (l[i]<l[i-1]):
8 new_place = min([u for u in range(i) if l[u]>l[i] ])
9 tt = l[new_place]
10 l[new_place] = l[i]
11 l[i] = tt
12 exchanges.append( ( new_place,i ) )
13 else:
14 i += 1
15 return [l,exchanges]
16
17 class Sum(list):
18 '''A sum is a list of terms with integer coefficients. Terms are arbitrary object with the only restriction that they must be hashable.'''
19
20 def __init__(self, elts, coeffs=None):
21 if coeffs == None:
22 coeffs = [1 for e in elts]
23 else:
24 coeffs = coeffs
25 self.coeffs = []
26 for i,e in enumerate(elts):
27 if isinstance(e,Sum):
28 for j,p in enumerate(e):
29 self.append(p)
30 self.coeffs.append( coeffs[i] * e.coeffs[j] )
31 else:
32 self.append(e)
33 self.coeffs.append(coeffs[i])
34
35
36 def diff(self,i,n=1):
37 if n == 1:
38 return Sum( [ e.diff(i) for e in self ] , self.coeffs ).collect()
39 else:
40 return self.diff(i,n-1).diff(i)
41
42 def collect(self):
43 new_elements = []
44 new_coeffs = []
45 for i,e in enumerate(self):
46 if e in new_elements:
47 k = new_elements.index(e)
48 new_coeffs[k] += self.coeffs[i]
49 else:
50 new_elements.append(e)
51 new_coeffs.append(self.coeffs[i])
52 return Sum(new_elements,new_coeffs)
53
54
55
56 def _latex_(self):
57 terms = []
58 for i,e in enumerate(self):
59 if self.coeffs[i] == 1:
60 terms.append( e._latex_() )
61 else:
62 terms.append( str(self.coeffs[i]) + ' ' + e._latex_() )
63 s = str.join(' + ', terms)
64 return s
65
66 class Bracket(list):
67 '''This class represent a multilinear function. Its content is stored as a list of tuples.
68 For instance, "[ ('a','e'),('e',) ]" would represent "f^{2}[V_ae,V_e]".
69 Derivatives with respect to any variable 'x' is computed by added successively 'x' to each of the tuples, and to add a new tuple containing only 'x'. Hence the result is a list of brackets. Each bracket of the resulting sum is also reordered in lexical order w.r.t. to its tuples.
70 '''
71
72 def diff(self,i):
73 new_brackets = [ ]
74 for k in range(len(self)):
75 n = Bracket(self)
76 n[k] = n[k] + (i,)
77 n.reorder()
78 new_brackets.append(n)
79 b = Bracket( tuple( self + [(i,)] ) )
80 b.reorder()
81 new_brackets.append( b )
82 return Sum(new_brackets).collect()
83
84 def reorder(self):
85 for i,e in enumerate(self):
86 self[i] = tuple( sorted(e) )
87 self.sort()
88 return None
89
90 def is_ascending(self):
91 l = []
92 [l.extend(list(a)) for a in self]
93 for i in range(1,len(l)):
94 if l[i] < l[i-1]:
95 return False
96 return True
97
98 def order_swaps(self):
99 l = []
100 [l.extend(list(a)) for a in self]
101 [junk,swaps] = quick(l)
102 return swaps
103
104
105
106 def __hash__(self):
107 return tuple(self).__hash__()
108
109 def flat_indices(self):
110 l = []
111 for e in self:
112 l.extend( list(e) )
113 return l
114
115 def _latex_(self):
116 if True:
117 n = len(self)
118 s = "f^{{({n})}}[{content}]"
119 content = []
120 for el in self:
121 index = str.join(' ', [ str(e) for e in el ] )
122 content.append( 'V_{{{index}}}'.format(index=index) )
123 content = str.join(',',content)
124 s = s.format(n=n,content=content)
125 s = s.replace('s','\\sigma')
126 return s
127
128 class BracketG(Bracket):
129
130 def diff(self,i):
131 new_brackets = [ ]
132 for k in range(len(self)):
133 if self[k] not in ( ('t',), ('y',), ('z',) ):
134 n = BracketG(self)
135 n[k] = n[k] + (i,)
136 n.reorder()
137 new_brackets.append(n)
138 if i not in ('p','s'):
139 b = BracketG( tuple( self + [(i,)] ) )
140 b.reorder()
141 new_brackets.append( b )
142 elif i == 'p':
143 b = BracketG( tuple( self + [(i,)] ) )
144 b1 = BracketG( tuple( self + [('t',)] ) )
145 b.reorder()
146 b1.reorder()
147 new_brackets.extend( [b,b1] )
148 elif i == 's':
149 b = BracketG( tuple( self + [(i,)] ) )
150 b1 = BracketG( tuple( self + [('y',)] ) )
151 b2 = BracketG( tuple( self + [('z',)] ) )
152 b.reorder()
153 b1.reorder()
154 b2.reorder()
155 new_brackets.append( b )
156 new_brackets.append( b1 )
157 new_brackets.append( b2 )
158 return Sum(new_brackets)
159
160 def _latex_(self,type='f'):
161 if True:
162 n = len(self)
163 s = 'g^{{( {n} )}}_{{ {ind} }}[ {content} ]'
164 ind = []
165 for i in self:
166 if i == ('t',):
167 ind.append('p')
168 elif i == ('y',):
169 ind.append('u')
170 elif i == ('z',):
171 ind.append('\\sigma')
172 else:
173 ind.append('a')
174 ind = str.join('',ind)
175 content = []
176 for i in self:
177 if i == ('t',):
178 content.append('I_p')
179 elif i == ('y',):
180 content.append('q')
181 elif i == ('z',):
182 content.append('I_s')
183 else:
184 content.append( 'g_{{ {0} }}'.format( str.join('', i) ) )
185 content = str.join(',',content)
186 content = content.replace('s','\\sigma ').replace('q','\\epsilon')
187 return s.format(n=n,ind=ind,content=content)
188
189 def stoch_order(self):
190 '''Counts the number of epsilons'''
191 return sum( [ el.count('y') for el in self] )
192
193 def kill_s(sum_expr):
194 '''Simplifies expressions using g_s=0.'''
195 bras = []
196 coeffs = []
197 for i,expr in enumerate(sum_expr):
198 coeff = sum_expr.coeffs[i]
199 valid = True
200 for e in expr:
201 if e.count('s') == 1:
202 valid = False
203 pass
204 if expr.count(('z',)) == 1:
205 valid = False
206 if valid:
207 bras.append(expr)
208 coeffs.append(coeff)
209 return Sum(bras,coeffs)
210
211 def get_nondecreasing_indices(args,order):
212 indices = [[(i,) for i in args]]
213 for o in range(2,order+1):
214 pindices = indices[-1]
215 cindices = []
216 for e in args:
217 for p in pindices:
218 if p[-1] <= e:
219 cindices += [ p + (e,) ]
220 indices.append(cindices)
221 return indices
222
223 def get_K(expr):
224 resp = Sum([ h for h in expr if len(h) > 1])
225 other = Sum([ h for h in expr if len(h) == 1])
226 return [resp,other]
227
228 def get_L(expr):
229 n = sum([len(el) for el in expr[0]])
230 special = BracketG( [('a',)] ) * n
231 resp = Sum([ h for h in expr if (len(h)>1) and (h!=special) ])
232 other = Sum([ h for h in expr if (len(h) == 1) or h==special])
233 return [resp,other]
234
235 def compute_derivatives(n,vars,):
236 eqs_f = [ Sum([Bracket([])]) ]
237 eqs_f.append( [ eqs_f[-1].diff(v) for v in vars ] )
238 eqs_g = [ Sum([BracketG([])]) ]
239 eqs_g.append( [ eqs_g[-1].diff(v) for v in vars ] )
240 ndt = [[],[(v,) for v in vars]]
241 for i in range(n-1):
242 new_eqs_f = []
243 new_eqs_g = []
244 new_ndt = []
245 o_eqs_f = eqs_f[-1]
246 o_eqs_g = eqs_g[-1]
247 o_ndt = ndt[-1]
248 for i,e in enumerate(o_eqs_f):
249 ind = o_ndt[i]
250 for v in vars:
251 if v>= ind[-1]:
252 new_eqs_f.append( e.diff(v).collect() )
253 for i,e in enumerate(o_eqs_g):
254 ind = o_ndt[i]
255 for v in vars:
256 if v>= ind[-1]:
257 new_eqs_g.append( e.diff(v).collect() )
258 new_ndt.append( ind + (v,) )
259 eqs_f.append(new_eqs_f)
260 eqs_g.append(new_eqs_g)
261 ndt.append(new_ndt)
262 return [eqs_f,eqs_g,ndt]
263
264 ###############################################
265 ###############################################
266 # #
267 # this part of the code is specific to python #
268 # #
269 ###############################################
270 ###############################################
271
272 def compile_bracket(bra):
273 n = len(bra.flat_indices())
274 if isinstance(bra,BracketG):
275 content = []
276 for el in bra:
277 if el == ('t',):
278 content += ['I_p']
279 elif el == ('y',):
280 content += ['I_e']
281 elif el == ('z',):
282 content += ['I_s']
283 else:
284 content += ["g_{0}".format( str.join('',el) )]
285 indices = []
286 for el in bra:
287 if el == ('t',):
288 indices += 'p'
289 elif el == ('y',):
290 indices += 'e'
291 elif el == ('z',):
292 indices += 's'
293 else:
294 indices += 'a'
295 indices = str.join('',indices)
296 txt = 'mdot(g_{0},[{1}])'.format( indices, str.join(",",content))
297 elif isinstance(bra,Bracket):
298 content = [ "V_{0}".format( str.join("",e) ) for e in bra]
299 txt = 'mdot(f_{0},[{1}])'.format(n, str.join(",",content))
300 if not bra.is_ascending():
301 l = []
302 for t in bra:
303 for e in t:
304 l.append(e)
305 [nl, exchanges] = quick(l)
306 swp = [".swapaxes({0},{1})".format(e[0],e[1]) for e in exchanges]
307 txt += str.join('',swp)
308 return txt
309
310 def compile_expectation(expr):
311 if isinstance(expr, BracketG):
312 bra = expr
313 ns = bra.stoch_order()
314 if ns == 0:
315 return compile_bracket(bra)
316 elif ns == 1:
317 return 'np.zeros((n_v,{0}))'.format(str.join(',', ['n_'+i for i in bra.flat_indices() if i != 'y'] ))
318 inds = bra.flat_indices()
319 eps_inds = [str(n) for n,i in enumerate(inds) if inds[n] == 'y' ]
320 sig_indices = [str(i) for i in tuple(range(ns))]
321 resp = compile_bracket(bra)
322 resp = 'np.tensordot( {0}, Sigma_{1}, axes=( ({2}), ({3}) ) )'.format(resp, n+1, str.join(',', eps_inds), str.join(',',sig_indices) )
323 return resp
324 if isinstance(expr,Sum):
325 resp = ''
326 for n,term in enumerate(expr):
327 c = expr.coeffs[n]
328 if c == 1:
329 resp += ' + {0}'.format( compile_expectation(term) )
330 elif c >= 1:
331 resp += ' + {0}*{1}'.format(c,compile_expectation(term) )
332 resp = resp.lstrip(' +')
333
334 return resp
335
336 def compile_sum(s):
337
338 resp = ''
339 for n,bra in enumerate(s):
340 c = s.coeffs[n]
341 if c == 1:
342 resp += ' + {0}'.format(compile_bracket(bra))
343 elif c >= 1:
344 resp += ' + {0}*{1}'.format(c,compile_bracket(bra))
345 resp = resp.lstrip(' +')
346 return resp
347
348 def write_order(order,vars,eqs_f,eqs_g,ndt):
349 code = "{br}Computing order {order}\n\n".format(br='\n#' + 10*'-' ,order=order)
350 code += 'order = {0}\n'.format(order)
351 for k in range(len(eqs_f[order])):
352 eq_f = eqs_f[order][k]
353 eq_g = eqs_g[order][k]
354 inds = ndt[order][k]
355 code += '\n#--- Computing derivatives {0}\n\n'.format(inds)
356 [K_eq,K_other] = get_K(eq_f)
357 K_name = 'K_'+str.join('',inds)
358 lhs = K_name
359 code += '{lhs} = {rhs}\n'.format(lhs=lhs,rhs=compile_sum(K_eq))
360 [L_eq,L_other] = get_L(eq_g)
361 L_name = 'L_'+str.join('',inds)
362 lhs = L_name
363 code += '{lhs} = {rhs}\n\n'.format(lhs=lhs,rhs=compile_sum(L_eq))
364 if k == 0:
365 code += '#We need to solve the infamous sylvester equation\n'
366 code += '#A = f_a + sdot(f_d,g_a)\n'
367 code += '#B = f_d\n'
368 code += '#C = g_a\n'
369 code += 'D = {K} + sdot(f_d,{L})\n'.format(K=K_name,L=L_name)
370 code += 'g_{sub} = solve_sylvester(A,B,C,D)\n'.format(sub = str.join('',inds))
371 else:
372 code += '#We solve A*X + const = 0\n'
373 code += 'const = sdot(f_d,{L}) + {K}\n'.format(L=L_name,K=K_name)
374 code += 'g_{sub} = - sdot(A_inv, const)\n'.format(sub=str.join('',inds))
375 if True:
376 sizes = str.join(',', ['n_'+v for v in vars] )
377 code += '\nif order < max_order:\n'
378 code += ' Y = {L}{other}\n'.format(L = L_name, other = compile_sum(L_other) )
379 code += ' Z = g_{sub}\n'.format(sub=str.join('',inds))
380 code += ' V_{sub} = build_V(Y,Z,({sizes}))\n'.format(sub = str.join('',inds), sizes=sizes)
381 # code += 'end\n'
382 return code
383
384 def gen_code(max_order,vars):
385 [eqs_f,eqs_g,ndt] = compute_derivatives(max_order,vars)
386
387 code = ''
388 # code = 'max_order = {0}\n'.format(max_order)
389 for o in range(2,max_order+1):
390 code += write_order(o,vars,eqs_f,eqs_g,ndt)
391 return code
392
393 def sort_order(expr):
394 orders = [term.stoch_order() for term in expr]
395 print orders
396 if len(orders) == 0:
397 return []
398 resp = []
399 for ord in range(0,max(orders)+1):
400 resp.append( [term for n,term in enumerate(expr) if orders[n] == ord] )
401 return resp
402
403
404 if __name__ == '__main__':
405 #code = gen_code(4,('a','e'))
406
407 [eqs_f,eqs_g,ndt] = compute_derivatives(5,('a','e','s'))
408
409 for order in range(len(eqs_g)):
410 for eq in eqs_g[order]:
411 print ''
412 print eq._latex_()
413 print compile_expectation( kill_s(eq) )
414 #for a in ( sort_order( eq ) ):
415 # print [compile_expectation( el ) for el in a]
416
417 #print code
Attached Files
To refer to attachments on a page, use attachment:filename, as shown below in the list of files. Do NOT use the URL of the [get] link, since this is subject to change and can break easily.You are not allowed to attach a file to this page.