Note
Go to the end to download the full example code.
Custom: Simple scalar network
Creating a network with multiple custom modules
This example demonstrates how multiple simple modules are implemented. Different methods of data storage inside the
odule are shown. Next to that, the effect of ordering the modules in a pymoto.Network is demonstrated. This
allows for different mathematical behavior, while keeping the same implementation of modules: no additional effort needs
to be made to keep the sensitivities consistent. Also basic usage of pymoto.Signal, response and sensitivity
calculation, and validation with pymoto.finite_difference() is demonstrated.
This example is identical in behavior to MathExpression: General math expressions, but uses manually implemented sensitivities instead of automatically generated ones.
15 import pymoto as pym
16 import math
17
18
19 # Module definitions
20 class ModuleA(pym.Module):
21 """ Evaluates y = x2 * sin(x1)
22 In this module, the state variables are stored internally during response() for use in the sensitivity()
23 """
24 def __call__(self, x1, x2):
25 # Store state for use in sensitivity
26 self.x1 = x1
27 self.x2 = x2
28 return self.x2 * math.sin(self.x1)
29
30 def _sensitivity(self, df_dy):
31 df_dx1 = df_dy * self.x2 * math.cos(self.x1)
32 df_dx2 = df_dy * math.sin(self.x1)
33 return df_dx1, df_dx2
34
35
36 class ModuleB(pym.Module):
37 """ Evaluates y = cos(x1) * cos(x2)
38 The derivatives are already calculated during the response(), for easy use in sensitivity()
39 """
40 def __call__(self, x1, x2):
41 # Already calculate the state derivative
42 self.dy_dx1 = math.sin(x1) * math.cos(x2)
43 self.dy_dx2 = math.cos(x1) * math.sin(x2)
44 return math.cos(x2) * math.cos(x1)
45
46 def _sensitivity(self, df_dy):
47 df_dx1 = - df_dy * self.dy_dx1
48 df_dx2 = - df_dy * self.dy_dx2
49 return df_dx1, df_dx2
50
51
52 class ModuleC(pym.Module):
53 """ Evaluates y = x1^2 * (1 + x2)
54 Obtain the state variables during sensitivity()
55 """
56 def __call__(self, x1, x2):
57 return x1**2 * (1 + x2)
58
59 def _sensitivity(self, df_dy):
60 # Obtain input state from signals
61 x1 = self.sig_in[0].state
62 x2 = self.sig_in[1].state
63 df_dx1 = 2 * df_dy * x1 * (1 + x2)
64 df_dx2 = df_dy * x1 * x1
65 return df_dx1, df_dx2
66
67
68 if __name__ == '__main__':
69 print(__doc__)
70
71 # --- SETUP ---
72 # Declare the signals and set initial values
73 x = pym.Signal('x', 1.0)
74 y = pym.Signal('y', 0.8)
75 z = pym.Signal('z', 3.4)
76
77 # Start building a network of modules
78 with pym.Network() as fn:
79 # Create the modules here
80 # Depending on how the input and output signals are routed between the modules, different behavior can be
81 # implemented
82
83 ordering = 0 # Change this to 1 to see a different ordering of the modules
84 if ordering == 0:
85 # A __
86 # \
87 # --> C
88 # B __/
89 a = ModuleA()(x, y)
90 b = ModuleB()(y, z)
91 g = ModuleC()(a, b)
92 a.tag, b.tag, g.tag = 'a', 'b', 'g' # Set tags for the signals
93 elif ordering == 1:
94 # B __
95 # \
96 # --> A
97 # C __/
98 print("Using an alternative module order")
99 a = ModuleB()(x, y)
100 b = ModuleC()(y, z)
101 g = ModuleA()(a, b)
102 a.tag, b.tag, g.tag = 'a', 'b', 'g' # Set tags for the signals
103
104
105
106 print("\nCurrent network:")
107 print(" -> ".join([type(m).__name__ for m in fn.mods]))
108
109 print(f"The response is g(x={x.state}, y={y.state}, z={z.state}) = {g.state}")
110
111 # --- FORWARD ANALYSIS ---
112 # Perform an extra forward analysis
113 # Change the values of the input state
114 x.state *= 2
115 y.state += 0.1
116 z.state -= 0.2
117
118 fn.response() # Run the forward analysis again
119 print(f"The updated response is g(x={x.state}, y={y.state}, z={z.state}) = {g.state}")
120
121 # --- BACKPROPAGATION ---
122 # Clear previous sensitivities (in this case it is redundant as the network is just created, but good practice)
123 fn.reset()
124
125 # Seed the response sensitivity
126 g.sensitivity = 1.0
127
128 # Calculate the sensitivities
129 fn.sensitivity()
130
131 print("\nThe sensitivities are:")
132 print(f"d{g.tag}/d{x.tag} = {x.sensitivity}")
133 print(f"d{g.tag}/d{y.tag} = {y.sensitivity}")
134 print(f"d{g.tag}/d{z.tag} = {z.sensitivity}")
135
136 # --- Finite difference checks ---
137 # On the individual modules
138 pym.finite_difference([x, y], a, random=False) # From (x, y) to a
139 pym.finite_difference([y, z], b, random=False) # From (y, z) to b
140 pym.finite_difference([a, b], g, random=False) # From (a, b) to g
141
142 # Do a finite difference check on the entire network
143 pym.finite_difference([x, y, z], g, random=False)