Coding for Performance

4: More Examples

Author

Paul Schrimpf

Published

April 26, 2024

For the assignment, most students focused on optimizing either the share function from our earlier BLP example, or the statparts function from GMMInference.

Let’s see how we can improve these functions’ performance.

share

Here’s a copy of the original version of the share function

using LinearAlgebra, Distributions, Random, BenchmarkTools, Profile, ProfileCanvas

function share(δ, Σ, x, ∫)
  J,K = size(x)
  (length(δ) == J) || error("length(δ)=$(length(δ)) != size(x,1)=$J")
  (K,K) == size(Σ) || error("size(x,2)=$K != size(Σ)=$(size(Σ))")
  function shareν(ν)
    s = δ .+ x*Σ*ν
    smax=max(0,maximum(s))
    s .-= smax
    s .= exp.(s)
    s ./= (sum(s) + exp(0-smax))
    return(s)
  end
  return((shareν))
end
share (generic function with 1 method)

and a simple function for the integral.

struct Integrator{Tx, Tw}
    x::Tx
    w::Tw
end

function Integrator(dx::Distribution, n=100)
    x = [rand(dx) for _ in 1:n]
    w = Base.Iterators.Repeated(1/n)
    Integrator(x,w)
end

(∫::Integrator)(f) = sum((xw)->f(xw[1])*xw[2], zip(∫.x, ∫.w))

For simplicity and to focus on the share function, we just use Monte Carlo integration since its easy to implement, but we could get more accuracy by using quasi monte carlo or quadrature.

Initial Benchmark

J = 10
K = 5
δ = rand(J)
X = randn(J,K)
Σ = I + zeros(K,K)
= Integrator(MvNormal(zeros(K),I));
@benchmark share(δ,Σ,X,∫)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  46.078 μs …  10.655 ms  ┊ GC (min … max): 0.00% … 98.7
0%
 Time  (median):     55.581 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   67.335 μs ± 199.461 μs  ┊ GC (mean ± σ):  8.98% ±  3.3
8%

     ▁▁    ▂▃▅▇█▇▄▄▄▂▁▁                    ▁▂▂  ▂▂▄▄▃▃▄▂▁▁     ▂
  ▆▇████████████████████▇▆▅▇▇▆▆▆▅▆▆▆▄▅▃▄▄▁▇█████████████████▇▆ █
  46.1 μs       Histogram: log(frequency) by time      87.2 μs <

 Memory estimate: 65.48 KiB, allocs estimate: 499.

Profiling it:

using Profile
Profile.clear();
Profile.init(n=10^7,delay=0.00001);
@profile sum(share(δ,Σ,X,∫) for i  1:1000)
profilehtmlstring()

It’s a good idea to check @code_warntype and verify that there are no type instabilities (variables with type Any or Union). I checked and there are no type instabilites here.

LoopVectorization

Examining the profile, we see that the multiplication of x*Σ*ν is the single most costly operation. The second most costly are lines computing the softmax function, exp.(s)./(1 .+exp.(s))

using LoopVectorization

@inline function dplusxMy(d::AbstractVector,x::AbstractMatrix,M::AbstractMatrix,y::AbstractVector)
  out = similar(d)
  @turbo for i  axes(x,1)
    r = d[i]
    for j  axes(x,2)
      for k  axes(M,2)
        r += x[i,j]*M[j,k]*y[k]
      end
    end
    out[i] = r
  end
  out
end
@inline function softmax0!(s)
  smax=@turbo reduce(max, s, init=zero(eltype(s)))
  s .= exp.(s .- smax)
  s ./= (sum(s) + exp(zero(smax)-smax))
  return(s)
end
function share_v2(δ, Σ, x, ∫)
  J,K = size(x)
  (length(δ) == J) || error("length(δ)=$(length(δ)) != size(x,1)=$J")
  (K,K) == size(Σ) || error("size(x,2)=$K != size(Σ)=$(size(Σ))")
  function shareν(ν)
    s = dplusxMy(δ,x,Σ,ν)
    softmax0!(s)
    s
  end
  return((shareν))
end
@benchmark share_v2(δ,Σ,X,∫)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  28.545 μs …   4.955 ms  ┊ GC (min … max): 0.00% … 98.4
1%
 Time  (median):     33.795 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   41.119 μs ± 117.942 μs  ┊ GC (mean ± σ):  6.94% ±  2.4
1%

             ▆█▂                                                
  ▂▂▂▂▂▂▂▂▂▃▇███▄▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁▂▂▂▃▃▂▄▄▃▃▃▄▅▃ ▃
  28.5 μs         Histogram: frequency by time         53.3 μs <

 Memory estimate: 42.05 KiB, allocs estimate: 299.

By writing the δ+x*Σ*ν as loop we avoid some allocations, and can use @turbo to insert SIMD instructions. This gives a noticeable speedup. About half the gains here are from writing out the loops in dplusxMy and half from using @turbo. The changes to the softmax calculation did not make much difference.

using Profile
Profile.clear();
Profile.init(n=10^7,delay=0.00001);
ProfileCanvas.@profview sum(share_v2(δ,Σ,X,∫) for i  1:1000)
profilehtmlstring()

The execution time is still dominated by the two functions mentioned above.

StaticArrays

Since the calculations involve small arrays, it is likely to benefit from using StaticArrays.

using StaticArrays
= SVector{J}(δ)
= SMatrix{K,K}(Σ)
sX = SMatrix{J,K}(X)
nd = length(∫.x)
iw = SVector{nd}(fill(1/nd,nd))
ix = [SVector{K}(x) for x  ∫.x]
s∫ = Integrator(ix,iw)
@benchmark share_v2(sδ,sΣ,sX,s∫)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  18.543 μs …   5.089 ms  ┊ GC (min … max): 0.00% … 98.5
7%
 Time  (median):     23.061 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   28.211 μs ± 110.996 μs  ┊ GC (mean ± σ):  8.71% ±  2.2
0%

               ▄█▆▂                                             
  ▂▃▃▃▃▃▃▂▂▂▃▄▇████▆▃▂▂▂▂▂▁▂▂▂▁▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▂▂▂▂▃▇▆▄▄▇▅▃▂▂▂ ▃
  18.5 μs         Histogram: frequency by time         36.1 μs <

 Memory estimate: 28.03 KiB, allocs estimate: 299.

The code is now 2 times faster than what we started with.

Interestly, slightly better performance can be achieved by simply passing StaticArrays to the original code.

@benchmark share($(MVector(sδ)),sΣ,sX,s∫) # δ gets mutated, so we must make it a Mutable SArray.
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  15.258 μs …   5.028 ms  ┊ GC (min … max): 0.00% … 98.7
4%
 Time  (median):     18.066 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   25.243 μs ± 111.254 μs  ┊ GC (mean ± σ):  9.76% ±  2.2
1%

         ▇█▁                                                    
  ▂▃▃▃▃▂▃███▅▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▁▁▁▁▂▁▁▂▁▁▁▂▂▃▅▅▆▆▅▅▆▄▂▂▂▂▂ ▃
  15.3 μs         Histogram: frequency by time         33.2 μs <

 Memory estimate: 28.03 KiB, allocs estimate: 299.

Non-allocating

We can do still better by eliminating the allocations. With normal arrays, broadcast operations are usually non-allocating, and normal array operations allocate. With StaticArrays, normal operatorations do not allocate.1

Additionally, we can precompute x*Σ outside of shareν (we could have done this at any time earlier too). This change accounts for about 1μs of the speed up.

function share_v3(δ, Σ, x, ∫)
  J,K = size(x)
  (length(δ) == J) || error("length(δ)=$(length(δ)) != size(x,1)=$J")
  (K,K) == size(Σ) || error("size(x,2)=$K != size(Σ)=$(size(Σ))")
= x*Σ
  function shareν(ν)
    s = δ +*ν
    smax=max(0,maximum(s))
    s -= smax*ones(typeof(s))
    s = exp.(s)
    s *= 1/(sum(s) + exp(0-smax))
    return(s)
  end
  return((shareν))
end
share_v3 (generic function with 1 method)
@benchmark share_v3(sδ,sΣ,sX,s∫)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  9.609 μs …  24.355 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     9.706 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   9.743 μs ± 337.037 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

         ▁▂▃█▁▂▅                                               
  ▁▁▁▁▂▄▄███████▇▆▇▄▅▆▅▇█▅▆▆▄▅▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
  9.61 μs         Histogram: frequency by time          10 μs <

 Memory estimate: 96 bytes, allocs estimate: 1.

The code is now 5 times faster than the original, and only allocates once (to store the return value).

Multi-Threading

Multi-threading this code is difficult because it is already quite fast. The sum in the integral can be parallelized, but unless there are a large number of integration points, the overhead from creating threads will likely outweigh the benefits.

The Polyester package provides a faster, but more limited threading model than base Julia. Using it, we can see modest gains, even with just 100 integration points.

To get type stability, I had to write the integration sum loop inside the share function. It would have been cleaner to keep the integration sum in a separate function, but I could not make it type stable that way.

import Polyester
function share_v4(δ, Σ, x, ∫, ::Val{B}=Val(length(∫.x) ÷ 10)) where {B}
  J,K = size(x)
  (length(δ) == J) || error("length(δ)=$(length(δ)) != size(x,1)=$J")
  (K,K) == size(Σ) || error("size(x,2)=$K != size(Σ)=$(size(Σ))")
= x*Σ
  function shareν(ν)
    s = δ +*ν
    smax=max(0,maximum(s))
    s -= smax*ones(typeof(s))
    s = exp.(s)
    s *= 1/(sum(s) + exp(0-smax))
    return(s)
  end
  batchlen= length(∫.x)÷B
  @assert B*batchlen==length(∫.x)
  out = MVector{B,typeof(δ)}(undef)
  Polyester.@batch for b  1:B  
    batch = ((b-1)*(batchlen)+1):(b*batchlen)
    out[b] = zero(typeof(δ))
    for i  batch
      out[b] += shareν(∫.x[i])*∫.w[i]
    end
  end
  return(sum(out))
end
Polyester.reset_threads!()
@benchmark share_v4(sδ,sΣ,sX,s∫, Val(20))
BenchmarkTools.Trial: 10000 samples with 8 evaluations.
 Range (min … max):  4.327 μs …  1.581 ms  ┊ GC (min … max): 0.00% … 37.99%
 Time  (median):     4.769 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   5.510 μs ± 22.203 μs  ┊ GC (mean ± σ):  2.16% ±  0.53%

       ▄▆█▇▄▁                                 ▁▂▁             
  ▁▂▂▅███████▇▅▄▄▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▄▆▇████▇▆▄▄▃▂▂▁▁▁ ▃
  4.33 μs        Histogram: frequency by time        6.39 μs <

 Memory estimate: 1.55 KiB, allocs estimate: 2.

statparts

A few people worked on optimizing the klm or clr function from GMMInference.jl. Here is the original code for klm.

using ForwardDiff
function statparts(gi::Function)
  function P(A::AbstractMatrix) # projection matrix
    A*pinv(A'*A)*A'
  end
  function(θ)
    giθ = gi(θ)
    p = length(θ)
    (n, k) = size(giθ)
    Ω = cov(giθ)
    gn=mean(gi(θ), dims=1)'
    Gi= ForwardDiff.jacobian(gi,θ)
    Gi = reshape(Gi, n , k, p)
    G = mean(Gi, dims=1)
    Γ = zeros(eltype(Gi),p,k,k)
    D = zeros(eltype(Gi),k, p)
    for j in 1:p
      for i in 1:n
        Γ[j,:,:] += (Gi[i,:,j] .- G[1,:,j]) * giθ[i,:]'
      end
      Γ[j,:,:] ./= n
      D[:,j] = G[1,:,j] - Γ[j,:,:]*inv(Ω)*gn
    end
    return(n,k,p,gn, Ω, D, P)
  end
end

function klm(gi::Function)
  SP = statparts(gi)
  function(θ)
    (n,k,p,gn, Ω, D, P) = SP(θ)
    return n*(gn'*Ω^(-1/2)*P^(-1/2)*D)*Ω^(-1/2)*gn)[1]
  end
end
klm (generic function with 1 method)

To run the code, we need an example gi function. We’ll just copy the example from the docs.

import Random
function simulate_ivshare(n,β,γ,ρ)
  z = randn(n, size(γ)[1])
  endo = randn(n, length(β))
  x = z*γ .+ endo
  ξ = rand(Normal(0,sqrt((1.0-ρ^2))),n).+endo[:,1]*ρ
  y = cdf.(Logistic(), x*β .+ ξ)
  return((y=y,x=x,z=z))
end
n = 100
k = 2
iv = 3
β0 = ones(k)
π0 = vcat(5*I,ones(iv-k,k))
ρ = 0.5
Random.seed!(622)
(y,x,z) = simulate_ivshare(n,β00,ρ)

function gi_ivshare(β,y,x,z)
  ξ = quantile.(Logistic(),y) .- x*β
  ξ.*z
end

gi = let y=y, x=x, z=z
  β->gi_ivshare(β,y,x,z)
end
#26 (generic function with 1 method)

Initial Benchmark

@benchmark klm(gi)(β0)
BenchmarkTools.Trial: 9747 samples with 1 evaluation.
 Range (min … max):  458.320 μs …  14.717 ms  ┊ GC (min … max): 0.00% … 95.
33%
 Time  (median):     473.890 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   506.388 μs ± 513.723 μs  ┊ GC (mean ± σ):  5.24% ±  4.
99%

        ▁▄▇█▇▆▄▂                                                 
  ▂▂▂▃▄▆█████████▇▆▆▅▅▅▄▄▄▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂ ▃
  458 μs           Histogram: frequency by time          542 μs <

 Memory estimate: 211.50 KiB, allocs estimate: 2759.

Fixing Type Instabilities

From @code_warntype, we see that the compiles is unable to infer the type of some variables. The problem seems to start with D. This is quite puzzling because D is explicitly initialized as zeros(eltype(Gi),...).

julia> @code_warntype klm(gi)(β0)
MethodInstance for (::var"#24#25"{var"#21#23"{var"#26#27"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}, var"#P#22"}})(::Vector{Float64})
  from (::var"#24#25")(θ) @ Main ~/.julia/dev/ARGridBootstrap/docs/jmd/assignment.jmd:31
Arguments
  #self#::var"#24#25"{var"#21#23"{var"#26#27"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}, var"#P#22"}}
  θ::Vector{Float64}
Locals
  @_3::Int64
  P::var"#P#22"
  D::ANY
  Ω::Matrix{Float64}
  gn::Adjoint{Float64, Matrix{Float64}}
  p::Int64
  k::Int64
  n::Int64
Body::ANY
1%1  = Core.getfield(#self#, :SP)::var"#21#23"{var"#26#27"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}, var"#P#22"}
%2  = (%1)(θ)::TUPLE{INT64, INT64, INT64, ADJOINT{FLOAT64, MATRIX{FLOAT64}}, MATRIX{FLOAT64}, ANY, VAR"#P#22"}
%3  = Base.indexed_iterate(%2, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│         (n = Core.getfield(%3, 1))
│         (@_3 = Core.getfield(%3, 2))
%6  = Base.indexed_iterate(%2, 2, @_3::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│         (k = Core.getfield(%6, 1))
│         (@_3 = Core.getfield(%6, 2))
%9  = Base.indexed_iterate(%2, 3, @_3::Core.Const(3))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(4)])
│         (p = Core.getfield(%9, 1))
│         (@_3 = Core.getfield(%9, 2))
%12 = Base.indexed_iterate(%2, 4, @_3::Core.Const(4))::Core.PartialStruct(Tuple{Adjoint{Float64, Matrix{Float64}}, Int64}, Any[Adjoint{Float64, Matrix{Float64}}, Core.Const(5)])
│         (gn = Core.getfield(%12, 1))
│         (@_3 = Core.getfield(%12, 2))
%15 = Base.indexed_iterate(%2, 5, @_3::Core.Const(5))::Core.PartialStruct(Tuple{Matrix{Float64}, Int64}, Any[Matrix{Float64}, Core.Const(6)])
│         (Ω = Core.getfield(%15, 1))
│         (@_3 = Core.getfield(%15, 2))
%18 = Base.indexed_iterate(%2, 6, @_3::Core.Const(6))::Core.PartialStruct(Tuple{Any, Int64}, Any[Any, Core.Const(7)])
│         (D = Core.getfield(%18, 1))
│         (@_3 = Core.getfield(%18, 2))
%21 = Base.indexed_iterate(%2, 7, @_3::Core.Const(7))::Core.PartialStruct(Tuple{var"#P#22", Int64}, Any[var"#P#22", Core.Const(8)])
│         (P = Core.getfield(%21, 1))
%23 = n::Int64
%24 = Main.:var"'"(gn)::Matrix{Float64}
%25 = Ω::Matrix{Float64}
%26 = (-1 / 2)::Core.Const(-0.5)
%27 = (%25 ^ %26)::ANY
%28 = Ω::Matrix{Float64}
%29 = (-1 / 2)::Core.Const(-0.5)
%30 = (%28 ^ %29)::ANY
%31 = (%30 * D)::ANY
%32 = (P)(%31)::ANY
%33 = Ω::Matrix{Float64}
%34 = (-1 / 2)::Core.Const(-0.5)
%35 = (%33 ^ %34)::ANY
%36 = (%24 * %27 * %32 * %35 * gn)::ANY
%37 = Base.getindex(%36, 1)::ANY
%38 = (%23 * %37)::ANY
└──       return %38

To investigate further, let us focus on statparts.

julia> @code_warntype statparts(gi)(β0)
MethodInstance for (::var"#21#23"{var"#26#27"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}, var"#P#22"})(::Vector{Float64})
  from (::var"#21#23")(θ) @ Main ~/.julia/dev/ARGridBootstrap/docs/jmd/assignment.jmd:7
Arguments
  #self#::var"#21#23"{var"#26#27"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}, var"#P#22"}
  θ::Vector{Float64}
Locals
  @_3::UNION{NOTHING, TUPLE{INT64, INT64}}
  @_4::Int64
  D::ANY
  Γ::ANY
  G::ANY
  Gi::ANY
  gn::Adjoint{Float64, Matrix{Float64}}
  Ω::Matrix{Float64}
  k::Int64
  n::Int64
  p::Int64
  giθ::Matrix{Float64}
  @_15::UNION{NOTHING, TUPLE{INT64, INT64}}
  j::Int64
  i::Int64
Body::TUPLE{INT64, INT64, INT64, ADJOINT{FLOAT64, MATRIX{FLOAT64}}, MATRIX{FLOAT64}, ANY, VAR"#P#22"}
1%1  = Core.getfield(#self#, :gi)::var"#26#27"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}
│         (giθ = (%1)(θ))
│         (p = Main.length(θ))
%4  = Main.size(giθ)::Tuple{Int64, Int64}
%5  = Base.indexed_iterate(%4, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│         (n = Core.getfield(%5, 1))
│         (@_4 = Core.getfield(%5, 2))
%8  = Base.indexed_iterate(%4, 2, @_4::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│         (k = Core.getfield(%8, 1))
│         (Ω = Main.cov(giθ))
%11 = Core.getfield(#self#, :gi)::var"#26#27"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}
%12 = (%11)(θ)::Matrix{Float64}
%13 = (:dims,)::Core.Const((:dims,))
%14 = Core.apply_type(Core.NamedTuple, %13)::Core.Const(NamedTuple{(:dims,)})
%15 = Core.tuple(1)::Core.Const((1,))
%16 = (%14)(%15)::Core.Const((dims = 1,))
%17 = Core.kwcall(%16, Main.mean, %12)::Matrix{Float64}
│         (gn = Main.:var"'"(%17))
%19 = ForwardDiff.jacobian::Core.Const(ForwardDiff.jacobian)
%20 = Core.getfield(#self#, :gi)::var"#26#27"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}
│         (Gi = (%19)(%20, θ))
│         (Gi = Main.reshape(Gi, n, k, p))
%23 = (:dims,)::Core.Const((:dims,))
%24 = Core.apply_type(Core.NamedTuple, %23)::Core.Const(NamedTuple{(:dims,)})
%25 = Core.tuple(1)::Core.Const((1,))
%26 = (%24)(%25)::Core.Const((dims = 1,))
│         (G = Core.kwcall(%26, Main.mean, Gi))
%28 = Main.eltype(Gi)::ANY
%29 = p::Int64
%30 = k::Int64
│         (Γ = Main.zeros(%28, %29, %30, k))
%32 = Main.eltype(Gi)::ANY
%33 = k::Int64
│         (D = Main.zeros(%32, %33, p))
%35 = (1:p)::Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])
│         (@_3 = Base.iterate(%35))
%37 = (@_3 === nothing)::Bool
%38 = Base.not_int(%37)::Bool
└──       goto #7 if not %38
2%40 = @_3::Tuple{Int64, Int64}
│         (j = Core.getfield(%40, 1))
%42 = Core.getfield(%40, 2)::Int64
%43 = (1:n)::Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])
│         (@_15 = Base.iterate(%43))
%45 = (@_15 === nothing)::Bool
%46 = Base.not_int(%45)::Bool
└──       goto #5 if not %46
3%48 = @_15::Tuple{Int64, Int64}
│         (i = Core.getfield(%48, 1))
%50 = Core.getfield(%48, 2)::Int64
%51 = Base.getindex(Γ, j, Main.:(:), Main.:(:))::ANY
%52 = Main.:-::Core.Const(-)
%53 = Base.getindex(Gi, i, Main.:(:), j)::ANY
%54 = Base.getindex(G, 1, Main.:(:), j)::ANY
%55 = Base.broadcasted(%52, %53, %54)::ANY
%56 = Base.materialize(%55)::ANY
%57 = Base.getindex(giθ, i, Main.:(:))::Vector{Float64}
%58 = Main.:var"'"(%57)::Adjoint{Float64, Vector{Float64}}
%59 = (%56 * %58)::ANY
%60 = (%51 + %59)::ANY
Base.setindex!(Γ, %60, j, Main.:(:), Main.:(:))
│         (@_15 = Base.iterate(%43, %50))
%63 = (@_15 === nothing)::Bool
%64 = Base.not_int(%63)::Bool
└──       goto #5 if not %64
4 ─       goto #3
5%67 = Base.dotview(Γ, j, Main.:(:), Main.:(:))::ANY
%68 = Main.:/::Core.Const(/)
%69 = Base.getindex(Γ, j, Main.:(:), Main.:(:))::ANY
%70 = Base.broadcasted(%68, %69, n)::ANY
Base.materialize!(%67, %70)
%72 = Base.getindex(G, 1, Main.:(:), j)::ANY
%73 = Base.getindex(Γ, j, Main.:(:), Main.:(:))::ANY
%74 = Main.inv(Ω)::Matrix{Float64}
%75 = (%73 * %74 * gn)::ANY
%76 = (%72 - %75)::ANY
Base.setindex!(D, %76, Main.:(:), j)
│         (@_3 = Base.iterate(%35, %42))
%79 = (@_3 === nothing)::Bool
%80 = Base.not_int(%79)::Bool
└──       goto #7 if not %80
6 ─       goto #2
7%83 = n::Int64
%84 = k::Int64
%85 = p::Int64
%86 = gn::Adjoint{Float64, Matrix{Float64}}
%87 = Ω::Matrix{Float64}
%88 = D::ANY
%89 = Core.getfield(#self#, :P)::Core.Const(var"#P#22"())
%90 = Core.tuple(%83, %84, %85, %86, %87, %88, %89)::TUPLE{INT64, INT64, INT64, ADJOINT{FLOAT64, MATRIX{FLOAT64}}, MATRIX{FLOAT64}, ANY, VAR"#P#22"}
└──       return %90

We see that G, Gi, Γ, and D are all type Any. For some reason, the return value of ForwardDiff.jacobian is not being inferred. We can workaround this by using an jacobian! instead.

function statparts(gi::F) where {F <: Function}
  function P(A::AbstractMatrix) # projection matrix
    A*pinv(A'*A)*A'
  end
  let gi=gi
    function(θ)
      giθ = gi(θ)
      p = length(θ)
      (n, k) = size(giθ)
      Ω = Hermitian(cov(giθ))
      gn=mean(gi(θ), dims=1)'
      Gi = zeros(n,k,p)
      ForwardDiff.jacobian!(Gi,gi,θ)
      Gi = reshape(Gi, n , k, p)
      G = mean(Gi, dims=1)
      Γ = zeros(eltype(Gi),p,k,k)
      D = zeros(eltype(Gi),k, p)
      for j in 1:p
        for i in 1:n
          Γ[j,:,:] += (Gi[i,:,j] .- G[1,:,j]) * giθ[i,:]'
        end
        Γ[j,:,:] ./= n
        D[:,j] = G[1,:,j] - Γ[j,:,:]*inv(Ω)*gn
      end
      return(n,k,p,gn, Ω, D, P)
    end
  end
end
statparts (generic function with 1 method)
julia> @code_warntype statparts(gi)(β0)
MethodInstance for (::var"#28#30"{var"#P#29", var"#26#27"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}})(::Vector{Float64})
  from (::var"#28#30")(θ) @ Main ~/.julia/dev/ARGridBootstrap/docs/jmd/assignment.jmd:7
Arguments
  #self#::var"#28#30"{var"#P#29", var"#26#27"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}}
  θ::Vector{Float64}
Locals
  @_3::UNION{NOTHING, TUPLE{INT64, INT64}}
  @_4::Int64
  D::Matrix{Float64}
  Γ::Array{Float64, 3}
  G::Array{Float64, 3}
  Gi::Array{Float64, 3}
  gn::Adjoint{Float64, Matrix{Float64}}
  Ω::Hermitian{Float64, Matrix{Float64}}
  k::Int64
  n::Int64
  p::Int64
  giθ::Matrix{Float64}
  @_15::UNION{NOTHING, TUPLE{INT64, INT64}}
  j::Int64
  i::Int64
Body::Tuple{Int64, Int64, Int64, Adjoint{Float64, Matrix{Float64}}, Hermitian{Float64, Matrix{Float64}}, Matrix{Float64}, var"#P#29"}
1%1  = Core.getfield(#self#, Symbol("#457#gi"))::var"#26#27"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}
│         (giθ = (%1)(θ))
│         (p = Main.length(θ))
%4  = Main.size(giθ)::Tuple{Int64, Int64}
%5  = Base.indexed_iterate(%4, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│         (n = Core.getfield(%5, 1))
│         (@_4 = Core.getfield(%5, 2))
%8  = Base.indexed_iterate(%4, 2, @_4::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│         (k = Core.getfield(%8, 1))
%10 = Main.cov(giθ)::Matrix{Float64}
│         (Ω = Main.Hermitian(%10))
%12 = Core.getfield(#self#, Symbol("#457#gi"))::var"#26#27"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}
%13 = (%12)(θ)::Matrix{Float64}
%14 = (:dims,)::Core.Const((:dims,))
%15 = Core.apply_type(Core.NamedTuple, %14)::Core.Const(NamedTuple{(:dims,)})
%16 = Core.tuple(1)::Core.Const((1,))
%17 = (%15)(%16)::Core.Const((dims = 1,))
%18 = Core.kwcall(%17, Main.mean, %13)::Matrix{Float64}
│         (gn = Main.:var"'"(%18))
│         (Gi = Main.zeros(n, k, p))
%21 = ForwardDiff.jacobian!::Core.Const(ForwardDiff.jacobian!)
%22 = Gi::Array{Float64, 3}
%23 = Core.getfield(#self#, Symbol("#457#gi"))::var"#26#27"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}
│         (%21)(%22, %23, θ)
│         (Gi = Main.reshape(Gi, n, k, p))
%26 = (:dims,)::Core.Const((:dims,))
%27 = Core.apply_type(Core.NamedTuple, %26)::Core.Const(NamedTuple{(:dims,)})
%28 = Core.tuple(1)::Core.Const((1,))
%29 = (%27)(%28)::Core.Const((dims = 1,))
│         (G = Core.kwcall(%29, Main.mean, Gi))
%31 = Main.eltype(Gi)::Core.Const(Float64)
%32 = p::Int64
%33 = k::Int64
│         (Γ = Main.zeros(%31, %32, %33, k))
%35 = Main.eltype(Gi)::Core.Const(Float64)
%36 = k::Int64
│         (D = Main.zeros(%35, %36, p))
%38 = (1:p)::Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])
│         (@_3 = Base.iterate(%38))
%40 = (@_3 === nothing)::Bool
%41 = Base.not_int(%40)::Bool
└──       goto #7 if not %41
2%43 = @_3::Tuple{Int64, Int64}
│         (j = Core.getfield(%43, 1))
%45 = Core.getfield(%43, 2)::Int64
%46 = (1:n)::Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])
│         (@_15 = Base.iterate(%46))
%48 = (@_15 === nothing)::Bool
%49 = Base.not_int(%48)::Bool
└──       goto #5 if not %49
3%51 = @_15::Tuple{Int64, Int64}
│         (i = Core.getfield(%51, 1))
%53 = Core.getfield(%51, 2)::Int64
%54 = Base.getindex(Γ, j, Main.:(:), Main.:(:))::Matrix{Float64}
%55 = Main.:-::Core.Const(-)
%56 = Base.getindex(Gi, i, Main.:(:), j)::Vector{Float64}
%57 = Base.getindex(G, 1, Main.:(:), j)::Vector{Float64}
%58 = Base.broadcasted(%55, %56, %57)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(-), Tuple{Vector{Float64}, Vector{Float64}}}
%59 = Base.materialize(%58)::Vector{Float64}
%60 = Base.getindex(giθ, i, Main.:(:))::Vector{Float64}
%61 = Main.:var"'"(%60)::Adjoint{Float64, Vector{Float64}}
%62 = (%59 * %61)::Matrix{Float64}
%63 = (%54 + %62)::Matrix{Float64}
Base.setindex!(Γ, %63, j, Main.:(:), Main.:(:))
│         (@_15 = Base.iterate(%46, %53))
%66 = (@_15 === nothing)::Bool
%67 = Base.not_int(%66)::Bool
└──       goto #5 if not %67
4 ─       goto #3
5%70 = Base.dotview(Γ, j, Main.:(:), Main.:(:))::SubArray{Float64, 2, Array{Float64, 3}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}}, true}
%71 = Main.:/::Core.Const(/)
%72 = Base.getindex(Γ, j, Main.:(:), Main.:(:))::Matrix{Float64}
%73 = Base.broadcasted(%71, %72, n)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(/), Tuple{Matrix{Float64}, Int64}}
Base.materialize!(%70, %73)
%75 = Base.getindex(G, 1, Main.:(:), j)::Vector{Float64}
%76 = Base.getindex(Γ, j, Main.:(:), Main.:(:))::Matrix{Float64}
%77 = Main.inv::Core.PartialStruct(Hermitian{Float64, Matrix{Float64}}, Any[Matrix{Float64}, Core.Const('U')]))::Core.PartialStruct(Hermitian{Float64, Matrix{Float64}}, Any[Matrix{Float64}, Core.Const('U')])
%78 = (%76 * %77 * gn)::Matrix{Float64}
%79 = (%75 - %78)::Matrix{Float64}
Base.setindex!(D, %79, Main.:(:), j)
│         (@_3 = Base.iterate(%38, %45))
%82 = (@_3 === nothing)::Bool
%83 = Base.not_int(%82)::Bool
└──       goto #7 if not %83
6 ─       goto #2
7%86 = n::Int64
%87 = k::Int64
%88 = p::Int64
%89 = gn::Adjoint{Float64, Matrix{Float64}}
%90 = Ω::Core.PartialStruct(Hermitian{Float64, Matrix{Float64}}, Any[Matrix{Float64}, Core.Const('U')])
%91 = D::Matrix{Float64}
%92 = Core.getfield(#self#, :P)::Core.Const(var"#P#29"())
%93 = Core.tuple(%86, %87, %88, %89, %90, %91, %92)::Core.PartialStruct(Tuple{Int64, Int64, Int64, Adjoint{Float64, Matrix{Float64}}, Hermitian{Float64, Matrix{Float64}}, Matrix{Float64}, var"#P#29"}, Any[Int64, Int64, Int64, Adjoint{Float64, Matrix{Float64}}, Core.PartialStruct(Hermitian{Float64, Matrix{Float64}}, Any[Matrix{Float64}, Core.Const('U')]), Matrix{Float64}, var"#P#29"])
└──       return %93

I also added the where {F statement to ensure compiler specialization, and added the let gi=gi line to help with the performance of captured variables.

julia> @code_warntype klm(gi)(β0)
MethodInstance for (::var"#24#25"{var"#28#30"{var"#P#29", var"#26#27"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}}})(::Vector{Float64})
  from (::var"#24#25")(θ) @ Main ~/.julia/dev/ARGridBootstrap/docs/jmd/assignment.jmd:31
Arguments
  #self#::var"#24#25"{var"#28#30"{var"#P#29", var"#26#27"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}}}
  θ::Vector{Float64}
Locals
  @_3::Int64
  P::var"#P#29"
  D::Matrix{Float64}
  Ω::Hermitian{Float64, Matrix{Float64}}
  gn::Adjoint{Float64, Matrix{Float64}}
  p::Int64
  k::Int64
  n::Int64
Body::UNION{FLOAT64, COMPLEXF64}
1%1  = Core.getfield(#self#, :SP)::var"#28#30"{var"#P#29", var"#26#27"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}}
%2  = (%1)(θ)::Core.PartialStruct(Tuple{Int64, Int64, Int64, Adjoint{Float64, Matrix{Float64}}, Hermitian{Float64, Matrix{Float64}}, Matrix{Float64}, var"#P#29"}, Any[Int64, Int64, Int64, Adjoint{Float64, Matrix{Float64}}, Core.PartialStruct(Hermitian{Float64, Matrix{Float64}}, Any[Matrix{Float64}, Core.Const('U')]), Matrix{Float64}, var"#P#29"])
%3  = Base.indexed_iterate(%2, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│         (n = Core.getfield(%3, 1))
│         (@_3 = Core.getfield(%3, 2))
%6  = Base.indexed_iterate(%2, 2, @_3::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│         (k = Core.getfield(%6, 1))
│         (@_3 = Core.getfield(%6, 2))
%9  = Base.indexed_iterate(%2, 3, @_3::Core.Const(3))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(4)])
│         (p = Core.getfield(%9, 1))
│         (@_3 = Core.getfield(%9, 2))
%12 = Base.indexed_iterate(%2, 4, @_3::Core.Const(4))::Core.PartialStruct(Tuple{Adjoint{Float64, Matrix{Float64}}, Int64}, Any[Adjoint{Float64, Matrix{Float64}}, Core.Const(5)])
│         (gn = Core.getfield(%12, 1))
│         (@_3 = Core.getfield(%12, 2))
%15 = Base.indexed_iterate(%2, 5, @_3::Core.Const(5))::Core.PartialStruct(Tuple{Hermitian{Float64, Matrix{Float64}}, Int64}, Any[Core.PartialStruct(Hermitian{Float64, Matrix{Float64}}, Any[Matrix{Float64}, Core.Const('U')]), Core.Const(6)])
│         (Ω = Core.getfield(%15, 1))
│         (@_3 = Core.getfield(%15, 2))
%18 = Base.indexed_iterate(%2, 6, @_3::Core.Const(6))::Core.PartialStruct(Tuple{Matrix{Float64}, Int64}, Any[Matrix{Float64}, Core.Const(7)])
│         (D = Core.getfield(%18, 1))
│         (@_3 = Core.getfield(%18, 2))
%21 = Base.indexed_iterate(%2, 7, @_3::Core.Const(7))::Core.PartialStruct(Tuple{var"#P#29", Int64}, Any[var"#P#29", Core.Const(8)])
│         (P = Core.getfield(%21, 1))
%23 = n::Int64
%24 = Main.:var"'"(gn)::Matrix{Float64}
%25 = Ω::Core.PartialStruct(Hermitian{Float64, Matrix{Float64}}, Any[Matrix{Float64}, Core.Const('U')])
%26 = (-1 / 2)::Core.Const(-0.5)
%27 = (%25 ^ %26)::UNION{HERMITIAN{FLOAT64, MATRIX{FLOAT64}}, MATRIX{COMPLEXF64}}
%28 = Ω::Core.PartialStruct(Hermitian{Float64, Matrix{Float64}}, Any[Matrix{Float64}, Core.Const('U')])
%29 = (-1 / 2)::Core.Const(-0.5)
%30 = (%28 ^ %29)::UNION{HERMITIAN{FLOAT64, MATRIX{FLOAT64}}, MATRIX{COMPLEXF64}}
%31 = (%30 * D)::UNION{MATRIX{COMPLEXF64}, MATRIX{FLOAT64}}
%32 = (P)(%31)::UNION{MATRIX{COMPLEXF64}, MATRIX{FLOAT64}}
%33 = Ω::Core.PartialStruct(Hermitian{Float64, Matrix{Float64}}, Any[Matrix{Float64}, Core.Const('U')])
%34 = (-1 / 2)::Core.Const(-0.5)
%35 = (%33 ^ %34)::UNION{HERMITIAN{FLOAT64, MATRIX{FLOAT64}}, MATRIX{COMPLEXF64}}
%36 = (%24 * %27 * %32 * %35 * gn)::UNION{MATRIX{COMPLEXF64}, MATRIX{FLOAT64}}
%37 = Base.getindex(%36, 1)::UNION{FLOAT64, COMPLEXF64}
%38 = (%23 * %37)::UNION{FLOAT64, COMPLEXF64}
└──       return %38

There’s still a type-instability in klm. This one is harder to understand. It is due to the fact that the appropriate meaning of a matrix square root depends on the nature of the matrix. In particular, the value could be a complex valued matrix instead of real valued. We know that Ω should be positive definite with a real matrix square root. We can compute its square root from its Eigen decomposition and avoid the type instability.

function klm(gi::F ) where {F <: Function}
  let gi=gi
    function(θ)
      (n,k,p,gn, Ω, D, P) = statparts(gi)(θ)
      λ, v = eigen(Ω)
      irΩ = v*diagm.^(-1/2))*v'
      return n*(gn'*irΩ*P(irΩ*D)*irΩ*gn)[1]
    end
  end
end
klm (generic function with 1 method)
julia> @code_warntype klm(gi)(β0)
MethodInstance for (::var"#31#32"{var"#26#27"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}})(::Vector{Float64})
  from (::var"#31#32")(θ) @ Main ~/.julia/dev/ARGridBootstrap/docs/jmd/assignment.jmd:4
Arguments
  #self#::var"#31#32"{var"#26#27"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}}
  θ::Vector{Float64}
Locals
  @_3::Val{:vectors}
  @_4::Int64
  irΩ::Matrix{Float64}
  v::Matrix{Float64}
  λ::Vector{Float64}
  P::var"#P#29"
  D::Matrix{Float64}
  Ω::Hermitian{Float64, Matrix{Float64}}
  gn::Adjoint{Float64, Matrix{Float64}}
  p::Int64
  k::Int64
  n::Int64
Body::Float64
1%1  = Core.getfield(#self#, Symbol("#458#gi"))::var"#26#27"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}
%2  = Main.statparts(%1)::var"#28#30"{var"#P#29", var"#26#27"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}}
%3  = (%2)(θ)::Core.PartialStruct(Tuple{Int64, Int64, Int64, Adjoint{Float64, Matrix{Float64}}, Hermitian{Float64, Matrix{Float64}}, Matrix{Float64}, var"#P#29"}, Any[Int64, Int64, Int64, Adjoint{Float64, Matrix{Float64}}, Core.PartialStruct(Hermitian{Float64, Matrix{Float64}}, Any[Matrix{Float64}, Core.Const('U')]), Matrix{Float64}, var"#P#29"])
%4  = Base.indexed_iterate(%3, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│         (n = Core.getfield(%4, 1))
│         (@_4 = Core.getfield(%4, 2))
%7  = Base.indexed_iterate(%3, 2, @_4::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│         (k = Core.getfield(%7, 1))
│         (@_4 = Core.getfield(%7, 2))
%10 = Base.indexed_iterate(%3, 3, @_4::Core.Const(3))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(4)])
│         (p = Core.getfield(%10, 1))
│         (@_4 = Core.getfield(%10, 2))
%13 = Base.indexed_iterate(%3, 4, @_4::Core.Const(4))::Core.PartialStruct(Tuple{Adjoint{Float64, Matrix{Float64}}, Int64}, Any[Adjoint{Float64, Matrix{Float64}}, Core.Const(5)])
│         (gn = Core.getfield(%13, 1))
│         (@_4 = Core.getfield(%13, 2))
%16 = Base.indexed_iterate(%3, 5, @_4::Core.Const(5))::Core.PartialStruct(Tuple{Hermitian{Float64, Matrix{Float64}}, Int64}, Any[Core.PartialStruct(Hermitian{Float64, Matrix{Float64}}, Any[Matrix{Float64}, Core.Const('U')]), Core.Const(6)])
│         (Ω = Core.getfield(%16, 1))
│         (@_4 = Core.getfield(%16, 2))
%19 = Base.indexed_iterate(%3, 6, @_4::Core.Const(6))::Core.PartialStruct(Tuple{Matrix{Float64}, Int64}, Any[Matrix{Float64}, Core.Const(7)])
│         (D = Core.getfield(%19, 1))
│         (@_4 = Core.getfield(%19, 2))
%22 = Base.indexed_iterate(%3, 7, @_4::Core.Const(7))::Core.PartialStruct(Tuple{var"#P#29", Int64}, Any[var"#P#29", Core.Const(8)])
│         (P = Core.getfield(%22, 1))
%24 = Main.eigen::Core.PartialStruct(Hermitian{Float64, Matrix{Float64}}, Any[Matrix{Float64}, Core.Const('U')]))::Eigen{Float64, Float64, Matrix{Float64}, Vector{Float64}}
%25 = Base.indexed_iterate(%24, 1)::Tuple{Vector{Float64}, Val{:vectors}}
│         (λ = Core.getfield(%25, 1))
│         (@_3 = Core.getfield(%25, 2))
%28 = Base.indexed_iterate(%24, 2, @_3)::Tuple{Matrix{Float64}, Val{:done}}
│         (v = Core.getfield(%28, 1))
%30 = v::Matrix{Float64}
%31 = Main.:^::Core.Const(^)
%32 = λ::Vector{Float64}
%33 = (-1 / 2)::Core.Const(-0.5)
%34 = Base.broadcasted(%31, %32, %33)::Core.PartialStruct(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(^), Tuple{Vector{Float64}, Float64}}, Any[Core.Const(Base.Broadcast.DefaultArrayStyle{1}()), Core.Const(^), Core.PartialStruct(Tuple{Vector{Float64}, Float64}, Any[Vector{Float64}, Core.Const(-0.5)]), Nothing])
%35 = Base.materialize(%34)::Vector{Float64}
%36 = Main.diagm(%35)::Matrix{Float64}
%37 = Main.:var"'"(v)::Adjoint{Float64, Matrix{Float64}}
│         (irΩ = %30 * %36 * %37)
%39 = n::Int64
%40 = Main.:var"'"(gn)::Matrix{Float64}
%41 = irΩ::Matrix{Float64}
%42 = (irΩ * D)::Matrix{Float64}
%43 = (P)(%42)::Matrix{Float64}
%44 = irΩ::Matrix{Float64}
%45 = (%40 * %41 * %43 * %44 * gn)::Matrix{Float64}
%46 = Base.getindex(%45, 1)::Float64
%47 = (%39 * %46)::Float64
└──       return %47

Fixing these type instabilities speeds up the code by a factor of about 5.

@benchmark klm(gi)(β0)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  132.812 μs …   9.251 ms  ┊ GC (min … max):  0.00% … 90
.89%
 Time  (median):     148.332 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   174.327 μs ± 435.225 μs  ┊ GC (mean ± σ):  12.87% ±  5
.07%

                ▃▆██▅▂                                           
  ▂▂▂▂▂▂▃▄▄▄▅▅▆████████▇▆▆▅▅▄▄▄▃▃▃▃▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂ ▃
  133 μs           Histogram: frequency by time          186 μs <

 Memory estimate: 219.61 KiB, allocs estimate: 1516.

Reducing allocations and Other Optimizations

Profiling reveals the majority of time is spent in the innermost loop of the statparts function. This loop allocates quite a bit because the arrays are using slices. We can avoid allocations by using @views and more broadcasting. See “Consider using views for slices” and “More dots”,

function statparts(gi::F) where {F <: Function}
  function P(A::AbstractMatrix) # projection matrix
    A*pinv(A'*A)*A'
  end
  let gi=gi
    function(θ)
      giθ = gi(θ)
      p = length(θ)
      (n, k) = size(giθ)
      Ω = Hermitian(cov(giθ))
      gn=mean(gi(θ), dims=1)'
      iΩgn = Ω \ gn
      Gi = zeros(n,k,p)
      ForwardDiff.jacobian!(Gi,gi,θ)
      Gi = reshape(Gi, n , k, p)
      G = mean(Gi, dims=1)
      Γ = zeros(eltype(Gi),p,k,k)
      D = zeros(eltype(Gi),k, p)
      @inbounds for j in 1:p
        @inbounds for i in 1:n
          @views Γ[j,:,:] .+= (Gi[i,:,j] .- G[1,:,j]) * giθ[i,:]'
        end
        Γ[j,:,:] ./= n
        @views D[:,j] .= G[1,:,j] .- Γ[j,:,:]*iΩgn
      end
      return(n,k,p,gn, Ω, D, P)
    end
  end
end
@benchmark klm(gi)(β0)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  69.363 μs …   6.704 ms  ┊ GC (min … max):  0.00% … 92.
90%
 Time  (median):     76.354 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   86.922 μs ± 255.393 μs  ┊ GC (mean ± σ):  11.10% ±  3.
73%

               ▁▁▃▄▆▅▇▇▆█▆▇▄▃▂                                  
  ▁▁▁▁▁▂▂▃▃▄▅▆█████████████████▇▆▅▄▃▃▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
  69.4 μs         Histogram: frequency by time           89 μs <

 Memory estimate: 118.53 KiB, allocs estimate: 501.

The code is now about ten times faster than the original.

Footnotes

  1. Like scalar variables, StaticArrays exist on the stack instead of the heap, so creating them is much less costly and they do not count toward the reported allocations.↩︎