diff --git a/conn.go b/conn.go index cf32fa4..e3aa63b 100644 --- a/conn.go +++ b/conn.go @@ -1,7 +1,12 @@ package libp2ptls import ( + "bytes" "crypto/tls" + "errors" + "fmt" + "io" + "net" ci "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/peer" @@ -35,3 +40,117 @@ func (c *conn) RemotePeer() peer.ID { func (c *conn) RemotePublicKey() ci.PubKey { return c.remotePubKey } + +const ( + recordTypeHandshake byte = 22 + versionTLS13 = 0x0304 + maxCiphertextTLS13 = 16384 + 256 // maximum ciphertext length in TLS 1.3 + maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB) +) + +var errSimultaneousConnect = errors.New("detected TCP simultaneous connect") + +type teeConn struct { + net.Conn + buf *bytes.Buffer +} + +func newTeeConn(c net.Conn, buf *bytes.Buffer) net.Conn { + return &teeConn{Conn: c, buf: buf} +} + +func (c *teeConn) Read(b []byte) (int, error) { + n, err := c.Conn.Read(b) + c.buf.Write(b[:n]) + return n, err +} + +type wrappedConn struct { + // Before reading the first handshake message, this is a *teeConn. + // After reading the first handshake message, we switch it to the rawConn. + net.Conn + + rawConn net.Conn + hasReadFirstMessage bool + raw *bytes.Buffer // contains a copy of every byte of the first handshake message we read from the wire + + hand bytes.Buffer // used to store the first handshake message until we've completely read it +} + +func newWrappedConn(c net.Conn) net.Conn { + wc := &wrappedConn{ + raw: &bytes.Buffer{}, + rawConn: c, + } + wc.Conn = newTeeConn(c, wc.raw) + return wc +} + +func (c *wrappedConn) Read(b []byte) (int, error) { + if c.hasReadFirstMessage { + return c.Conn.Read(b) + } + + // We read the first handshake message, and it was not a ClientHello. + // We now need to feed all the bytes we read from the wire into the TLS stack, + // so it can proceed with the handshake. + if c.raw.Len() > 0 { + n, err := c.raw.Read(b) + if err == io.EOF || c.raw.Len() == 0 { + c.raw = nil + c.Conn = c.rawConn + c.hasReadFirstMessage = true + err = nil + } + return n, err + } + + mes, err := c.readFirstHandshakeMessage() + if err != nil { + return 0, err + } + + switch mes[0] { + case 1: // ClientHello + return 0, errSimultaneousConnect + case 2: // ServerHello + return c.Read(b) + default: + return 0, fmt.Errorf("unexpected message type: %d", mes[0]) + } +} + +func (c *wrappedConn) readFirstHandshakeMessage() ([]byte, error) { + for c.hand.Len() < 4 { + if err := c.readRecord(); err != nil { + return nil, err + } + } + data := c.hand.Bytes() + n := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) + if n > maxHandshake { + return nil, fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake) + } + for c.hand.Len() < 4+n { + if err := c.readRecord(); err != nil { + return nil, err + } + } + return c.hand.Next(4 + n), nil +} + +func (c *wrappedConn) readRecord() error { + hdr := make([]byte, 5) + if _, err := io.ReadFull(c.Conn, hdr); err != nil { + return err + } + if hdr[0] != recordTypeHandshake { + return errors.New("expected a handshake record") + } + n := int(hdr[3])<<8 | int(hdr[4]) + if n > maxCiphertextTLS13 { + return fmt.Errorf("oversized record received with length %d", n) + } + _, err := io.CopyN(&c.hand, c.Conn, int64(n)) + return err +} diff --git a/crypto.go b/crypto.go index 14e1db0..40358a5 100644 --- a/crypto.go +++ b/crypto.go @@ -1,9 +1,11 @@ package libp2ptls import ( + "bytes" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" + "crypto/sha256" "crypto/tls" "crypto/x509" "crypto/x509/pkix" @@ -220,3 +222,11 @@ func preferServerCipherSuites() bool { ) return !hasGCMAsm } + +// Compare two peer IDs by their SHA256 hash. +// The result will be 0 if H(a) == H(b), -1 if H(a) < H(b), and +1 if H(a) > H(b). +func comparePeerIDs(p1, p2 peer.ID) int { + p1Hash := sha256.Sum256([]byte(p1)) + p2Hash := sha256.Sum256([]byte(p2)) + return bytes.Compare(p1Hash[:], p2Hash[:]) +} diff --git a/transport.go b/transport.go index 214853c..9f53de9 100644 --- a/transport.go +++ b/transport.go @@ -70,11 +70,28 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (sec.S // notice this after 1 RTT when calling Read. func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { config, keyCh := t.identity.ConfigForPeer(p) - cs, err := t.handshake(ctx, tls.Client(insecure, config), keyCh) - if err != nil { + conn, err := t.handshake(ctx, tls.Client(newWrappedConn(insecure), config), keyCh) + if err == errSimultaneousConnect { + switch comparePeerIDs(t.localPeer, p) { + case 0: + return nil, errors.New("tried to simultaneous connect to oneself") + case -1: + // SHA256(our peer ID) is smaller than SHA256(their peer ID). + // We're the client in the next connection attempt. + config, keyCh := t.identity.ConfigForPeer(p) + return t.handshake(ctx, tls.Client(insecure, config), keyCh) + case 1: + // SHA256(our peer ID) is larger than SHA256(their peer ID). + // We're the server in the next connection attempt. + config, keyCh := t.identity.ConfigForPeer(p) + return t.handshake(ctx, tls.Server(insecure, config), keyCh) + default: + panic("unexpected peer ID comparison result") + } + } else if err != nil { insecure.Close() } - return cs, err + return conn, err } func (t *Transport) handshake( diff --git a/transport_test.go b/transport_test.go index 94f2c21..ab52d21 100644 --- a/transport_test.go +++ b/transport_test.go @@ -15,6 +15,7 @@ import ( "math/big" mrand "math/rand" "net" + "reflect" "time" "github.com/onsi/gomega/gbytes" @@ -188,6 +189,62 @@ var _ = Describe("Transport", func() { Eventually(done).Should(BeClosed()) }) + It("handles simultaneous open", func() { + // Avoid confusion regarding the naming. + p1, p1Key := serverID, serverKey + p2, p2Key := clientID, clientKey + + // We use a normal dial / listen to establish the TCP connection, + // but we then start two clients. + c1raw, c2raw := connect() + + c1Transport, err := New(p1Key) + Expect(err).ToNot(HaveOccurred()) + c2Transport, err := New(p2Key) + Expect(err).ToNot(HaveOccurred()) + + c1ConnChan := make(chan sec.SecureConn, 1) + go func() { + defer GinkgoRecover() + conn, err := c1Transport.SecureOutbound(context.Background(), c1raw, p2) + Expect(err).ToNot(HaveOccurred()) + c1ConnChan <- conn + }() + + c2, err := c2Transport.SecureOutbound(context.Background(), c2raw, p1) + Expect(err).ToNot(HaveOccurred()) + defer c2.Close() + var c1 sec.SecureConn + Eventually(c1ConnChan).Should(Receive(&c1)) + defer c1.Close() + + // check that the peers are in the correct roles + isClient := func(c sec.SecureConn) bool { + // the isClient field of the tls.Conn will tell us who is client and server + return reflect.ValueOf(c.(*conn).Conn).Elem().FieldByName("isClient").Bool() + } + switch comparePeerIDs(p1, p2) { + case -1: + // H(p1) < H(p2) => p1 acts as a client, p2 as a server + Expect(isClient(c1)).To(BeTrue()) + Expect(isClient(c2)).To(BeFalse()) + case 1: + // H(p1) > H(p2) => p1 acts as a server, p2 as a client + Expect(isClient(c1)).To(BeFalse()) + Expect(isClient(c2)).To(BeTrue()) + default: + Fail("unexpected peer comparison result") + } + + // exchange some data + _, err = c1.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + b := make([]byte, 6) + _, err = c2.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(string(b)).To(Equal("foobar")) + }) + Context("invalid certificates", func() { invalidateCertChain := func(identity *Identity) { switch identity.config.Certificates[0].PrivateKey.(type) {