diff --git a/plugins/tactics/src/Ide/Plugin/Tactic.hs b/plugins/tactics/src/Ide/Plugin/Tactic.hs index 25874bf242..c5744e811a 100644 --- a/plugins/tactics/src/Ide/Plugin/Tactic.hs +++ b/plugins/tactics/src/Ide/Plugin/Tactic.hs @@ -22,9 +22,11 @@ import Control.Monad.Trans import Control.Monad.Trans.Maybe import Data.Aeson import Data.Coerce +import Data.Functor ((<&>)) import Data.Generics.Aliases (mkQ) import Data.Generics.Schemes (everything) import Data.List +import Data.Map (Map) import qualified Data.Map as M import Data.Maybe import Data.Monoid @@ -214,7 +216,7 @@ filterBindingType filterBindingType p tp dflags plId uri range jdg = let hy = jHypothesis jdg g = jGoal jdg - in fmap join $ for (M.toList hy) $ \(occ, CType ty) -> + in fmap join $ for (M.toList hy) $ \(occ, hi_type -> CType ty) -> case p (unCType g) ty of True -> tp occ ty dflags plId uri range jdg False -> pure [] @@ -264,23 +266,28 @@ judgementForHole state nfp range = do (mapMaybe (sequenceA . (occName *** coerce)) $ getDefiningBindings binds rss) tcg - hyps = hypothesisFromBindings rss binds - ambient = M.fromList $ contextMethodHypothesis ctx + top_provs = getRhsPosVals rss tcs + local_hy = spliceProvenance top_provs + $ hypothesisFromBindings rss binds + cls_hy = contextMethodHypothesis ctx pure ( resulting_range , mkFirstJudgement - hyps - ambient + (local_hy <> cls_hy) (isRhsHole rss tcs) - (maybe - mempty - (uncurry M.singleton . fmap pure) - $ getRhsPosVals rss tcs) goal , ctx , dflags ) +spliceProvenance + :: Map OccName Provenance + -> Map OccName (HyInfo a) + -> Map OccName (HyInfo a) +spliceProvenance provs = + M.mapWithKey $ \name hi -> + overProvenance (maybe id const $ M.lookup name provs) hi + tacticCmd :: (OccName -> TacticsM ()) -> CommandFunction TacticParams tacticCmd tac lf state (TacticParams uri range var_name) @@ -334,17 +341,22 @@ isRhsHole rss tcs = everything (||) (mkQ False $ \case ------------------------------------------------------------------------------ -- | Compute top-level position vals of a function -getRhsPosVals :: RealSrcSpan -> TypecheckedSource -> Maybe (OccName, [OccName]) -getRhsPosVals rss tcs = getFirst $ everything (<>) (mkQ mempty $ \case - TopLevelRHS name ps - (L (RealSrcSpan span) -- body with no guards and a single defn - (HsVar _ (L _ hole))) - | containsSpan rss span -- which contains our span - , isHole $ occName hole -- and the span is a hole - -> First $ do - patnames <- traverse getPatName ps - pure (occName name, patnames) - _ -> mempty +getRhsPosVals :: RealSrcSpan -> TypecheckedSource -> Map OccName Provenance +getRhsPosVals rss tcs + = M.fromList + $ join + $ maybeToList + $ getFirst + $ everything (<>) (mkQ mempty $ \case + TopLevelRHS name ps + (L (RealSrcSpan span) -- body with no guards and a single defn + (HsVar _ (L _ hole))) + | containsSpan rss span -- which contains our span + , isHole $ occName hole -- and the span is a hole + -> First $ do + patnames <- traverse getPatName ps + pure $ zip patnames $ [0..] <&> TopLevelArgPrv name + _ -> mempty ) tcs diff --git a/plugins/tactics/src/Ide/Plugin/Tactic/Auto.hs b/plugins/tactics/src/Ide/Plugin/Tactic/Auto.hs index 5ce8605e70..e07aa1dfb2 100644 --- a/plugins/tactics/src/Ide/Plugin/Tactic/Auto.hs +++ b/plugins/tactics/src/Ide/Plugin/Tactic/Auto.hs @@ -23,6 +23,6 @@ auto = do commit knownStrategies . tracing "auto" . localTactic (auto' 4) - . disallowing + . disallowing RecursiveCall $ fmap fst current diff --git a/plugins/tactics/src/Ide/Plugin/Tactic/CodeGen.hs b/plugins/tactics/src/Ide/Plugin/Tactic/CodeGen.hs index db20420ede..ec8caf8a9b 100644 --- a/plugins/tactics/src/Ide/Plugin/Tactic/CodeGen.hs +++ b/plugins/tactics/src/Ide/Plugin/Tactic/CodeGen.hs @@ -43,8 +43,8 @@ useOccName jdg name = ------------------------------------------------------------------------------ -- | Doing recursion incurs a small penalty in the score. -penalizeRecursion :: MonadState TacticState m => m () -penalizeRecursion = modify $ field @"ts_recursion_penality" +~ 1 +countRecursiveCall :: TacticState -> TacticState +countRecursiveCall = field @"ts_recursion_count" +~ 1 ------------------------------------------------------------------------------ @@ -57,14 +57,14 @@ addUnusedTopVals vals = modify $ field @"ts_unused_top_vals" <>~ vals destructMatches :: (DataCon -> Judgement -> Rule) -- ^ How to construct each match - -> ([(OccName, CType)] -> Judgement -> Judgement) - -- ^ How to derive each match judgement + -> Maybe OccName + -- ^ Scrutinee -> CType -- ^ Type being destructed -> Judgement -> RuleM (Trace, [RawMatch]) -destructMatches f f2 t jdg = do - let hy = jHypothesis jdg +destructMatches f scrut t jdg = do + let hy = jEntireHypothesis jdg g = jGoal jdg case splitTyConApp_maybe $ unCType t of Nothing -> throwError $ GoalMismatch "destruct" g @@ -76,11 +76,7 @@ destructMatches f f2 t jdg = do let args = dataConInstOrigArgTys' dc apps names <- mkManyGoodNames hy args let hy' = zip names $ coerce args - dcon_name = nameOccName $ dataConName dc - - let j = f2 hy' - $ withPositionMapping dcon_name names - $ introducingPat hy' + j = introducingPat scrut dc hy' $ withNewGoal g jdg (tr, sg) <- f dc j modify $ withIntroducedVals $ mappend $ S.fromList names @@ -142,14 +138,14 @@ destruct' f term jdg = do let hy = jHypothesis jdg case find ((== term) . fst) $ toList hy of Nothing -> throwError $ UndefinedHypothesis term - Just (_, t) -> do + Just (_, hi_type -> t) -> do useOccName jdg term (tr, ms) <- destructMatches f - (\cs -> setParents term (fmap fst cs) . destructing term) + (Just term) t - jdg + $ disallowing AlreadyDestructed [term] jdg pure ( rose ("destruct " <> show term) $ pure tr , noLoc $ case' (var' term) ms ) @@ -165,7 +161,7 @@ destructLambdaCase' f jdg = do case splitFunTy_maybe (unCType g) of Just (arg, _) | isAlgType arg -> fmap (fmap noLoc $ lambdaCase) <$> - destructMatches f (const id) (CType arg) jdg + destructMatches f Nothing (CType arg) jdg _ -> throwError $ GoalMismatch "destructLambdaCase'" g @@ -178,12 +174,11 @@ buildDataCon -> RuleM (Trace, LHsExpr GhcPs) buildDataCon jdg dc apps = do let args = dataConInstOrigArgTys' dc apps - dcon_name = nameOccName $ dataConName dc (tr, sgs) <- fmap unzipTrace $ traverse ( \(arg, n) -> newSubgoal - . filterSameTypeFromOtherPositions dcon_name n + . filterSameTypeFromOtherPositions dc n . blacklistingDestruct . flip withNewGoal jdg $ CType arg diff --git a/plugins/tactics/src/Ide/Plugin/Tactic/Context.hs b/plugins/tactics/src/Ide/Plugin/Tactic/Context.hs index 1621c36393..8522e0ddc4 100644 --- a/plugins/tactics/src/Ide/Plugin/Tactic/Context.hs +++ b/plugins/tactics/src/Ide/Plugin/Tactic/Context.hs @@ -7,6 +7,8 @@ import Bag import Control.Arrow import Control.Monad.Reader import Data.List +import Data.Map (Map) +import qualified Data.Map as M import Data.Maybe (mapMaybe) import Data.Set (Set) import qualified Data.Set as S @@ -33,9 +35,10 @@ mkContext locals tcg = Context ------------------------------------------------------------------------------ -- | Find all of the class methods that exist from the givens in the context. -contextMethodHypothesis :: Context -> [(OccName, CType)] +contextMethodHypothesis :: Context -> Map OccName (HyInfo CType) contextMethodHypothesis ctx - = excludeForbiddenMethods + = M.fromList + . excludeForbiddenMethods . join . concatMap ( mapMaybe methodHypothesis @@ -51,7 +54,7 @@ contextMethodHypothesis ctx -- | Many operations are defined in typeclasses for performance reasons, rather -- than being a true part of the class. This function filters out those, in -- order to keep our hypothesis space small. -excludeForbiddenMethods :: [(OccName, CType)] -> [(OccName, CType)] +excludeForbiddenMethods :: [(OccName, a)] -> [(OccName, a)] excludeForbiddenMethods = filter (not . flip S.member forbiddenMethods . fst) where forbiddenMethods :: Set OccName diff --git a/plugins/tactics/src/Ide/Plugin/Tactic/Judgements.hs b/plugins/tactics/src/Ide/Plugin/Tactic/Judgements.hs index 3beb40daa4..d4fdc6fa00 100644 --- a/plugins/tactics/src/Ide/Plugin/Tactic/Judgements.hs +++ b/plugins/tactics/src/Ide/Plugin/Tactic/Judgements.hs @@ -1,10 +1,35 @@ -{-# LANGUAGE TupleSections #-} {-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE ViewPatterns #-} -module Ide.Plugin.Tactic.Judgements where +module Ide.Plugin.Tactic.Judgements + ( blacklistingDestruct + , unwhitelistingSplit + , introducingLambda + , introducingRecursively + , introducingPat + , jGoal + , jHypothesis + , jEntireHypothesis + , jPatHypothesis + , substJdg + , unsetIsTopHole + , filterSameTypeFromOtherPositions + , isDestructBlacklisted + , withNewGoal + , jLocalHypothesis + , isSplitWhitelisted + , isPatternMatch + , filterPosition + , isTopHole + , disallowing + , mkFirstJudgement + , hypothesisFromBindings + , isTopLevel + ) where import Control.Lens hiding (Context) import Data.Bool @@ -14,7 +39,9 @@ import Data.Generics.Product (field) import Data.Map (Map) import qualified Data.Map as M import Data.Maybe +import Data.Set (Set) import qualified Data.Set as S +import DataCon (DataCon) import Development.IDE.Spans.LocalBindings import Ide.Plugin.Tactic.Types import OccName @@ -24,30 +51,23 @@ import Type ------------------------------------------------------------------------------ -- | Given a 'SrcSpan' and a 'Bindings', create a hypothesis. -hypothesisFromBindings :: RealSrcSpan -> Bindings -> Map OccName CType +hypothesisFromBindings :: RealSrcSpan -> Bindings -> Map OccName (HyInfo CType) hypothesisFromBindings span bs = buildHypothesis $ getLocalScope bs span + ------------------------------------------------------------------------------ -- | Convert a @Set Id@ into a hypothesis. -buildHypothesis :: [(Name, Maybe Type)] -> Map OccName CType +buildHypothesis :: [(Name, Maybe Type)] -> Map OccName (HyInfo CType) buildHypothesis = M.fromList . mapMaybe go where go (occName -> occ, t) | Just ty <- t - , isAlpha . head . occNameString $ occ = Just (occ, CType ty) + , isAlpha . head . occNameString $ occ = Just (occ, HyInfo UserPrv $ CType ty) | otherwise = Nothing -hasDestructed :: Judgement -> OccName -> Bool -hasDestructed j n = S.member n $ _jDestructed j - - -destructing :: OccName -> Judgement -> Judgement -destructing n = field @"_jDestructed" <>~ S.singleton n - - blacklistingDestruct :: Judgement -> Judgement blacklistingDestruct = field @"_jBlacklistDestruct" .~ True @@ -70,72 +90,146 @@ withNewGoal :: a -> Judgement' a -> Judgement' a withNewGoal t = field @"_jGoal" .~ t -introducing :: [(OccName, a)] -> Judgement' a -> Judgement' a -introducing ns = - field @"_jHypothesis" <>~ M.fromList ns - - ------------------------------------------------------------------------------ --- | Add some terms to the ambient hypothesis -introducingAmbient :: [(OccName, a)] -> Judgement' a -> Judgement' a -introducingAmbient ns = - field @"_jAmbientHypothesis" <>~ M.fromList ns +-- | Helper function for implementing functions which introduce new hypotheses. +introducing + :: (Int -> Provenance) -- ^ A function from the position of the arg to its + -- provenance. + -> [(OccName, a)] + -> Judgement' a + -> Judgement' a +introducing f ns = + field @"_jHypothesis" <>~ M.fromList (zip [0..] ns <&> + \(pos, (name, ty)) -> (name, HyInfo (f pos) ty)) -filterPosition :: OccName -> Int -> Judgement -> Judgement -filterPosition defn pos jdg = - withHypothesis (M.filterWithKey go) jdg - where - go name _ = isJust $ hasPositionalAncestry jdg defn pos name +------------------------------------------------------------------------------ +-- | Introduce bindings in the context of a lamba. +introducingLambda + :: Maybe OccName -- ^ The name of the top level function. For any other + -- function, this should be 'Nothing'. + -> [(OccName, a)] + -> Judgement' a + -> Judgement' a +introducingLambda func = introducing $ \pos -> + maybe UserPrv (\x -> TopLevelArgPrv x pos) func -filterSameTypeFromOtherPositions :: OccName -> Int -> Judgement -> Judgement -filterSameTypeFromOtherPositions defn pos jdg = - let hy = jHypothesis $ filterPosition defn pos jdg - tys = S.fromList $ fmap snd $ M.toList hy - in withHypothesis (\hy2 -> M.filter (not . flip S.member tys) hy2 <> hy) jdg +------------------------------------------------------------------------------ +-- | Introduce a binding in a recursive context. +introducingRecursively :: [(OccName, a)] -> Judgement' a -> Judgement' a +introducingRecursively = introducing $ const RecursivePrv +------------------------------------------------------------------------------ +-- | Check whether any of the given occnames are an ancestor of the term. hasPositionalAncestry - :: Judgement - -> OccName -- ^ defining fn - -> Int -- ^ position - -> OccName -- ^ thing to check ancestry + :: Foldable t + => t OccName -- ^ Desired ancestors. + -> Judgement + -> OccName -- ^ Potential child -> Maybe Bool -- ^ Just True if the result is the oldest positional ancestor -- just false if it's a descendent -- otherwise nothing -hasPositionalAncestry jdg defn n name +hasPositionalAncestry ancestors jdg name | not $ null ancestors = case any (== name) ancestors of True -> Just True False -> - case M.lookup name $ _jAncestry jdg of + case M.lookup name $ jAncestryMap jdg of Just ancestry -> bool Nothing (Just False) $ any (flip S.member ancestry) ancestors Nothing -> Nothing | otherwise = Nothing - where - ancestors = toListOf (_Just . traversed . ix n) - $ M.lookup defn - $ _jPositionMaps jdg -setParents - :: OccName -- ^ parent - -> [OccName] -- ^ children +------------------------------------------------------------------------------ +-- | Helper function for disallowing hypotheses that have the wrong ancestry. +filterAncestry + :: Foldable t + => t OccName + -> DisallowReason -> Judgement -> Judgement -setParents p cs jdg = - let ancestry = mappend (S.singleton p) - $ fromMaybe mempty - $ M.lookup p - $ _jAncestry jdg - in jdg & field @"_jAncestry" <>~ M.fromList (fmap (, ancestry) cs) +filterAncestry ancestry reason jdg = + disallowing reason (M.keys $ M.filterWithKey go $ jHypothesis jdg) jdg + where + go name _ + = not + . isJust + $ hasPositionalAncestry ancestry jdg name + + +------------------------------------------------------------------------------ +-- | @filter defn pos@ removes any hypotheses which are bound in @defn@ to +-- a position other than @pos@. Any terms whose ancestry doesn't include @defn@ +-- remain. +filterPosition :: OccName -> Int -> Judgement -> Judgement +filterPosition defn pos jdg = + filterAncestry (findPositionVal jdg defn pos) (WrongBranch pos) jdg + + +------------------------------------------------------------------------------ +-- | Helper function for determining the ancestry list for 'filterPosition'. +findPositionVal :: Judgement' a -> OccName -> Int -> Maybe OccName +findPositionVal jdg defn pos = listToMaybe $ do + -- It's important to inspect the entire hypothesis here, as we need to trace + -- ancstry through potentially disallowed terms in the hypothesis. + (name, hi) <- M.toList $ M.map (overProvenance expandDisallowed) $ jEntireHypothesis jdg + case hi_provenance hi of + TopLevelArgPrv defn' pos' + | defn == defn' + , pos == pos' -> pure name + PatternMatchPrv pv + | pv_scrutinee pv == Just defn + , pv_position pv == pos -> pure name + _ -> [] + + +------------------------------------------------------------------------------ +-- | Helper function for determining the ancestry list for +-- 'filterSameTypeFromOtherPositions'. +findDconPositionVals :: Judgement' a -> DataCon -> Int -> [OccName] +findDconPositionVals jdg dcon pos = do + (name, hi) <- M.toList $ jHypothesis jdg + case hi_provenance hi of + PatternMatchPrv pv + | pv_datacon pv == Uniquely dcon + , pv_position pv == pos -> pure name + _ -> [] + + +------------------------------------------------------------------------------ +-- | Disallow any hypotheses who have the same type as anything bound by the +-- given position for the datacon. Used to ensure recursive functions like +-- 'fmap' preserve the relative ordering of their arguments by eliminating any +-- other term which might match. +filterSameTypeFromOtherPositions :: DataCon -> Int -> Judgement -> Judgement +filterSameTypeFromOtherPositions dcon pos jdg = + let hy = jHypothesis + $ filterAncestry + (findDconPositionVals jdg dcon pos) + (WrongBranch pos) + jdg + tys = S.fromList $ fmap (hi_type . snd) $ M.toList hy + to_remove = + M.filter (flip S.member tys . hi_type) (jHypothesis jdg) + M.\\ hy + in disallowing Shadowed (M.keys to_remove) jdg -withPositionMapping :: OccName -> [OccName] -> Judgement -> Judgement -withPositionMapping defn names = - field @"_jPositionMaps" . at defn <>~ Just [names] +------------------------------------------------------------------------------ +-- | Return the ancestry of a 'PatVal', or 'mempty' otherwise. +getAncestry :: Judgement' a -> OccName -> Set OccName +getAncestry jdg name = + case M.lookup name $ jPatHypothesis jdg of + Just pv -> pv_ancestry pv + Nothing -> mempty + + +jAncestryMap :: Judgement' a -> Map OccName (Set OccName) +jAncestryMap jdg = + flip M.map (jPatHypothesis jdg) pv_ancestry ------------------------------------------------------------------------------ @@ -149,44 +243,65 @@ extremelyStupid__definingFunction = fst . head . ctxDefiningFuncs -withHypothesis - :: (Map OccName a -> Map OccName a) +------------------------------------------------------------------------------ +-- | Pattern vals are currently tracked in jHypothesis, with an extra piece of +-- data sitting around in jPatternVals. +introducingPat + :: Maybe OccName + -> DataCon + -> [(OccName, a)] -> Judgement' a -> Judgement' a -withHypothesis f = - field @"_jHypothesis" %~ f +introducingPat scrutinee dc ns jdg + = introducing (\pos -> + PatternMatchPrv $ + PatVal + scrutinee + (maybe mempty + (\scrut -> S.singleton scrut <> getAncestry jdg scrut) + scrutinee) + (Uniquely dc) + pos + ) ns jdg + ------------------------------------------------------------------------------ --- | Pattern vals are currently tracked in jHypothesis, with an extra piece of data sitting around in jPatternVals. -introducingPat :: [(OccName, a)] -> Judgement' a -> Judgement' a -introducingPat ns jdg = jdg - & field @"_jHypothesis" <>~ M.fromList ns - & field @"_jPatternVals" <>~ S.fromList (fmap fst ns) +-- | Prevent some occnames from being used in the hypothesis. This will hide +-- them from 'jHypothesis', but not from 'jEntireHypothesis'. +disallowing :: DisallowReason -> [OccName] -> Judgement' a -> Judgement' a +disallowing reason (S.fromList -> ns) = + field @"_jHypothesis" %~ (M.mapWithKey $ \name hi -> + case S.member name ns of + True -> overProvenance (DisallowedPrv reason) hi + False -> hi + ) -disallowing :: [OccName] -> Judgement' a -> Judgement' a -disallowing ns = - field @"_jHypothesis" %~ flip M.withoutKeys (S.fromList ns) +------------------------------------------------------------------------------ +-- | The hypothesis, consisting of local terms and the ambient environment +-- (impors and class methods.) Hides disallowed values. +jHypothesis :: Judgement' a -> Map OccName (HyInfo a) +jHypothesis = M.filter (not . isDisallowed . hi_provenance) . jEntireHypothesis ------------------------------------------------------------------------------ --- | The hypothesis, consisting of local terms and the ambient environment --- (includes and class methods.) -jHypothesis :: Judgement' a -> Map OccName a -jHypothesis = _jHypothesis <> _jAmbientHypothesis +-- | The whole hypothesis, including things disallowed. +jEntireHypothesis :: Judgement' a -> Map OccName (HyInfo a) +jEntireHypothesis = _jHypothesis ------------------------------------------------------------------------------ -- | Just the local hypothesis. -jLocalHypothesis :: Judgement' a -> Map OccName a -jLocalHypothesis = _jHypothesis +jLocalHypothesis :: Judgement' a -> Map OccName (HyInfo a) +jLocalHypothesis = M.filter (isLocalHypothesis . hi_provenance) . jHypothesis -isPatVal :: Judgement' a -> OccName -> Bool -isPatVal j n = S.member n $ _jPatternVals j +------------------------------------------------------------------------------ +-- | If we're in a top hole, the name of the defining function. +isTopHole :: Context -> Judgement' a -> Maybe OccName +isTopHole ctx = + bool Nothing (Just $ extremelyStupid__definingFunction ctx) . _jIsTopHole -isTopHole :: Judgement' a -> Bool -isTopHole = _jIsTopHole unsetIsTopHole :: Judgement' a -> Judgement' a unsetIsTopHole = field @"_jIsTopHole" .~ False @@ -194,9 +309,15 @@ unsetIsTopHole = field @"_jIsTopHole" .~ False ------------------------------------------------------------------------------ -- | Only the hypothesis members which are pattern vals -jPatHypothesis :: Judgement' a -> Map OccName a -jPatHypothesis jdg - = M.restrictKeys (jHypothesis jdg) $ _jPatternVals jdg +jPatHypothesis :: Judgement' a -> Map OccName PatVal +jPatHypothesis = M.mapMaybe (getPatVal . hi_provenance) . jHypothesis + + +getPatVal :: Provenance-> Maybe PatVal +getPatVal prov = + case prov of + PatternMatchPrv pv -> Just pv + _ -> Nothing jGoal :: Judgement' a -> a @@ -206,23 +327,54 @@ jGoal = _jGoal substJdg :: TCvSubst -> Judgement -> Judgement substJdg subst = fmap $ coerce . substTy subst . coerce + mkFirstJudgement - :: M.Map OccName CType -- ^ local hypothesis - -> M.Map OccName CType -- ^ ambient hypothesis + :: M.Map OccName (HyInfo CType) -> Bool -- ^ are we in the top level rhs hole? - -> M.Map OccName [[OccName]] -- ^ existing pos vals -> Type -> Judgement' CType -mkFirstJudgement hy ambient top posvals goal = Judgement +mkFirstJudgement hy top goal = Judgement { _jHypothesis = hy - , _jAmbientHypothesis = ambient - , _jDestructed = mempty - , _jPatternVals = mempty , _jBlacklistDestruct = False , _jWhitelistSplit = True - , _jPositionMaps = posvals - , _jAncestry = mempty , _jIsTopHole = top , _jGoal = CType goal } + +------------------------------------------------------------------------------ +-- | Is this a top level function binding? +isTopLevel :: Provenance -> Bool +isTopLevel TopLevelArgPrv{} = True +isTopLevel _ = False + + +------------------------------------------------------------------------------ +-- | Is this a local function argument, pattern match or user val? +isLocalHypothesis :: Provenance -> Bool +isLocalHypothesis UserPrv{} = True +isLocalHypothesis PatternMatchPrv{} = True +isLocalHypothesis TopLevelArgPrv{} = True +isLocalHypothesis _ = False + + +------------------------------------------------------------------------------ +-- | Is this a pattern match? +isPatternMatch :: Provenance -> Bool +isPatternMatch PatternMatchPrv{} = True +isPatternMatch _ = False + + +------------------------------------------------------------------------------ +-- | Was this term ever disallowed? +isDisallowed :: Provenance -> Bool +isDisallowed DisallowedPrv{} = True +isDisallowed _ = False + + +------------------------------------------------------------------------------ +-- | Eliminates 'DisallowedPrv' provenances. +expandDisallowed :: Provenance -> Provenance +expandDisallowed (DisallowedPrv _ prv) = expandDisallowed prv +expandDisallowed prv = prv + diff --git a/plugins/tactics/src/Ide/Plugin/Tactic/Machinery.hs b/plugins/tactics/src/Ide/Plugin/Tactic/Machinery.hs index f3e41c0061..22ef2b6b5e 100644 --- a/plugins/tactics/src/Ide/Plugin/Tactic/Machinery.hs +++ b/plugins/tactics/src/Ide/Plugin/Tactic/Machinery.hs @@ -25,13 +25,16 @@ import Control.Monad.Reader import Control.Monad.State (MonadState(..)) import Control.Monad.State.Class (gets, modify) import Control.Monad.State.Strict (StateT (..)) +import Data.Bool (bool) import Data.Coerce import Data.Either import Data.Foldable import Data.Functor ((<&>)) import Data.Generics (mkQ, everything, gcount) -import Data.List (nub, sortBy) +import Data.List (sortBy) +import qualified Data.Map as M import Data.Ord (comparing, Down(..)) +import Data.Set (Set) import qualified Data.Set as S import Development.IDE.GHC.Compat import Ide.Plugin.Tactic.Judgements @@ -71,15 +74,19 @@ runTactic -> TacticsM () -- ^ Tactic to use -> Either [TacticError] RunTacticResults runTactic ctx jdg t = - let skolems = nub + let skolems = S.fromList $ foldMap (tyCoVarsOfTypeWellScoped . unCType) - $ jGoal jdg - : (toList $ jHypothesis jdg) - unused_topvals = nub $ join $ join $ toList $ _jPositionMaps jdg + $ (:) (jGoal jdg) + $ fmap hi_type + $ toList + $ jHypothesis jdg + unused_topvals = M.keysSet + $ M.filter (isTopLevel . hi_provenance) + $ jHypothesis jdg tacticState = defaultTacticState { ts_skolems = skolems - , ts_unused_top_vals = S.fromList unused_topvals + , ts_unused_top_vals = unused_topvals } in case partitionEithers . flip runReader ctx @@ -118,20 +125,28 @@ tracing s (TacticT m) mapExtract' (first $ rose s . pure) $ runStateT m jdg -recursiveCleanup +------------------------------------------------------------------------------ +-- | Recursion is allowed only when we can prove it is on a structurally +-- smaller argument. The top of the 'ts_recursion_stack' witnesses the smaller +-- pattern val. +guardStructurallySmallerRecursion :: TacticState -> Maybe TacticError -recursiveCleanup s = - let r = head $ ts_recursion_stack s - in case r of - True -> Nothing - False -> Just NoProgress +guardStructurallySmallerRecursion s = + case head $ ts_recursion_stack s of + Just _ -> Nothing + Nothing -> Just NoProgress -setRecursionFrameData :: MonadState TacticState m => Bool -> m () -setRecursionFrameData b = do +------------------------------------------------------------------------------ +-- | Mark that the current recursive call is structurally smaller, due to +-- having been matched on a pattern value. +-- +-- Implemented by setting the top of the 'ts_recursion_stack'. +markStructuralySmallerRecursion :: MonadState TacticState m => PatVal -> m () +markStructuralySmallerRecursion pv = do modify $ withRecursionStack $ \case - (_ : bs) -> b : bs + (_ : bs) -> Just pv : bs [] -> [] @@ -159,7 +174,7 @@ scoreSolution ext TacticState{..} holes , Penalize $ S.size ts_unused_top_vals , Penalize $ S.size ts_intro_vals , Reward $ S.size ts_used_vals - , Penalize $ ts_recursion_penality + , Penalize $ ts_recursion_count , Penalize $ solutionSize ext ) @@ -181,21 +196,16 @@ newtype Reward a = Reward a ------------------------------------------------------------------------------ -- | Like 'tcUnifyTy', but takes a list of skolems to prevent unification of. -tryUnifyUnivarsButNotSkolems :: [TyVar] -> CType -> CType -> Maybe TCvSubst +tryUnifyUnivarsButNotSkolems :: Set TyVar -> CType -> CType -> Maybe TCvSubst tryUnifyUnivarsButNotSkolems skolems goal inst = - case tcUnifyTysFG (skolemsOf skolems) [unCType inst] [unCType goal] of + case tcUnifyTysFG + (bool BindMe Skolem . flip S.member skolems) + [unCType inst] + [unCType goal] of Unifiable subst -> pure subst _ -> Nothing ------------------------------------------------------------------------------- --- | Helper method for 'tryUnifyUnivarsButNotSkolems' -skolemsOf :: [TyVar] -> TyVar -> BindFlag -skolemsOf tvs tv = - case elem tv tvs of - True -> Skolem - False -> BindMe - ------------------------------------------------------------------------------ -- | Attempt to unify two types. @@ -213,7 +223,7 @@ unify goal inst = do ------------------------------------------------------------------------------ -- | Get the class methods of a 'PredType', correctly dealing with -- instantiation of quantified class types. -methodHypothesis :: PredType -> Maybe [(OccName, CType)] +methodHypothesis :: PredType -> Maybe [(OccName, HyInfo CType)] methodHypothesis ty = do (tc, apps) <- splitTyConApp_maybe ty cls <- tyConClass_maybe tc @@ -225,7 +235,9 @@ methodHypothesis ty = do $ classSCTheta cls pure $ mappend sc_methods $ methods <&> \method -> let (_, _, ty) = tcSplitSigmaTy $ idType method - in (occName method, CType $ substTy subst ty) + in ( occName method + , HyInfo (ClassMethodPrv $ Uniquely cls) $ CType $ substTy subst ty + ) ------------------------------------------------------------------------------ @@ -234,7 +246,7 @@ methodHypothesis ty = do requireConcreteHole :: TacticsM a -> TacticsM a requireConcreteHole m = do jdg <- goal - skolems <- gets $ S.fromList . ts_skolems + skolems <- gets ts_skolems let vars = S.fromList $ tyCoVarsOfTypeWellScoped $ unCType $ jGoal jdg case S.size $ vars S.\\ skolems of 0 -> m diff --git a/plugins/tactics/src/Ide/Plugin/Tactic/Tactics.hs b/plugins/tactics/src/Ide/Plugin/Tactic/Tactics.hs index f1c2a6d220..8f6de3ebce 100644 --- a/plugins/tactics/src/Ide/Plugin/Tactic/Tactics.hs +++ b/plugins/tactics/src/Ide/Plugin/Tactic/Tactics.hs @@ -20,6 +20,7 @@ import Control.Monad.Reader.Class (MonadReader(ask)) import Control.Monad.State.Class import Control.Monad.State.Strict (StateT(..), runStateT) import Data.Bool (bool) +import Data.Foldable import Data.List import qualified Data.Map as M import Data.Maybe @@ -55,10 +56,9 @@ assume :: OccName -> TacticsM () assume name = rule $ \jdg -> do let g = jGoal jdg case M.lookup name $ jHypothesis jdg of - Just ty -> do + Just (hi_type -> ty) -> do unify ty $ jGoal jdg - when (M.member name $ jPatHypothesis jdg) $ - setRecursionFrameData True + for_ (M.lookup name $ jPatHypothesis jdg) markStructuralySmallerRecursion useOccName jdg name pure $ (tracePrim $ "assume " <> occNameString name, ) $ noLoc $ var' name Nothing -> throwError $ UndefinedHypothesis name @@ -68,10 +68,9 @@ recursion :: TacticsM () recursion = requireConcreteHole $ tracing "recursion" $ do defs <- getCurrentDefinitions attemptOn (const $ fmap fst defs) $ \name -> do - modify $ withRecursionStack (False :) - penalizeRecursion - ensure recursiveCleanup (withRecursionStack tail) $ do - (localTactic (apply name) $ introducingAmbient defs) + modify $ pushRecursionStack . countRecursiveCall + ensure guardStructurallySmallerRecursion popRecursionStack $ do + (localTactic (apply name) $ introducingRecursively defs) <@> fmap (localTactic assumption . filterPosition name) [0..] @@ -85,19 +84,13 @@ intros = rule $ \jdg -> do case tcSplitFunTys $ unCType g of ([], _) -> throwError $ GoalMismatch "intros" g (as, b) -> do - vs <- mkManyGoodNames hy as - let jdg' = introducing (zip vs $ coerce as) + vs <- mkManyGoodNames (jEntireHypothesis jdg) as + let top_hole = isTopHole ctx jdg + jdg' = introducingLambda top_hole (zip vs $ coerce as) $ withNewGoal (CType b) jdg modify $ withIntroducedVals $ mappend $ S.fromList vs - when (isTopHole jdg) $ addUnusedTopVals $ S.fromList vs - (tr, sg) - <- newSubgoal - $ bool - id - (withPositionMapping - (extremelyStupid__definingFunction ctx) vs) - (isTopHole jdg) - $ jdg' + when (isJust top_hole) $ addUnusedTopVals $ S.fromList vs + (tr, sg) <- newSubgoal jdg' pure . (rose ("intros {" <> intercalate ", " (fmap show vs) <> "}") $ pure tr, ) . noLoc @@ -110,29 +103,27 @@ intros = rule $ \jdg -> do destructAuto :: OccName -> TacticsM () destructAuto name = requireConcreteHole $ tracing "destruct(auto)" $ do jdg <- goal - case hasDestructed jdg name of - True -> throwError $ AlreadyDestructed name - False -> + case M.lookup name $ jHypothesis jdg of + Nothing -> throwError $ NotInScope name + Just hi -> let subtactic = rule $ destruct' (const subgoal) name - in case isPatVal jdg name of + in case isPatternMatch $ hi_provenance hi of True -> pruning subtactic $ \jdgs -> - let getHyTypes = S.fromList . fmap snd . M.toList . jHypothesis + let getHyTypes = S.fromList . fmap (hi_type . snd) . M.toList . jHypothesis new_hy = foldMap getHyTypes jdgs old_hy = getHyTypes jdg - in case S.null $ new_hy S.\\ old_hy of + in case S.null $ new_hy S.\\ old_hy of True -> Just $ UnhelpfulDestruct name False -> Nothing False -> subtactic + ------------------------------------------------------------------------------ -- | Case split, and leave holes in the matches. destruct :: OccName -> TacticsM () -destruct name = requireConcreteHole $ tracing "destruct(user)" $ do - jdg <- goal - case hasDestructed jdg name of - True -> throwError $ AlreadyDestructed name - False -> rule $ \jdg -> destruct' (const subgoal) name jdg +destruct name = requireConcreteHole $ tracing "destruct(user)" $ + rule $ destruct' (const subgoal) name ------------------------------------------------------------------------------ @@ -167,7 +158,7 @@ apply func = requireConcreteHole $ tracing ("apply' " <> show func) $ do let hy = jHypothesis jdg g = jGoal jdg case M.lookup func hy of - Just (CType ty) -> do + Just (hi_type -> CType ty) -> do ty' <- freshTyvars ty let (_, _, args, ret) = tacticsSplitFunTy ty' requireNewHoles $ rule $ \jdg -> do @@ -283,12 +274,12 @@ auto' n = do overFunctions :: (OccName -> TacticsM ()) -> TacticsM () overFunctions = - attemptOn $ M.keys . M.filter (isFunction . unCType) . jHypothesis + attemptOn $ M.keys . M.filter (isFunction . unCType . hi_type) . jHypothesis overAlgebraicTerms :: (OccName -> TacticsM ()) -> TacticsM () overAlgebraicTerms = attemptOn $ - M.keys . M.filter (isJust . algebraicTyCon . unCType) . jHypothesis + M.keys . M.filter (isJust . algebraicTyCon . unCType . hi_type) . jHypothesis allNames :: Judgement -> [OccName] diff --git a/plugins/tactics/src/Ide/Plugin/Tactic/Types.hs b/plugins/tactics/src/Ide/Plugin/Tactic/Types.hs index 6b4201b49a..95e80f004f 100644 --- a/plugins/tactics/src/Ide/Plugin/Tactic/Types.hs +++ b/plugins/tactics/src/Ide/Plugin/Tactic/Types.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE DerivingVia #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE TypeApplications #-} @@ -39,7 +40,7 @@ import Refinery.Tactic import System.IO.Unsafe (unsafePerformIO) import Type import UniqSupply (takeUniqFromSupply, mkSplitUniqSupply, UniqSupply) -import Unique (Unique) +import Unique (nonDetCmpUnique, Uniquable, getUnique, Unique) ------------------------------------------------------------------------------ @@ -70,10 +71,13 @@ instance Show (LHsExpr GhcPs) where instance Show DataCon where show = unsafeRender +instance Show Class where + show = unsafeRender + ------------------------------------------------------------------------------ data TacticState = TacticState - { ts_skolems :: !([TyVar]) + { ts_skolems :: !(Set TyVar) -- ^ The known skolems. , ts_unifier :: !(TCvSubst) -- ^ The current substitution of univars. @@ -83,11 +87,11 @@ data TacticState = TacticState -- ^ Set of values introduced by tactics. , ts_unused_top_vals :: !(Set OccName) -- ^ Set of currently unused arguments to the function being defined. - , ts_recursion_stack :: ![Bool] + , ts_recursion_stack :: ![Maybe PatVal] -- ^ Stack for tracking whether or not the current recursive call has -- used at least one smaller pat val. Recursive calls for which this -- value is 'False' are guaranteed to loop, and must be pruned. - , ts_recursion_penality :: !Int + , ts_recursion_count :: !Int -- ^ Number of calls to recursion. We penalize each. , ts_unique_gen :: !UniqSupply } deriving stock (Show, Generic) @@ -113,7 +117,7 @@ defaultTacticState = , ts_intro_vals = mempty , ts_unused_top_vals = mempty , ts_recursion_stack = mempty - , ts_recursion_penality = 0 + , ts_recursion_count = 0 , ts_unique_gen = unsafeDefaultUniqueSupply } @@ -128,10 +132,16 @@ freshUnique = do withRecursionStack - :: ([Bool] -> [Bool]) -> TacticState -> TacticState + :: ([Maybe PatVal] -> [Maybe PatVal]) -> TacticState -> TacticState withRecursionStack f = field @"ts_recursion_stack" %~ f +pushRecursionStack :: TacticState -> TacticState +pushRecursionStack = withRecursionStack (Nothing :) + +popRecursionStack :: TacticState -> TacticState +popRecursionStack = withRecursionStack tail + withUsedVals :: (Set OccName -> Set OccName) -> TacticState -> TacticState withUsedVals f = @@ -143,26 +153,97 @@ withIntroducedVals f = field @"ts_intro_vals" %~ f +------------------------------------------------------------------------------ +-- | Describes where hypotheses came from. Used extensively to prune stupid +-- solutions from the search space. +data Provenance + = -- | An argument given to the topmost function that contains the current + -- hole. Recursive calls are restricted to values whose provenance lines up + -- with the same argument. + TopLevelArgPrv + OccName -- ^ Binding function + Int -- ^ Argument Position + -- | A binding created in a pattern match. + | PatternMatchPrv PatVal + -- | A class method from the given context. + | ClassMethodPrv + (Uniquely Class) -- ^ Class + -- | A binding explicitly written by the user. + | UserPrv + -- | The recursive hypothesis. Present only in the context of the recursion + -- tactic. + | RecursivePrv + -- | A hypothesis which has been disallowed for some reason. It's important + -- to keep these in the hypothesis set, rather than filtering it, in order + -- to continue tracking downstream provenance. + | DisallowedPrv DisallowReason Provenance + deriving stock (Eq, Show, Generic, Ord) + + +------------------------------------------------------------------------------ +-- | Why was a hypothesis disallowed? +data DisallowReason + = WrongBranch Int + | Shadowed + | RecursiveCall + | AlreadyDestructed + deriving stock (Eq, Show, Generic, Ord) + + +------------------------------------------------------------------------------ +-- | Provenance of a pattern value. +data PatVal = PatVal + { pv_scrutinee :: Maybe OccName + -- ^ Original scrutinee which created this PatVal. Nothing, for lambda + -- case. + , pv_ancestry :: Set OccName + -- ^ The set of values which had to be destructed to discover this term. + -- Always contains the scrutinee. + , pv_datacon :: Uniquely DataCon + -- ^ The datacon which introduced this term. + , pv_position :: Int + -- ^ The position of this binding in the datacon's arguments. + } deriving stock (Eq, Show, Generic, Ord) + + +------------------------------------------------------------------------------ +-- | A wrapper which uses a 'Uniquable' constraint for providing 'Eq' and 'Ord' +-- instances. +newtype Uniquely a = Uniquely { getViaUnique :: a } + deriving Show via a + +instance Uniquable a => Eq (Uniquely a) where + (==) = (==) `on` getUnique . getViaUnique + +instance Uniquable a => Ord (Uniquely a) where + compare = nonDetCmpUnique `on` getUnique . getViaUnique + + +------------------------------------------------------------------------------ +-- | The provenance and type of a hypothesis term. +data HyInfo a = HyInfo + { hi_provenance :: Provenance + , hi_type :: a + } + deriving stock (Functor, Eq, Show, Generic, Ord) + + +------------------------------------------------------------------------------ +-- | Map a function over the provenance. +overProvenance :: (Provenance -> Provenance) -> HyInfo a -> HyInfo a +overProvenance f (HyInfo prv ty) = HyInfo (f prv) ty + ------------------------------------------------------------------------------ -- | The current bindings and goal for a hole to be filled by refinery. data Judgement' a = Judgement - { _jHypothesis :: !(Map OccName a) - , _jAmbientHypothesis :: !(Map OccName a) - -- ^ Things in the hypothesis that were imported. Solutions don't get - -- points for using the ambient hypothesis. - , _jDestructed :: !(Set OccName) - -- ^ These should align with keys of _jHypothesis - , _jPatternVals :: !(Set OccName) - -- ^ These should align with keys of _jHypothesis + { _jHypothesis :: !(Map OccName (HyInfo a)) , _jBlacklistDestruct :: !(Bool) , _jWhitelistSplit :: !(Bool) - , _jPositionMaps :: !(Map OccName [[OccName]]) - , _jAncestry :: !(Map OccName (Set OccName)) , _jIsTopHole :: !Bool , _jGoal :: !(a) } - deriving stock (Eq, Ord, Generic, Functor, Show) + deriving stock (Eq, Generic, Functor, Show) type Judgement = Judgement' CType @@ -185,12 +266,12 @@ data TacticError | UnificationError CType CType | NoProgress | NoApplicableTactic - | AlreadyDestructed OccName | IncorrectDataCon DataCon | RecursionOnWrongParam OccName Int OccName | UnhelpfulDestruct OccName | UnhelpfulSplit OccName | TooPolymorphic + | NotInScope OccName deriving stock (Eq) instance Show TacticError where @@ -216,8 +297,6 @@ instance Show TacticError where "Unable to make progress" show NoApplicableTactic = "No tactic could be applied" - show (AlreadyDestructed name) = - "Already destructed " <> unsafeRender name show (IncorrectDataCon dcon) = "Data con doesn't align with goal type (" <> unsafeRender dcon <> ")" show (RecursionOnWrongParam call p arg) = @@ -229,6 +308,8 @@ instance Show TacticError where "Splitting constructor " <> show n <> " leads to no new goals" show TooPolymorphic = "The tactic isn't applicable because the goal is too polymorphic" + show (NotInScope name) = + "Tried to do something with the out of scope name " <> show name ------------------------------------------------------------------------------ diff --git a/plugins/tactics/test/AutoTupleSpec.hs b/plugins/tactics/test/AutoTupleSpec.hs index 9b73c7c2f9..94125c06c4 100644 --- a/plugins/tactics/test/AutoTupleSpec.hs +++ b/plugins/tactics/test/AutoTupleSpec.hs @@ -40,12 +40,10 @@ spec = describe "auto for tuple" $ do pure $ -- We should always be able to find a solution runTactic - (Context [] []) + emptyContext (mkFirstJudgement - (M.singleton (mkVarOcc "x") $ CType in_type) - mempty + (M.singleton (mkVarOcc "x") $ HyInfo UserPrv $ CType in_type) True - mempty out_type) (auto' $ n * 2) `shouldSatisfy` isRight diff --git a/plugins/tactics/test/UnificationSpec.hs b/plugins/tactics/test/UnificationSpec.hs index 9351725036..a5a21567ae 100644 --- a/plugins/tactics/test/UnificationSpec.hs +++ b/plugins/tactics/test/UnificationSpec.hs @@ -3,21 +3,22 @@ module UnificationSpec where -import Control.Arrow -import Data.Bool (bool) -import Data.Functor ((<&>)) -import Data.Maybe (mapMaybe) -import Data.Traversable -import Data.Tuple (swap) -import Ide.Plugin.Tactic.Debug -import Ide.Plugin.Tactic.Machinery -import Ide.Plugin.Tactic.Types -import TcType (tcGetTyVar_maybe, substTy) -import Test.Hspec -import Test.QuickCheck -import Type (mkTyVarTy) -import TysPrim (alphaTyVars) -import TysWiredIn (mkBoxedTupleTy) +import Control.Arrow +import Data.Bool (bool) +import Data.Functor ((<&>)) +import Data.Maybe (mapMaybe) +import qualified Data.Set as S +import Data.Traversable +import Data.Tuple (swap) +import Ide.Plugin.Tactic.Debug +import Ide.Plugin.Tactic.Machinery +import Ide.Plugin.Tactic.Types +import TcType (tcGetTyVar_maybe, substTy) +import Test.Hspec +import Test.QuickCheck +import Type (mkTyVarTy) +import TysPrim (alphaTyVars) +import TysWiredIn (mkBoxedTupleTy) instance Show Type where @@ -42,7 +43,7 @@ spec = describe "unification" $ do counterexample (show lhs) $ counterexample (show rhs) $ case tryUnifyUnivarsButNotSkolems - (mapMaybe tcGetTyVar_maybe skolems) + (S.fromList $ mapMaybe tcGetTyVar_maybe skolems) (CType lhs) (CType rhs) of Just subst ->