On Tue, 1 Jun 2021 at 05:16, Neil Girdhar <mistersheik@gmail.com> wrote:
Hi Oscar,
The problem that the original poster was trying to address with additional syntax is the automatic naming of symbols. He wants to omit this line:
x = symbols("x")
You're right that if you have many one-character symbol names, you can use a shortcut, but this benefit is lost if you want descriptive names like:
momentum = symbols('momentum')
He is proposing new syntax to eliminate the repeated name. The function approach specifies each name exactly once. This is one of the benefits of JAX over TensorFLow.
Second, the function approach allows the function to be a single object that can be used in calcuations. You might ask for:
grad(equation, 2)(2, 3, 4 5) # derivative with respect to parameter 2 of equation evaluated at (2, 3, 4, 5)
With the symbolic approach, you need to keep the equation object as well as the symbols that compose it to interact with it.
This makes more sense in a limited context for symbolic manipulation where symbols only represent function parameters so that all symbols are bound. How would you handle the situation where the same symbols are free in two different expressions that you want to manipulate in tandem though? In this example we have two different equations containing the same symbols and we want to solve them as a system of equations: p, m, h = symbols('p, m, h') E = p**2 / 2*m lamda = h / p E1 = 5 lamda1 = 2 [(p1, m1)] = solve([Eq(E, E1), Eq(lamda, lamda1)], [p, m]) I don't see a good way of doing this without keeping track of the symbols as separate objects. I don't think this kind of thing comes up in Jax because it is only designed for the more limited symbolic task of evaluating and differentiating Python functions. Also for simple expressions like this I think that a decorated function seems quite cumbersome: @symbolic def E(p, m): return p**2 / (2*m) @symbolic def lamda(h, p): return h / p
Finally, the function can just be called with concrete values:
equation(2, 3, 4, 5) # gives 25
which is convenient.
That is convenient but I think again this only really makes sense if all of your expressions are really just functions and all of your symbols are bound symbols representing function parameters. It is possible in sympy to convert an expression into a function but you need to specify the ordering of the symbols as function parameters: expression = p**2 / (2*m) function = lambdify([p, m], expression) function(1, 2) # 0.25 The need to specify the ordering comes from the fact that the expression itself is not conceptually a function and does not have an ordered parameter list. -- Oscar
On Tue, Jun 1, 2021 at 5:39 AM Oscar Benjamin <oscar.j.benjamin@gmail.com> wrote:
On Tue, 1 Jun 2021 at 05:16, Neil Girdhar <mistersheik@gmail.com> wrote:
Hi Oscar,
The problem that the original poster was trying to address with additional syntax is the automatic naming of symbols. He wants to omit this line:
x = symbols("x")
You're right that if you have many one-character symbol names, you can use a shortcut, but this benefit is lost if you want descriptive names like:
momentum = symbols('momentum')
He is proposing new syntax to eliminate the repeated name. The function approach specifies each name exactly once. This is one of the benefits of JAX over TensorFLow.
Second, the function approach allows the function to be a single object that can be used in calcuations. You might ask for:
grad(equation, 2)(2, 3, 4 5) # derivative with respect to parameter 2 of equation evaluated at (2, 3, 4, 5)
With the symbolic approach, you need to keep the equation object as well as the symbols that compose it to interact with it.
This makes more sense in a limited context for symbolic manipulation where symbols only represent function parameters so that all symbols are bound. How would you handle the situation where the same symbols are free in two different expressions that you want to manipulate in tandem though?
In this example we have two different equations containing the same symbols and we want to solve them as a system of equations:
p, m, h = symbols('p, m, h') E = p**2 / 2*m lamda = h / p
E1 = 5 lamda1 = 2 [(p1, m1)] = solve([Eq(E, E1), Eq(lamda, lamda1)], [p, m])
I don't see a good way of doing this without keeping track of the symbols as separate objects. I don't think this kind of thing comes up in Jax because it is only designed for the more limited symbolic task of evaluating and differentiating Python functions.
This is a really cool design question. One of the things I like about JAX is that they stayed extremely close to NumPy's interface. In NumPy, comparison operators applied to matrices return Boolean matrices. I would ideally express what you wrote as def E(p, m): ... def lamda(h, p): ... def f(p, m): return jnp.all(E(p, m) == E1) and jnp.all(lamda(h, p) == lamda1) p1, m1 = solve(f)
Also for simple expressions like this I think that a decorated function seems quite cumbersome:
@symbolic def E(p, m): return p**2 / (2*m)
@symbolic def lamda(h, p): return h / p
I guess we don't need the decorator unless we want to memoize some results. JAX uses a decorator to memoize the jitted code, for example. I thought it might be nice for example to memoize the parse tree so that the function doesn't have to be called every time it's used.
Finally, the function can just be called with concrete values:
equation(2, 3, 4, 5) # gives 25
which is convenient.
That is convenient but I think again this only really makes sense if all of your expressions are really just functions and all of your symbols are bound symbols representing function parameters. It is possible in sympy to convert an expression into a function but you need to specify the ordering of the symbols as function parameters:
expression = p**2 / (2*m) function = lambdify([p, m], expression) function(1, 2) # 0.25
The need to specify the ordering comes from the fact that the expression itself is not conceptually a function and does not have an ordered parameter list.
Right. I think it's simpler to just specify symbolic functions as Python functions, but I can see why that does make the very simplest cases slightly more wordy.
-- Oscar
On Tue, 1 Jun 2021 at 10:53, Neil Girdhar <mistersheik@gmail.com> wrote:
On Tue, Jun 1, 2021 at 5:39 AM Oscar Benjamin <oscar.j.benjamin@gmail.com> wrote:
On Tue, 1 Jun 2021 at 05:16, Neil Girdhar <mistersheik@gmail.com> wrote:
Hi Oscar,
The problem that the original poster was trying to address with additional syntax is the automatic naming of symbols. He wants to omit this line:
x = symbols("x")
You're right that if you have many one-character symbol names, you can use a shortcut, but this benefit is lost if you want descriptive names like:
momentum = symbols('momentum')
He is proposing new syntax to eliminate the repeated name. The function approach specifies each name exactly once. This is one of the benefits of JAX over TensorFLow.
Second, the function approach allows the function to be a single object that can be used in calcuations. You might ask for:
grad(equation, 2)(2, 3, 4 5) # derivative with respect to parameter 2 of equation evaluated at (2, 3, 4, 5)
With the symbolic approach, you need to keep the equation object as well as the symbols that compose it to interact with it.
This makes more sense in a limited context for symbolic manipulation where symbols only represent function parameters so that all symbols are bound. How would you handle the situation where the same symbols are free in two different expressions that you want to manipulate in tandem though?
In this example we have two different equations containing the same symbols and we want to solve them as a system of equations:
p, m, h = symbols('p, m, h') E = p**2 / 2*m lamda = h / p
E1 = 5 lamda1 = 2 [(p1, m1)] = solve([Eq(E, E1), Eq(lamda, lamda1)], [p, m])
I don't see a good way of doing this without keeping track of the symbols as separate objects. I don't think this kind of thing comes up in Jax because it is only designed for the more limited symbolic task of evaluating and differentiating Python functions.
This is a really cool design question.
One of the things I like about JAX is that they stayed extremely close to NumPy's interface. In NumPy, comparison operators applied to matrices return Boolean matrices.
I would ideally express what you wrote as
def E(p, m): ...
def lamda(h, p): ...
def f(p, m): return jnp.all(E(p, m) == E1) and jnp.all(lamda(h, p) == lamda1)
p1, m1 = solve(f)
So how does solve know to solve for p and m rather than h? Note that I deliberately included a third symbol and made the parameter lists of E and lamda inconsistent. Should Jax recognise that the 2nd parameter of lamda has the same name as the 1st parameter of E? Or should symbols at the same parameter index be considered the same regardless of their name? In Jax everything is a function so I would expect it to ignore the symbol names so that if args = solve([f1, f2]) then f1(*args) == f2(*args) == 0. This is usually how the API works for numerical rather than symbolic root-finding algorithms. But then how do I solve a system of equations that has a symbolic parameter like `h`? -- Oscar
On Tue, Jun 1, 2021 at 6:28 AM Oscar Benjamin <oscar.j.benjamin@gmail.com> wrote:
On Tue, 1 Jun 2021 at 10:53, Neil Girdhar <mistersheik@gmail.com> wrote:
On Tue, Jun 1, 2021 at 5:39 AM Oscar Benjamin <oscar.j.benjamin@gmail.com> wrote:
On Tue, 1 Jun 2021 at 05:16, Neil Girdhar <mistersheik@gmail.com> wrote:
Hi Oscar,
The problem that the original poster was trying to address with additional syntax is the automatic naming of symbols. He wants to omit this line:
x = symbols("x")
You're right that if you have many one-character symbol names, you can use a shortcut, but this benefit is lost if you want descriptive names like:
momentum = symbols('momentum')
He is proposing new syntax to eliminate the repeated name. The function approach specifies each name exactly once. This is one of the benefits of JAX over TensorFLow.
Second, the function approach allows the function to be a single object that can be used in calcuations. You might ask for:
grad(equation, 2)(2, 3, 4 5) # derivative with respect to parameter 2 of equation evaluated at (2, 3, 4, 5)
With the symbolic approach, you need to keep the equation object as well as the symbols that compose it to interact with it.
This makes more sense in a limited context for symbolic manipulation where symbols only represent function parameters so that all symbols are bound. How would you handle the situation where the same symbols are free in two different expressions that you want to manipulate in tandem though?
In this example we have two different equations containing the same symbols and we want to solve them as a system of equations:
p, m, h = symbols('p, m, h') E = p**2 / 2*m lamda = h / p
E1 = 5 lamda1 = 2 [(p1, m1)] = solve([Eq(E, E1), Eq(lamda, lamda1)], [p, m])
I don't see a good way of doing this without keeping track of the symbols as separate objects. I don't think this kind of thing comes up in Jax because it is only designed for the more limited symbolic task of evaluating and differentiating Python functions.
This is a really cool design question.
One of the things I like about JAX is that they stayed extremely close to NumPy's interface. In NumPy, comparison operators applied to matrices return Boolean matrices.
I would ideally express what you wrote as
def E(p, m): ...
def lamda(h, p): ...
def f(p, m): return jnp.all(E(p, m) == E1) and jnp.all(lamda(h, p) == lamda1)
p1, m1 = solve(f)
So how does solve know to solve for p and m rather than h?
Because those are the parameters of f.
Note that I deliberately included a third symbol and made the parameter lists of E and lamda inconsistent.
Should Jax recognise that the 2nd parameter of lamda has the same name as the 1st parameter of E? Or should symbols at the same parameter index be considered the same regardless of their name?
It doesn't need to because the same variable p (which will ultimately point to a tracer object) is passed to the functions E and lamda. It's not using the names.
In Jax everything is a function so I would expect it to ignore the symbol names so that if args = solve([f1, f2]) then f1(*args) == f2(*args) == 0.
This is usually how the API works for numerical rather than symbolic root-finding algorithms. But then how do I solve a system of equations that has a symbolic parameter like `h`?
I assumed that h was a constant and was being closed over by f. If h is a symbol, I think to stay consistent with functions creating symbols, we could do: def f(p, m, h): return E(p, m) == E1 and lamda(h, p) == lamda1 def g(h): return solve(partial(f, h=h)) g is now a symbolic equation that returns p, m in terms of h. By the way, it occured to me that it might be reasonable to build a system like this fairly quickly using sympy as a backend. Although, there are some other benefits of Jax's symbolic expression-builder that might be a lot harder to capture: Last time I used sympy though (years ago), I had a really hard time making matrices of symbols, and there were some incongruities between numpy and sympy functions. Best, Neil
-- Oscar
Oh, and I see you're a sympy developer! I hope I'm not coming across as critical in any way. I love that sympy exists and thought it was a really cool project when I first learned about what it can do. Perhaps we should move our discussion to a "Github discussion" under the sympy Github? We might be able to build a toy example that takes functions, inspects the signature, and converts it into a traditional sympy expression. Then we could examine a variety of examples in the docs to see whether the function-based symbols are easier or harder to use in general. Best, Neil On Tue, Jun 1, 2021 at 6:47 AM Neil Girdhar <mistersheik@gmail.com> wrote:
On Tue, Jun 1, 2021 at 6:28 AM Oscar Benjamin <oscar.j.benjamin@gmail.com> wrote:
On Tue, 1 Jun 2021 at 10:53, Neil Girdhar <mistersheik@gmail.com> wrote:
On Tue, Jun 1, 2021 at 5:39 AM Oscar Benjamin <oscar.j.benjamin@gmail.com> wrote:
On Tue, 1 Jun 2021 at 05:16, Neil Girdhar <mistersheik@gmail.com> wrote:
Hi Oscar,
The problem that the original poster was trying to address with additional syntax is the automatic naming of symbols. He wants to omit this line:
x = symbols("x")
You're right that if you have many one-character symbol names, you can use a shortcut, but this benefit is lost if you want descriptive names like:
momentum = symbols('momentum')
He is proposing new syntax to eliminate the repeated name. The function approach specifies each name exactly once. This is one of the benefits of JAX over TensorFLow.
Second, the function approach allows the function to be a single object that can be used in calcuations. You might ask for:
grad(equation, 2)(2, 3, 4 5) # derivative with respect to parameter 2 of equation evaluated at (2, 3, 4, 5)
With the symbolic approach, you need to keep the equation object as well as the symbols that compose it to interact with it.
This makes more sense in a limited context for symbolic manipulation where symbols only represent function parameters so that all symbols are bound. How would you handle the situation where the same symbols are free in two different expressions that you want to manipulate in tandem though?
In this example we have two different equations containing the same symbols and we want to solve them as a system of equations:
p, m, h = symbols('p, m, h') E = p**2 / 2*m lamda = h / p
E1 = 5 lamda1 = 2 [(p1, m1)] = solve([Eq(E, E1), Eq(lamda, lamda1)], [p, m])
I don't see a good way of doing this without keeping track of the symbols as separate objects. I don't think this kind of thing comes up in Jax because it is only designed for the more limited symbolic task of evaluating and differentiating Python functions.
This is a really cool design question.
One of the things I like about JAX is that they stayed extremely close to NumPy's interface. In NumPy, comparison operators applied to matrices return Boolean matrices.
I would ideally express what you wrote as
def E(p, m): ...
def lamda(h, p): ...
def f(p, m): return jnp.all(E(p, m) == E1) and jnp.all(lamda(h, p) == lamda1)
p1, m1 = solve(f)
So how does solve know to solve for p and m rather than h?
Because those are the parameters of f.
Note that I deliberately included a third symbol and made the parameter lists of E and lamda inconsistent.
Should Jax recognise that the 2nd parameter of lamda has the same name as the 1st parameter of E? Or should symbols at the same parameter index be considered the same regardless of their name?
It doesn't need to because the same variable p (which will ultimately point to a tracer object) is passed to the functions E and lamda. It's not using the names.
In Jax everything is a function so I would expect it to ignore the symbol names so that if args = solve([f1, f2]) then f1(*args) == f2(*args) == 0.
This is usually how the API works for numerical rather than symbolic root-finding algorithms. But then how do I solve a system of equations that has a symbolic parameter like `h`?
I assumed that h was a constant and was being closed over by f. If h is a symbol, I think to stay consistent with functions creating symbols, we could do:
def f(p, m, h): return E(p, m) == E1 and lamda(h, p) == lamda1
def g(h): return solve(partial(f, h=h))
g is now a symbolic equation that returns p, m in terms of h.
By the way, it occured to me that it might be reasonable to build a system like this fairly quickly using sympy as a backend.
Although, there are some other benefits of Jax's symbolic expression-builder that might be a lot harder to capture: Last time I used sympy though (years ago), I had a really hard time making matrices of symbols, and there were some incongruities between numpy and sympy functions.
Best,
Neil
-- Oscar
participants (2)
-
Neil Girdhar
-
Oscar Benjamin