Coding for Performance

4: More Examples

Author

Paul Schrimpf

Published

March 22, 2026

For the assignment, most students focused on optimizing either the share function from our earlier BLP example, or the statparts function from GMMInference.

Let’s see how we can improve these functions’ performance.

share

Here’s a copy of the original version of the share function

using LinearAlgebra, Distributions, Random, BenchmarkTools, Profile, ProfileCanvas

function share(δ, Σ, x, ∫)
  J,K = size(x)
  (length(δ) == J) || error("length(δ)=$(length(δ)) != size(x,1)=$J")
  (K,K) == size(Σ) || error("size(x,2)=$K != size(Σ)=$(size(Σ))")
  function shareν(ν)
    s = δ .+ x*Σ*ν
    smax=max(0,maximum(s))
    s .-= smax
    s .= exp.(s)
    s ./= (sum(s) + exp(0-smax))
    return(s)
  end
  return((shareν))
end
share (generic function with 1 method)

and a simple function for the integral.

struct Integrator{Tx, Tw}
    x::Tx
    w::Tw
end

function Integrator(dx::Distribution, n=100)
    x = [rand(dx) for _ in 1:n]
    w = Base.Iterators.Repeated(1/n)
    Integrator(x,w)
end

(∫::Integrator)(f) = sum((xw)->f(xw[1])*xw[2], zip(∫.x, ∫.w))

For simplicity and to focus on the share function, we just use Monte Carlo integration since its easy to implement, but we could get more accuracy by using quasi monte carlo or quadrature.

Initial Benchmark

J = 10
K = 5
δ = rand(J)
X = randn(J,K)
Σ = I + zeros(K,K)
= Integrator(MvNormal(zeros(K),I));
@benchmark share(δ,Σ,X,∫)
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range (min … max):  13.125 μs …  33.016 ms  ┊ GC (min … max):  0.00% … 99.
73%
 Time  (median):     20.319 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   24.440 μs ± 330.084 μs  ┊ GC (mean ± σ):  13.47% ±  1.
00%

    ▁▃▄▅▄▄▃▂▁▁  ▁▁▄▅▇███▇▇▆▅▄▃▂▂▂▂▂▁▁▂▂▁▁▁▁▁▁▁▂▁▁▁▁▁▁          ▃
  ▃▆██████████████████████████████████████████████████▇▆▆▆▇▆▆▅ █
  13.1 μs       Histogram: log(frequency) by time      33.6 μs <

 Memory estimate: 65.48 KiB, allocs estimate: 998.

Profiling it:

using Profile
Profile.clear();
Profile.init(n=10^7,delay=0.00001);
@profile sum(share(δ,Σ,X,∫) for i  1:1000)
profilehtmlstring()

It’s a good idea to check @code_warntype and verify that there are no type instabilities (variables with type Any or Union). I checked and there are no type instabilites here.

LoopVectorization

Examining the profile, we see that the multiplication of x*Σ*ν is the single most costly operation. The second most costly are lines computing the softmax function, exp.(s)./(1 .+exp.(s))

using LoopVectorization

@inline function dplusxMy(d::AbstractVector,x::AbstractMatrix,M::AbstractMatrix,y::AbstractVector)
  out = similar(d)
  @turbo for i  axes(x,1)
    r = d[i]
    for j  axes(x,2)
      for k  axes(M,2)
        r += x[i,j]*M[j,k]*y[k]
      end
    end
    out[i] = r
  end
  out
end
@inline function softmax0!(s)
  smax=@turbo reduce(max, s, init=zero(eltype(s)))
  s .= exp.(s .- smax)
  s ./= (sum(s) + exp(zero(smax)-smax))
  return(s)
end
function share_v2(δ, Σ, x, ∫)
  J,K = size(x)
  (length(δ) == J) || error("length(δ)=$(length(δ)) != size(x,1)=$J")
  (K,K) == size(Σ) || error("size(x,2)=$K != size(Σ)=$(size(Σ))")
  function shareν(ν)
    s = dplusxMy(δ,x,Σ,ν)
    softmax0!(s)
    s
  end
  return((shareν))
end
@benchmark share_v2(δ,Σ,X,∫)
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range (min … max):  10.700 μs … 244.813 μs  ┊ GC (min … max): 0.00% … 0.00
%
 Time  (median):     15.088 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   15.453 μs ±  11.460 μs  ┊ GC (mean ± σ):  0.00% ± 0.00
%

                               ▃▆▆▇█▇▅▄▃▂                       
  ▂▃▄▅▇▇▆▅▄▃▃▂▂▂▂▂▁▂▁▂▂▂▂▂▂▂▃▅████████████▇▆▅▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂ ▄
  10.7 μs         Histogram: frequency by time         18.6 μs <

 Memory estimate: 42.05 KiB, allocs estimate: 598.

By writing the δ+x*Σ*ν as loop we avoid some allocations, and can use @turbo to insert SIMD instructions. This gives a noticeable speedup. About half the gains here are from writing out the loops in dplusxMy and half from using @turbo. The changes to the softmax calculation did not make much difference.

using Profile
Profile.clear();
Profile.init(n=10^7,delay=0.00001);
ProfileCanvas.@profview sum(share_v2(δ,Σ,X,∫) for i  1:1000)
profilehtmlstring()

The execution time is still dominated by the two functions mentioned above.

StaticArrays

Since the calculations involve small arrays, it is likely to benefit from using StaticArrays.

using StaticArrays
= SVector{J}(δ)
= SMatrix{K,K}(Σ)
sX = SMatrix{J,K}(X)
nd = length(∫.x)
iw = SVector{nd}(fill(1/nd,nd))
ix = [SVector{K}(x) for x  ∫.x]
s∫ = Integrator(ix,iw)
@benchmark share_v2(sδ,sΣ,sX,s∫)
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range (min … max):  10.259 μs … 222.170 μs  ┊ GC (min … max): 0.00% … 0.00
%
 Time  (median):     13.906 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   14.045 μs ±   9.905 μs  ┊ GC (mean ± σ):  0.00% ± 0.00
%

                                   ▃▇█▇▃                        
  ▁▂▃▅▇▆▅▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▃▅███████▆▆▅▆▅▅▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁ ▃
  10.3 μs         Histogram: frequency by time         16.4 μs <

 Memory estimate: 28.03 KiB, allocs estimate: 299.

The code is now 2 times faster than what we started with.

Interestly, slightly better performance can be achieved by simply passing StaticArrays to the original code.

@benchmark share($(MVector(sδ)),sΣ,sX,s∫) # δ gets mutated, so we must make it a Mutable SArray.
BenchmarkTools.Trial: 10000 samples with 4 evaluations per sample.
 Range (min … max):   7.454 μs …  6.347 ms  ┊ GC (min … max):  0.00% … 99.7
3%
 Time  (median):     10.988 μs              ┊ GC (median):     0.00%
 Time  (mean ± σ):   12.519 μs ± 92.640 μs  ┊ GC (mean ± σ):  15.06% ±  2.2
2%

    ▃▆▇▅▄▂                    ▁▃▅▇███▇▆▅▄▃▂▂▁▁                ▃
  ▂▆███████████▇▇██▇███▇█▇▇▇▆█████████████████████▇▄▆▆▃▆▄▅▄▃▅ █
  7.45 μs      Histogram: log(frequency) by time      13.8 μs <

 Memory estimate: 28.03 KiB, allocs estimate: 299.

Non-allocating

We can do still better by eliminating the allocations. With normal arrays, broadcast operations are usually non-allocating, and normal array operations allocate. With StaticArrays, normal operatorations do not allocate.1

Additionally, we can precompute x*Σ outside of shareν (we could have done this at any time earlier too). This change accounts for about 1μs of the speed up.

function share_v3(δ, Σ, x, ∫)
  J,K = size(x)
  (length(δ) == J) || error("length(δ)=$(length(δ)) != size(x,1)=$J")
  (K,K) == size(Σ) || error("size(x,2)=$K != size(Σ)=$(size(Σ))")
= x*Σ
  function shareν(ν)
    s = δ +*ν
    smax=max(0,maximum(s))
    s -= smax*ones(typeof(s))
    s = exp.(s)
    s *= 1/(sum(s) + exp(0-smax))
    return(s)
  end
  return((shareν))
end
share_v3 (generic function with 1 method)
@benchmark share_v3(sδ,sΣ,sX,s∫)
BenchmarkTools.Trial: 10000 samples with 7 evaluations per sample.
 Range (min … max):  4.287 μs …  6.528 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     4.321 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   4.333 μs ± 55.434 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

    ▁▂▄▇██▆▅▄▃▂▁▂▂▂▂ ▁                                       ▂
  ▃▅██████████████████▇█▆▅▆▄▄▅▁▄▃▄▁▁▁▄▇▅▆▆▆▅▅▆▅▆▇▇█▇█▇▇▇▆▇▆▅ █
  4.29 μs      Histogram: log(frequency) by time     4.57 μs <

 Memory estimate: 96 bytes, allocs estimate: 1.

The code is now 5 times faster than the original, and only allocates once (to store the return value).

Multi-Threading

Multi-threading this code is difficult because it is already quite fast. The sum in the integral can be parallelized, but unless there are a large number of integration points, the overhead from creating threads will likely outweigh the benefits.

The Polyester package provides a faster, but more limited threading model than base Julia. Using it, we can see modest gains, even with just 100 integration points.

To get type stability, I had to write the integration sum loop inside the share function. It would have been cleaner to keep the integration sum in a separate function, but I could not make it type stable that way.

import Polyester
function share_v4(δ, Σ, x, ∫, ::Val{B}=Val(length(∫.x) ÷ 10)) where {B}
  J,K = size(x)
  (length(δ) == J) || error("length(δ)=$(length(δ)) != size(x,1)=$J")
  (K,K) == size(Σ) || error("size(x,2)=$K != size(Σ)=$(size(Σ))")
= x*Σ
  function shareν(ν)
    s = δ +*ν
    smax=max(0,maximum(s))
    s -= smax*ones(typeof(s))
    s = exp.(s)
    s *= 1/(sum(s) + exp(0-smax))
    return(s)
  end
  batchlen= length(∫.x)÷B
  @assert B*batchlen==length(∫.x)
  out = MVector{B,typeof(δ)}(undef)
  Polyester.@batch for b  1:B  
    batch = ((b-1)*(batchlen)+1):(b*batchlen)
    out[b] = zero(typeof(δ))
    for i  batch
      out[b] += shareν(∫.x[i])*∫.w[i]
    end
  end
  return(sum(out))
end
Polyester.reset_threads!()
@benchmark share_v4(sδ,sΣ,sX,s∫, Val(20))
BenchmarkTools.Trial: 10000 samples with 10 evaluations per sample.
 Range (min … max):  1.872 μs …   8.854 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     2.189 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   2.205 μs ± 297.213 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

             ▁▁▁▃▂▂▂▂▃▄▂▅▆▆▆▇█▆▅▅▄▂▂                           
  ▂▁▂▂▃▃▄▄▅▅▇████████████████████████▇▇▅▆▅▄▃▃▃▃▃▃▃▂▂▃▂▂▂▂▂▂▂▂ ▅
  1.87 μs         Histogram: frequency by time        2.62 μs <

 Memory estimate: 1.55 KiB, allocs estimate: 2.

statparts

A few people worked on optimizing the klm or clr function from GMMInference.jl. Here is the original code for klm.

using ForwardDiff
function statparts(gi::Function)
  function P(A::AbstractMatrix) # projection matrix
    A*pinv(A'*A)*A'
  end
  function(θ)
    giθ = gi(θ)
    p = length(θ)
    (n, k) = size(giθ)
    Ω = cov(giθ)
    gn=mean(gi(θ), dims=1)'
    Gi= ForwardDiff.jacobian(gi,θ)
    Gi = reshape(Gi, n , k, p)
    G = mean(Gi, dims=1)
    Γ = zeros(eltype(Gi),p,k,k)
    D = zeros(eltype(Gi),k, p)
    for j in 1:p
      for i in 1:n
        Γ[j,:,:] += (Gi[i,:,j] .- G[1,:,j]) * giθ[i,:]'
      end
      Γ[j,:,:] ./= n
      D[:,j] = G[1,:,j] - Γ[j,:,:]*inv(Ω)*gn
    end
    return(n,k,p,gn, Ω, D, P)
  end
end

function klm(gi::Function)
  SP = statparts(gi)
  function(θ)
    (n,k,p,gn, Ω, D, P) = SP(θ)
    return n*(gn'*Ω^(-1/2)*P^(-1/2)*D)*Ω^(-1/2)*gn)[1]
  end
end
klm (generic function with 1 method)

To run the code, we need an example gi function. We’ll just copy the example from the docs.

import Random
function simulate_ivshare(n,β,γ,ρ)
  z = randn(n, size(γ)[1])
  endo = randn(n, length(β))
  x = z*γ .+ endo
  ξ = rand(Normal(0,sqrt((1.0-ρ^2))),n).+endo[:,1]*ρ
  y = cdf.(Logistic(), x*β .+ ξ)
  return((y=y,x=x,z=z))
end
n = 100
k = 2
iv = 3
β0 = ones(k)
π0 = vcat(5*I,ones(iv-k,k))
ρ = 0.5
Random.seed!(622)
(y,x,z) = simulate_ivshare(n,β00,ρ)

function gi_ivshare(β,y,x,z)
  ξ = quantile.(Logistic(),y) .- x*β
  ξ.*z
end

gi = let y=y, x=x, z=z
  β->gi_ivshare(β,y,x,z)
end
#69 (generic function with 1 method)

Initial Benchmark

@benchmark klm(gi)(β0)
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range (min … max):  176.563 μs …  36.670 ms  ┊ GC (min … max): 0.00% … 99.
13%
 Time  (median):     196.281 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   213.042 μs ± 545.559 μs  ┊ GC (mean ± σ):  7.25% ±  3.
08%

      ▄██▇▅▃▁              ▃▅▅▅▃                                 
  ▁▁▃▇████████▇▇▆▅▆▅▆▅▆▆▅▆███████▇▇▅▄▅▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁ ▄
  177 μs           Histogram: frequency by time          236 μs <

 Memory estimate: 251.80 KiB, allocs estimate: 5450.

Fixing Type Instabilities

From @code_warntype, we see that the compiles is unable to infer the type of some variables. The problem seems to start with D. This is quite puzzling because D is explicitly initialized as zeros(eltype(Gi),...).

julia> @code_warntype klm(gi)(β0)
MethodInstance for (::var"#klm##0#klm##1"{var"#statparts##1#statparts##2"{var"#69#70"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}, var"#P#statparts##0"}})(::Vector{Float64})
  from (::var"#klm##0#klm##1")(θ) @ Main ~/ARGridBootstrap/docs/jmd/assignment.jmd:31
Arguments
  #self#::var"#klm##0#klm##1"{var"#statparts##1#statparts##2"{var"#69#70"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}, var"#P#statparts##0"}}
  θ::Vector{Float64}
Locals
  @_3::Int64
  P::var"#P#statparts##0"
  D::ANY
  Ω::Matrix{Float64}
  gn::Adjoint{Float64, Matrix{Float64}}
  p::Int64
  k::Int64
  n::Int64
Body::ANY
1%1  = Core.getfield(#self#, :SP)::var"#statparts##1#statparts##2"{var"#69#70"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}, var"#P#statparts##0"}
%2  = (%1)(θ)::TUPLE{INT64, INT64, INT64, ADJOINT{FLOAT64, MATRIX{FLOAT64}}, MATRIX{FLOAT64}, ANY, VAR"#P#STATPARTS##0"}
%3  = Base.indexed_iterate(%2, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│         (n = Core.getfield(%3, 1))
│         (@_3 = Core.getfield(%3, 2))
%6  = @_3::Core.Const(2)
%7  = Base.indexed_iterate(%2, 2, %6)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│         (k = Core.getfield(%7, 1))
│         (@_3 = Core.getfield(%7, 2))
%10 = @_3::Core.Const(3)
%11 = Base.indexed_iterate(%2, 3, %10)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(4)])
│         (p = Core.getfield(%11, 1))
│         (@_3 = Core.getfield(%11, 2))
%14 = @_3::Core.Const(4)
%15 = Base.indexed_iterate(%2, 4, %14)::Core.PartialStruct(Tuple{Adjoint{Float64, Matrix{Float64}}, Int64}, Any[Adjoint{Float64, Matrix{Float64}}, Core.Const(5)])
│         (gn = Core.getfield(%15, 1))
│         (@_3 = Core.getfield(%15, 2))
%18 = @_3::Core.Const(5)
%19 = Base.indexed_iterate(%2, 5, %18)::Core.PartialStruct(Tuple{Matrix{Float64}, Int64}, Any[Matrix{Float64}, Core.Const(6)])
│         (Ω = Core.getfield(%19, 1))
│         (@_3 = Core.getfield(%19, 2))
%22 = @_3::Core.Const(6)
%23 = Base.indexed_iterate(%2, 6, %22)::Core.PartialStruct(Tuple{Any, Int64}, Any[Any, Core.Const(7)])
│         (D = Core.getfield(%23, 1))
│         (@_3 = Core.getfield(%23, 2))
%26 = @_3::Core.Const(7)
%27 = Base.indexed_iterate(%2, 7, %26)::Core.Const((var"#P#statparts##0"(), 8))
│         (P = Core.getfield(%27, 1))
%29 = Main.:*::Core.Const(*)
%30 = n::Int64
%31 = Main.:*::Core.Const(*)
%32 = Main.:var"'"::Core.Const(adjoint)
%33 = gn::Adjoint{Float64, Matrix{Float64}}
%34 = (%32)(%33)::Matrix{Float64}
%35 = Main.:^::Core.Const(^)
%36 = Ω::Matrix{Float64}
%37 = Main.:/::Core.Const(/)
%38 = (%37)(-1, 2)::Core.Const(-0.5)
%39 = (%35)(%36, %38)::UNION{HERMITIAN{FLOAT64, MATRIX{FLOAT64}}, SYMMETRIC{COMPLEXF64, MATRIX{COMPLEXF64}}, MATRIX}
%40 = P::Core.Const(var"#P#statparts##0"())
%41 = Main.:*::Core.Const(*)
%42 = Main.:^::Core.Const(^)
%43 = Ω::Matrix{Float64}
%44 = Main.:/::Core.Const(/)
%45 = (%44)(-1, 2)::Core.Const(-0.5)
%46 = (%42)(%43, %45)::UNION{HERMITIAN{FLOAT64, MATRIX{FLOAT64}}, SYMMETRIC{COMPLEXF64, MATRIX{COMPLEXF64}}, MATRIX}
%47 = D::ANY
%48 = (%41)(%46, %47)::ANY
%49 = (%40)(%48)::ANY
%50 = Main.:^::Core.Const(^)
%51 = Ω::Matrix{Float64}
%52 = Main.:/::Core.Const(/)
%53 = (%52)(-1, 2)::Core.Const(-0.5)
%54 = (%50)(%51, %53)::UNION{HERMITIAN{FLOAT64, MATRIX{FLOAT64}}, SYMMETRIC{COMPLEXF64, MATRIX{COMPLEXF64}}, MATRIX}
%55 = gn::Adjoint{Float64, Matrix{Float64}}
%56 = (%31)(%34, %39, %49, %54, %55)::ANY
%57 = Base.getindex(%56, 1)::ANY
%58 = (%29)(%30, %57)::ANY
└──       return %58

To investigate further, let us focus on statparts.

julia> @code_warntype statparts(gi)(β0)
MethodInstance for (::var"#statparts##1#statparts##2"{var"#69#70"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}, var"#P#statparts##0"})(::Vector{Float64})
  from (::var"#statparts##1#statparts##2")(θ) @ Main ~/ARGridBootstrap/docs/jmd/assignment.jmd:7
Arguments
  #self#::var"#statparts##1#statparts##2"{var"#69#70"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}, var"#P#statparts##0"}
  θ::Vector{Float64}
Locals
  @_3::UNION{NOTHING, TUPLE{INT64, INT64}}
  @_4::Int64
  D::ANY
  Γ::ANY
  G::ANY
  Gi::ANY
  gn::Adjoint{Float64, Matrix{Float64}}
  Ω::Matrix{Float64}
  k::Int64
  n::Int64
  p::Int64
  giθ::Matrix{Float64}
  @_15::UNION{NOTHING, TUPLE{INT64, INT64}}
  j::Int64
  i::Int64
Body::TUPLE{INT64, INT64, INT64, ADJOINT{FLOAT64, MATRIX{FLOAT64}}, MATRIX{FLOAT64}, ANY, VAR"#P#STATPARTS##0"}
1%1   = Core.getfield(#self#, :gi)::var"#69#70"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}
│          (giθ = (%1)(θ))
%3   = Main.length::Core.Const(length)
│          (p = (%3)(θ))
%5   = Main.size::Core.Const(size)
%6   = giθ::Matrix{Float64}
%7   = (%5)(%6)::Tuple{Int64, Int64}
%8   = Base.indexed_iterate(%7, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│          (n = Core.getfield(%8, 1))
│          (@_4 = Core.getfield(%8, 2))
%11  = @_4::Core.Const(2)
%12  = Base.indexed_iterate(%7, 2, %11)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│          (k = Core.getfield(%12, 1))
%14  = Main.cov::Core.Const(Statistics.cov)
%15  = giθ::Matrix{Float64}
│          (Ω = (%14)(%15))
%17  = Main.:var"'"::Core.Const(adjoint)
%18  = Core.getfield(#self#, :gi)::var"#69#70"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}
%19  = (%18)(θ)::Matrix{Float64}
%20  = (:dims,)::Core.Const((:dims,))
%21  = Core.apply_type(Core.NamedTuple, %20)::Core.Const(NamedTuple{(:dims,)})
%22  = Core.tuple(1)::Core.Const((1,))
%23  = (%21)(%22)::Core.Const((dims = 1,))
%24  = Main.mean::Core.Const(Statistics.mean)
%25  = Core.kwcall(%23, %24, %19)::Matrix{Float64}
│          (gn = (%17)(%25))
%27  = Main.ForwardDiff::Core.Const(ForwardDiff)
%28  = Base.getproperty(%27, :jacobian)::Core.Const(ForwardDiff.jacobian)
%29  = Core.getfield(#self#, :gi)::var"#69#70"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}
│          (Gi = (%28)(%29, θ))
%31  = Main.reshape::Core.Const(reshape)
%32  = Gi::ANY
%33  = n::Int64
%34  = k::Int64
%35  = p::Int64
│          (Gi = (%31)(%32, %33, %34, %35))
%37  = (:dims,)::Core.Const((:dims,))
%38  = Core.apply_type(Core.NamedTuple, %37)::Core.Const(NamedTuple{(:dims,)})
%39  = Core.tuple(1)::Core.Const((1,))
%40  = (%38)(%39)::Core.Const((dims = 1,))
%41  = Main.mean::Core.Const(Statistics.mean)
%42  = Gi::ANY
│          (G = Core.kwcall(%40, %41, %42))
%44  = Main.zeros::Core.Const(zeros)
%45  = Main.eltype::Core.Const(eltype)
%46  = Gi::ANY
%47  = (%45)(%46)::ANY
%48  = p::Int64
%49  = k::Int64
%50  = k::Int64
│          (Γ = (%44)(%47, %48, %49, %50))
%52  = Main.zeros::Core.Const(zeros)
%53  = Main.eltype::Core.Const(eltype)
%54  = Gi::ANY
%55  = (%53)(%54)::ANY
%56  = k::Int64
%57  = p::Int64
│          (D = (%52)(%55, %56, %57))
%59  = Main.:(:)::Core.Const(Colon())
%60  = p::Int64
%61  = (%59)(1, %60)::Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])
│          (@_3 = Base.iterate(%61))
%63  = @_3::UNION{NOTHING, TUPLE{INT64, INT64}}
%64  = (%63 === nothing)::Bool
%65  = Base.not_int(%64)::Bool
└──        goto #7 if not %65
2%67  = @_3::Tuple{Int64, Int64}
│          (j = Core.getfield(%67, 1))
%69  = Core.getfield(%67, 2)::Int64
%70  = Main.:(:)::Core.Const(Colon())
%71  = n::Int64
%72  = (%70)(1, %71)::Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])
│          (@_15 = Base.iterate(%72))
%74  = @_15::UNION{NOTHING, TUPLE{INT64, INT64}}
%75  = (%74 === nothing)::Bool
%76  = Base.not_int(%75)::Bool
└──        goto #5 if not %76
3%78  = @_15::Tuple{Int64, Int64}
│          (i = Core.getfield(%78, 1))
%80  = Core.getfield(%78, 2)::Int64
%81  = Main.:+::Core.Const(+)
%82  = Γ::ANY
%83  = j::Int64
%84  = Main.:(:)::Core.Const(Colon())
%85  = Main.:(:)::Core.Const(Colon())
%86  = Base.getindex(%82, %83, %84, %85)::ANY
%87  = Main.:*::Core.Const(*)
%88  = Main.:-::Core.Const(-)
%89  = Gi::ANY
%90  = i::Int64
%91  = Main.:(:)::Core.Const(Colon())
%92  = j::Int64
%93  = Base.getindex(%89, %90, %91, %92)::ANY
%94  = G::ANY
%95  = Main.:(:)::Core.Const(Colon())
%96  = j::Int64
%97  = Base.getindex(%94, 1, %95, %96)::ANY
%98  = Base.broadcasted(%88, %93, %97)::ANY
%99  = Base.materialize(%98)::ANY
%100 = Main.:var"'"::Core.Const(adjoint)
%101 = giθ::Matrix{Float64}
%102 = i::Int64
%103 = Main.:(:)::Core.Const(Colon())
%104 = Base.getindex(%101, %102, %103)::Vector{Float64}
%105 = (%100)(%104)::Adjoint{Float64, Vector{Float64}}
%106 = (%87)(%99, %105)::ANY
%107 = (%81)(%86, %106)::ANY
%108 = Γ::ANY
%109 = j::Int64
%110 = Main.:(:)::Core.Const(Colon())
%111 = Main.:(:)::Core.Const(Colon())
Base.setindex!(%108, %107, %109, %110, %111)
│          (@_15 = Base.iterate(%72, %80))
%114 = @_15::UNION{NOTHING, TUPLE{INT64, INT64}}
%115 = (%114 === nothing)::Bool
%116 = Base.not_int(%115)::Bool
└──        goto #5 if not %116
4 ─        goto #3
5%119 = Γ::ANY
%120 = j::Int64
%121 = Main.:(:)::Core.Const(Colon())
%122 = Main.:(:)::Core.Const(Colon())
%123 = Base.dotview(%119, %120, %121, %122)::ANY
%124 = Main.:/::Core.Const(/)
%125 = Γ::ANY
%126 = j::Int64
%127 = Main.:(:)::Core.Const(Colon())
%128 = Main.:(:)::Core.Const(Colon())
%129 = Base.getindex(%125, %126, %127, %128)::ANY
%130 = n::Int64
%131 = Base.broadcasted(%124, %129, %130)::ANY
Base.materialize!(%123, %131)
%133 = Main.:-::Core.Const(-)
%134 = G::ANY
%135 = Main.:(:)::Core.Const(Colon())
%136 = j::Int64
%137 = Base.getindex(%134, 1, %135, %136)::ANY
%138 = Main.:*::Core.Const(*)
%139 = Γ::ANY
%140 = j::Int64
%141 = Main.:(:)::Core.Const(Colon())
%142 = Main.:(:)::Core.Const(Colon())
%143 = Base.getindex(%139, %140, %141, %142)::ANY
%144 = Main.inv::Core.Const(inv)
%145 = Ω::Matrix{Float64}
%146 = (%144)(%145)::Matrix{Float64}
%147 = gn::Adjoint{Float64, Matrix{Float64}}
%148 = (%138)(%143, %146, %147)::ANY
%149 = (%133)(%137, %148)::ANY
%150 = D::ANY
%151 = Main.:(:)::Core.Const(Colon())
%152 = j::Int64
Base.setindex!(%150, %149, %151, %152)
│          (@_3 = Base.iterate(%61, %69))
%155 = @_3::UNION{NOTHING, TUPLE{INT64, INT64}}
%156 = (%155 === nothing)::Bool
%157 = Base.not_int(%156)::Bool
└──        goto #7 if not %157
6 ─        goto #2
7%160 = n::Int64
%161 = k::Int64
%162 = p::Int64
%163 = gn::Adjoint{Float64, Matrix{Float64}}
%164 = Ω::Matrix{Float64}
%165 = D::ANY
%166 = Core.getfield(#self#, :P)::Core.Const(var"#P#statparts##0"())
%167 = Core.tuple(%160, %161, %162, %163, %164, %165, %166)::TUPLE{INT64, INT64, INT64, ADJOINT{FLOAT64, MATRIX{FLOAT64}}, MATRIX{FLOAT64}, ANY, VAR"#P#STATPARTS##0"}
└──        return %167

We see that G, Gi, Γ, and D are all type Any. For some reason, the return value of ForwardDiff.jacobian is not being inferred. We can workaround this by using an jacobian! instead.

function statparts(gi::F) where {F <: Function}
  function P(A::AbstractMatrix) # projection matrix
    A*pinv(A'*A)*A'
  end
  let gi=gi
    function(θ)
      giθ = gi(θ)
      p = length(θ)
      (n, k) = size(giθ)
      Ω = Hermitian(cov(giθ))
      gn=mean(gi(θ), dims=1)'
      Gi = zeros(n,k,p)
      ForwardDiff.jacobian!(Gi,gi,θ)
      Gi = reshape(Gi, n , k, p)
      G = mean(Gi, dims=1)
      Γ = zeros(eltype(Gi),p,k,k)
      D = zeros(eltype(Gi),k, p)
      for j in 1:p
        for i in 1:n
          Γ[j,:,:] += (Gi[i,:,j] .- G[1,:,j]) * giθ[i,:]'
        end
        Γ[j,:,:] ./= n
        D[:,j] = G[1,:,j] - Γ[j,:,:]*inv(Ω)*gn
      end
      return(n,k,p,gn, Ω, D, P)
    end
  end
end
statparts (generic function with 1 method)
julia> @code_warntype statparts(gi)(β0)
MethodInstance for (::var"#statparts##4#statparts##5"{var"#P#statparts##3", var"#69#70"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}})(::Vector{Float64})
  from (::var"#statparts##4#statparts##5")(θ) @ Main ~/ARGridBootstrap/docs/jmd/assignment.jmd:7
Arguments
  #self#::var"#statparts##4#statparts##5"{var"#P#statparts##3", var"#69#70"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}}
  θ::Vector{Float64}
Locals
  @_3::UNION{NOTHING, TUPLE{INT64, INT64}}
  @_4::Int64
  D::Matrix{Float64}
  Γ::Array{Float64, 3}
  G::Array{Float64, 3}
  Gi::Array{Float64, 3}
  gn::Adjoint{Float64, Matrix{Float64}}
  Ω::Hermitian{Float64, Matrix{Float64}}
  k::Int64
  n::Int64
  p::Int64
  giθ::Matrix{Float64}
  @_15::UNION{NOTHING, TUPLE{INT64, INT64}}
  j::Int64
  i::Int64
Body::Tuple{Int64, Int64, Int64, Adjoint{Float64, Matrix{Float64}}, Hermitian{Float64, Matrix{Float64}}, Matrix{Float64}, var"#P#statparts##3"}
1%1   = Core.getfield(#self#, Symbol("#2836#gi"))::var"#69#70"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}
│          (giθ = (%1)(θ))
%3   = Main.length::Core.Const(length)
│          (p = (%3)(θ))
%5   = Main.size::Core.Const(size)
%6   = giθ::Matrix{Float64}
%7   = (%5)(%6)::Tuple{Int64, Int64}
%8   = Base.indexed_iterate(%7, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│          (n = Core.getfield(%8, 1))
│          (@_4 = Core.getfield(%8, 2))
%11  = @_4::Core.Const(2)
%12  = Base.indexed_iterate(%7, 2, %11)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│          (k = Core.getfield(%12, 1))
%14  = Main.Hermitian::Core.Const(Hermitian)
%15  = Main.cov::Core.Const(Statistics.cov)
%16  = giθ::Matrix{Float64}
%17  = (%15)(%16)::Matrix{Float64}
│          (Ω = (%14)(%17))
%19  = Main.:var"'"::Core.Const(adjoint)
%20  = Core.getfield(#self#, Symbol("#2836#gi"))::var"#69#70"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}
%21  = (%20)(θ)::Matrix{Float64}
%22  = (:dims,)::Core.Const((:dims,))
%23  = Core.apply_type(Core.NamedTuple, %22)::Core.Const(NamedTuple{(:dims,)})
%24  = Core.tuple(1)::Core.Const((1,))
%25  = (%23)(%24)::Core.Const((dims = 1,))
%26  = Main.mean::Core.Const(Statistics.mean)
%27  = Core.kwcall(%25, %26, %21)::Matrix{Float64}
│          (gn = (%19)(%27))
%29  = Main.zeros::Core.Const(zeros)
%30  = n::Int64
%31  = k::Int64
%32  = p::Int64
│          (Gi = (%29)(%30, %31, %32))
%34  = Main.ForwardDiff::Core.Const(ForwardDiff)
%35  = Base.getproperty(%34, :jacobian!)::Core.Const(ForwardDiff.jacobian!)
%36  = Gi::Array{Float64, 3}
%37  = Core.getfield(#self#, Symbol("#2836#gi"))::var"#69#70"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}
│          (%35)(%36, %37, θ)
%39  = Main.reshape::Core.Const(reshape)
%40  = Gi::Array{Float64, 3}
%41  = n::Int64
%42  = k::Int64
%43  = p::Int64
│          (Gi = (%39)(%40, %41, %42, %43))
%45  = (:dims,)::Core.Const((:dims,))
%46  = Core.apply_type(Core.NamedTuple, %45)::Core.Const(NamedTuple{(:dims,)})
%47  = Core.tuple(1)::Core.Const((1,))
%48  = (%46)(%47)::Core.Const((dims = 1,))
%49  = Main.mean::Core.Const(Statistics.mean)
%50  = Gi::Array{Float64, 3}
│          (G = Core.kwcall(%48, %49, %50))
%52  = Main.zeros::Core.Const(zeros)
%53  = Main.eltype::Core.Const(eltype)
%54  = Gi::Array{Float64, 3}
%55  = (%53)(%54)::Core.Const(Float64)
%56  = p::Int64
%57  = k::Int64
%58  = k::Int64
│          (Γ = (%52)(%55, %56, %57, %58))
%60  = Main.zeros::Core.Const(zeros)
%61  = Main.eltype::Core.Const(eltype)
%62  = Gi::Array{Float64, 3}
%63  = (%61)(%62)::Core.Const(Float64)
%64  = k::Int64
%65  = p::Int64
│          (D = (%60)(%63, %64, %65))
%67  = Main.:(:)::Core.Const(Colon())
%68  = p::Int64
%69  = (%67)(1, %68)::Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])
│          (@_3 = Base.iterate(%69))
%71  = @_3::UNION{NOTHING, TUPLE{INT64, INT64}}
%72  = (%71 === nothing)::Bool
%73  = Base.not_int(%72)::Bool
└──        goto #7 if not %73
2%75  = @_3::Tuple{Int64, Int64}
│          (j = Core.getfield(%75, 1))
%77  = Core.getfield(%75, 2)::Int64
%78  = Main.:(:)::Core.Const(Colon())
%79  = n::Int64
%80  = (%78)(1, %79)::Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])
│          (@_15 = Base.iterate(%80))
%82  = @_15::UNION{NOTHING, TUPLE{INT64, INT64}}
%83  = (%82 === nothing)::Bool
%84  = Base.not_int(%83)::Bool
└──        goto #5 if not %84
3%86  = @_15::Tuple{Int64, Int64}
│          (i = Core.getfield(%86, 1))
%88  = Core.getfield(%86, 2)::Int64
%89  = Main.:+::Core.Const(+)
%90  = Γ::Array{Float64, 3}
%91  = j::Int64
%92  = Main.:(:)::Core.Const(Colon())
%93  = Main.:(:)::Core.Const(Colon())
%94  = Base.getindex(%90, %91, %92, %93)::Matrix{Float64}
%95  = Main.:*::Core.Const(*)
%96  = Main.:-::Core.Const(-)
%97  = Gi::Array{Float64, 3}
%98  = i::Int64
%99  = Main.:(:)::Core.Const(Colon())
%100 = j::Int64
%101 = Base.getindex(%97, %98, %99, %100)::Vector{Float64}
%102 = G::Array{Float64, 3}
%103 = Main.:(:)::Core.Const(Colon())
%104 = j::Int64
%105 = Base.getindex(%102, 1, %103, %104)::Vector{Float64}
%106 = Base.broadcasted(%96, %101, %105)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(-), Tuple{Vector{Float64}, Vector{Float64}}}
%107 = Base.materialize(%106)::Vector{Float64}
%108 = Main.:var"'"::Core.Const(adjoint)
%109 = giθ::Matrix{Float64}
%110 = i::Int64
%111 = Main.:(:)::Core.Const(Colon())
%112 = Base.getindex(%109, %110, %111)::Vector{Float64}
%113 = (%108)(%112)::Adjoint{Float64, Vector{Float64}}
%114 = (%95)(%107, %113)::Matrix{Float64}
%115 = (%89)(%94, %114)::Matrix{Float64}
%116 = Γ::Array{Float64, 3}
%117 = j::Int64
%118 = Main.:(:)::Core.Const(Colon())
%119 = Main.:(:)::Core.Const(Colon())
Base.setindex!(%116, %115, %117, %118, %119)
│          (@_15 = Base.iterate(%80, %88))
%122 = @_15::UNION{NOTHING, TUPLE{INT64, INT64}}
%123 = (%122 === nothing)::Bool
%124 = Base.not_int(%123)::Bool
└──        goto #5 if not %124
4 ─        goto #3
5%127 = Γ::Array{Float64, 3}
%128 = j::Int64
%129 = Main.:(:)::Core.Const(Colon())
%130 = Main.:(:)::Core.Const(Colon())
%131 = Base.dotview(%127, %128, %129, %130)::SubArray{Float64, 2, Array{Float64, 3}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}}, true}
%132 = Main.:/::Core.Const(/)
%133 = Γ::Array{Float64, 3}
%134 = j::Int64
%135 = Main.:(:)::Core.Const(Colon())
%136 = Main.:(:)::Core.Const(Colon())
%137 = Base.getindex(%133, %134, %135, %136)::Matrix{Float64}
%138 = n::Int64
%139 = Base.broadcasted(%132, %137, %138)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(/), Tuple{Matrix{Float64}, Int64}}
Base.materialize!(%131, %139)
%141 = Main.:-::Core.Const(-)
%142 = G::Array{Float64, 3}
%143 = Main.:(:)::Core.Const(Colon())
%144 = j::Int64
%145 = Base.getindex(%142, 1, %143, %144)::Vector{Float64}
%146 = Main.:*::Core.Const(*)
%147 = Γ::Array{Float64, 3}
%148 = j::Int64
%149 = Main.:(:)::Core.Const(Colon())
%150 = Main.:(:)::Core.Const(Colon())
%151 = Base.getindex(%147, %148, %149, %150)::Matrix{Float64}
%152 = Main.inv::Core.Const(inv)
%153 = Ω::Core.PartialStruct(Hermitian{Float64, Matrix{Float64}}, Any[Matrix{Float64}, Core.Const('U')])
%154 = (%152)(%153)::Core.PartialStruct(Hermitian{Float64, Matrix{Float64}}, Any[Matrix{Float64}, Core.Const('U')])
%155 = gn::Adjoint{Float64, Matrix{Float64}}
%156 = (%146)(%151, %154, %155)::Matrix{Float64}
%157 = (%141)(%145, %156)::Matrix{Float64}
%158 = D::Matrix{Float64}
%159 = Main.:(:)::Core.Const(Colon())
%160 = j::Int64
Base.setindex!(%158, %157, %159, %160)
│          (@_3 = Base.iterate(%69, %77))
%163 = @_3::UNION{NOTHING, TUPLE{INT64, INT64}}
%164 = (%163 === nothing)::Bool
%165 = Base.not_int(%164)::Bool
└──        goto #7 if not %165
6 ─        goto #2
7%168 = n::Int64
%169 = k::Int64
%170 = p::Int64
%171 = gn::Adjoint{Float64, Matrix{Float64}}
%172 = Ω::Core.PartialStruct(Hermitian{Float64, Matrix{Float64}}, Any[Matrix{Float64}, Core.Const('U')])
%173 = D::Matrix{Float64}
%174 = Core.getfield(#self#, :P)::Core.Const(var"#P#statparts##3"())
%175 = Core.tuple(%168, %169, %170, %171, %172, %173, %174)::Core.PartialStruct(Tuple{Int64, Int64, Int64, Adjoint{Float64, Matrix{Float64}}, Hermitian{Float64, Matrix{Float64}}, Matrix{Float64}, var"#P#statparts##3"}, Any[Int64, Int64, Int64, Adjoint{Float64, Matrix{Float64}}, Core.PartialStruct(Hermitian{Float64, Matrix{Float64}}, Any[Matrix{Float64}, Core.Const('U')]), Matrix{Float64}, var"#P#statparts##3"])
└──        return %175

I also added the where {F statement to ensure compiler specialization, and added the let gi=gi line to help with the performance of captured variables.

julia> @code_warntype klm(gi)(β0)
MethodInstance for (::var"#klm##0#klm##1"{var"#statparts##4#statparts##5"{var"#P#statparts##3", var"#69#70"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}}})(::Vector{Float64})
  from (::var"#klm##0#klm##1")(θ) @ Main ~/ARGridBootstrap/docs/jmd/assignment.jmd:31
Arguments
  #self#::var"#klm##0#klm##1"{var"#statparts##4#statparts##5"{var"#P#statparts##3", var"#69#70"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}}}
  θ::Vector{Float64}
Locals
  @_3::Int64
  P::var"#P#statparts##3"
  D::Matrix{Float64}
  Ω::Hermitian{Float64, Matrix{Float64}}
  gn::Adjoint{Float64, Matrix{Float64}}
  p::Int64
  k::Int64
  n::Int64
Body::UNION{FLOAT64, COMPLEXF64}
1%1  = Core.getfield(#self#, :SP)::var"#statparts##4#statparts##5"{var"#P#statparts##3", var"#69#70"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}}
%2  = (%1)(θ)::Core.PartialStruct(Tuple{Int64, Int64, Int64, Adjoint{Float64, Matrix{Float64}}, Hermitian{Float64, Matrix{Float64}}, Matrix{Float64}, var"#P#statparts##3"}, Any[Int64, Int64, Int64, Adjoint{Float64, Matrix{Float64}}, Core.PartialStruct(Hermitian{Float64, Matrix{Float64}}, Any[Matrix{Float64}, Core.Const('U')]), Matrix{Float64}, var"#P#statparts##3"])
%3  = Base.indexed_iterate(%2, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│         (n = Core.getfield(%3, 1))
│         (@_3 = Core.getfield(%3, 2))
%6  = @_3::Core.Const(2)
%7  = Base.indexed_iterate(%2, 2, %6)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│         (k = Core.getfield(%7, 1))
│         (@_3 = Core.getfield(%7, 2))
%10 = @_3::Core.Const(3)
%11 = Base.indexed_iterate(%2, 3, %10)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(4)])
│         (p = Core.getfield(%11, 1))
│         (@_3 = Core.getfield(%11, 2))
%14 = @_3::Core.Const(4)
%15 = Base.indexed_iterate(%2, 4, %14)::Core.PartialStruct(Tuple{Adjoint{Float64, Matrix{Float64}}, Int64}, Any[Adjoint{Float64, Matrix{Float64}}, Core.Const(5)])
│         (gn = Core.getfield(%15, 1))
│         (@_3 = Core.getfield(%15, 2))
%18 = @_3::Core.Const(5)
%19 = Base.indexed_iterate(%2, 5, %18)::Core.PartialStruct(Tuple{Hermitian{Float64, Matrix{Float64}}, Int64}, Any[Core.PartialStruct(Hermitian{Float64, Matrix{Float64}}, Any[Matrix{Float64}, Core.Const('U')]), Core.Const(6)])
│         (Ω = Core.getfield(%19, 1))
│         (@_3 = Core.getfield(%19, 2))
%22 = @_3::Core.Const(6)
%23 = Base.indexed_iterate(%2, 6, %22)::Core.PartialStruct(Tuple{Matrix{Float64}, Int64}, Any[Matrix{Float64}, Core.Const(7)])
│         (D = Core.getfield(%23, 1))
│         (@_3 = Core.getfield(%23, 2))
%26 = @_3::Core.Const(7)
%27 = Base.indexed_iterate(%2, 7, %26)::Core.Const((var"#P#statparts##3"(), 8))
│         (P = Core.getfield(%27, 1))
%29 = Main.:*::Core.Const(*)
%30 = n::Int64
%31 = Main.:*::Core.Const(*)
%32 = Main.:var"'"::Core.Const(adjoint)
%33 = gn::Adjoint{Float64, Matrix{Float64}}
%34 = (%32)(%33)::Matrix{Float64}
%35 = Main.:^::Core.Const(^)
%36 = Ω::Core.PartialStruct(Hermitian{Float64, Matrix{Float64}}, Any[Matrix{Float64}, Core.Const('U')])
%37 = Main.:/::Core.Const(/)
%38 = (%37)(-1, 2)::Core.Const(-0.5)
%39 = (%35)(%36, %38)::UNION{HERMITIAN{FLOAT64, MATRIX{FLOAT64}}, SYMMETRIC{COMPLEXF64, MATRIX{COMPLEXF64}}}
%40 = P::Core.Const(var"#P#statparts##3"())
%41 = Main.:*::Core.Const(*)
%42 = Main.:^::Core.Const(^)
%43 = Ω::Core.PartialStruct(Hermitian{Float64, Matrix{Float64}}, Any[Matrix{Float64}, Core.Const('U')])
%44 = Main.:/::Core.Const(/)
%45 = (%44)(-1, 2)::Core.Const(-0.5)
%46 = (%42)(%43, %45)::UNION{HERMITIAN{FLOAT64, MATRIX{FLOAT64}}, SYMMETRIC{COMPLEXF64, MATRIX{COMPLEXF64}}}
%47 = D::Matrix{Float64}
%48 = (%41)(%46, %47)::UNION{MATRIX{COMPLEXF64}, MATRIX{FLOAT64}}
%49 = (%40)(%48)::UNION{MATRIX{COMPLEXF64}, MATRIX{FLOAT64}}
%50 = Main.:^::Core.Const(^)
%51 = Ω::Core.PartialStruct(Hermitian{Float64, Matrix{Float64}}, Any[Matrix{Float64}, Core.Const('U')])
%52 = Main.:/::Core.Const(/)
%53 = (%52)(-1, 2)::Core.Const(-0.5)
%54 = (%50)(%51, %53)::UNION{HERMITIAN{FLOAT64, MATRIX{FLOAT64}}, SYMMETRIC{COMPLEXF64, MATRIX{COMPLEXF64}}}
%55 = gn::Adjoint{Float64, Matrix{Float64}}
%56 = (%31)(%34, %39, %49, %54, %55)::UNION{MATRIX{COMPLEXF64}, MATRIX{FLOAT64}}
%57 = Base.getindex(%56, 1)::UNION{FLOAT64, COMPLEXF64}
%58 = (%29)(%30, %57)::UNION{FLOAT64, COMPLEXF64}
└──       return %58

There’s still a type-instability in klm. This one is harder to understand. It is due to the fact that the appropriate meaning of a matrix square root depends on the nature of the matrix. In particular, the value could be a complex valued matrix instead of real valued. We know that Ω should be positive definite with a real matrix square root. We can compute its square root from its Eigen decomposition and avoid the type instability.

function klm(gi::F ) where {F <: Function}
  let gi=gi
    function(θ)
      (n,k,p,gn, Ω, D, P) = statparts(gi)(θ)
      λ, v = eigen(Ω)
      irΩ = v*diagm.^(-1/2))*v'
      return n*(gn'*irΩ*P(irΩ*D)*irΩ*gn)[1]
    end
  end
end
klm (generic function with 1 method)
julia> @code_warntype klm(gi)(β0)
MethodInstance for (::var"#klm##2#klm##3"{var"#69#70"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}})(::Vector{Float64})
  from (::var"#klm##2#klm##3")(θ) @ Main ~/ARGridBootstrap/docs/jmd/assignment.jmd:4
Arguments
  #self#::var"#klm##2#klm##3"{var"#69#70"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}}
  θ::Vector{Float64}
Locals
  @_3::Val{:vectors}
  @_4::Int64
  irΩ::Matrix{Float64}
  v::Matrix{Float64}
  λ::Vector{Float64}
  P::var"#P#statparts##3"
  D::Matrix{Float64}
  Ω::Hermitian{Float64, Matrix{Float64}}
  gn::Adjoint{Float64, Matrix{Float64}}
  p::Int64
  k::Int64
  n::Int64
Body::Float64
1%1  = Main.statparts::Core.Const(Main.statparts)
%2  = Core.getfield(#self#, Symbol("#2839#gi"))::var"#69#70"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}
%3  = (%1)(%2)::var"#statparts##4#statparts##5"{var"#P#statparts##3", var"#69#70"{Vector{Float64}, Matrix{Float64}, Matrix{Float64}}}
%4  = (%3)(θ)::Core.PartialStruct(Tuple{Int64, Int64, Int64, Adjoint{Float64, Matrix{Float64}}, Hermitian{Float64, Matrix{Float64}}, Matrix{Float64}, var"#P#statparts##3"}, Any[Int64, Int64, Int64, Adjoint{Float64, Matrix{Float64}}, Core.PartialStruct(Hermitian{Float64, Matrix{Float64}}, Any[Matrix{Float64}, Core.Const('U')]), Matrix{Float64}, var"#P#statparts##3"])
%5  = Base.indexed_iterate(%4, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│         (n = Core.getfield(%5, 1))
│         (@_4 = Core.getfield(%5, 2))
%8  = @_4::Core.Const(2)
%9  = Base.indexed_iterate(%4, 2, %8)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│         (k = Core.getfield(%9, 1))
│         (@_4 = Core.getfield(%9, 2))
%12 = @_4::Core.Const(3)
%13 = Base.indexed_iterate(%4, 3, %12)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(4)])
│         (p = Core.getfield(%13, 1))
│         (@_4 = Core.getfield(%13, 2))
%16 = @_4::Core.Const(4)
%17 = Base.indexed_iterate(%4, 4, %16)::Core.PartialStruct(Tuple{Adjoint{Float64, Matrix{Float64}}, Int64}, Any[Adjoint{Float64, Matrix{Float64}}, Core.Const(5)])
│         (gn = Core.getfield(%17, 1))
│         (@_4 = Core.getfield(%17, 2))
%20 = @_4::Core.Const(5)
%21 = Base.indexed_iterate(%4, 5, %20)::Core.PartialStruct(Tuple{Hermitian{Float64, Matrix{Float64}}, Int64}, Any[Core.PartialStruct(Hermitian{Float64, Matrix{Float64}}, Any[Matrix{Float64}, Core.Const('U')]), Core.Const(6)])
│         (Ω = Core.getfield(%21, 1))
│         (@_4 = Core.getfield(%21, 2))
%24 = @_4::Core.Const(6)
%25 = Base.indexed_iterate(%4, 6, %24)::Core.PartialStruct(Tuple{Matrix{Float64}, Int64}, Any[Matrix{Float64}, Core.Const(7)])
│         (D = Core.getfield(%25, 1))
│         (@_4 = Core.getfield(%25, 2))
%28 = @_4::Core.Const(7)
%29 = Base.indexed_iterate(%4, 7, %28)::Core.Const((var"#P#statparts##3"(), 8))
│         (P = Core.getfield(%29, 1))
%31 = Main.eigen::Core.Const(LinearAlgebra.eigen)
%32 = Ω::Core.PartialStruct(Hermitian{Float64, Matrix{Float64}}, Any[Matrix{Float64}, Core.Const('U')])
%33 = (%31)(%32)::Eigen{Float64, Float64, Matrix{Float64}, Vector{Float64}}
%34 = Base.indexed_iterate(%33, 1)::Tuple{Vector{Float64}, Val{:vectors}}
│         (λ = Core.getfield(%34, 1))
│         (@_3 = Core.getfield(%34, 2))
%37 = @_3::Core.Const(Val{:vectors}())
%38 = Base.indexed_iterate(%33, 2, %37)::Tuple{Matrix{Float64}, Val{:done}}
│         (v = Core.getfield(%38, 1))
%40 = Main.:*::Core.Const(*)
%41 = v::Matrix{Float64}
%42 = Main.diagm::Core.Const(LinearAlgebra.diagm)
%43 = Main.:^::Core.Const(^)
%44 = λ::Vector{Float64}
%45 = Main.:/::Core.Const(/)
%46 = (%45)(-1, 2)::Core.Const(-0.5)
%47 = Base.broadcasted(%43, %44, %46)::Core.PartialStruct(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(^), Tuple{Vector{Float64}, Float64}}, Any[Core.Const(Base.Broadcast.DefaultArrayStyle{1}()), Core.Const(^), Core.PartialStruct(Tuple{Vector{Float64}, Float64}, Any[Vector{Float64}, Core.Const(-0.5)]), Core.Const(nothing)])
%48 = Base.materialize(%47)::Vector{Float64}
%49 = (%42)(%48)::Matrix{Float64}
%50 = Main.:var"'"::Core.Const(adjoint)
%51 = v::Matrix{Float64}
%52 = (%50)(%51)::Adjoint{Float64, Matrix{Float64}}
│         (irΩ = (%40)(%41, %49, %52))
%54 = Main.:*::Core.Const(*)
%55 = n::Int64
%56 = Main.:*::Core.Const(*)
%57 = Main.:var"'"::Core.Const(adjoint)
%58 = gn::Adjoint{Float64, Matrix{Float64}}
%59 = (%57)(%58)::Matrix{Float64}
%60 = irΩ::Matrix{Float64}
%61 = P::Core.Const(var"#P#statparts##3"())
%62 = Main.:*::Core.Const(*)
%63 = irΩ::Matrix{Float64}
%64 = D::Matrix{Float64}
%65 = (%62)(%63, %64)::Matrix{Float64}
%66 = (%61)(%65)::Matrix{Float64}
%67 = irΩ::Matrix{Float64}
%68 = gn::Adjoint{Float64, Matrix{Float64}}
%69 = (%56)(%59, %60, %66, %67, %68)::Matrix{Float64}
%70 = Base.getindex(%69, 1)::Float64
%71 = (%54)(%55, %70)::Float64
└──       return %71

Fixing these type instabilities speeds up the code by a factor of about 5.

@benchmark klm(gi)(β0)
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range (min … max):  27.542 μs …  30.530 ms  ┊ GC (min … max):  0.00% … 99.
38%
 Time  (median):     44.143 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   54.128 μs ± 473.233 μs  ┊ GC (mean ± σ):  18.62% ±  2.
22%

       ▁▁             ▂▅▆█▆▂                                    
  ▁▂▃▆███▇▇▅▄▄▄▄▃▃▃▃▄▆███████▆▅▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
  27.5 μs         Histogram: frequency by time         72.7 μs <

 Memory estimate: 187.52 KiB, allocs estimate: 2972.

Reducing allocations and Other Optimizations

Profiling reveals the majority of time is spent in the innermost loop of the statparts function. This loop allocates quite a bit because the arrays are using slices. We can avoid allocations by using @views and more broadcasting. See “Consider using views for slices” and “More dots”,

function statparts(gi::F) where {F <: Function}
  function P(A::AbstractMatrix) # projection matrix
    A*pinv(A'*A)*A'
  end
  let gi=gi
    function(θ)
      giθ = gi(θ)
      p = length(θ)
      (n, k) = size(giθ)
      Ω = Hermitian(cov(giθ))
      gn=mean(gi(θ), dims=1)'
      iΩgn = Ω \ gn
      Gi = zeros(n,k,p)
      ForwardDiff.jacobian!(Gi,gi,θ)
      Gi = reshape(Gi, n , k, p)
      G = mean(Gi, dims=1)
      Γ = zeros(eltype(Gi),p,k,k)
      D = zeros(eltype(Gi),k, p)
      @inbounds for j in 1:p
        @inbounds for i in 1:n
          @views Γ[j,:,:] .+= (Gi[i,:,j] .- G[1,:,j]) * giθ[i,:]'
        end
        Γ[j,:,:] ./= n
        @views D[:,j] .= G[1,:,j] .- Γ[j,:,:]*iΩgn
      end
      return(n,k,p,gn, Ω, D, P)
    end
  end
end
@benchmark klm(gi)(β0)
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range (min … max):  14.327 μs …  27.651 ms  ┊ GC (min … max):  0.00% … 99.
59%
 Time  (median):     20.208 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   23.216 μs ± 276.399 μs  ┊ GC (mean ± σ):  11.86% ±  1.
00%

      ▃▇▇▇▄▁      ▁▂▄▆███▆▅▂                                    
  ▁▁▃▅██████▆▅▅▇▇████████████▇▅▄▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂ ▄
  14.3 μs         Histogram: frequency by time         32.6 μs <

 Memory estimate: 80.09 KiB, allocs estimate: 946.

The code is now about ten times faster than the original.

Footnotes

  1. Like scalar variables, StaticArrays exist on the stack instead of the heap, so creating them is much less costly and they do not count toward the reported allocations.↩︎