diff --git a/src/Bolt/ConnectionPool.php b/src/Bolt/ConnectionPool.php index 937cffa4..1d2c2930 100644 --- a/src/Bolt/ConnectionPool.php +++ b/src/Bolt/ConnectionPool.php @@ -53,7 +53,7 @@ public static function create( ): self { return new self( $semaphore, - BoltFactory::create($conf->getLogger()), + BoltFactory::create($conf->getLogger(), $conf->getSocketType()), new ConnectionRequestData( $uri->getHost(), $uri, diff --git a/src/Bolt/SystemWideConnectionFactory.php b/src/Bolt/SystemWideConnectionFactory.php index fe2c1598..86766b16 100644 --- a/src/Bolt/SystemWideConnectionFactory.php +++ b/src/Bolt/SystemWideConnectionFactory.php @@ -16,6 +16,7 @@ use function extension_loaded; use Laudis\Neo4j\Contracts\BasicConnectionFactoryInterface; +use Laudis\Neo4j\Enum\SocketType; /** * Singleton connection factory based on the installed extensions. @@ -35,8 +36,17 @@ private function __construct( /** * @psalm-suppress InvalidNullableReturnType */ - public static function getInstance(): SystemWideConnectionFactory + public static function getInstance(?SocketType $preferredSocket = null): SystemWideConnectionFactory { + // If a specific socket type is requested, create a new instance without caching + if ($preferredSocket === SocketType::SOCKETS() && extension_loaded('sockets')) { + return new self(new SocketConnectionFactory(new StreamConnectionFactory())); + } + + if ($preferredSocket === SocketType::STREAM()) { + return new self(new StreamConnectionFactory()); + } + if (self::$instance === null) { $factory = new StreamConnectionFactory(); if (extension_loaded('sockets')) { diff --git a/src/BoltFactory.php b/src/BoltFactory.php index afe1d744..bc3550d9 100644 --- a/src/BoltFactory.php +++ b/src/BoltFactory.php @@ -27,6 +27,7 @@ use Laudis\Neo4j\Databags\SessionConfiguration; use Laudis\Neo4j\Databags\TransactionConfiguration; use Laudis\Neo4j\Enum\ConnectionProtocol; +use Laudis\Neo4j\Enum\SocketType; /** * Small wrapper around the bolt library to easily guarantee only bolt version 3 and up will be created and authenticated. @@ -44,9 +45,9 @@ public function __construct( ) { } - public static function create(?Neo4jLogger $logger): self + public static function create(?Neo4jLogger $logger, ?SocketType $socketType = null): self { - return new self(SystemWideConnectionFactory::getInstance(), new ProtocolFactory(), new SslConfigurationFactory(), $logger); + return new self(SystemWideConnectionFactory::getInstance($socketType), new ProtocolFactory(), new SslConfigurationFactory(), $logger); } public function createConnection(ConnectionRequestData $data, SessionConfiguration $sessionConfig): BoltConnection diff --git a/src/Databags/DriverConfiguration.php b/src/Databags/DriverConfiguration.php index 68c35a87..2b5b9368 100644 --- a/src/Databags/DriverConfiguration.php +++ b/src/Databags/DriverConfiguration.php @@ -24,6 +24,7 @@ use Laudis\Neo4j\Common\Neo4jLogger; use Laudis\Neo4j\Common\SemaphoreFactory; use Laudis\Neo4j\Contracts\SemaphoreFactoryInterface; +use Laudis\Neo4j\Enum\SocketType; use Psr\Log\LoggerInterface; use Psr\Log\LogLevel; use Psr\SimpleCache\CacheInterface; @@ -44,11 +45,13 @@ final class DriverConfiguration /** @var callable():(SemaphoreFactoryInterface|null)|SemaphoreFactoryInterface|null */ private $semaphoreFactory; private ?Neo4jLogger $logger; + private ?SocketType $socketType; /** * @param callable():(CacheInterface|null)|CacheInterface|null $cache * @param callable():(SemaphoreFactoryInterface|null)|SemaphoreFactoryInterface|null $semaphore - * @param string|null $logLevel The log level to use. If null, LogLevel::INFO is used. + * @param string|null $logLevel The log level to use. If null, LogLevel::INFO is used. + * @param SocketType|null $socketType the socket type to use (SocketType::SOCKETS(), SocketType::STREAM(), or null for auto-detect) * * @psalm-external-mutation-free */ @@ -61,6 +64,7 @@ public function __construct( callable|SemaphoreFactoryInterface|null $semaphore, ?string $logLevel, ?LoggerInterface $logger, + ?SocketType $socketType = null, ) { $this->cache = $cache; $this->semaphoreFactory = $semaphore; @@ -69,6 +73,7 @@ public function __construct( } else { $this->logger = null; } + $this->socketType = $socketType; } /** @@ -83,6 +88,7 @@ public static function create( SemaphoreFactoryInterface $semaphore, ?string $logLevel, ?LoggerInterface $logger, + ?SocketType $socketType = null, ): self { return new self( $userAgent, @@ -92,7 +98,8 @@ public static function create( $acquireConnectionTimeout, $semaphore, $logLevel, - $logger + $logger, + $socketType ); } @@ -261,4 +268,23 @@ public function withLogger(?string $logLevel, ?LoggerInterface $logger): self return $tbr; } + + /** + * @psalm-immutable + */ + public function getSocketType(): ?SocketType + { + return $this->socketType; + } + + /** + * @psalm-immutable + */ + public function withSocketType(?SocketType $socketType): self + { + $tbr = clone $this; + $tbr->socketType = $socketType; + + return $tbr; + } } diff --git a/src/Enum/SocketType.php b/src/Enum/SocketType.php new file mode 100644 index 00000000..4b3badfa --- /dev/null +++ b/src/Enum/SocketType.php @@ -0,0 +1,40 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Laudis\Neo4j\Enum; + +use JsonSerializable; +use Laudis\TypedEnum\TypedEnum; + +/** + * Defines the socket type to use for connections. + * + * @method static self SOCKETS() + * @method static self STREAM() + * + * @extends TypedEnum + * + * @psalm-immutable + * + * @psalm-suppress MutableDependency + */ +final class SocketType extends TypedEnum implements JsonSerializable +{ + private const SOCKETS = 'sockets'; + private const STREAM = 'stream'; + + public function jsonSerialize(): string + { + return $this->getValue(); + } +} diff --git a/tests/Unit/BoltFactoryTest.php b/tests/Unit/BoltFactoryTest.php index 1aaf278f..4750de56 100644 --- a/tests/Unit/BoltFactoryTest.php +++ b/tests/Unit/BoltFactoryTest.php @@ -20,13 +20,16 @@ use Laudis\Neo4j\Bolt\Connection; use Laudis\Neo4j\Bolt\ProtocolFactory; use Laudis\Neo4j\Bolt\SslConfigurationFactory; +use Laudis\Neo4j\Bolt\SystemWideConnectionFactory; use Laudis\Neo4j\BoltFactory; use Laudis\Neo4j\Common\Uri; use Laudis\Neo4j\Contracts\AuthenticateInterface; use Laudis\Neo4j\Contracts\BasicConnectionFactoryInterface; use Laudis\Neo4j\Databags\ConnectionRequestData; +use Laudis\Neo4j\Databags\DriverConfiguration; use Laudis\Neo4j\Databags\SessionConfiguration; use Laudis\Neo4j\Databags\SslConfiguration; +use Laudis\Neo4j\Enum\SocketType; use PHPUnit\Framework\TestCase; final class BoltFactoryTest extends TestCase @@ -73,4 +76,35 @@ public function testCreateBasic(): void self::assertInstanceOf(Connection::class, $connection->getImplementation()[1]); } + + public function testSystemWideConnectionFactoryStreamOverride(): void + { + $factory = SystemWideConnectionFactory::getInstance(SocketType::STREAM()); + self::assertInstanceOf(SystemWideConnectionFactory::class, $factory); + } + + public function testSystemWideConnectionFactorySocketOverride(): void + { + if (!extension_loaded('sockets')) { + self::markTestSkipped('sockets extension not loaded'); + } + + $factory = SystemWideConnectionFactory::getInstance(SocketType::SOCKETS()); + self::assertInstanceOf(SystemWideConnectionFactory::class, $factory); + } + + public function testDriverConfigurationWithSocketType(): void + { + $socketType = SocketType::STREAM(); + $config = DriverConfiguration::default() + ->withSocketType($socketType); + + self::assertEquals($socketType, $config->getSocketType()); + } + + public function testBoltFactoryWithSocketTypeOverride(): void + { + $factory = BoltFactory::create(null, SocketType::STREAM()); + self::assertInstanceOf(BoltFactory::class, $factory); + } }