diff --git a/Project.toml b/Project.toml index e4e9f62..cb18cb3 100644 --- a/Project.toml +++ b/Project.toml @@ -5,9 +5,11 @@ version = "0.2.3" [deps] SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" +TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" [compat] SymbolicUtils = "3" +TermInterface = "2.0" julia = "1.10" [extras] diff --git a/src/limits.jl b/src/limits.jl index 0a4259b..e490345 100644 --- a/src/limits.jl +++ b/src/limits.jl @@ -29,8 +29,8 @@ # `x`? using SymbolicUtils -using SymbolicUtils: BasicSymbolic, exprtype -using SymbolicUtils: SYM, TERM, ADD, MUL, POW, DIV +using SymbolicUtils: BasicSymbolic, isterm, issym, isaddmul, isdiv, isadd, ismul +using TermInterface """ is_exp(expr) @@ -41,7 +41,7 @@ Returns `true` if `expr` is a symbolic expression with the exponential function `false` otherwise. """ is_exp(expr) = false -is_exp(expr::BasicSymbolic) = exprtype(expr) == TERM && operation(expr) == exp +is_exp(expr::BasicSymbolic) = iscall(expr) && operation(expr) == exp # unused. This function provides a measure of the "size" of an expression, for use in proofs # of termination and debugging nontermination only: @@ -66,7 +66,7 @@ Compute the limit of `expr` as `x` approaches infinity and return `(limit, assum This is the internal API boundary between the internal limits.jl file and the public SymbolicLimits.jl file """ -function limit_inf(expr, x::BasicSymbolic{Field}) where {Field} +function limit_inf(expr, x::BasicSymbolic) assumptions = Set{Any}() limit = signed_limit_inf(expr, x, assumptions)[1] limit, assumptions @@ -95,28 +95,30 @@ A tuple `(limit, sign)` where: - `limit`: The computed limit value - `sign`: The sign of the limit """ -function signed_limit_inf(expr::Field, x::BasicSymbolic{Field}, assumptions) where {Field} +function signed_limit_inf(expr::Field, x::BasicSymbolic, assumptions) expr, sign(expr) end -function signed_limit_inf(expr::BasicSymbolic{Field}, x::BasicSymbolic{Field}, assumptions) where {Field} +function signed_limit_inf(expr::BasicSymbolic{T}, x::BasicSymbolic{T}, assumptions) where {T} + Field = symtype(expr) + @assert symtype(x) == Field expr === x && return (Inf, 1) Ω = most_rapidly_varying_subexpressions(expr, x, assumptions) isempty(Ω) && return (expr, sign(expr)) ω_val = last(Ω) - ω_sym = SymbolicUtils.Sym{Field}(Symbol(:ω, gensym())) + ω_sym = SymbolicUtils.Sym{T}(Symbol(:ω, gensym()); type = Field, shape = SymbolicUtils.ShapeVecT()) while !is_exp(ω_val) # equivalent to x ∈ Ω expr = recursive(expr) do f, ex - ex isa BasicSymbolic{Field} || return ex - exprtype(ex) == SYM && return ex === x ? exp(x) : ex + symtype(ex) === Field || return ex + iscall(ex) || return ex === x ? exp(x) : ex operation(ex)(f.(arguments(ex))...) end expr = log_exp_simplify(expr) # Ω = most_rapidly_varying_subexpressions(expr, x) NO! this line could lead to infinite recursion Ω = [log_exp_simplify(recursive(expr) do f, ex - ex isa BasicSymbolic{Field} || return ex - exprtype(ex) == SYM && return ex === x ? exp(x) : ex + symtype(ex) === Field || return ex + iscall(ex) || return ex === x ? exp(x) : ex operation(ex)(f.(arguments(ex))...) end) for expr in Ω] ω_val = last(Ω) @@ -134,13 +136,13 @@ function signed_limit_inf(expr::BasicSymbolic{Field}, x::BasicSymbolic{Field}, a # This ensures that mrv(expr2) == {ω}. TODO: do we need to do top-down with recursion even after replacement? expr2 = recursive(expr) do f, ex # This traverses from largest to smallest, as required? - ex isa BasicSymbolic{Field} || return ex - exprtype(ex) == SYM && return ex + symtype(ex) === Field || return ex + iscall(ex) || return ex # ex ∈ Ω && return rewrite(ex, ω, h, x) # ∈ uses symbolic equality which is iffy if any(x -> zero_equivalence(x - ex, assumptions), Ω) ex = rewrite(ex, ω_sym, h, x, assumptions) - ex isa BasicSymbolic{Field} || return ex - exprtype(ex) == SYM && return ex + symtype(ex) === Field || return ex + iscall(ex) || return ex end operation(ex)(f.(arguments(ex))...) end @@ -216,12 +218,12 @@ cancels log(exp(x)) and exp(log(x)), the latter may extend the domain """ strong_log_exp_simplify(expr) = expr function strong_log_exp_simplify(expr::BasicSymbolic) - exprtype(expr) == SYM && return expr - exprtype(expr) == TERM && operation(expr) in (log, exp) || + iscall(expr) || return expr + operation(expr) in (log, exp) || return operation(expr)(strong_log_exp_simplify.(arguments(expr))...) arg = strong_log_exp_simplify(only(arguments(expr))) # TODO: return _log(arg) - arg isa BasicSymbolic && exprtype(arg) == TERM && operation(arg) in (log, exp) && + iscall(arg) && operation(arg) in (log, exp) && operation(arg) != operation(expr) || return operation(expr)(arg) only(arguments(arg)) end @@ -246,22 +248,26 @@ For scalar expressions, returns an empty list. A list of the most rapidly varying subexpressions in `expr`. """ -function most_rapidly_varying_subexpressions(expr::Field, x::BasicSymbolic{Field}, assumptions) where {Field} +function most_rapidly_varying_subexpressions(expr::Field, x::BasicSymbolic{T}, assumptions) where {Field, T} + @assert symtype(x) === Field [] end function most_rapidly_varying_subexpressions( - expr::BasicSymbolic{Field}, x::BasicSymbolic{Field}, assumptions) where {Field} - exprtype(x) == SYM || + expr::BasicSymbolic{T}, x::BasicSymbolic{T}, assumptions) where {T} + Field = symtype(expr) + @assert symtype(x) == Field + issym(x) || throw(ArgumentError("Must expand with respect to a symbol. Got $x")) # TODO: this is slow. This whole algorithm is slow. Profile, benchmark, and optimize it. - et = exprtype(expr) - ret = if et == SYM + ret = if isconst(expr) + return most_rapidly_varying_subexpressions(unwrap_const(expr), x, assumptions) + elseif issym(expr) if expr.name == x.name [expr] else [] end - elseif et == TERM + elseif isterm(expr) op = operation(expr) if op == log arg = only(arguments(expr)) @@ -274,29 +280,29 @@ function most_rapidly_varying_subexpressions( mrv_join(x, assumptions)([expr], most_rapidly_varying_subexpressions(arg, x, assumptions)) # ensure that the inner most exprs stay last end res + elseif op == ^ + args = arguments(expr) + @assert length(args) == 2 + base, exponent = args + if exponent isa Real && isinteger(exponent) && exponent > 0 + most_rapidly_varying_subexpressions(base, x, assumptions) + else + error("Not implemented: POW with noninteger exponent $exponent. Transform to log/exp.") + end else error("Not implemented: $op") end - elseif et ∈ (ADD, MUL, DIV) + elseif isaddmul(expr) || isdiv(expr) mapreduce(expr -> most_rapidly_varying_subexpressions(expr, x, assumptions), mrv_join(x, assumptions), arguments(expr), init = []) - elseif et == POW - args = arguments(expr) - @assert length(args) == 2 - base, exponent = args - if exponent isa Real && isinteger(exponent) && exponent > 0 - most_rapidly_varying_subexpressions(base, x, assumptions) - else - error("Not implemented: POW with noninteger exponent $exponent. Transform to log/exp.") - end else - error("Unknown Expr type: $et") + error("Unknown Expr type") end ret end function is_exp_or_x(expr::BasicSymbolic, x::BasicSymbolic) - expr === x || exprtype(expr) == TERM && operation(expr) == exp + expr === x || iscall(expr) && operation(expr) == exp end """ @@ -393,11 +399,13 @@ The rewriting follows the formula: if `expr = exp(s)` and `ω = exp(h)`, then we An expression of the form `A⋅ω^c` where `A` is less rapidly varying than `ω`. """ -function rewrite(expr::BasicSymbolic{Field}, ω::BasicSymbolic{Field}, - h::BasicSymbolic{Field}, x::BasicSymbolic{Field}, assumptions) where {Field} - @assert exprtype(expr) == TERM && operation(expr) == exp - @assert exprtype(ω) == SYM - @assert exprtype(x) == SYM +function rewrite(expr::BasicSymbolic{T}, ω::BasicSymbolic{T}, + h::BasicSymbolic{T}, x::BasicSymbolic{T}, assumptions) where {T} + Field = symtype(expr) + @assert symtype(ω) === symtype(h) === symtype(x) === Field + @assert iscall(expr) && operation(expr) == exp + @assert issym(ω) + @assert issym(x) s = only(arguments(expr)) t = h @@ -429,19 +437,22 @@ This is a core component of the Gruntz algorithm for computing limits. The coefficient of `ω^i` in the series expansion of `expr`. """ -function get_series_term(expr::BasicSymbolic{Field}, ω::BasicSymbolic{Field}, - h, i::Int, assumptions) where {Field} - exprtype(ω) == SYM || +function get_series_term(expr::BasicSymbolic{T}, ω::BasicSymbolic{T}, + h, i::Int, assumptions) where {T} + Field = symtype(expr) + @assert symtype(ω) === Field + issym(ω) || throw(ArgumentError("Must expand with respect to a symbol. Got $ω")) - et = exprtype(expr) - if et == SYM + if isconst(expr) + return get_series_term(unwrap_const(expr), ω, h, i, assumptions) + elseif issym(expr) if expr.name == ω.name i == 1 ? one(Field) : zero(Field) else i == 0 ? expr : zero(Field) end - elseif et == TERM + elseif isterm(expr) op = operation(expr) if op == log arg = only(arguments(expr)) @@ -489,12 +500,26 @@ function get_series_term(expr::BasicSymbolic{Field}, ω::BasicSymbolic{Field}, end sm end + elseif op == ^ + args = arguments(expr) + @assert length(args) == 2 + base, exponent = args + if exponent isa Real && isinteger(exponent) && exponent > 0 + t = i ÷ exponent + if t * exponent == i # integral + get_series_term(base, ω, h, t, assumptions) ^ exponent + else + zero(Field) + end + else + error("Not implemented: POW with noninteger exponent $exponent. Transform to log/exp.") + end else error("Not implemented: $op") end - elseif et == ADD + elseif isadd(expr) sum(get_series_term(arg, ω, h, i, assumptions) for arg in arguments(expr)) - elseif et == MUL + elseif ismul(expr) arg1, arg_rest = Iterators.peel(arguments(expr)) arg2 = prod(arg_rest) exponent1 = get_leading_exponent(arg1, ω, h, assumptions) @@ -506,21 +531,7 @@ function get_series_term(expr::BasicSymbolic{Field}, ω::BasicSymbolic{Field}, sm += t1 * t2 end sm - elseif et == POW - args = arguments(expr) - @assert length(args) == 2 - base, exponent = args - if exponent isa Real && isinteger(exponent) && exponent > 0 - t = i ÷ exponent - if t * exponent == i # integral - get_series_term(base, ω, h, t, assumptions) ^ exponent - else - zero(Field) - end - else - error("Not implemented: POW with noninteger exponent $exponent. Transform to log/exp.") - end - elseif et == DIV + elseif isdiv(expr) args = arguments(expr) @assert length(args) == 2 num, den = args @@ -545,12 +556,13 @@ function get_series_term(expr::BasicSymbolic{Field}, ω::BasicSymbolic{Field}, end sm / den_leading_term else - error("Unknown Expr type: $et") + error("Unknown Expr type") end end function get_series_term( - expr::Field, ω::BasicSymbolic{Field}, h, i::Int, assumptions) where {Field} - exprtype(ω) == SYM || + expr::Field, ω::BasicSymbolic{T}, h, i::Int, assumptions) where {Field, T} + @assert symtype(ω) === Field + issym(ω) || throw(ArgumentError("Must expand with respect to a symbol. Got $ω")) i == 0 ? expr : zero(Field) end @@ -577,20 +589,23 @@ Returns `Inf` if `expr` is equivalent to zero. The leading exponent `e`, or `Inf` if `expr` is equivalent to zero. """ -function get_leading_exponent(expr::BasicSymbolic{Field}, ω::BasicSymbolic{Field}, h, assumptions) where {Field} - exprtype(ω) == SYM || +function get_leading_exponent(expr::BasicSymbolic{T}, ω::BasicSymbolic{T}, h, assumptions) where {T} + Field = symtype(expr) + @assert symtype(ω) === Field + issym(ω) || throw(ArgumentError("Must expand with respect to a symbol. Got $ω")) zero_equivalence(expr, assumptions) && return Inf - et = exprtype(expr) - if et == SYM + if isconst(expr) + return get_leading_exponent(unwrap_const(expr), ω, h, assumptions) + elseif issym(expr) if expr.name == ω.name 1 else 0 end - elseif et == TERM + elseif isterm(expr) op = operation(expr) if op == log arg = only(arguments(expr)) @@ -607,10 +622,19 @@ function get_leading_exponent(expr::BasicSymbolic{Field}, ω::BasicSymbolic{Fiel end elseif op == exp 0 + elseif op == ^ + args = arguments(expr) + @assert length(args) == 2 + base, exponent = args + if exponent isa Real && isinteger(exponent) && exponent > 0 + exponent * get_leading_exponent(base, ω, h, assumptions) + else + error("Not implemented: POW with noninteger exponent $exponent. Transform to log/exp.") + end else error("Not implemented: $op") end - elseif et == ADD + elseif isadd(expr) exponent = minimum(get_leading_exponent(arg, ω, h, assumptions) for arg in arguments(expr)) for i in exponent:typemax(Int) @@ -620,18 +644,9 @@ function get_leading_exponent(expr::BasicSymbolic{Field}, ω::BasicSymbolic{Fiel end i > exponent+1000 && error("This is likely due to known zero_equivalence bugs") end - elseif et == MUL + elseif ismul(expr) sum(get_leading_exponent(arg, ω, h, assumptions) for arg in arguments(expr)) - elseif et == POW # This is not an idiomatic representation of powers. Avoid it if possible. - args = arguments(expr) - @assert length(args) == 2 - base, exponent = args - if exponent isa Real && isinteger(exponent) && exponent > 0 - exponent * get_leading_exponent(base, ω, h, assumptions) - else - error("Not implemented: POW with noninteger exponent $exponent. Transform to log/exp.") - end - elseif et == DIV + elseif isdiv(expr) args = arguments(expr) @assert length(args) == 2 num, den = args @@ -639,11 +654,12 @@ function get_leading_exponent(expr::BasicSymbolic{Field}, ω::BasicSymbolic{Fiel get_leading_exponent(num, ω, h, assumptions) - get_leading_exponent(den, ω, h, assumptions) else - error("Unknown Expr type: $et") + error("Unknown Expr type") end end -function get_leading_exponent(expr::Field, ω::BasicSymbolic{Field}, h, assumptions) where {Field} - exprtype(ω) == SYM || +function get_leading_exponent(expr::Field, ω::BasicSymbolic{T}, h, assumptions) where {Field, T} + @assert symtype(ω) === Field + issym(ω) || throw(ArgumentError("Must expand with respect to a symbol. Got $ω")) zero_equivalence(expr, assumptions) ? Inf : 0 end @@ -651,7 +667,7 @@ end _log(x) = _log(x, nothing, nothing) _log(x, ω, h) = log(x) function _log(x::BasicSymbolic, ω, h) - exprtype(x) == TERM && operation(x) == exp && return only(arguments(x)) + iscall(x) && operation(x) == exp && return only(arguments(x)) x === ω && return h log(x) end