diff --git a/src/Type/WebMozartAssert/AssertTypeSpecifyingExtension.php b/src/Type/WebMozartAssert/AssertTypeSpecifyingExtension.php index 6480d97..6220aab 100644 --- a/src/Type/WebMozartAssert/AssertTypeSpecifyingExtension.php +++ b/src/Type/WebMozartAssert/AssertTypeSpecifyingExtension.php @@ -146,6 +146,14 @@ public function specifyTypes( ); } + if (substr($staticMethodReflection->getName(), 0, 6) === 'nullOr') { + return $this->handleNullOr( + $staticMethodReflection->getName(), + $node, + $scope + ); + } + $expression = self::createExpression($scope, $staticMethodReflection->getName(), $node->getArgs()); if ($expression === null) { return new SpecifiedTypes([], []); @@ -170,22 +178,7 @@ private static function createExpression( $trimmedName = self::trimName($name); $resolvers = self::getExpressionResolvers(); $resolver = $resolvers[$trimmedName]; - $expression = $resolver($scope, ...$args); - if ($expression === null) { - return null; - } - - if (substr($name, 0, 6) === 'nullOr') { - $expression = new BooleanOr( - $expression, - new Identical( - $args[0]->value, - new ConstFetch(new Name('null')) - ) - ); - } - - return $expression; + return $resolver($scope, ...$args); } /** @@ -799,24 +792,45 @@ private function handleAll( TypeSpecifierContext::createTruthy() ); - if (count($specifiedTypes->getSureTypes()) > 0) { - $sureTypes = $specifiedTypes->getSureTypes(); - $exprString = key($sureTypes); - [$exprNode, $type] = $sureTypes[$exprString]; + $type = $this->determineVariableTypeFromSpecifiedTypes($scope, $specifiedTypes); - return $this->arrayOrIterable( - $scope, - $exprNode, - static function () use ($type): Type { - return $type->getIterableValueType(); - } - ); - } - if (count($specifiedTypes->getSureNotTypes()) > 0) { - throw new ShouldNotHappenException(); + return $this->arrayOrIterable( + $scope, + $node->getArgs()[0]->value, + static function () use ($type): Type { + return $type->getIterableValueType(); + } + ); + } + + private function handleNullOr( + string $methodName, + StaticCall $node, + Scope $scope + ): SpecifiedTypes + { + $innerExpression = self::createExpression($scope, $methodName, $node->getArgs()); + if ($innerExpression === null) { + return new SpecifiedTypes(); } - return $specifiedTypes; + $expression = new BooleanAnd( + new NotIdentical( + $node->getArgs()[0]->value, + new ConstFetch(new Name('null')) + ), + $innerExpression + ); + + $specifiedTypes = $this->typeSpecifier->specifyTypesInCondition( + $scope, + $expression, + TypeSpecifierContext::createTruthy() + ); + + $type = $this->determineVariableTypeFromSpecifiedTypes($scope, $specifiedTypes); + + return $this->typeSpecifier->create($node->getArgs()[0]->value, TypeCombinator::addNull($type), TypeSpecifierContext::createTruthy(), false, $scope); } private function arrayOrIterable( @@ -856,6 +870,33 @@ private function arrayOrIterable( ); } + private function determineVariableTypeFromSpecifiedTypes(Scope $scope, SpecifiedTypes $specifiedTypes): Type + { + $sureTypes = $specifiedTypes->getSureTypes(); + $sureNotTypes = $specifiedTypes->getSureNotTypes(); + + if (count($sureTypes) > 0) { + $exprString = key($sureTypes); + [, $type] = $sureTypes[$exprString]; + + if (array_key_exists($exprString, $sureNotTypes)) { + $type = TypeCombinator::remove($type, $sureNotTypes[$exprString][1]); + } + + return $type; + } + + if (count($sureNotTypes) > 0) { + $sureNotTypes = $specifiedTypes->getSureNotTypes(); + $exprString = key($sureNotTypes); + [$exprNode, $type] = $sureNotTypes[$exprString]; + + return TypeCombinator::remove($scope->getType($exprNode), $type); + } + + throw new ShouldNotHappenException(); + } + /** * @param Expr[] $expressions * @param class-string $binaryOp