a distributed method for fitting laplacian regularized ...boyd/papers/pdf/strat_models_talk.pdf ·...
TRANSCRIPT
-
A Distributed Method for FittingLaplacian Regularized Stratified Models
Jonathan Tuck Shane Barratt Stephen Boyd
International Conference on Statistical Optimization and LearningBeijing Jiatong University, June 19 2019
-
Outline
Stratified models
Data models
Regularization graphs
Solution method
Examples
Conclusions
Stratified models 2
-
Data records
I Data records have form (z , x , y) ∈ Z × X × YI z ∈ Z = {1, . . . ,K} are categorical features (we’ll stratify over)I x ∈ X are the other featuresI y ∈ Y is the label/dependent variable
Stratified models 3
-
Stratified model
I Fit a model to data (z , x , y)
I Stratified model: fit different model for each value of z
I θk is model parameter for z = k
I θ = (θ1, . . . , θK ) ∈ Θ ⊆ RKn parameterizes the stratified modelI Old idea [Kernan 99]
I Example: stratified regression model ŷ = xT θz
Stratified models 4
-
Example
0.0 0.2 0.4 0.6 0.8 1.0
10
5
0
5
10z=1z=-1
ŷ = θ1 + θ2x + θ3z , z ∈ {−1, 1}
Stratified models 5
-
Example
0.0 0.2 0.4 0.6 0.8 1.0
10
5
0
5
10z=1z=-1
ŷ =
{θ−1,1 + θ−1,2x z = −1θ1,1 + θ1,2x z = 1
Stratified models 6
-
Stratified model
I Stratified model is simple function of x (e.g., linear), arbitrary function of zI Examples: separate models for
– z ∈ {Male, Female}– z ∈ {Monday, . . . , Sunday}
I If K is large, might not have enough training data to fit θkI Extreme case: no training data for some values of z
I We’ll add regularization so nearby θk ’s are close
Stratified models 7
-
Stratified model with Laplacian regularization
I Choose θ = (θ1, . . . , θK ) to minimize
N∑i=1
l(θzi , xi , yi ) +K∑
k=1
r(θk) +K∑
u,v=1
Wuv‖θu − θv‖2
I l is loss function, r is (local) regularization
I Last term is Laplacian regularization
I Wuv ≥ 0 are edge weights of graph with node set ZI Convex problem when l , r convex in θ
Stratified models 8
-
Stratified model with Laplacian regularization
I Graph encodes prior that nearby values of z should have similar models
I Can be used to capture periodicities, other structureI Examples:
– θmale and θfemale should be close– θjan should be close to θfeb and θdec
I Model for each value of z ‘borrows strength’ from its neighbors
I Works even when there’s no data for some values of z
I As Wuv → 0, get traditional (unregularized) stratified modelI As Wuv →∞, get common model (θ1 = · · · = θK )
Stratified models 9
-
Related work
I Network lasso [Hallac 15])
I Pliable lasso [Tibshirani 17]
I Varying-coefficient models [Hastie 93, Fan 08]
I Multi-task learning [Caruana 97]
Stratified models 10
-
Outline
Stratified models
Data models
Regularization graphs
Solution method
Examples
Conclusions
Data models 11
-
Point estimate: Predict y
I Regression: X = Rn, Y = R– l(θ, x , y) = p(xT θ − y), p is penalty function– ŷ = xT θz
I Classification: X = Rn, Y = {−1, 1}– l(θ, x , y) = p(yxT θ)– ŷ = sign(xT θz)
I M-class classification: X ∈ Rn, Y = {1, . . . ,M}– l(θ, x , y) = py (x
T θ), θ ∈ Rn×M , py is penalty function for class y– ŷ = argmax(xT θz)
Data models 12
-
Conditional distribution estimate: Predict prob(y | x , z)
I Multinomial logistic regression: X = Rn, Y = {1, . . . ,M}I Conditional distribution:
prob(y | x , z) = exp(xT θz)y∑M
j=1 exp(xT θz)j
, y = 1, . . . ,M
I Loss function (convex in θ):
l(θ, x , y) = log
M∑j=1
exp(xT θ)j
− (xT θ)yData models 13
-
Distribution estimate: Predict p(y | z)
I Gaussian distribution: Y = Rm
I Density:
p(y | z) = (2π)−m/2 det(Σ)−1/2 exp(−1/2(y − µ)TΣ−1(y − µ))
I Use natural parameter θ = (S , ν) = (Σ−1,Σ−1µ) (so Σ = S−1, µ = S−1ν)
I Loss function (convex in θ):
l(θ, y) = − log detS + yTSy − 2yTν + νTS−1ν
Data models 14
-
Outline
Stratified models
Data models
Regularization graphs
Solution method
Examples
Conclusions
Regularization graphs 15
-
Path graph
0 1 2 3 4 5
I Models time, distance, . . .
I Yields time-varying, distance-varying models
Regularization graphs 16
-
Cycle graph
M
TuW
Th
F
SaSu
I Yields diurnal, weekly, seasonally-varying models
Regularization graphs 17
-
Tree graph
CEO
CTO
SVP1
COO
SVP2
VP1 VP2
SVP3
VP3
I Yields hierarchical models
Regularization graphs 18
-
Grid graph
0,0
0,1
0,2
0,3
1,0
1,1
1,2
1,3
2,0
2,1
2,2
2,3
3,0
3,1
3,2
3,3
I Yields (2D) space-varying models
Regularization graphs 19
-
Products of graphs
M,1
F,1
M,2
F,2
· · ·
· · ·
M,99
F,99
M,100
F,100
I Z = {Male, Female} × {1, 2, . . . , 99, 100} (sex × age)I Yields sex, age-varying models
Regularization graphs 20
-
Outline
Stratified models
Data models
Regularization graphs
Solution method
Examples
Conclusions
Solution method 21
-
Fitting problem
I To fit stratified model, minimize `(θ) + r(θ) + L(θ)I `(θ) =
∑Kk=1 `k(θk) is loss, `k(θk) =
∑i :zi=k
l(θk , xi , yi )
I r(θ) =∑K
k=1 r(θk) is (local) regularization
I L(θ) =∑K
u,v=1Wuv‖θu − θv‖2 is Laplacian regularization
I `, r are separable in θkI L is quadratic, separable in components of θk
I We’ll use operator splitting method (ADMM)
Solution method 22
-
Reformulation
I Replicate variables:minimize `(θ) + r(θ̃) + L(θ̂)subject to θ = θ̃ = θ̂
I Augmented Lagrangian
L(θ, θ̃, θ̂, u, ũ) = `(θ) + r(θ̃) + L(θ̂) + 12t‖θ − θ̂ + u‖22 +
1
2t‖θ̃ − θ̂ + ũ‖22
I u, ũ dual variables for the two constraints, t > 0 is algorithm parameter
Solution method 23
-
ADMM
I Augmented Lagrangian
L(θ, θ̃, θ̂, u, ũ) = `(θ) + r(θ̃) + L(θ̂) + 12t‖θ − θ̂ + u‖22 +
1
2t‖θ̃ − θ̂ + ũ‖22
I ADMM: for i = 1, 2, . . .
θi+1, θ̃i+1 := argminθ,θ̃
L(θ, θ̃, θ̂i , ui , ũi )
θ̂i+1 := argminθ̂
L(θi , θ̃i , θ̂, ui , ũi )
ui+1 := ui + θi+1 − θ̂i+1
ũi+1 := ũi + θ̃i+1 − θ̂i+1
Solution method 24
-
ADMM
I First step can be expressed as
θi+1k = proxt`k (θ̂ik − uik), θ̃i+1k = proxtr (θ̂
ik − ũik), k = 1, . . . ,K
I proximal operator of tg is
proxtg (ν) = argminθ
(tg(θ) + (1/2)‖θ − ν‖22
)I Can evaluate these 2K proximal operators in parallel
Solution method 25
-
Loss proximal operators
I Often has closed form expression or efficient implementationI Square loss: `k(θk) = ‖Xkθk − yk‖22
– proxt`k (ν) = (I + tXTk Xk)
−1(ν − tXTk yk)– Cache factorization or warm-start CG
I Logistic loss: `k(θk) =∑
i log(1 + exp(−ykiθTk xki ))– Use L-BFGS to evaluate proxt`k (ν)– Warm-start
I Many others . . .
Solution method 26
-
Regularizer proximal operators
I Often has closed form expression or efficient implementation
Regularizer r(θk) proxtr (ν)
Sum of squares (`2) (λ/2)‖θk‖22 ν/(tλ+ 1)
`1 norm λ‖θk‖1 (ν − tλ)+ − (−ν − tλ)+Nonnegative I+(θk) (θk)+
I λ > 0 local regularization parameter
Solution method 27
-
Laplacian proximal operator
I Separable across each component of θkI To find each component (θk)i , need to solve a Laplacian system
I Many efficient ways to solve, e.g., diagonally preconditioned CG
I These n systems can be solved in parallel
Solution method 28
-
Software implementation
I Available at www.github.com/cvxgrp/strat_models
I numpy, scipy for matrix operations
I networkx for handling graphs and graph operations
I torch for L-BFGS and GPU computation
I multiprocessing for parallelism
I model.fit(X,Y,Z,G) (writes θk on graph nodes)
Solution method 29
www.github.com/cvxgrp/strat_modelsnumpyscipynetworkxtorchmultiprocessingmodel.fit(X,Y,Z,G)
-
Outline
Stratified models
Data models
Regularization graphs
Solution method
Examples
Conclusions
Examples 30
-
House price prediction
I N ≈ 22000 (z , x , y) from King County, WAI Split into train/test ≈ 16200/5400I z = (latitude bin, longitude bin)
I x ∈ R10 = features of house, y = log of house priceI Graph is 50× 50 grid with all edge weights same; K = 2500I Stratified ridge regression model with two hyperparameters
(one for local regularization, one for Laplacian regularization)
Examples 31
-
House price prediction: Results
I Compare stratifed, common, and random forest with 50 trees
Model Parameters RMS test error
Stratified 25000 0.181Common 10 0.316Random forest 985888 0.184
Examples 32
-
House price prediction: Parameters
bedrooms bathrooms sqft living sqft lot floors
waterfront condition grade yr built intercept
Examples 33
-
Chicago crime prediction
I N ≈ 7 000 000 (z , y) pairs for 2017, 2018I Train on 2017, test on 2018
I y = number of crimes
I z = (location bin,week of year, day of week, hour of day); K ≈ 3 500 000I Graph is Cartesian product of grid, three cycles; four graph edge weights
I Stratified Poisson model with four hyperparameters(one for each graph edge weight)
Examples 34
-
Chicago crime prediction
I Compare three models, average negative log likelihood on test data
Model Train Test
Separate 0.068 0.740Stratified 0.221 0.234Common 0.279 0.278
Examples 35
-
Chicago crime prediction
Crime rate as a function of latitude/longitude
-87.85 -87.77 -87.69 -87.61 -87.52
42.02
41.93
41.83
41.74
41.64
0.02
0.04
0.06
0.08
0.10
0.12
Examples 36
-
Chicago crime prediction
Crime rate as a function of week of year
Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec
0.036
0.038
0.040
0.042
0.044
Examples 37
-
Chicago crime prediction
Crime rate as a function of hour of week
M Tu W Th F Sa Su
0.020
0.025
0.030
0.035
0.040
0.045
0.050
0.055
Examples 38
-
Outline
Stratified models
Data models
Regularization graphs
Solution method
Examples
Conclusions
Conclusions 39
-
Conclusions
I Stratified models combine
– simple dependence on some features (x)– complex dependence on others (z)
I Often interpretable
I Laplacian regularization encodes prior on values of z , so models can borrowstrength from their neighbors
I Effective method to build time-varying, space-varying, seasonally-varying models
I Efficient, distributed ADMM-based implementation for large-scale data
Conclusions 40
Stratified modelsData modelsRegularization graphsSolution methodExamplesConclusions