|
| 1 | +if isdefined(Base, :get_extension) |
| 2 | + import ForwardDiff |
| 3 | + import ChainRulesTestUtils |
| 4 | + |
| 5 | + @testset "ForwardDiff" begin |
| 6 | + @test ForwardDiff.derivative(PolyLog.li0, float(pi)) == 1/(1 - pi)^2 |
| 7 | + @test ForwardDiff.derivative(PolyLog.li0, 0.0) == 1 |
| 8 | + @test ForwardDiff.derivative(PolyLog.li0, 1.0) == Inf |
| 9 | + |
| 10 | + @test ForwardDiff.derivative(PolyLog.reli1, float(pi)) == 1/(1 - pi) |
| 11 | + @test ForwardDiff.derivative(PolyLog.reli1, 0.0) == 1.0 |
| 12 | + ChainRulesTestUtils.test_frule(PolyLog.reli1, 0.0) |
| 13 | + ChainRulesTestUtils.test_rrule(PolyLog.reli1, float(pi)) |
| 14 | + |
| 15 | + @test ForwardDiff.derivative(PolyLog.reli2, float(pi)) == PolyLog.reli1(pi)/pi |
| 16 | + @test ForwardDiff.derivative(PolyLog.reli2, 0.0) == 1.0 |
| 17 | + ChainRulesTestUtils.test_frule(PolyLog.reli2, 0.0) |
| 18 | + ChainRulesTestUtils.test_rrule(PolyLog.reli2, float(pi)) |
| 19 | + |
| 20 | + @test ForwardDiff.derivative(PolyLog.reli3, float(pi)) == PolyLog.reli2(pi)/pi |
| 21 | + @test ForwardDiff.derivative(PolyLog.reli3, 0.0) == 1.0 |
| 22 | + ChainRulesTestUtils.test_frule(PolyLog.reli3, 0.0) |
| 23 | + ChainRulesTestUtils.test_rrule(PolyLog.reli3, float(pi)) |
| 24 | + |
| 25 | + @test ForwardDiff.derivative(PolyLog.reli4, float(pi)) == PolyLog.reli3(pi)/pi |
| 26 | + @test ForwardDiff.derivative(PolyLog.reli4, 0.0) == 1.0 |
| 27 | + ChainRulesTestUtils.test_frule(PolyLog.reli4, 0.0) |
| 28 | + ChainRulesTestUtils.test_rrule(PolyLog.reli4, float(pi)) |
| 29 | + |
| 30 | + for n in vcat(collect(-10:10), [100, 1000000]) |
| 31 | + @test ForwardDiff.derivative(z -> PolyLog.reli(n, z), float(pi)) == PolyLog.reli(n - 1, pi)/pi |
| 32 | + @test ForwardDiff.derivative(z -> PolyLog.reli(n, z), 0.0) == 1.0 |
| 33 | + ChainRulesTestUtils.test_frule(PolyLog.reli, n, 0.0) |
| 34 | + ChainRulesTestUtils.test_rrule(PolyLog.reli, n, float(pi)) |
| 35 | + end |
| 36 | + end |
| 37 | +end |
0 commit comments