diff --git a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Queries.hs b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Queries.hs index 033efb8655..936dd91cdf 100644 --- a/codebase2/codebase-sqlite/U/Codebase/Sqlite/Queries.hs +++ b/codebase2/codebase-sqlite/U/Codebase/Sqlite/Queries.hs @@ -228,6 +228,7 @@ module U.Codebase.Sqlite.Queries expectEntity, syncToTempEntity, insertTempEntity, + insertTempEntityV2, saveTempEntityInMain, expectTempEntity, deleteTempEntity, @@ -315,6 +316,7 @@ import Data.Map.NonEmpty qualified as NEMap import Data.Maybe qualified as Maybe import Data.Sequence qualified as Seq import Data.Set qualified as Set +import Data.Set.NonEmpty (NESet) import Data.Text qualified as Text import Data.Text.Encoding qualified as Text import Data.Text.Lazy qualified as Text.Lazy @@ -532,23 +534,18 @@ countWatches = queryOneCol [sql| SELECT COUNT(*) FROM watch |] saveHash :: Hash32 -> Transaction HashId saveHash hash = do - execute - [sql| - INSERT INTO hash (base32) VALUES (:hash) - ON CONFLICT DO NOTHING - |] - expectHashId hash + loadHashId hash >>= \case + Just h -> pure h + Nothing -> do + queryOneCol + [sql| + INSERT INTO hash (base32) VALUES (:hash) + RETURNING id + |] saveHashes :: Traversable f => f Hash32 -> Transaction (f HashId) saveHashes hashes = do - for_ hashes \hash -> - execute - [sql| - INSERT INTO hash (base32) - VALUES (:hash) - ON CONFLICT DO NOTHING - |] - traverse expectHashId hashes + for hashes saveHash saveHashHash :: Hash -> Transaction HashId saveHashHash = saveHash . Hash32.fromHash @@ -623,13 +620,15 @@ expectBranchHashForCausalHash ch = do saveText :: Text -> Transaction TextId saveText t = do - execute - [sql| - INSERT INTO text (text) - VALUES (:t) - ON CONFLICT DO NOTHING - |] - expectTextId t + loadTextId t >>= \case + Just h -> pure h + Nothing -> do + queryOneCol + [sql| + INSERT INTO text (text) + VALUES (:t) + RETURNING id + |] saveTexts :: Traversable f => f Text -> Transaction (f TextId) saveTexts = @@ -686,7 +685,7 @@ saveObject :: ObjectType -> ByteString -> Transaction ObjectId -saveObject hh h t blob = do +saveObject _hh h t blob = do execute [sql| INSERT INTO object (primary_hash_id, type_id, bytes) @@ -697,9 +696,9 @@ saveObject hh h t blob = do saveHashObject h oId 2 -- todo: remove this from here, and add it to other relevant places once there are v1 and v2 hashes rowsModified >>= \case 0 -> pure () - _ -> do - hash <- expectHash32 h - tryMoveTempEntityDependents hh hash + _ -> pure () + -- hash <- expectHash32 h + -- tryMoveTempEntityDependents hh hash pure oId expectObject :: SqliteExceptionReason e => ObjectId -> (ByteString -> Either e a) -> Transaction a @@ -957,7 +956,7 @@ saveCausal :: BranchHashId -> [CausalHashId] -> Transaction () -saveCausal hh self value parents = do +saveCausal _hh self value parents = do execute [sql| INSERT INTO causal (self_hash_id, value_hash_id) @@ -973,15 +972,15 @@ saveCausal hh self value parents = do INSERT INTO causal_parent (causal_id, parent_id) VALUES (:self, :parent) |] - flushCausalDependents hh self + -- flushCausalDependents hh self -flushCausalDependents :: +_flushCausalDependents :: HashHandle -> CausalHashId -> Transaction () -flushCausalDependents hh chId = do +_flushCausalDependents hh chId = do hash <- expectHash32 (unCausalHashId chId) - tryMoveTempEntityDependents hh hash + _tryMoveTempEntityDependents hh hash -- | `tryMoveTempEntityDependents #foo` does this: -- 0. Precondition: We just inserted object #foo. @@ -989,11 +988,11 @@ flushCausalDependents hh chId = do -- 2. Delete #foo as dependency from temp_entity_missing_dependency. e.g. (#bar, #foo), (#baz, #foo) -- 3. For each like #bar and #baz with no more rows in temp_entity_missing_dependency, -- insert_entity them. -tryMoveTempEntityDependents :: +_tryMoveTempEntityDependents :: HashHandle -> Hash32 -> Transaction () -tryMoveTempEntityDependents hh dependency = do +_tryMoveTempEntityDependents hh dependency = do dependents <- queryListCol [sql| @@ -2993,6 +2992,35 @@ insertTempEntity entityHash entity missingDependencies = do entityType = Entity.entityType entity +-- | Insert a new `temp_entity` row, and its associated 1+ `temp_entity_missing_dependency` rows. +-- +-- Preconditions: +-- 1. The entity does not already exist in "main" storage (`object` / `causal`) +-- 2. The entity does not already exist in `temp_entity`. +insertTempEntityV2 :: Hash32 -> TempEntity -> NESet Hash32 -> Transaction () +insertTempEntityV2 entityHash entity missingDependencies = do + execute + [sql| + INSERT INTO temp_entity (hash, blob, type_id) + VALUES (:entityHash, :entityBlob, :entityType) + ON CONFLICT DO NOTHING + |] + + for_ missingDependencies \depHash -> + execute + [sql| + INSERT INTO temp_entity_missing_dependency (dependent, dependency) + VALUES (:entityHash, :depHash) + |] + where + entityBlob :: ByteString + entityBlob = + runPutS (Serialization.putTempEntity entity) + + entityType :: TempEntityType + entityType = + Entity.entityType entity + -- | Delete a row from the `temp_entity` table, if it exists. deleteTempEntity :: Hash32 -> Transaction () deleteTempEntity hash = diff --git a/codebase2/codebase-sqlite/sql/001-temp-entity-tables.sql b/codebase2/codebase-sqlite/sql/001-temp-entity-tables.sql index 0ae13812b1..6651d4a6fe 100644 --- a/codebase2/codebase-sqlite/sql/001-temp-entity-tables.sql +++ b/codebase2/codebase-sqlite/sql/001-temp-entity-tables.sql @@ -56,7 +56,8 @@ create table if not exists temp_entity ( create table if not exists temp_entity_missing_dependency ( dependent text not null references temp_entity(hash), dependency text not null, - dependencyJwt text not null, + -- TODO: this is just for testing + dependencyJwt text null, unique (dependent, dependency) ); create index if not exists temp_entity_missing_dependency_ix_dependent on temp_entity_missing_dependency (dependent); diff --git a/lib/unison-sqlite/package.yaml b/lib/unison-sqlite/package.yaml index 84d0201eab..b90bd2aa57 100644 --- a/lib/unison-sqlite/package.yaml +++ b/lib/unison-sqlite/package.yaml @@ -9,6 +9,7 @@ library: dependencies: - base + - containers - direct-sqlite - megaparsec - pretty-simple diff --git a/lib/unison-sqlite/src/Unison/Sqlite/Connection.hs b/lib/unison-sqlite/src/Unison/Sqlite/Connection.hs index 48167980db..726cac860e 100644 --- a/lib/unison-sqlite/src/Unison/Sqlite/Connection.hs +++ b/lib/unison-sqlite/src/Unison/Sqlite/Connection.hs @@ -58,6 +58,7 @@ module Unison.Sqlite.Connection ) where +import Data.Map qualified as Map import Database.SQLite.Simple qualified as Sqlite import Database.SQLite.Simple.FromField qualified as Sqlite import Database.SQLite3 qualified as Direct.Sqlite @@ -71,7 +72,10 @@ import Unison.Sqlite.Connection.Internal (Connection (..)) import Unison.Sqlite.Exception import Unison.Sqlite.Sql (Sql (..)) import Unison.Sqlite.Sql qualified as Sql +import UnliftIO (atomically) import UnliftIO.Exception +import UnliftIO.STM (readTVar) +import UnliftIO.STM qualified as STM -- | Perform an action with a connection to a SQLite database. -- @@ -103,19 +107,47 @@ openConnection name file = do Just "" -> file _ -> "file:" <> file <> "?mode=ro" conn0 <- Sqlite.open sqliteURI `catch` rethrowAsSqliteConnectException name file - let conn = Connection {conn = conn0, file, name} + statementCache <- STM.newTVarIO Map.empty + let conn = Connection {conn = conn0, file, name, statementCache} execute conn [Sql.sql| PRAGMA foreign_keys = ON |] execute conn [Sql.sql| PRAGMA busy_timeout = 60000 |] + execute conn [Sql.sql| PRAGMA synchronous = normal |] + execute conn [Sql.sql| PRAGMA journal_size_limit = 6144000 |] + execute conn [Sql.sql| PRAGMA cache_size = -64000 |] + execute conn [Sql.sql| PRAGMA temp_store = 2 |] + pure conn -- Close a connection opened with 'openConnection'. closeConnection :: Connection -> IO () -closeConnection (Connection _ _ conn) = +closeConnection conn@(Connection {conn = conn0}) = do -- FIXME if this throws an exception, it won't be under `SomeSqliteException` -- Possible fixes: -- 1. Add close exception to the hierarchy, e.g. `SqliteCloseException` -- 2. Always ignore exceptions thrown by `close` (Mitchell prefers this one) - Sqlite.close conn + closeAllStatements conn + Sqlite.close conn0 + +withStatement :: Connection -> Text -> (Sqlite.Statement -> IO a) -> IO a +withStatement conn sql action = do + bracket (prepareStatement conn sql) Sqlite.reset action + where + prepareStatement :: Connection -> Text -> IO Sqlite.Statement + prepareStatement Connection {conn, statementCache} sql = do + cached <- atomically $ do + cache <- STM.readTVar statementCache + pure $ Map.lookup sql cache + case cached of + Just stmt -> pure stmt + Nothing -> do + stmt <- Sqlite.openStatement conn (coerce @Text @Sqlite.Query sql) + atomically $ STM.modifyTVar statementCache (Map.insert sql stmt) + pure stmt + +closeAllStatements :: Connection -> IO () +closeAllStatements Connection {statementCache} = do + cache <- atomically $ readTVar statementCache + for_ cache Sqlite.closeStatement -- An internal type, for making prettier debug logs @@ -152,7 +184,7 @@ logQuery (Sql sql params) result = -- Without results execute :: (HasCallStack) => Connection -> Sql -> IO () -execute conn@(Connection _ _ conn0) sql@(Sql s params) = do +execute conn sql@(Sql s params) = do logQuery sql Nothing doExecute `catch` \(exception :: Sqlite.SQLError) -> throwSqliteQueryException @@ -163,16 +195,16 @@ execute conn@(Connection _ _ conn0) sql@(Sql s params) = do } where doExecute :: IO () - doExecute = - Sqlite.withStatement conn0 (coerce s) \(Sqlite.Statement statement) -> do - bindParameters statement params - void (Direct.Sqlite.step statement) + doExecute = do + withStatement conn s \statement -> do + bindParameters (coerce statement) params + void (Direct.Sqlite.step $ coerce statement) -- | Execute one or more semicolon-delimited statements. -- -- This function does not support parameters, and is mostly useful for executing DDL and migrations. executeStatements :: (HasCallStack) => Connection -> Text -> IO () -executeStatements conn@(Connection _ _ (Sqlite.Connection database _tempNameCounter)) sql = do +executeStatements conn@(Connection {conn = Sqlite.Connection database _tempNameCounter}) sql = do logQuery (Sql sql []) Nothing Direct.Sqlite.exec database sql `catch` \(exception :: Sqlite.SQLError) -> throwSqliteQueryException @@ -185,7 +217,7 @@ executeStatements conn@(Connection _ _ (Sqlite.Connection database _tempNameCoun -- With results, without checks queryStreamRow :: (HasCallStack, Sqlite.FromRow a) => Connection -> Sql -> (IO (Maybe a) -> IO r) -> IO r -queryStreamRow conn@(Connection _ _ conn0) sql@(Sql s params) callback = +queryStreamRow conn sql@(Sql s params) callback = run `catch` \(exception :: Sqlite.SQLError) -> throwSqliteQueryException SqliteQueryExceptionInfo @@ -194,8 +226,8 @@ queryStreamRow conn@(Connection _ _ conn0) sql@(Sql s params) callback = sql } where - run = - bracket (Sqlite.openStatement conn0 (coerce s)) Sqlite.closeStatement \statement -> do + run = do + withStatement conn s \statement -> do Sqlite.bind statement params callback (Sqlite.nextRow statement) @@ -213,7 +245,7 @@ queryStreamCol = queryStreamRow queryListRow :: forall a. (Sqlite.FromRow a, HasCallStack) => Connection -> Sql -> IO [a] -queryListRow conn@(Connection _ _ conn0) sql@(Sql s params) = do +queryListRow conn sql@(Sql s params) = do result <- doQuery `catch` \(exception :: Sqlite.SQLError) -> @@ -228,7 +260,7 @@ queryListRow conn@(Connection _ _ conn0) sql@(Sql s params) = do where doQuery :: IO [a] doQuery = - Sqlite.withStatement conn0 (coerce s) \statement -> do + withStatement conn (coerce s) \statement -> do bindParameters (coerce statement) params let loop :: [a] -> IO [a] loop rows = @@ -347,7 +379,7 @@ queryOneColCheck conn s check = -- Rows modified rowsModified :: Connection -> IO Int -rowsModified (Connection _ _ conn) = +rowsModified (Connection {conn}) = Sqlite.changes conn -- Vacuum diff --git a/lib/unison-sqlite/src/Unison/Sqlite/Connection/Internal.hs b/lib/unison-sqlite/src/Unison/Sqlite/Connection/Internal.hs index 5f80151f94..579c37cfb9 100644 --- a/lib/unison-sqlite/src/Unison/Sqlite/Connection/Internal.hs +++ b/lib/unison-sqlite/src/Unison/Sqlite/Connection/Internal.hs @@ -3,15 +3,19 @@ module Unison.Sqlite.Connection.Internal ) where +import Data.Map (Map) +import Data.Text (Text) import Database.SQLite.Simple qualified as Sqlite +import UnliftIO.STM (TVar) -- | A /non-thread safe/ connection to a SQLite database. data Connection = Connection { name :: String, file :: FilePath, - conn :: Sqlite.Connection + conn :: Sqlite.Connection, + statementCache :: TVar (Map Text Sqlite.Statement) } instance Show Connection where - show (Connection name file _conn) = + show (Connection name file _conn _statementCache) = "Connection { name = " ++ show name ++ ", file = " ++ show file ++ " }" diff --git a/lib/unison-sqlite/unison-sqlite.cabal b/lib/unison-sqlite/unison-sqlite.cabal index 28ea0f7c4f..329a05c5d8 100644 --- a/lib/unison-sqlite/unison-sqlite.cabal +++ b/lib/unison-sqlite/unison-sqlite.cabal @@ -1,6 +1,6 @@ cabal-version: 1.12 --- This file has been generated from package.yaml by hpack version 0.36.0. +-- This file has been generated from package.yaml by hpack version 0.37.0. -- -- see: https://github.com/sol/hpack @@ -64,6 +64,7 @@ library ghc-options: -Wall build-depends: base + , containers , direct-sqlite , megaparsec , pretty-simple diff --git a/unison-cli/src/Unison/Share/Sync/Util.hs b/unison-cli/src/Unison/Share/Sync/Util.hs new file mode 100644 index 0000000000..39eeb2cede --- /dev/null +++ b/unison-cli/src/Unison/Share/Sync/Util.hs @@ -0,0 +1,42 @@ +module Unison.Share.Sync.Util + ( BailT (..), + MonadBail (..), + runBailT, + mapBailT, + withError, + ) +where + +import Control.Monad.Reader (MonadReader (..), ReaderT (..), mapReaderT, withReaderT) +import Data.Data (Typeable) +import UnliftIO qualified as IO + +newtype Handler e = Handler {runHandler :: forall x. e -> IO x} + +newtype BailT e m a = BailT {unErrGroupT :: ReaderT (Handler e) m a} + deriving newtype (Functor, Applicative, Monad, IO.MonadUnliftIO, IO.MonadIO) + +newtype ExceptionWrapper e = ExceptionWrapper {unException :: e} + +instance Show (ExceptionWrapper e) where + show (ExceptionWrapper _) = "ExceptionWrapper<>" + +instance (Typeable e) => IO.Exception (ExceptionWrapper e) + +class MonadBail e m where + bail :: e -> m a + +mapBailT :: (Monad n) => (m a -> n b) -> BailT e m a -> BailT e n b +mapBailT f (BailT m) = BailT $ mapReaderT f $ m + +withError :: (Monad m) => (e' -> e) -> BailT e' m a -> BailT e m a +withError f (BailT m) = BailT $ withReaderT (\h -> Handler $ runHandler h . f) m + +instance (IO.MonadUnliftIO m, Typeable e) => MonadBail e (BailT e m) where + bail e = do + handler <- BailT ask + BailT $ IO.liftIO $ runHandler handler e + +runBailT :: (IO.MonadUnliftIO m, Typeable e) => BailT e m a -> (e -> m a) -> m a +runBailT (BailT m) handler = do + IO.handle (handler . unException) $ runReaderT m (Handler (IO.throwIO . ExceptionWrapper)) diff --git a/unison-share-api/src/Unison/Sync/EntityValidation.hs b/unison-share-api/src/Unison/Sync/EntityValidation.hs index 02ad6d8330..4e8c854407 100644 --- a/unison-share-api/src/Unison/Sync/EntityValidation.hs +++ b/unison-share-api/src/Unison/Sync/EntityValidation.hs @@ -4,6 +4,7 @@ -- | Module for validating hashes of entities received/sent via sync. module Unison.Sync.EntityValidation ( validateEntity, + validateTempEntity, ) where @@ -21,6 +22,7 @@ import U.Codebase.Sqlite.HashHandle qualified as HH import U.Codebase.Sqlite.Orphans () import U.Codebase.Sqlite.Patch.Format qualified as PatchFormat import U.Codebase.Sqlite.Serialization qualified as Serialization +import U.Codebase.Sqlite.TempEntity (TempEntity) import U.Codebase.Sqlite.Term.Format qualified as TermFormat import U.Codebase.Sqlite.V2.HashHandle (v2HashHandle) import Unison.Hash (Hash) @@ -35,7 +37,13 @@ import Unison.Sync.Types qualified as Share -- We should add more validation as more entities are shared. validateEntity :: Hash32 -> Share.Entity Text Hash32 Hash32 -> Maybe Share.EntityValidationError validateEntity expectedHash32 entity = do - case Share.entityToTempEntity id entity of + validateTempEntity expectedHash32 $ Share.entityToTempEntity id entity + +-- | Note: We currently only validate Namespace hashes. +-- We should add more validation as more entities are shared. +validateTempEntity :: Hash32 -> TempEntity -> Maybe Share.EntityValidationError +validateTempEntity expectedHash32 tempEntity = do + case tempEntity of Entity.TC (TermFormat.SyncTerm localComp) -> do validateTerm expectedHash localComp Entity.DC (DeclFormat.SyncDecl localComp) -> do diff --git a/unison-share-api/src/Unison/SyncV2/API.hs b/unison-share-api/src/Unison/SyncV2/API.hs new file mode 100644 index 0000000000..71ea8693d3 --- /dev/null +++ b/unison-share-api/src/Unison/SyncV2/API.hs @@ -0,0 +1,29 @@ +{-# LANGUAGE DataKinds #-} + +module Unison.SyncV2.API + ( API, + api, + Routes (..), + ) +where + +import Data.Proxy +import GHC.Generics (Generic) +import Servant.API +import Unison.SyncV2.Types +import Unison.Util.Servant.CBOR (CBOR) + +api :: Proxy API +api = Proxy + +type API = NamedRoutes Routes + +type DownloadEntitiesStream = + -- | The causal hash the client needs. The server should provide it and all of its dependencies + ReqBody '[CBOR, JSON] DownloadEntitiesRequest + :> StreamPost NetstringFraming CBOR (SourceIO DownloadEntitiesChunk) + +data Routes mode = Routes + { downloadEntitiesStream :: mode :- "entities" :> "download" :> DownloadEntitiesStream + } + deriving stock (Generic) diff --git a/unison-share-api/src/Unison/SyncV2/Types.hs b/unison-share-api/src/Unison/SyncV2/Types.hs new file mode 100644 index 0000000000..04ce112d8f --- /dev/null +++ b/unison-share-api/src/Unison/SyncV2/Types.hs @@ -0,0 +1,317 @@ +module Unison.SyncV2.Types + ( DownloadEntitiesRequest (..), + DownloadEntitiesChunk (..), + EntityChunk (..), + ErrorChunk (..), + StreamInitInfo (..), + SyncError (..), + DownloadEntitiesError (..), + CBORBytes (..), + EntityKind (..), + serialiseCBORBytes, + deserialiseOrFailCBORBytes, + UploadEntitiesRequest (..), + BranchRef (..), + PullError (..), + EntitySorting (..), + Version (..), + ) +where + +import Codec.CBOR.Encoding qualified as CBOR +import Codec.Serialise (Serialise (..)) +import Codec.Serialise qualified as CBOR +import Codec.Serialise.Decoding qualified as CBOR +import Control.Exception (Exception) +import Data.Aeson (FromJSON (..), ToJSON (..), object, withObject, (.:), (.=)) +import Data.Map (Map) +import Data.Map qualified as Map +import Data.Set (Set) +import Data.Text (Text) +import Data.Text qualified as Text +import Data.Word (Word16, Word64) +import U.Codebase.HashTags (CausalHash) +import U.Codebase.Sqlite.TempEntity (TempEntity) +import Unison.Core.Project (ProjectAndBranch (..), ProjectBranchName, ProjectName) +import Unison.Debug qualified as Debug +import Unison.Hash32 (Hash32) +import Unison.Prelude (From (..)) +import Unison.Server.Orphans () +import Unison.Share.API.Hash (HashJWT) +import Unison.Sync.Types qualified as SyncV1 +import Unison.Util.Servant.CBOR + +newtype BranchRef = BranchRef {unBranchRef :: Text} + deriving (Serialise, Eq, Show, Ord, ToJSON, FromJSON) via Text + +instance From (ProjectAndBranch ProjectName ProjectBranchName) BranchRef where + from pab = BranchRef $ from pab + +data GetCausalHashErrorTag + = GetCausalHashNoReadPermissionTag + | GetCausalHashUserNotFoundTag + | GetCausalHashInvalidBranchRefTag + deriving stock (Show, Eq, Ord) + +instance Serialise GetCausalHashErrorTag where + encode GetCausalHashNoReadPermissionTag = CBOR.encodeWord8 0 + encode GetCausalHashUserNotFoundTag = CBOR.encodeWord8 1 + encode GetCausalHashInvalidBranchRefTag = CBOR.encodeWord8 2 + decode = do + tag <- CBOR.decodeWord8 + case tag of + 0 -> pure GetCausalHashNoReadPermissionTag + 1 -> pure GetCausalHashUserNotFoundTag + 2 -> pure GetCausalHashInvalidBranchRefTag + _ -> fail "invalid tag" + +data DownloadEntitiesRequest = DownloadEntitiesRequest + { causalHash :: HashJWT, + branchRef :: BranchRef, + knownHashes :: Set Hash32 + } + +instance Serialise DownloadEntitiesRequest where + encode (DownloadEntitiesRequest {causalHash, branchRef, knownHashes}) = + encode causalHash <> encode branchRef <> encode knownHashes + decode = DownloadEntitiesRequest <$> decode <*> decode <*> decode + +instance FromJSON DownloadEntitiesRequest where + parseJSON = withObject "DownloadEntitiesRequest" $ \o -> do + causalHash <- o .: "causalHash" + branchRef <- o .: "branchRef" + knownHashes <- o .: "knownHashes" + pure DownloadEntitiesRequest {causalHash, branchRef, knownHashes} + +instance ToJSON DownloadEntitiesRequest where + toJSON (DownloadEntitiesRequest {causalHash, branchRef, knownHashes}) = + object + [ "causalHash" .= causalHash, + "branchRef" .= branchRef, + "knownHashes" .= knownHashes + ] + +data DownloadEntitiesError + = DownloadEntitiesNoReadPermission BranchRef + | -- | msg, branchRef + DownloadEntitiesInvalidBranchRef Text BranchRef + | -- | userHandle + DownloadEntitiesUserNotFound Text + | -- | project shorthand + DownloadEntitiesProjectNotFound Text + | DownloadEntitiesEntityValidationFailure SyncV1.EntityValidationError + deriving stock (Eq, Show, Ord) + +data DownloadEntitiesErrorTag + = NoReadPermissionTag + | InvalidBranchRefTag + | UserNotFoundTag + | ProjectNotFoundTag + | EntityValidationFailureTag + deriving stock (Eq, Show, Ord) + +instance Serialise DownloadEntitiesErrorTag where + encode = \case + NoReadPermissionTag -> CBOR.encodeWord8 0 + InvalidBranchRefTag -> CBOR.encodeWord8 1 + UserNotFoundTag -> CBOR.encodeWord8 2 + ProjectNotFoundTag -> CBOR.encodeWord8 3 + EntityValidationFailureTag -> CBOR.encodeWord8 4 + decode = do + tag <- CBOR.decodeWord8 + case tag of + 0 -> pure NoReadPermissionTag + 1 -> pure InvalidBranchRefTag + 2 -> pure UserNotFoundTag + 3 -> pure ProjectNotFoundTag + 4 -> pure EntityValidationFailureTag + _ -> fail "invalid tag" + +instance Serialise DownloadEntitiesError where + encode = \case + DownloadEntitiesNoReadPermission branchRef -> CBOR.encode NoReadPermissionTag <> CBOR.encode branchRef + DownloadEntitiesInvalidBranchRef msg branchRef -> CBOR.encode InvalidBranchRefTag <> CBOR.encode (msg, branchRef) + DownloadEntitiesUserNotFound userHandle -> CBOR.encode UserNotFoundTag <> CBOR.encode userHandle + DownloadEntitiesProjectNotFound projectShorthand -> CBOR.encode ProjectNotFoundTag <> CBOR.encode projectShorthand + DownloadEntitiesEntityValidationFailure err -> CBOR.encode EntityValidationFailureTag <> CBOR.encode err + + decode = do + tag <- CBOR.decode + case tag of + NoReadPermissionTag -> DownloadEntitiesNoReadPermission <$> CBOR.decode + InvalidBranchRefTag -> uncurry DownloadEntitiesInvalidBranchRef <$> CBOR.decode + UserNotFoundTag -> DownloadEntitiesUserNotFound <$> CBOR.decode + ProjectNotFoundTag -> DownloadEntitiesProjectNotFound <$> CBOR.decode + EntityValidationFailureTag -> DownloadEntitiesEntityValidationFailure <$> CBOR.decode + +data EntitySorting + = -- all dependencies of an entity are guaranteed to be sent before the entity itself + DependenciesFirst + | -- no guarantees. + Unsorted + deriving (Show, Eq, Ord) + +instance Serialise EntitySorting where + encode = \case + DependenciesFirst -> CBOR.encodeWord8 0 + Unsorted -> CBOR.encodeWord8 1 + decode = do + tag <- CBOR.decodeWord8 + case tag of + 0 -> pure DependenciesFirst + 1 -> pure Unsorted + _ -> fail "invalid tag" + +newtype Version = Version Word16 + deriving stock (Show) + deriving newtype (Eq, Ord, Serialise) + +data StreamInitInfo + = StreamInitInfo + { version :: Version, + entitySorting :: EntitySorting, + numEntities :: Maybe Word64, + rootCausalHash :: Hash32, + rootBranchRef :: Maybe BranchRef + } + deriving (Show, Eq, Ord) + +decodeMapKey :: (Serialise r) => Text -> Map Text UnknownCBORBytes -> CBOR.Decoder s r +decodeMapKey k m = + optionalDecodeMapKey k m >>= \case + Nothing -> fail $ "Expected key: " <> Text.unpack k + Just x -> pure x + +optionalDecodeMapKey :: (Serialise r) => Text -> Map Text UnknownCBORBytes -> CBOR.Decoder s (Maybe r) +optionalDecodeMapKey k m = + case Map.lookup k m of + Nothing -> pure Nothing + Just bs -> Just <$> decodeUnknownCBORBytes bs + +-- | Serialised as a map to allow for future expansion +instance Serialise StreamInitInfo where + encode (StreamInitInfo {version, entitySorting, numEntities, rootCausalHash, rootBranchRef}) = + CBOR.encode + ( Map.fromList $ + [ ("v" :: Text, serialiseUnknownCBORBytes version), + ("es", serialiseUnknownCBORBytes entitySorting), + ("rc", serialiseUnknownCBORBytes rootCausalHash) + ] + <> maybe [] (\ne -> [("ne", serialiseUnknownCBORBytes ne)]) numEntities + <> maybe [] (\br -> [("br", serialiseUnknownCBORBytes br)]) rootBranchRef + ) + decode = do + Debug.debugLogM Debug.Temp "Decoding StreamInitInfo" + Debug.debugLogM Debug.Temp "Decoding Map" + m <- CBOR.decode + Debug.debugLogM Debug.Temp "Decoding Version" + version <- decodeMapKey "v" m + Debug.debugLogM Debug.Temp "Decoding Entity Sorting" + entitySorting <- decodeMapKey "es" m + Debug.debugLogM Debug.Temp "Decoding Number of Entities" + numEntities <- (optionalDecodeMapKey "ne" m) + Debug.debugLogM Debug.Temp "Decoding Root Causal Hash" + rootCausalHash <- decodeMapKey "rc" m + Debug.debugLogM Debug.Temp "Decoding Branch Ref" + rootBranchRef <- optionalDecodeMapKey "br" m + pure StreamInitInfo {version, entitySorting, numEntities, rootCausalHash, rootBranchRef} + +data EntityChunk = EntityChunk + { hash :: Hash32, + entityCBOR :: CBORBytes TempEntity + } + deriving (Show, Eq, Ord) + +instance Serialise EntityChunk where + encode (EntityChunk {hash, entityCBOR}) = CBOR.encode hash <> CBOR.encode entityCBOR + decode = EntityChunk <$> CBOR.decode <*> CBOR.decode + +data ErrorChunk = ErrorChunk + { err :: DownloadEntitiesError + } + deriving (Show, Eq, Ord) + +instance Serialise ErrorChunk where + encode (ErrorChunk {err}) = CBOR.encode err + decode = ErrorChunk <$> CBOR.decode + +-- | A chunk of the download entities response stream. +data DownloadEntitiesChunk + = InitialC StreamInitInfo + | EntityC EntityChunk + | ErrorC ErrorChunk + deriving (Show, Eq, Ord) + +data DownloadEntitiesChunkTag = InitialChunkTag | EntityChunkTag | ErrorChunkTag + deriving (Show, Eq, Ord) + +instance Serialise DownloadEntitiesChunkTag where + encode InitialChunkTag = CBOR.encodeWord8 0 + encode EntityChunkTag = CBOR.encodeWord8 1 + encode ErrorChunkTag = CBOR.encodeWord8 2 + decode = do + tag <- CBOR.decodeWord8 + case tag of + 0 -> pure InitialChunkTag + 1 -> pure EntityChunkTag + 2 -> pure ErrorChunkTag + _ -> fail "invalid tag" + +instance Serialise DownloadEntitiesChunk where + encode (EntityC ec) = encode EntityChunkTag <> CBOR.encode ec + encode (ErrorC ec) = encode ErrorChunkTag <> CBOR.encode ec + encode (InitialC ic) = encode InitialChunkTag <> encode ic + decode = do + tag <- decode + case tag of + InitialChunkTag -> InitialC <$> decode + EntityChunkTag -> EntityC <$> decode + ErrorChunkTag -> ErrorC <$> decode + +-- TODO +data UploadEntitiesRequest = UploadEntitiesRequest + +instance Serialise UploadEntitiesRequest where + encode _ = mempty + decode = pure UploadEntitiesRequest + +-- | An error occurred while pulling code from Unison Share. +data PullError + = PullError'DownloadEntities DownloadEntitiesError + | PullError'Sync SyncError + deriving stock (Show, Eq, Ord) + deriving anyclass (Exception) + +data SyncError + = SyncErrorExpectedResultNotInMain CausalHash + | SyncErrorDeserializationFailure CBOR.DeserialiseFailure + | SyncErrorMissingInitialChunk + | SyncErrorMisplacedInitialChunk + | SyncErrorStreamFailure Text + | SyncErrorUnsupportedVersion Version + deriving stock (Show, Eq, Ord) + +data EntityKind + = CausalEntity + | NamespaceEntity + | TermEntity + | TypeEntity + | PatchEntity + deriving (Show, Eq, Ord) + +instance Serialise EntityKind where + encode = \case + CausalEntity -> CBOR.encodeWord8 0 + NamespaceEntity -> CBOR.encodeWord8 1 + TermEntity -> CBOR.encodeWord8 2 + TypeEntity -> CBOR.encodeWord8 3 + PatchEntity -> CBOR.encodeWord8 4 + decode = do + tag <- CBOR.decodeWord8 + case tag of + 0 -> pure CausalEntity + 1 -> pure NamespaceEntity + 2 -> pure TermEntity + 3 -> pure TypeEntity + 4 -> pure PatchEntity + _ -> fail "invalid tag" diff --git a/unison-share-api/src/Unison/Util/Servant/CBOR.hs b/unison-share-api/src/Unison/Util/Servant/CBOR.hs new file mode 100644 index 0000000000..18fd94904c --- /dev/null +++ b/unison-share-api/src/Unison/Util/Servant/CBOR.hs @@ -0,0 +1,88 @@ +-- | Servant configuration for the CBOR media type +-- +-- Adapted from https://hackage.haskell.org/package/servant-serialization-0.3/docs/Servant-API-ContentTypes-SerialiseCBOR.html via MIT license +module Unison.Util.Servant.CBOR + ( CBOR, + UnknownCBORBytes, + CBORBytes (..), + deserialiseOrFailCBORBytes, + serialiseCBORBytes, + decodeCBORBytes, + decodeUnknownCBORBytes, + serialiseUnknownCBORBytes, + ) +where + +import Codec.CBOR.Read (DeserialiseFailure (..)) +import Codec.Serialise (Serialise, deserialiseOrFail, serialise) +import Codec.Serialise qualified as CBOR +import Codec.Serialise.Decoding qualified as CBOR +import Data.ByteString.Lazy qualified as BL +import Data.List.NonEmpty qualified as NonEmpty +import Network.HTTP.Media.MediaType qualified as MediaType +import Servant + +-- | Content-type for encoding and decoding objects as their CBOR representations +data CBOR + +-- | Mime-type for CBOR and additional ones using the word "hackage" and the +-- name of the package "serialise". +instance Accept CBOR where + contentTypes Proxy = + NonEmpty.singleton ("application" MediaType.// "cbor") + +-- | +-- +-- >>> mimeRender (Proxy :: Proxy CBOR) ("Hello" :: String) +-- "eHello" +instance (Serialise a) => MimeRender CBOR a where + mimeRender Proxy = serialise + +-- | +-- +-- >>> let bsl = mimeRender (Proxy :: Proxy CBOR) (3.14 :: Float) +-- >>> mimeUnrender (Proxy :: Proxy CBOR) bsl :: Either String Float +-- Right 3.14 +-- +-- >>> mimeUnrender (Proxy :: Proxy CBOR) (bsl <> "trailing garbage") :: Either String Float +-- Right 3.14 +-- +-- >>> mimeUnrender (Proxy :: Proxy CBOR) ("preceding garbage" <> bsl) :: Either String Float +-- Left "Codec.Serialise.deserialiseOrFail: expected float at byte-offset 0" +instance (Serialise a) => MimeUnrender CBOR a where + mimeUnrender Proxy = mapLeft prettyErr . deserialiseOrFail + where + mapLeft f = either (Left . f) Right + prettyErr (DeserialiseFailure offset err) = + "Codec.Serialise.deserialiseOrFail: " ++ err ++ " at byte-offset " ++ show offset + +-- | Wrapper for CBOR data that has already been serialized. +-- In our case, we use this because we may load pre-serialized CBOR directly from the database, +-- but it's also useful in allowing us to more quickly seek through a CBOR stream, since we only need to decode the CBOR when/if we actually need to use it, and can skip past it using a byte offset otherwise. +-- +-- The 't' phantom type is the type of the data encoded in the bytestring. +newtype CBORBytes t = CBORBytes BL.ByteString + deriving (Serialise) via (BL.ByteString) + deriving (Eq, Show, Ord) + +-- | Deserialize a 'CBORBytes' value into its tagged type, throwing an error if the deserialization fails. +deserialiseOrFailCBORBytes :: (Serialise t) => CBORBytes t -> Either CBOR.DeserialiseFailure t +deserialiseOrFailCBORBytes (CBORBytes bs) = CBOR.deserialiseOrFail bs + +decodeCBORBytes :: (Serialise t) => CBORBytes t -> CBOR.Decoder s t +decodeCBORBytes (CBORBytes bs) = decodeUnknownCBORBytes (CBORBytes bs) + +decodeUnknownCBORBytes :: (Serialise t) => UnknownCBORBytes -> CBOR.Decoder s t +decodeUnknownCBORBytes (CBORBytes bs) = case deserialiseOrFailCBORBytes (CBORBytes bs) of + Left err -> fail (show err) + Right t -> pure t + +serialiseCBORBytes :: (Serialise t) => t -> CBORBytes t +serialiseCBORBytes = CBORBytes . CBOR.serialise + +serialiseUnknownCBORBytes :: (Serialise t) => t -> UnknownCBORBytes +serialiseUnknownCBORBytes = CBORBytes . CBOR.serialise + +data Unknown + +type UnknownCBORBytes = CBORBytes Unknown