Skip to content

Commit e45cf8c

Browse files
authored
Refactor tests of ForwardDiff and ChainRules (#42)
1 parent 4239684 commit e45cf8c

8 files changed

Lines changed: 38 additions & 51 deletions

File tree

test/ForwardDiff.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
if isdefined(Base, :get_extension)
2+
import ForwardDiff
3+
import ChainRulesTestUtils
4+
5+
@testset "ForwardDiff" begin
6+
@test ForwardDiff.derivative(PolyLog.li0, float(pi)) == 1/(1 - pi)^2
7+
@test ForwardDiff.derivative(PolyLog.li0, 0.0) == 1
8+
@test ForwardDiff.derivative(PolyLog.li0, 1.0) == Inf
9+
10+
@test ForwardDiff.derivative(PolyLog.reli1, float(pi)) == 1/(1 - pi)
11+
@test ForwardDiff.derivative(PolyLog.reli1, 0.0) == 1.0
12+
ChainRulesTestUtils.test_frule(PolyLog.reli1, 0.0)
13+
ChainRulesTestUtils.test_rrule(PolyLog.reli1, float(pi))
14+
15+
@test ForwardDiff.derivative(PolyLog.reli2, float(pi)) == PolyLog.reli1(pi)/pi
16+
@test ForwardDiff.derivative(PolyLog.reli2, 0.0) == 1.0
17+
ChainRulesTestUtils.test_frule(PolyLog.reli2, 0.0)
18+
ChainRulesTestUtils.test_rrule(PolyLog.reli2, float(pi))
19+
20+
@test ForwardDiff.derivative(PolyLog.reli3, float(pi)) == PolyLog.reli2(pi)/pi
21+
@test ForwardDiff.derivative(PolyLog.reli3, 0.0) == 1.0
22+
ChainRulesTestUtils.test_frule(PolyLog.reli3, 0.0)
23+
ChainRulesTestUtils.test_rrule(PolyLog.reli3, float(pi))
24+
25+
@test ForwardDiff.derivative(PolyLog.reli4, float(pi)) == PolyLog.reli3(pi)/pi
26+
@test ForwardDiff.derivative(PolyLog.reli4, 0.0) == 1.0
27+
ChainRulesTestUtils.test_frule(PolyLog.reli4, 0.0)
28+
ChainRulesTestUtils.test_rrule(PolyLog.reli4, float(pi))
29+
30+
for n in vcat(collect(-10:10), [100, 1000000])
31+
@test ForwardDiff.derivative(z -> PolyLog.reli(n, z), float(pi)) == PolyLog.reli(n - 1, pi)/pi
32+
@test ForwardDiff.derivative(z -> PolyLog.reli(n, z), 0.0) == 1.0
33+
ChainRulesTestUtils.test_frule(PolyLog.reli, n, 0.0)
34+
ChainRulesTestUtils.test_rrule(PolyLog.reli, n, float(pi))
35+
end
36+
end
37+
end

test/Li.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,6 @@ end
9595
@test PolyLog.li(n, 1//1 + 0//1im) zeta
9696
@test PolyLog.li(n, 1 + 0im) zeta
9797
@test PolyLog.li(n, BigFloat("1.0") + 0im) == PolyLog.zeta(n, BigFloat)
98-
99-
# ForwardDiff Test
100-
if isdefined(Base, :get_extension)
101-
@test ForwardDiff.derivative(z -> PolyLog.reli(n, z), float(pi)) == PolyLog.reli(n - 1, pi)/pi
102-
@test ForwardDiff.derivative(z -> PolyLog.reli(n, z), 0.0) == 1.0
103-
ChainRulesTestUtils.test_frule(PolyLog.reli, n, 0.0)
104-
ChainRulesTestUtils.test_rrule(PolyLog.reli, n, float(pi))
105-
end
10698
end
10799

108100
# value close to boundary between series 1 and 2 in arXiv:2010.09860

test/Li0.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,4 @@
5050
# test value that causes overflow if squared
5151
@test PolyLog.li0(1e300 + 1im) -1.0 rtol=eps(Float64)
5252
@test PolyLog.li0(1.0 + 1e300im) -1.0 rtol=eps(Float64)
53-
54-
# ForwardDiff Test
55-
if isdefined(Base, :get_extension)
56-
@test ForwardDiff.derivative(PolyLog.li0, float(pi)) == 1/(1 - pi)^2
57-
@test ForwardDiff.derivative(PolyLog.li0, 0.0) == 1
58-
@test ForwardDiff.derivative(PolyLog.li0, 1.0) == Inf
59-
end
6053
end

test/Li1.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,4 @@
5959
# test value that causes overflow if squared
6060
@test PolyLog.li1(1e300 + 1im) -690.77552789821371 + 3.14159265358979im rtol=eps(Float64)
6161
@test PolyLog.li1(1.0 + 1e300im) -690.77552789821371 + 1.5707963267948966im rtol=eps(Float64)
62-
63-
# ForwardDiff Test
64-
if isdefined(Base, :get_extension)
65-
@test ForwardDiff.derivative(PolyLog.reli1, float(pi)) == 1/(1 - pi)
66-
@test ForwardDiff.derivative(PolyLog.reli1, 0.0) == 1.0
67-
ChainRulesTestUtils.test_frule(PolyLog.reli1, 0.0)
68-
ChainRulesTestUtils.test_rrule(PolyLog.reli1, float(pi))
69-
end
7062
end

test/Li2.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,4 @@ end
119119
# test value that causes overflow if squared
120120
@test PolyLog.li2(1e300 + 1im) -238582.12510339421 + 2170.13532372464im rtol=eps(Float64)
121121
@test PolyLog.li2(1.0 + 1e300im) -238585.82620504462 + 1085.06766186232im rtol=eps(Float64)
122-
123-
# ForwardDiff Test
124-
if isdefined(Base, :get_extension)
125-
@test ForwardDiff.derivative(PolyLog.reli2, float(pi)) == PolyLog.reli1(pi)/pi
126-
@test ForwardDiff.derivative(PolyLog.reli2, 0.0) == 1.0
127-
ChainRulesTestUtils.test_frule(PolyLog.reli2, 0.0)
128-
ChainRulesTestUtils.test_rrule(PolyLog.reli2, float(pi))
129-
end
130122
end

test/Li3.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,4 @@
6060
# test value that causes overflow if squared
6161
@test PolyLog.li3(1e300 + 1im) -5.4934049431527088e7 + 749538.186928224im rtol=eps(Float64)
6262
@test PolyLog.li3(1.0 + 1e300im) -5.4936606061973454e7 + 374771.031356405im rtol=eps(Float64)
63-
64-
# ForwardDiff Test
65-
if isdefined(Base, :get_extension)
66-
@test ForwardDiff.derivative(PolyLog.reli3, float(pi)) == PolyLog.reli2(pi)/pi
67-
@test ForwardDiff.derivative(PolyLog.reli3, 0.0) == 1.0
68-
ChainRulesTestUtils.test_frule(PolyLog.reli3, 0.0)
69-
ChainRulesTestUtils.test_rrule(PolyLog.reli3, float(pi))
70-
end
7163
end

test/Li4.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,4 @@
5252
# test value that causes overflow if squared
5353
@test PolyLog.li4(1e300 + 1im) -9.4863817894708364e9 + 1.725875455850714e8im rtol=eps(Float64)
5454
@test PolyLog.li4(1.0 + 1e300im) -9.4872648206269765e9 + 8.62951114411071e7im rtol=eps(Float64)
55-
56-
# ForwardDiff Test
57-
if isdefined(Base, :get_extension)
58-
@test ForwardDiff.derivative(PolyLog.reli4, float(pi)) == PolyLog.reli3(pi)/pi
59-
@test ForwardDiff.derivative(PolyLog.reli4, 0.0) == 1.0
60-
ChainRulesTestUtils.test_frule(PolyLog.reli4, 0.0)
61-
ChainRulesTestUtils.test_rrule(PolyLog.reli4, float(pi))
62-
end
6355
end

test/runtests.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
using Test
22
import PolyLog
3-
if isdefined(Base, :get_extension)
4-
import ForwardDiff
5-
import ChainRulesTestUtils
6-
end
73

84
include("TestPrecision.jl")
95
include("DataReader.jl")
@@ -12,6 +8,7 @@ include("Digamma.jl")
128
include("Dual.jl")
139
include("Eta.jl")
1410
include("Factorial.jl")
11+
include("ForwardDiff.jl")
1512
include("Harmonic.jl")
1613
include("Li0.jl")
1714
include("Li1.jl")

0 commit comments

Comments
 (0)