{-# LANGUAGE NoImplicitPrelude     #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE TypeFamilies          #-}

{-|
Module      : Stack.Storage.Util
Description : Utilities for other @Stack.Storage@ modules.
License     : BSD-3-Clause

Utilities for the other @Stack.Storage@ modules.
-}

module Stack.Storage.Util
  ( handleMigrationException
  , updateCollection
  , setUpdateDiff
  , listUpdateDiff
  ) where

import qualified Data.Set as Set
import           Database.Persist
                   ( BaseBackend, EntityField, Filter, PersistEntity
                   , PersistEntityBackend, PersistField, PersistQueryWrite
                   , SafeToInsert, (<-.), deleteWhere, insertMany_
                   )
import           Stack.Prelude
import           Stack.Types.Storage ( StoragePrettyException (..) )

-- | Efficiently update a collection of values with a given diff function.

updateCollection ::
     ( PersistEntityBackend record ~ BaseBackend backend
     , Eq (collection rawValue)
     , PersistEntity record
     , PersistField value
     , MonadIO m
     , PersistQueryWrite backend
     , SafeToInsert record
     , Foldable collection
     )
  => (collection rawValue -> collection rawValue -> ([Filter record], [value]))
  -> (value -> record)
  -> [Filter record]
  -> collection rawValue
  -> collection rawValue
  -> ReaderT backend m ()
updateCollection :: forall record backend (collection :: * -> *) rawValue value
       (m :: * -> *).
(PersistEntityBackend record ~ BaseBackend backend,
 Eq (collection rawValue), PersistEntity record, PersistField value,
 MonadIO m, PersistQueryWrite backend, SafeToInsert record,
 Foldable collection) =>
(collection rawValue
 -> collection rawValue -> ([Filter record], [value]))
-> (value -> record)
-> [Filter record]
-> collection rawValue
-> collection rawValue
-> ReaderT backend m ()
updateCollection collection rawValue
-> collection rawValue -> ([Filter record], [value])
fnDiffer value -> record
recordCons [Filter record]
extra collection rawValue
old collection rawValue
new =
  Bool -> ReaderT backend m () -> ReaderT backend m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (collection rawValue
old collection rawValue -> collection rawValue -> Bool
forall a. Eq a => a -> a -> Bool
/= collection rawValue
new) (ReaderT backend m () -> ReaderT backend m ())
-> ReaderT backend m () -> ReaderT backend m ()
forall a b. (a -> b) -> a -> b
$ do
    let ([Filter record]
oldMinusNewFilter, [value]
newMinusOld) = collection rawValue
-> collection rawValue -> ([Filter record], [value])
fnDiffer collection rawValue
old collection rawValue
new
    Bool -> ReaderT backend m () -> ReaderT backend m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Filter record] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Filter record]
oldMinusNewFilter) (ReaderT backend m () -> ReaderT backend m ())
-> ReaderT backend m () -> ReaderT backend m ()
forall a b. (a -> b) -> a -> b
$ [Filter record] -> ReaderT backend m ()
forall backend (m :: * -> *) record.
(PersistQueryWrite backend, MonadIO m,
 PersistRecordBackend record backend) =>
[Filter record] -> ReaderT backend m ()
forall (m :: * -> *) record.
(MonadIO m, PersistRecordBackend record backend) =>
[Filter record] -> ReaderT backend m ()
deleteWhere
      ([Filter record]
extra [Filter record] -> [Filter record] -> [Filter record]
forall a. [a] -> [a] -> [a]
++ [Filter record]
oldMinusNewFilter)
    Bool -> ReaderT backend m () -> ReaderT backend m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([value] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [value]
newMinusOld) (ReaderT backend m () -> ReaderT backend m ())
-> ReaderT backend m () -> ReaderT backend m ()
forall a b. (a -> b) -> a -> b
$ [record] -> ReaderT backend m ()
forall backend record (m :: * -> *).
(PersistStoreWrite backend, MonadIO m,
 PersistRecordBackend record backend, SafeToInsert record) =>
[record] -> ReaderT backend m ()
forall record (m :: * -> *).
(MonadIO m, PersistRecordBackend record backend,
 SafeToInsert record) =>
[record] -> ReaderT backend m ()
insertMany_ ([record] -> ReaderT backend m ())
-> [record] -> ReaderT backend m ()
forall a b. (a -> b) -> a -> b
$
      (value -> record) -> [value] -> [record]
forall a b. (a -> b) -> [a] -> [b]
map value -> record
recordCons ([value] -> [record]) -> [value] -> [record]
forall a b. (a -> b) -> a -> b
$ [value] -> [value]
forall a. [a] -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList [value]
newMinusOld

setUpdateDiff ::
     (Ord value, PersistField value)
  => EntityField record value
  -> Set value
  -> Set value
  -> ([Filter record], [value])
setUpdateDiff :: forall value record.
(Ord value, PersistField value) =>
EntityField record value
-> Set value -> Set value -> ([Filter record], [value])
setUpdateDiff EntityField record value
indexFieldCons Set value
old Set value
new =
  let oldMinusNew :: Set value
oldMinusNew = Set value -> Set value -> Set value
forall a. Ord a => Set a -> Set a -> Set a
Set.difference Set value
old Set value
new
  in  ([EntityField record value
indexFieldCons EntityField record value -> [value] -> Filter record
forall v typ.
PersistField typ =>
EntityField v typ -> [typ] -> Filter v
<-. Set value -> [value]
forall a. Set a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList Set value
oldMinusNew], Set value -> [value]
forall a. Set a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (Set value -> [value]) -> Set value -> [value]
forall a b. (a -> b) -> a -> b
$ Set value -> Set value -> Set value
forall a. Ord a => Set a -> Set a -> Set a
Set.difference Set value
new Set value
old)

listUpdateDiff ::
     (Ord value)
  => EntityField record Int
  -> [value]
  -> [value]
  -> ([Filter record], [(Int, value)])
listUpdateDiff :: forall value record.
Ord value =>
EntityField record Int
-> [value] -> [value] -> ([Filter record], [(Int, value)])
listUpdateDiff EntityField record Int
indexFieldCons [value]
old [value]
new =
  let oldSet :: Set (Int, value)
oldSet = [(Int, value)] -> Set (Int, value)
forall a. Ord a => [a] -> Set a
Set.fromList ([Int] -> [value] -> [(Int, value)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..] [value]
old)
      newSet :: Set (Int, value)
newSet = [(Int, value)] -> Set (Int, value)
forall a. Ord a => [a] -> Set a
Set.fromList ([Int] -> [value] -> [(Int, value)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..] [value]
new)
      oldMinusNew :: Set (Int, value)
oldMinusNew = Set (Int, value) -> Set (Int, value) -> Set (Int, value)
forall a. Ord a => Set a -> Set a -> Set a
Set.difference Set (Int, value)
oldSet Set (Int, value)
newSet
      indexList :: [Int]
indexList = ((Int, value) -> Int) -> [(Int, value)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int, value) -> Int
forall a b. (a, b) -> a
fst (Set (Int, value) -> [(Int, value)]
forall a. Set a -> [a]
Set.toList Set (Int, value)
oldMinusNew)
  in  ([EntityField record Int
indexFieldCons EntityField record Int -> [Int] -> Filter record
forall v typ.
PersistField typ =>
EntityField v typ -> [typ] -> Filter v
<-. [Int]
indexList], Set (Int, value) -> [(Int, value)]
forall a. Set a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (Set (Int, value) -> [(Int, value)])
-> Set (Int, value) -> [(Int, value)]
forall a b. (a -> b) -> a -> b
$ Set (Int, value) -> Set (Int, value) -> Set (Int, value)
forall a. Ord a => Set a -> Set a -> Set a
Set.difference Set (Int, value)
newSet Set (Int, value)
oldSet)

handleMigrationException :: HasLogFunc env => RIO env a -> RIO env a
handleMigrationException :: forall env a. HasLogFunc env => RIO env a -> RIO env a
handleMigrationException RIO env a
inner = do
  Either PantryException a
eres <- RIO env a -> RIO env (Either PantryException a)
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> m (Either e a)
try RIO env a
inner
  (PantryException -> RIO env a)
-> (a -> RIO env a) -> Either PantryException a -> RIO env a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either
    ( \PantryException
e -> case PantryException
e :: PantryException of
              MigrationFailure Text
desc Path Abs File
fp SomeException
ex ->
                StoragePrettyException -> RIO env a
forall e (m :: * -> *) a.
(Exception e, MonadIO m, Pretty e) =>
e -> m a
prettyThrowIO (StoragePrettyException -> RIO env a)
-> StoragePrettyException -> RIO env a
forall a b. (a -> b) -> a -> b
$ Text -> Path Abs File -> SomeException -> StoragePrettyException
StorageMigrationFailure Text
desc Path Abs File
fp SomeException
ex
              PantryException
_ -> PantryException -> RIO env a
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO PantryException
e
    )
    a -> RIO env a
forall a. a -> RIO env a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    Either PantryException a
eres