Skip to content

Commit

Permalink
add possibility to regularize
Browse files Browse the repository at this point in the history
  • Loading branch information
cspollard committed Jul 10, 2019
1 parent 07dfa24 commit 30af2de
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 48 deletions.
1 change: 0 additions & 1 deletion data/test.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
]

, "Lumi" : 100
, "Reg" : "Rising 5"
}

, "ModelVars" : {
Expand Down
53 changes: 6 additions & 47 deletions src/Model.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

module Model
( Model(..)
, mBkgs, mSig, mMig, mLumi, mReg
, mBkgs, mSig, mMig, mLumi
, ModelVar(..)
, mvBkgs, mvSig, mvMig, mvLumi
, ModelParam(..)
Expand All @@ -36,51 +36,12 @@ import Data.Aeson.Types
import Data.HashMap.Strict
import Data.Text (Text)
import Data.Vector (Vector)
import qualified Data.Vector as V
import GHC.Generics
import GHC.TypeLits
import Data.Monoid (Sum(..))
import Data.Foldable (foldMap)

import Matrix
import Probability


data Reg a
= NoReg
| Falling a
| Rising a
deriving (Show, Generic, Functor)

instance FromJSON a => FromJSON (Reg a) where
parseJSON = genericParseJSON defaultOptions


neighbordiff :: Num a => [a] -> [a]
neighbordiff [] = []
neighbordiff [_] = []
neighbordiff (x:xs@(y:_)) = x - y : neighbordiff xs


regularize :: (Ord a, Num a) => Reg a -> Vector a -> a

regularize NoReg _ = 0

regularize (Falling mu) sigs =
(*mu)
. getSum
. foldMap (\x -> if x > 0 then mempty else Sum (abs x))
. neighbordiff
$ V.toList sigs

regularize (Rising mu) sigs =
(*mu)
. getSum
. foldMap (\x -> if x < 0 then mempty else Sum x)
. neighbordiff
$ V.toList sigs


-- TODO
-- TODO!
-- update this to work with Vars from atlas.git!!
Expand All @@ -90,7 +51,6 @@ data Model a =
, _mSig :: Vector a
, _mMig :: Vector (Vector a)
, _mLumi :: a
, _mReg :: Reg a
} deriving (Show, Generic, Functor)

makeLenses ''Model
Expand All @@ -100,13 +60,12 @@ instance FromJSON a => FromJSON (Model a) where


addM :: Num a => Model a -> Model a -> Model a
addM (Model b s m l r) (Model b' s' m' l' _) =
addM (Model b s m l) (Model b' s' m' l') =
Model
(liftU2 (^+^) b b')
(s ^+^ s')
(liftU2 (^+^) m m')
(l + l')
r


data ModelVar a =
Expand All @@ -131,6 +90,7 @@ data ParamPrior a =
| LogNormal a a
deriving (Generic, Functor, Show)


instance FromJSON a => FromJSON (ParamPrior a) where
parseJSON (String "Flat") = return Flat
parseJSON (String "NonNegative") = return NonNegative
Expand Down Expand Up @@ -172,19 +132,19 @@ modelLogPosterior
-> Model a
-> Vector (ModelVar a)
-> Vector (a -> a)
-> (Vector a -> a)
-> Vector a
-> Either String a
modelLogPosterior dats model mps logPriors ps =
modelLogPosterior dats model mps logPriors logReg ps =
reifyVector ps $ \ps' -> do
mps' <- toEither "incorrect length of mps" $ fromVector mps
logPriors' <- toEither "incorrect length of logPriors" $ fromVector logPriors
model' <- appVars mps' ps' model
logLike <- toEither "incorrect length of data" $ modelLogLikelihood dats model'

let logPrior = sum $ logPriors' <*> ps'
logReg = regularize (_mReg model') (_mSig model')

return $ logLike + logPrior + logReg
return $ logLike + logPrior + logReg (_mSig model')



Expand Down Expand Up @@ -269,7 +229,6 @@ varDiff ModelVar{..} x Model{..} = toEither "failed to apply model variation." $
(toVector $ maybe zero (sigDiff x sig) vSig)
(toVectorM $ maybe zero (migDiff x mig) vMig)
(maybe 0 (lumiDiff x _mLumi) _mvLumi)
_mReg


toEither :: a -> Maybe b -> Either a b
Expand Down
1 change: 1 addition & 0 deletions src/RunModel.hs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ runModel hamParams nsamps outfile dataH model' modelparams = do
model
variations
(ppToFunc . fmap auto <$> priors)
(const 0)

gLogLH = grad logLH

Expand Down

0 comments on commit 30af2de

Please sign in to comment.