Commit df1304ee by Ian Lance Taylor

libgo: Update to weekly.2012-01-15.

From-SVN: r183539
parent 3be18e47
4a8268927758
354b17404643
The first line of this file holds the Mercurial revision number of the
last merge done from the master library sources.
......@@ -188,7 +188,7 @@ toolexeclibgocryptoopenpgpdir = $(toolexeclibgocryptodir)/openpgp
toolexeclibgocryptoopenpgp_DATA = \
crypto/openpgp/armor.gox \
crypto/openpgp/elgamal.gox \
crypto/openpgp/error.gox \
crypto/openpgp/errors.gox \
crypto/openpgp/packet.gox \
crypto/openpgp/s2k.gox
......@@ -235,6 +235,7 @@ toolexeclibgoexp_DATA = \
exp/ebnf.gox \
$(exp_inotify_gox) \
exp/norm.gox \
exp/proxy.gox \
exp/spdy.gox \
exp/sql.gox \
exp/ssh.gox \
......@@ -669,17 +670,25 @@ endif # !LIBGO_IS_RTEMS
if LIBGO_IS_LINUX
go_net_cgo_file = go/net/cgo_linux.go
go_net_sock_file = go/net/sock_linux.go
go_net_sockopt_file = go/net/sockopt_linux.go
go_net_sockoptip_file = go/net/sockoptip_linux.go
else
if LIBGO_IS_IRIX
go_net_cgo_file = go/net/cgo_linux.go
go_net_sock_file = go/net/sock_linux.go
go_net_sockopt_file = go/net/sockopt_linux.go
go_net_sockoptip_file = go/net/sockoptip_linux.go
else
if LIBGO_IS_SOLARIS
go_net_cgo_file = go/net/cgo_linux.go
go_net_sock_file = go/net/sock_linux.go
go_net_sockopt_file = go/net/sockopt_linux.go
go_net_sockoptip_file = go/net/sockoptip_linux.go
else
go_net_cgo_file = go/net/cgo_bsd.go
go_net_sock_file = go/net/sock_bsd.go
go_net_sockopt_file = go/net/sockopt_bsd.go
go_net_sockoptip_file = go/net/sockoptip_bsd.go
endif
endif
endif
......@@ -728,6 +737,10 @@ go_net_files = \
$(go_net_sendfile_file) \
go/net/sock.go \
$(go_net_sock_file) \
go/net/sockopt.go \
$(go_net_sockopt_file) \
go/net/sockoptip.go \
$(go_net_sockoptip_file) \
go/net/tcpsock.go \
go/net/tcpsock_posix.go \
go/net/udpsock.go \
......@@ -890,8 +903,7 @@ go_syslog_c_files = \
go_testing_files = \
go/testing/benchmark.go \
go/testing/example.go \
go/testing/testing.go \
go/testing/wrapper.go
go/testing/testing.go
go_time_files = \
go/time/format.go \
......@@ -1061,8 +1073,8 @@ go_crypto_openpgp_armor_files = \
go/crypto/openpgp/armor/encode.go
go_crypto_openpgp_elgamal_files = \
go/crypto/openpgp/elgamal/elgamal.go
go_crypto_openpgp_error_files = \
go/crypto/openpgp/error/error.go
go_crypto_openpgp_errors_files = \
go/crypto/openpgp/errors/errors.go
go_crypto_openpgp_packet_files = \
go/crypto/openpgp/packet/compressed.go \
go/crypto/openpgp/packet/encrypted_key.go \
......@@ -1142,6 +1154,7 @@ go_encoding_pem_files = \
go_encoding_xml_files = \
go/encoding/xml/marshal.go \
go/encoding/xml/read.go \
go/encoding/xml/typeinfo.go \
go/encoding/xml/xml.go
go_exp_ebnf_files = \
......@@ -1157,6 +1170,11 @@ go_exp_norm_files = \
go/exp/norm/readwriter.go \
go/exp/norm/tables.go \
go/exp/norm/trie.go
go_exp_proxy_files = \
go/exp/proxy/direct.go \
go/exp/proxy/per_host.go \
go/exp/proxy/proxy.go \
go/exp/proxy/socks5.go
go_exp_spdy_files = \
go/exp/spdy/read.go \
go/exp/spdy/types.go \
......@@ -1173,7 +1191,7 @@ go_exp_ssh_files = \
go/exp/ssh/doc.go \
go/exp/ssh/messages.go \
go/exp/ssh/server.go \
go/exp/ssh/server_shell.go \
go/exp/ssh/server_terminal.go \
go/exp/ssh/session.go \
go/exp/ssh/tcpip.go \
go/exp/ssh/transport.go
......@@ -1210,7 +1228,8 @@ go_go_doc_files = \
go/go/doc/doc.go \
go/go/doc/example.go \
go/go/doc/exports.go \
go/go/doc/filter.go
go/go/doc/filter.go \
go/go/doc/reader.go
go_go_parser_files = \
go/go/parser/interface.go \
go/go/parser/parser.go
......@@ -1461,8 +1480,15 @@ endif
# Define ForkExec and Exec.
if LIBGO_IS_RTEMS
syscall_exec_file = go/syscall/exec_stubs.go
syscall_exec_os_file =
else
if LIBGO_IS_LINUX
syscall_exec_file = go/syscall/exec_unix.go
syscall_exec_os_file = go/syscall/exec_linux.go
else
syscall_exec_file = go/syscall/exec_unix.go
syscall_exec_os_file = go/syscall/exec_bsd.go
endif
endif
# Define Wait4.
......@@ -1573,6 +1599,7 @@ go_base_syscall_files = \
go/syscall/syscall.go \
$(syscall_syscall_file) \
$(syscall_exec_file) \
$(syscall_exec_os_file) \
$(syscall_wait_file) \
$(syscall_sleep_file) \
$(syscall_errstr_file) \
......@@ -1720,7 +1747,7 @@ libgo_go_objs = \
crypto/xtea.lo \
crypto/openpgp/armor.lo \
crypto/openpgp/elgamal.lo \
crypto/openpgp/error.lo \
crypto/openpgp/errors.lo \
crypto/openpgp/packet.lo \
crypto/openpgp/s2k.lo \
crypto/x509/pkix.lo \
......@@ -1743,6 +1770,7 @@ libgo_go_objs = \
encoding/xml.lo \
exp/ebnf.lo \
exp/norm.lo \
exp/proxy.lo \
exp/spdy.lo \
exp/sql.lo \
exp/ssh.lo \
......@@ -2578,15 +2606,15 @@ crypto/openpgp/elgamal/check: $(CHECK_DEPS)
@$(CHECK)
.PHONY: crypto/openpgp/elgamal/check
@go_include@ crypto/openpgp/error.lo.dep
crypto/openpgp/error.lo.dep: $(go_crypto_openpgp_error_files)
@go_include@ crypto/openpgp/errors.lo.dep
crypto/openpgp/errors.lo.dep: $(go_crypto_openpgp_errors_files)
$(BUILDDEPS)
crypto/openpgp/error.lo: $(go_crypto_openpgp_error_files)
crypto/openpgp/errors.lo: $(go_crypto_openpgp_errors_files)
$(BUILDPACKAGE)
crypto/openpgp/error/check: $(CHECK_DEPS)
@$(MKDIR_P) crypto/openpgp/error
crypto/openpgp/errors/check: $(CHECK_DEPS)
@$(MKDIR_P) crypto/openpgp/errors
@$(CHECK)
.PHONY: crypto/openpgp/error/check
.PHONY: crypto/openpgp/errors/check
@go_include@ crypto/openpgp/packet.lo.dep
crypto/openpgp/packet.lo.dep: $(go_crypto_openpgp_packet_files)
......@@ -2808,6 +2836,16 @@ exp/norm/check: $(CHECK_DEPS)
@$(CHECK)
.PHONY: exp/norm/check
@go_include@ exp/proxy.lo.dep
exp/proxy.lo.dep: $(go_exp_proxy_files)
$(BUILDDEPS)
exp/proxy.lo: $(go_exp_proxy_files)
$(BUILDPACKAGE)
exp/proxy/check: $(CHECK_DEPS)
@$(MKDIR_P) exp/proxy
@$(CHECK)
.PHONY: exp/proxy/check
@go_include@ exp/spdy.lo.dep
exp/spdy.lo.dep: $(go_exp_spdy_files)
$(BUILDDEPS)
......@@ -3622,7 +3660,7 @@ crypto/openpgp/armor.gox: crypto/openpgp/armor.lo
$(BUILDGOX)
crypto/openpgp/elgamal.gox: crypto/openpgp/elgamal.lo
$(BUILDGOX)
crypto/openpgp/error.gox: crypto/openpgp/error.lo
crypto/openpgp/errors.gox: crypto/openpgp/errors.lo
$(BUILDGOX)
crypto/openpgp/packet.gox: crypto/openpgp/packet.lo
$(BUILDGOX)
......@@ -3674,6 +3712,8 @@ exp/inotify.gox: exp/inotify.lo
$(BUILDGOX)
exp/norm.gox: exp/norm.lo
$(BUILDGOX)
exp/proxy.gox: exp/proxy.lo
$(BUILDGOX)
exp/spdy.gox: exp/spdy.lo
$(BUILDGOX)
exp/sql.gox: exp/sql.lo
......@@ -3920,6 +3960,7 @@ TEST_PACKAGES = \
exp/ebnf/check \
$(exp_inotify_check) \
exp/norm/check \
exp/proxy/check \
exp/spdy/check \
exp/sql/check \
exp/ssh/check \
......
......@@ -74,6 +74,9 @@
/* Define to 1 if you have the <sys/mman.h> header file. */
#undef HAVE_SYS_MMAN_H
/* Define to 1 if you have the <sys/prctl.h> header file. */
#undef HAVE_SYS_PRCTL_H
/* Define to 1 if you have the <sys/ptrace.h> header file. */
#undef HAVE_SYS_PTRACE_H
......
......@@ -14505,7 +14505,7 @@ no)
;;
esac
for ac_header in sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h sys/socket.h net/if.h
for ac_header in sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h sys/socket.h net/if.h sys/prctl.h
do :
as_ac_Header=`$as_echo "ac_cv_header_$ac_header" | $as_tr_sh`
ac_fn_c_check_header_mongrel "$LINENO" "$ac_header" "$as_ac_Header" "$ac_includes_default"
......
......@@ -451,7 +451,7 @@ no)
;;
esac
AC_CHECK_HEADERS(sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h sys/socket.h net/if.h)
AC_CHECK_HEADERS(sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h sys/socket.h net/if.h sys/prctl.h)
AC_CHECK_HEADERS([linux/filter.h linux/netlink.h linux/rtnetlink.h], [], [],
[#ifdef HAVE_SYS_SOCKET_H
......
......@@ -97,8 +97,7 @@ func (b *Buffer) grow(n int) int {
func (b *Buffer) Write(p []byte) (n int, err error) {
b.lastRead = opInvalid
m := b.grow(len(p))
copy(b.buf[m:], p)
return len(p), nil
return copy(b.buf[m:], p), nil
}
// WriteString appends the contents of s to the buffer. The return
......@@ -200,13 +199,16 @@ func (b *Buffer) WriteRune(r rune) (n int, err error) {
// Read reads the next len(p) bytes from the buffer or until the buffer
// is drained. The return value n is the number of bytes read. If the
// buffer has no data to return, err is io.EOF even if len(p) is zero;
// buffer has no data to return, err is io.EOF (unless len(p) is zero);
// otherwise it is nil.
func (b *Buffer) Read(p []byte) (n int, err error) {
b.lastRead = opInvalid
if b.off >= len(b.buf) {
// Buffer is empty, reset to recover space.
b.Truncate(0)
if len(p) == 0 {
return
}
return 0, io.EOF
}
n = copy(p, b.buf[b.off:])
......
......@@ -373,3 +373,16 @@ func TestReadBytes(t *testing.T) {
}
}
}
// Was a bug: used to give EOF reading empty slice at EOF.
func TestReadEmptyAtEOF(t *testing.T) {
b := new(Buffer)
slice := make([]byte, 0)
n, err := b.Read(slice)
if err != nil {
t.Errorf("read error: %v", err)
}
if n != 0 {
t.Errorf("wrong count; got %d want 0", n)
}
}
......@@ -9,7 +9,7 @@ package armor
import (
"bufio"
"bytes"
error_ "crypto/openpgp/error"
"crypto/openpgp/errors"
"encoding/base64"
"io"
)
......@@ -35,7 +35,7 @@ type Block struct {
oReader openpgpReader
}
var ArmorCorrupt error = error_.StructuralError("armor invalid")
var ArmorCorrupt error = errors.StructuralError("armor invalid")
const crc24Init = 0xb704ce
const crc24Poly = 0x1864cfb
......
......@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package error contains common error types for the OpenPGP packages.
package error
// Package errors contains common error types for the OpenPGP packages.
package errors
import (
"strconv"
......
......@@ -7,8 +7,9 @@ package openpgp
import (
"crypto"
"crypto/openpgp/armor"
error_ "crypto/openpgp/error"
"crypto/openpgp/errors"
"crypto/openpgp/packet"
"crypto/rand"
"crypto/rsa"
"io"
"time"
......@@ -181,13 +182,13 @@ func (el EntityList) DecryptionKeys() (keys []Key) {
func ReadArmoredKeyRing(r io.Reader) (EntityList, error) {
block, err := armor.Decode(r)
if err == io.EOF {
return nil, error_.InvalidArgumentError("no armored data found")
return nil, errors.InvalidArgumentError("no armored data found")
}
if err != nil {
return nil, err
}
if block.Type != PublicKeyType && block.Type != PrivateKeyType {
return nil, error_.InvalidArgumentError("expected public or private key block, got: " + block.Type)
return nil, errors.InvalidArgumentError("expected public or private key block, got: " + block.Type)
}
return ReadKeyRing(block.Body)
......@@ -203,7 +204,7 @@ func ReadKeyRing(r io.Reader) (el EntityList, err error) {
var e *Entity
e, err = readEntity(packets)
if err != nil {
if _, ok := err.(error_.UnsupportedError); ok {
if _, ok := err.(errors.UnsupportedError); ok {
lastUnsupportedError = err
err = readToNextPublicKey(packets)
}
......@@ -235,7 +236,7 @@ func readToNextPublicKey(packets *packet.Reader) (err error) {
if err == io.EOF {
return
} else if err != nil {
if _, ok := err.(error_.UnsupportedError); ok {
if _, ok := err.(errors.UnsupportedError); ok {
err = nil
continue
}
......@@ -266,14 +267,14 @@ func readEntity(packets *packet.Reader) (*Entity, error) {
if e.PrimaryKey, ok = p.(*packet.PublicKey); !ok {
if e.PrivateKey, ok = p.(*packet.PrivateKey); !ok {
packets.Unread(p)
return nil, error_.StructuralError("first packet was not a public/private key")
return nil, errors.StructuralError("first packet was not a public/private key")
} else {
e.PrimaryKey = &e.PrivateKey.PublicKey
}
}
if !e.PrimaryKey.PubKeyAlgo.CanSign() {
return nil, error_.StructuralError("primary key cannot be used for signatures")
return nil, errors.StructuralError("primary key cannot be used for signatures")
}
var current *Identity
......@@ -303,12 +304,12 @@ EachPacket:
sig, ok := p.(*packet.Signature)
if !ok {
return nil, error_.StructuralError("user ID packet not followed by self-signature")
return nil, errors.StructuralError("user ID packet not followed by self-signature")
}
if (sig.SigType == packet.SigTypePositiveCert || sig.SigType == packet.SigTypeGenericCert) && sig.IssuerKeyId != nil && *sig.IssuerKeyId == e.PrimaryKey.KeyId {
if err = e.PrimaryKey.VerifyUserIdSignature(pkt.Id, sig); err != nil {
return nil, error_.StructuralError("user ID self-signature invalid: " + err.Error())
return nil, errors.StructuralError("user ID self-signature invalid: " + err.Error())
}
current.SelfSignature = sig
break
......@@ -317,7 +318,7 @@ EachPacket:
}
case *packet.Signature:
if current == nil {
return nil, error_.StructuralError("signature packet found before user id packet")
return nil, errors.StructuralError("signature packet found before user id packet")
}
current.Signatures = append(current.Signatures, pkt)
case *packet.PrivateKey:
......@@ -344,7 +345,7 @@ EachPacket:
}
if len(e.Identities) == 0 {
return nil, error_.StructuralError("entity without any identities")
return nil, errors.StructuralError("entity without any identities")
}
return e, nil
......@@ -359,19 +360,19 @@ func addSubkey(e *Entity, packets *packet.Reader, pub *packet.PublicKey, priv *p
return io.ErrUnexpectedEOF
}
if err != nil {
return error_.StructuralError("subkey signature invalid: " + err.Error())
return errors.StructuralError("subkey signature invalid: " + err.Error())
}
var ok bool
subKey.Sig, ok = p.(*packet.Signature)
if !ok {
return error_.StructuralError("subkey packet not followed by signature")
return errors.StructuralError("subkey packet not followed by signature")
}
if subKey.Sig.SigType != packet.SigTypeSubkeyBinding {
return error_.StructuralError("subkey signature with wrong type")
return errors.StructuralError("subkey signature with wrong type")
}
err = e.PrimaryKey.VerifyKeySignature(subKey.PublicKey, subKey.Sig)
if err != nil {
return error_.StructuralError("subkey signature invalid: " + err.Error())
return errors.StructuralError("subkey signature invalid: " + err.Error())
}
e.Subkeys = append(e.Subkeys, subKey)
return nil
......@@ -385,7 +386,7 @@ const defaultRSAKeyBits = 2048
func NewEntity(rand io.Reader, currentTime time.Time, name, comment, email string) (*Entity, error) {
uid := packet.NewUserId(name, comment, email)
if uid == nil {
return nil, error_.InvalidArgumentError("user id field contained invalid characters")
return nil, errors.InvalidArgumentError("user id field contained invalid characters")
}
signingPriv, err := rsa.GenerateKey(rand, defaultRSAKeyBits)
if err != nil {
......@@ -397,8 +398,8 @@ func NewEntity(rand io.Reader, currentTime time.Time, name, comment, email strin
}
e := &Entity{
PrimaryKey: packet.NewRSAPublicKey(currentTime, &signingPriv.PublicKey, false /* not a subkey */ ),
PrivateKey: packet.NewRSAPrivateKey(currentTime, signingPriv, false /* not a subkey */ ),
PrimaryKey: packet.NewRSAPublicKey(currentTime, &signingPriv.PublicKey),
PrivateKey: packet.NewRSAPrivateKey(currentTime, signingPriv),
Identities: make(map[string]*Identity),
}
isPrimaryId := true
......@@ -420,8 +421,8 @@ func NewEntity(rand io.Reader, currentTime time.Time, name, comment, email strin
e.Subkeys = make([]Subkey, 1)
e.Subkeys[0] = Subkey{
PublicKey: packet.NewRSAPublicKey(currentTime, &encryptingPriv.PublicKey, true /* is a subkey */ ),
PrivateKey: packet.NewRSAPrivateKey(currentTime, encryptingPriv, true /* is a subkey */ ),
PublicKey: packet.NewRSAPublicKey(currentTime, &encryptingPriv.PublicKey),
PrivateKey: packet.NewRSAPrivateKey(currentTime, encryptingPriv),
Sig: &packet.Signature{
CreationTime: currentTime,
SigType: packet.SigTypeSubkeyBinding,
......@@ -433,6 +434,8 @@ func NewEntity(rand io.Reader, currentTime time.Time, name, comment, email strin
IssuerKeyId: &e.PrimaryKey.KeyId,
},
}
e.Subkeys[0].PublicKey.IsSubkey = true
e.Subkeys[0].PrivateKey.IsSubkey = true
return e, nil
}
......@@ -450,7 +453,7 @@ func (e *Entity) SerializePrivate(w io.Writer) (err error) {
if err != nil {
return
}
err = ident.SelfSignature.SignUserId(ident.UserId.Id, e.PrimaryKey, e.PrivateKey)
err = ident.SelfSignature.SignUserId(rand.Reader, ident.UserId.Id, e.PrimaryKey, e.PrivateKey)
if err != nil {
return
}
......@@ -464,7 +467,7 @@ func (e *Entity) SerializePrivate(w io.Writer) (err error) {
if err != nil {
return
}
err = subkey.Sig.SignKey(subkey.PublicKey, e.PrivateKey)
err = subkey.Sig.SignKey(rand.Reader, subkey.PublicKey, e.PrivateKey)
if err != nil {
return
}
......@@ -518,14 +521,14 @@ func (e *Entity) Serialize(w io.Writer) error {
// necessary.
func (e *Entity) SignIdentity(identity string, signer *Entity) error {
if signer.PrivateKey == nil {
return error_.InvalidArgumentError("signing Entity must have a private key")
return errors.InvalidArgumentError("signing Entity must have a private key")
}
if signer.PrivateKey.Encrypted {
return error_.InvalidArgumentError("signing Entity's private key must be decrypted")
return errors.InvalidArgumentError("signing Entity's private key must be decrypted")
}
ident, ok := e.Identities[identity]
if !ok {
return error_.InvalidArgumentError("given identity string not found in Entity")
return errors.InvalidArgumentError("given identity string not found in Entity")
}
sig := &packet.Signature{
......@@ -535,7 +538,7 @@ func (e *Entity) SignIdentity(identity string, signer *Entity) error {
CreationTime: time.Now(),
IssuerKeyId: &signer.PrivateKey.KeyId,
}
if err := sig.SignKey(e.PrimaryKey, signer.PrivateKey); err != nil {
if err := sig.SignKey(rand.Reader, e.PrimaryKey, signer.PrivateKey); err != nil {
return err
}
ident.Signatures = append(ident.Signatures, sig)
......
......@@ -7,7 +7,7 @@ package packet
import (
"compress/flate"
"compress/zlib"
error_ "crypto/openpgp/error"
"crypto/openpgp/errors"
"io"
"strconv"
)
......@@ -31,7 +31,7 @@ func (c *Compressed) parse(r io.Reader) error {
case 2:
c.Body, err = zlib.NewReader(r)
default:
err = error_.UnsupportedError("unknown compression algorithm: " + strconv.Itoa(int(buf[0])))
err = errors.UnsupportedError("unknown compression algorithm: " + strconv.Itoa(int(buf[0])))
}
return err
......
......@@ -6,7 +6,7 @@ package packet
import (
"crypto/openpgp/elgamal"
error_ "crypto/openpgp/error"
"crypto/openpgp/errors"
"crypto/rand"
"crypto/rsa"
"encoding/binary"
......@@ -35,7 +35,7 @@ func (e *EncryptedKey) parse(r io.Reader) (err error) {
return
}
if buf[0] != encryptedKeyVersion {
return error_.UnsupportedError("unknown EncryptedKey version " + strconv.Itoa(int(buf[0])))
return errors.UnsupportedError("unknown EncryptedKey version " + strconv.Itoa(int(buf[0])))
}
e.KeyId = binary.BigEndian.Uint64(buf[1:9])
e.Algo = PublicKeyAlgorithm(buf[9])
......@@ -77,7 +77,7 @@ func (e *EncryptedKey) Decrypt(priv *PrivateKey) error {
c2 := new(big.Int).SetBytes(e.encryptedMPI2)
b, err = elgamal.Decrypt(priv.PrivateKey.(*elgamal.PrivateKey), c1, c2)
default:
err = error_.InvalidArgumentError("cannot decrypted encrypted session key with private key of type " + strconv.Itoa(int(priv.PubKeyAlgo)))
err = errors.InvalidArgumentError("cannot decrypted encrypted session key with private key of type " + strconv.Itoa(int(priv.PubKeyAlgo)))
}
if err != nil {
......@@ -89,7 +89,7 @@ func (e *EncryptedKey) Decrypt(priv *PrivateKey) error {
expectedChecksum := uint16(b[len(b)-2])<<8 | uint16(b[len(b)-1])
checksum := checksumKeyMaterial(e.Key)
if checksum != expectedChecksum {
return error_.StructuralError("EncryptedKey checksum incorrect")
return errors.StructuralError("EncryptedKey checksum incorrect")
}
return nil
......@@ -116,16 +116,16 @@ func SerializeEncryptedKey(w io.Writer, rand io.Reader, pub *PublicKey, cipherFu
case PubKeyAlgoElGamal:
return serializeEncryptedKeyElGamal(w, rand, buf, pub.PublicKey.(*elgamal.PublicKey), keyBlock)
case PubKeyAlgoDSA, PubKeyAlgoRSASignOnly:
return error_.InvalidArgumentError("cannot encrypt to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo)))
return errors.InvalidArgumentError("cannot encrypt to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo)))
}
return error_.UnsupportedError("encrypting a key to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo)))
return errors.UnsupportedError("encrypting a key to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo)))
}
func serializeEncryptedKeyRSA(w io.Writer, rand io.Reader, header [10]byte, pub *rsa.PublicKey, keyBlock []byte) error {
cipherText, err := rsa.EncryptPKCS1v15(rand, pub, keyBlock)
if err != nil {
return error_.InvalidArgumentError("RSA encryption failed: " + err.Error())
return errors.InvalidArgumentError("RSA encryption failed: " + err.Error())
}
packetLen := 10 /* header length */ + 2 /* mpi size */ + len(cipherText)
......@@ -144,7 +144,7 @@ func serializeEncryptedKeyRSA(w io.Writer, rand io.Reader, header [10]byte, pub
func serializeEncryptedKeyElGamal(w io.Writer, rand io.Reader, header [10]byte, pub *elgamal.PublicKey, keyBlock []byte) error {
c1, c2, err := elgamal.Encrypt(rand, pub, keyBlock)
if err != nil {
return error_.InvalidArgumentError("ElGamal encryption failed: " + err.Error())
return errors.InvalidArgumentError("ElGamal encryption failed: " + err.Error())
}
packetLen := 10 /* header length */
......
......@@ -6,7 +6,7 @@ package packet
import (
"crypto"
error_ "crypto/openpgp/error"
"crypto/openpgp/errors"
"crypto/openpgp/s2k"
"encoding/binary"
"io"
......@@ -33,13 +33,13 @@ func (ops *OnePassSignature) parse(r io.Reader) (err error) {
return
}
if buf[0] != onePassSignatureVersion {
err = error_.UnsupportedError("one-pass-signature packet version " + strconv.Itoa(int(buf[0])))
err = errors.UnsupportedError("one-pass-signature packet version " + strconv.Itoa(int(buf[0])))
}
var ok bool
ops.Hash, ok = s2k.HashIdToHash(buf[2])
if !ok {
return error_.UnsupportedError("hash function: " + strconv.Itoa(int(buf[2])))
return errors.UnsupportedError("hash function: " + strconv.Itoa(int(buf[2])))
}
ops.SigType = SignatureType(buf[1])
......@@ -57,7 +57,7 @@ func (ops *OnePassSignature) Serialize(w io.Writer) error {
var ok bool
buf[2], ok = s2k.HashToHashId(ops.Hash)
if !ok {
return error_.UnsupportedError("hash type: " + strconv.Itoa(int(ops.Hash)))
return errors.UnsupportedError("hash type: " + strconv.Itoa(int(ops.Hash)))
}
buf[3] = uint8(ops.PubKeyAlgo)
binary.BigEndian.PutUint64(buf[4:12], ops.KeyId)
......
......@@ -10,7 +10,7 @@ import (
"crypto/aes"
"crypto/cast5"
"crypto/cipher"
error_ "crypto/openpgp/error"
"crypto/openpgp/errors"
"io"
"math/big"
)
......@@ -162,7 +162,7 @@ func readHeader(r io.Reader) (tag packetType, length int64, contents io.Reader,
return
}
if buf[0]&0x80 == 0 {
err = error_.StructuralError("tag byte does not have MSB set")
err = errors.StructuralError("tag byte does not have MSB set")
return
}
if buf[0]&0x40 == 0 {
......@@ -337,7 +337,7 @@ func Read(r io.Reader) (p Packet, err error) {
se.MDC = true
p = se
default:
err = error_.UnknownPacketTypeError(tag)
err = errors.UnknownPacketTypeError(tag)
}
if p != nil {
err = p.parse(contents)
......
......@@ -6,7 +6,7 @@ package packet
import (
"bytes"
error_ "crypto/openpgp/error"
"crypto/openpgp/errors"
"encoding/hex"
"fmt"
"io"
......@@ -152,7 +152,7 @@ func TestReadHeader(t *testing.T) {
for i, test := range readHeaderTests {
tag, length, contents, err := readHeader(readerFromHex(test.hexInput))
if test.structuralError {
if _, ok := err.(error_.StructuralError); ok {
if _, ok := err.(errors.StructuralError); ok {
continue
}
t.Errorf("%d: expected StructuralError, got:%s", i, err)
......
......@@ -9,7 +9,7 @@ import (
"crypto/cipher"
"crypto/dsa"
"crypto/openpgp/elgamal"
error_ "crypto/openpgp/error"
"crypto/openpgp/errors"
"crypto/openpgp/s2k"
"crypto/rsa"
"crypto/sha1"
......@@ -28,14 +28,21 @@ type PrivateKey struct {
encryptedData []byte
cipher CipherFunction
s2k func(out, in []byte)
PrivateKey interface{} // An *rsa.PrivateKey.
PrivateKey interface{} // An *rsa.PrivateKey or *dsa.PrivateKey.
sha1Checksum bool
iv []byte
}
func NewRSAPrivateKey(currentTime time.Time, priv *rsa.PrivateKey, isSubkey bool) *PrivateKey {
func NewRSAPrivateKey(currentTime time.Time, priv *rsa.PrivateKey) *PrivateKey {
pk := new(PrivateKey)
pk.PublicKey = *NewRSAPublicKey(currentTime, &priv.PublicKey, isSubkey)
pk.PublicKey = *NewRSAPublicKey(currentTime, &priv.PublicKey)
pk.PrivateKey = priv
return pk
}
func NewDSAPrivateKey(currentTime time.Time, priv *dsa.PrivateKey) *PrivateKey {
pk := new(PrivateKey)
pk.PublicKey = *NewDSAPublicKey(currentTime, &priv.PublicKey)
pk.PrivateKey = priv
return pk
}
......@@ -72,13 +79,13 @@ func (pk *PrivateKey) parse(r io.Reader) (err error) {
pk.sha1Checksum = true
}
default:
return error_.UnsupportedError("deprecated s2k function in private key")
return errors.UnsupportedError("deprecated s2k function in private key")
}
if pk.Encrypted {
blockSize := pk.cipher.blockSize()
if blockSize == 0 {
return error_.UnsupportedError("unsupported cipher in private key: " + strconv.Itoa(int(pk.cipher)))
return errors.UnsupportedError("unsupported cipher in private key: " + strconv.Itoa(int(pk.cipher)))
}
pk.iv = make([]byte, blockSize)
_, err = readFull(r, pk.iv)
......@@ -121,8 +128,10 @@ func (pk *PrivateKey) Serialize(w io.Writer) (err error) {
switch priv := pk.PrivateKey.(type) {
case *rsa.PrivateKey:
err = serializeRSAPrivateKey(privateKeyBuf, priv)
case *dsa.PrivateKey:
err = serializeDSAPrivateKey(privateKeyBuf, priv)
default:
err = error_.InvalidArgumentError("non-RSA private key")
err = errors.InvalidArgumentError("unknown private key type")
}
if err != nil {
return
......@@ -172,6 +181,10 @@ func serializeRSAPrivateKey(w io.Writer, priv *rsa.PrivateKey) error {
return writeBig(w, priv.Precomputed.Qinv)
}
func serializeDSAPrivateKey(w io.Writer, priv *dsa.PrivateKey) error {
return writeBig(w, priv.X)
}
// Decrypt decrypts an encrypted private key using a passphrase.
func (pk *PrivateKey) Decrypt(passphrase []byte) error {
if !pk.Encrypted {
......@@ -188,18 +201,18 @@ func (pk *PrivateKey) Decrypt(passphrase []byte) error {
if pk.sha1Checksum {
if len(data) < sha1.Size {
return error_.StructuralError("truncated private key data")
return errors.StructuralError("truncated private key data")
}
h := sha1.New()
h.Write(data[:len(data)-sha1.Size])
sum := h.Sum(nil)
if !bytes.Equal(sum, data[len(data)-sha1.Size:]) {
return error_.StructuralError("private key checksum failure")
return errors.StructuralError("private key checksum failure")
}
data = data[:len(data)-sha1.Size]
} else {
if len(data) < 2 {
return error_.StructuralError("truncated private key data")
return errors.StructuralError("truncated private key data")
}
var sum uint16
for i := 0; i < len(data)-2; i++ {
......@@ -207,7 +220,7 @@ func (pk *PrivateKey) Decrypt(passphrase []byte) error {
}
if data[len(data)-2] != uint8(sum>>8) ||
data[len(data)-1] != uint8(sum) {
return error_.StructuralError("private key checksum failure")
return errors.StructuralError("private key checksum failure")
}
data = data[:len(data)-2]
}
......
......@@ -7,7 +7,7 @@ package packet
import (
"crypto/dsa"
"crypto/openpgp/elgamal"
error_ "crypto/openpgp/error"
"crypto/openpgp/errors"
"crypto/rsa"
"crypto/sha1"
"encoding/binary"
......@@ -39,12 +39,11 @@ func fromBig(n *big.Int) parsedMPI {
}
// NewRSAPublicKey returns a PublicKey that wraps the given rsa.PublicKey.
func NewRSAPublicKey(creationTime time.Time, pub *rsa.PublicKey, isSubkey bool) *PublicKey {
func NewRSAPublicKey(creationTime time.Time, pub *rsa.PublicKey) *PublicKey {
pk := &PublicKey{
CreationTime: creationTime,
PubKeyAlgo: PubKeyAlgoRSA,
PublicKey: pub,
IsSubkey: isSubkey,
n: fromBig(pub.N),
e: fromBig(big.NewInt(int64(pub.E))),
}
......@@ -53,6 +52,22 @@ func NewRSAPublicKey(creationTime time.Time, pub *rsa.PublicKey, isSubkey bool)
return pk
}
// NewDSAPublicKey returns a PublicKey that wraps the given rsa.PublicKey.
func NewDSAPublicKey(creationTime time.Time, pub *dsa.PublicKey) *PublicKey {
pk := &PublicKey{
CreationTime: creationTime,
PubKeyAlgo: PubKeyAlgoDSA,
PublicKey: pub,
p: fromBig(pub.P),
q: fromBig(pub.Q),
g: fromBig(pub.G),
y: fromBig(pub.Y),
}
pk.setFingerPrintAndKeyId()
return pk
}
func (pk *PublicKey) parse(r io.Reader) (err error) {
// RFC 4880, section 5.5.2
var buf [6]byte
......@@ -61,7 +76,7 @@ func (pk *PublicKey) parse(r io.Reader) (err error) {
return
}
if buf[0] != 4 {
return error_.UnsupportedError("public key version")
return errors.UnsupportedError("public key version")
}
pk.CreationTime = time.Unix(int64(uint32(buf[1])<<24|uint32(buf[2])<<16|uint32(buf[3])<<8|uint32(buf[4])), 0)
pk.PubKeyAlgo = PublicKeyAlgorithm(buf[5])
......@@ -73,7 +88,7 @@ func (pk *PublicKey) parse(r io.Reader) (err error) {
case PubKeyAlgoElGamal:
err = pk.parseElGamal(r)
default:
err = error_.UnsupportedError("public key type: " + strconv.Itoa(int(pk.PubKeyAlgo)))
err = errors.UnsupportedError("public key type: " + strconv.Itoa(int(pk.PubKeyAlgo)))
}
if err != nil {
return
......@@ -105,7 +120,7 @@ func (pk *PublicKey) parseRSA(r io.Reader) (err error) {
}
if len(pk.e.bytes) > 3 {
err = error_.UnsupportedError("large public exponent")
err = errors.UnsupportedError("large public exponent")
return
}
rsa := &rsa.PublicKey{
......@@ -255,7 +270,7 @@ func (pk *PublicKey) serializeWithoutHeaders(w io.Writer) (err error) {
case PubKeyAlgoElGamal:
return writeMPIs(w, pk.p, pk.g, pk.y)
}
return error_.InvalidArgumentError("bad public-key algorithm")
return errors.InvalidArgumentError("bad public-key algorithm")
}
// CanSign returns true iff this public key can generate signatures
......@@ -267,18 +282,18 @@ func (pk *PublicKey) CanSign() bool {
// public key, of the data hashed into signed. signed is mutated by this call.
func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err error) {
if !pk.CanSign() {
return error_.InvalidArgumentError("public key cannot generate signatures")
return errors.InvalidArgumentError("public key cannot generate signatures")
}
signed.Write(sig.HashSuffix)
hashBytes := signed.Sum(nil)
if hashBytes[0] != sig.HashTag[0] || hashBytes[1] != sig.HashTag[1] {
return error_.SignatureError("hash tag doesn't match")
return errors.SignatureError("hash tag doesn't match")
}
if pk.PubKeyAlgo != sig.PubKeyAlgo {
return error_.InvalidArgumentError("public key and signature use different algorithms")
return errors.InvalidArgumentError("public key and signature use different algorithms")
}
switch pk.PubKeyAlgo {
......@@ -286,13 +301,18 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err erro
rsaPublicKey, _ := pk.PublicKey.(*rsa.PublicKey)
err = rsa.VerifyPKCS1v15(rsaPublicKey, sig.Hash, hashBytes, sig.RSASignature.bytes)
if err != nil {
return error_.SignatureError("RSA verification failure")
return errors.SignatureError("RSA verification failure")
}
return nil
case PubKeyAlgoDSA:
dsaPublicKey, _ := pk.PublicKey.(*dsa.PublicKey)
// Need to truncate hashBytes to match FIPS 186-3 section 4.6.
subgroupSize := (dsaPublicKey.Q.BitLen() + 7) / 8
if len(hashBytes) > subgroupSize {
hashBytes = hashBytes[:subgroupSize]
}
if !dsa.Verify(dsaPublicKey, hashBytes, new(big.Int).SetBytes(sig.DSASigR.bytes), new(big.Int).SetBytes(sig.DSASigS.bytes)) {
return error_.SignatureError("DSA verification failure")
return errors.SignatureError("DSA verification failure")
}
return nil
default:
......@@ -306,7 +326,7 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err erro
func keySignatureHash(pk, signed *PublicKey, sig *Signature) (h hash.Hash, err error) {
h = sig.Hash.New()
if h == nil {
return nil, error_.UnsupportedError("hash function")
return nil, errors.UnsupportedError("hash function")
}
// RFC 4880, section 5.2.4
......@@ -332,7 +352,7 @@ func (pk *PublicKey) VerifyKeySignature(signed *PublicKey, sig *Signature) (err
func userIdSignatureHash(id string, pk *PublicKey, sig *Signature) (h hash.Hash, err error) {
h = sig.Hash.New()
if h == nil {
return nil, error_.UnsupportedError("hash function")
return nil, errors.UnsupportedError("hash function")
}
// RFC 4880, section 5.2.4
......
......@@ -5,7 +5,7 @@
package packet
import (
error_ "crypto/openpgp/error"
"crypto/openpgp/errors"
"io"
)
......@@ -34,7 +34,7 @@ func (r *Reader) Next() (p Packet, err error) {
r.readers = r.readers[:len(r.readers)-1]
continue
}
if _, ok := err.(error_.UnknownPacketTypeError); !ok {
if _, ok := err.(errors.UnknownPacketTypeError); !ok {
return nil, err
}
}
......
......@@ -7,7 +7,7 @@ package packet
import (
"bytes"
"crypto/cipher"
error_ "crypto/openpgp/error"
"crypto/openpgp/errors"
"crypto/openpgp/s2k"
"io"
"strconv"
......@@ -37,12 +37,12 @@ func (ske *SymmetricKeyEncrypted) parse(r io.Reader) (err error) {
return
}
if buf[0] != symmetricKeyEncryptedVersion {
return error_.UnsupportedError("SymmetricKeyEncrypted version")
return errors.UnsupportedError("SymmetricKeyEncrypted version")
}
ske.CipherFunc = CipherFunction(buf[1])
if ske.CipherFunc.KeySize() == 0 {
return error_.UnsupportedError("unknown cipher: " + strconv.Itoa(int(buf[1])))
return errors.UnsupportedError("unknown cipher: " + strconv.Itoa(int(buf[1])))
}
ske.s2k, err = s2k.Parse(r)
......@@ -60,7 +60,7 @@ func (ske *SymmetricKeyEncrypted) parse(r io.Reader) (err error) {
err = nil
if n != 0 {
if n == maxSessionKeySizeInBytes {
return error_.UnsupportedError("oversized encrypted session key")
return errors.UnsupportedError("oversized encrypted session key")
}
ske.encryptedKey = encryptedKey[:n]
}
......@@ -89,13 +89,13 @@ func (ske *SymmetricKeyEncrypted) Decrypt(passphrase []byte) error {
c.XORKeyStream(ske.encryptedKey, ske.encryptedKey)
ske.CipherFunc = CipherFunction(ske.encryptedKey[0])
if ske.CipherFunc.blockSize() == 0 {
return error_.UnsupportedError("unknown cipher: " + strconv.Itoa(int(ske.CipherFunc)))
return errors.UnsupportedError("unknown cipher: " + strconv.Itoa(int(ske.CipherFunc)))
}
ske.CipherFunc = CipherFunction(ske.encryptedKey[0])
ske.Key = ske.encryptedKey[1:]
if len(ske.Key)%ske.CipherFunc.blockSize() != 0 {
ske.Key = nil
return error_.StructuralError("length of decrypted key not a multiple of block size")
return errors.StructuralError("length of decrypted key not a multiple of block size")
}
}
......@@ -110,7 +110,7 @@ func (ske *SymmetricKeyEncrypted) Decrypt(passphrase []byte) error {
func SerializeSymmetricKeyEncrypted(w io.Writer, rand io.Reader, passphrase []byte, cipherFunc CipherFunction) (key []byte, err error) {
keySize := cipherFunc.KeySize()
if keySize == 0 {
return nil, error_.UnsupportedError("unknown cipher: " + strconv.Itoa(int(cipherFunc)))
return nil, errors.UnsupportedError("unknown cipher: " + strconv.Itoa(int(cipherFunc)))
}
s2kBuf := new(bytes.Buffer)
......
......@@ -6,8 +6,7 @@ package packet
import (
"crypto/cipher"
error_ "crypto/openpgp/error"
"crypto/rand"
"crypto/openpgp/errors"
"crypto/sha1"
"crypto/subtle"
"hash"
......@@ -35,7 +34,7 @@ func (se *SymmetricallyEncrypted) parse(r io.Reader) error {
return err
}
if buf[0] != symmetricallyEncryptedVersion {
return error_.UnsupportedError("unknown SymmetricallyEncrypted version")
return errors.UnsupportedError("unknown SymmetricallyEncrypted version")
}
}
se.contents = r
......@@ -48,10 +47,10 @@ func (se *SymmetricallyEncrypted) parse(r io.Reader) error {
func (se *SymmetricallyEncrypted) Decrypt(c CipherFunction, key []byte) (io.ReadCloser, error) {
keySize := c.KeySize()
if keySize == 0 {
return nil, error_.UnsupportedError("unknown cipher: " + strconv.Itoa(int(c)))
return nil, errors.UnsupportedError("unknown cipher: " + strconv.Itoa(int(c)))
}
if len(key) != keySize {
return nil, error_.InvalidArgumentError("SymmetricallyEncrypted: incorrect key length")
return nil, errors.InvalidArgumentError("SymmetricallyEncrypted: incorrect key length")
}
if se.prefix == nil {
......@@ -61,7 +60,7 @@ func (se *SymmetricallyEncrypted) Decrypt(c CipherFunction, key []byte) (io.Read
return nil, err
}
} else if len(se.prefix) != c.blockSize()+2 {
return nil, error_.InvalidArgumentError("can't try ciphers with different block lengths")
return nil, errors.InvalidArgumentError("can't try ciphers with different block lengths")
}
ocfbResync := cipher.OCFBResync
......@@ -72,7 +71,7 @@ func (se *SymmetricallyEncrypted) Decrypt(c CipherFunction, key []byte) (io.Read
s := cipher.NewOCFBDecrypter(c.new(key), se.prefix, ocfbResync)
if s == nil {
return nil, error_.KeyIncorrectError
return nil, errors.KeyIncorrectError
}
plaintext := cipher.StreamReader{S: s, R: se.contents}
......@@ -181,7 +180,7 @@ const mdcPacketTagByte = byte(0x80) | 0x40 | 19
func (ser *seMDCReader) Close() error {
if ser.error {
return error_.SignatureError("error during reading")
return errors.SignatureError("error during reading")
}
for !ser.eof {
......@@ -192,18 +191,18 @@ func (ser *seMDCReader) Close() error {
break
}
if err != nil {
return error_.SignatureError("error during reading")
return errors.SignatureError("error during reading")
}
}
if ser.trailer[0] != mdcPacketTagByte || ser.trailer[1] != sha1.Size {
return error_.SignatureError("MDC packet not found")
return errors.SignatureError("MDC packet not found")
}
ser.h.Write(ser.trailer[:2])
final := ser.h.Sum(nil)
if subtle.ConstantTimeCompare(final, ser.trailer[2:]) != 1 {
return error_.SignatureError("hash mismatch")
return errors.SignatureError("hash mismatch")
}
return nil
}
......@@ -253,9 +252,9 @@ func (c noOpCloser) Close() error {
// SerializeSymmetricallyEncrypted serializes a symmetrically encrypted packet
// to w and returns a WriteCloser to which the to-be-encrypted packets can be
// written.
func SerializeSymmetricallyEncrypted(w io.Writer, c CipherFunction, key []byte) (contents io.WriteCloser, err error) {
func SerializeSymmetricallyEncrypted(w io.Writer, rand io.Reader, c CipherFunction, key []byte) (contents io.WriteCloser, err error) {
if c.KeySize() != len(key) {
return nil, error_.InvalidArgumentError("SymmetricallyEncrypted.Serialize: bad key length")
return nil, errors.InvalidArgumentError("SymmetricallyEncrypted.Serialize: bad key length")
}
writeCloser := noOpCloser{w}
ciphertext, err := serializeStreamHeader(writeCloser, packetTypeSymmetricallyEncryptedMDC)
......@@ -271,7 +270,7 @@ func SerializeSymmetricallyEncrypted(w io.Writer, c CipherFunction, key []byte)
block := c.new(key)
blockSize := block.BlockSize()
iv := make([]byte, blockSize)
_, err = rand.Reader.Read(iv)
_, err = rand.Read(iv)
if err != nil {
return
}
......
......@@ -6,7 +6,8 @@ package packet
import (
"bytes"
error_ "crypto/openpgp/error"
"crypto/openpgp/errors"
"crypto/rand"
"crypto/sha1"
"encoding/hex"
"io"
......@@ -70,7 +71,7 @@ func testMDCReader(t *testing.T) {
err = mdcReader.Close()
if err == nil {
t.Error("corruption: no error")
} else if _, ok := err.(*error_.SignatureError); !ok {
} else if _, ok := err.(*errors.SignatureError); !ok {
t.Errorf("corruption: expected SignatureError, got: %s", err)
}
}
......@@ -82,7 +83,7 @@ func TestSerialize(t *testing.T) {
c := CipherAES128
key := make([]byte, c.KeySize())
w, err := SerializeSymmetricallyEncrypted(buf, c, key)
w, err := SerializeSymmetricallyEncrypted(buf, rand.Reader, c, key)
if err != nil {
t.Errorf("error from SerializeSymmetricallyEncrypted: %s", err)
return
......
......@@ -8,7 +8,7 @@ package openpgp
import (
"crypto"
"crypto/openpgp/armor"
error_ "crypto/openpgp/error"
"crypto/openpgp/errors"
"crypto/openpgp/packet"
_ "crypto/sha256"
"hash"
......@@ -27,7 +27,7 @@ func readArmored(r io.Reader, expectedType string) (body io.Reader, err error) {
}
if block.Type != expectedType {
return nil, error_.InvalidArgumentError("expected '" + expectedType + "', got: " + block.Type)
return nil, errors.InvalidArgumentError("expected '" + expectedType + "', got: " + block.Type)
}
return block.Body, nil
......@@ -130,7 +130,7 @@ ParsePackets:
case *packet.Compressed, *packet.LiteralData, *packet.OnePassSignature:
// This message isn't encrypted.
if len(symKeys) != 0 || len(pubKeys) != 0 {
return nil, error_.StructuralError("key material not followed by encrypted message")
return nil, errors.StructuralError("key material not followed by encrypted message")
}
packets.Unread(p)
return readSignedMessage(packets, nil, keyring)
......@@ -161,7 +161,7 @@ FindKey:
continue
}
decrypted, err = se.Decrypt(pk.encryptedKey.CipherFunc, pk.encryptedKey.Key)
if err != nil && err != error_.KeyIncorrectError {
if err != nil && err != errors.KeyIncorrectError {
return nil, err
}
if decrypted != nil {
......@@ -179,11 +179,11 @@ FindKey:
}
if len(candidates) == 0 && len(symKeys) == 0 {
return nil, error_.KeyIncorrectError
return nil, errors.KeyIncorrectError
}
if prompt == nil {
return nil, error_.KeyIncorrectError
return nil, errors.KeyIncorrectError
}
passphrase, err := prompt(candidates, len(symKeys) != 0)
......@@ -197,7 +197,7 @@ FindKey:
err = s.Decrypt(passphrase)
if err == nil && !s.Encrypted {
decrypted, err = se.Decrypt(s.CipherFunc, s.Key)
if err != nil && err != error_.KeyIncorrectError {
if err != nil && err != errors.KeyIncorrectError {
return nil, err
}
if decrypted != nil {
......@@ -237,7 +237,7 @@ FindLiteralData:
packets.Push(p.Body)
case *packet.OnePassSignature:
if !p.IsLast {
return nil, error_.UnsupportedError("nested signatures")
return nil, errors.UnsupportedError("nested signatures")
}
h, wrappedHash, err = hashForSignature(p.Hash, p.SigType)
......@@ -281,7 +281,7 @@ FindLiteralData:
func hashForSignature(hashId crypto.Hash, sigType packet.SignatureType) (hash.Hash, hash.Hash, error) {
h := hashId.New()
if h == nil {
return nil, nil, error_.UnsupportedError("hash not available: " + strconv.Itoa(int(hashId)))
return nil, nil, errors.UnsupportedError("hash not available: " + strconv.Itoa(int(hashId)))
}
switch sigType {
......@@ -291,7 +291,7 @@ func hashForSignature(hashId crypto.Hash, sigType packet.SignatureType) (hash.Ha
return h, NewCanonicalTextHash(h), nil
}
return nil, nil, error_.UnsupportedError("unsupported signature type: " + strconv.Itoa(int(sigType)))
return nil, nil, errors.UnsupportedError("unsupported signature type: " + strconv.Itoa(int(sigType)))
}
// checkReader wraps an io.Reader from a LiteralData packet. When it sees EOF
......@@ -333,7 +333,7 @@ func (scr *signatureCheckReader) Read(buf []byte) (n int, err error) {
var ok bool
if scr.md.Signature, ok = p.(*packet.Signature); !ok {
scr.md.SignatureError = error_.StructuralError("LiteralData not followed by Signature")
scr.md.SignatureError = errors.StructuralError("LiteralData not followed by Signature")
return
}
......@@ -363,16 +363,16 @@ func CheckDetachedSignature(keyring KeyRing, signed, signature io.Reader) (signe
sig, ok := p.(*packet.Signature)
if !ok {
return nil, error_.StructuralError("non signature packet found")
return nil, errors.StructuralError("non signature packet found")
}
if sig.IssuerKeyId == nil {
return nil, error_.StructuralError("signature doesn't have an issuer")
return nil, errors.StructuralError("signature doesn't have an issuer")
}
keys := keyring.KeysById(*sig.IssuerKeyId)
if len(keys) == 0 {
return nil, error_.UnknownIssuerError
return nil, errors.UnknownIssuerError
}
h, wrappedHash, err := hashForSignature(sig.Hash, sig.SigType)
......@@ -399,7 +399,7 @@ func CheckDetachedSignature(keyring KeyRing, signed, signature io.Reader) (signe
return
}
return nil, error_.UnknownIssuerError
return nil, errors.UnknownIssuerError
}
// CheckArmoredDetachedSignature performs the same actions as
......
......@@ -6,7 +6,8 @@ package openpgp
import (
"bytes"
error_ "crypto/openpgp/error"
"crypto/openpgp/errors"
_ "crypto/sha512"
"encoding/hex"
"io"
"io/ioutil"
......@@ -77,6 +78,15 @@ func TestReadDSAKey(t *testing.T) {
}
}
func TestDSAHashTruncatation(t *testing.T) {
// dsaKeyWithSHA512 was generated with GnuPG and --cert-digest-algo
// SHA512 in order to require DSA hash truncation to verify correctly.
_, err := ReadKeyRing(readerFromHex(dsaKeyWithSHA512))
if err != nil {
t.Error(err)
}
}
func TestGetKeyById(t *testing.T) {
kring, _ := ReadKeyRing(readerFromHex(testKeys1And2Hex))
......@@ -151,18 +161,18 @@ func TestSignedEncryptedMessage(t *testing.T) {
prompt := func(keys []Key, symmetric bool) ([]byte, error) {
if symmetric {
t.Errorf("prompt: message was marked as symmetrically encrypted")
return nil, error_.KeyIncorrectError
return nil, errors.KeyIncorrectError
}
if len(keys) == 0 {
t.Error("prompt: no keys requested")
return nil, error_.KeyIncorrectError
return nil, errors.KeyIncorrectError
}
err := keys[0].PrivateKey.Decrypt([]byte("passphrase"))
if err != nil {
t.Errorf("prompt: error decrypting key: %s", err)
return nil, error_.KeyIncorrectError
return nil, errors.KeyIncorrectError
}
return nil, nil
......@@ -286,7 +296,7 @@ func TestReadingArmoredPrivateKey(t *testing.T) {
func TestNoArmoredData(t *testing.T) {
_, err := ReadArmoredKeyRing(bytes.NewBufferString("foo"))
if _, ok := err.(error_.InvalidArgumentError); !ok {
if _, ok := err.(errors.InvalidArgumentError); !ok {
t.Errorf("error was not an InvalidArgumentError: %s", err)
}
}
......@@ -358,3 +368,5 @@ AHcVnXjtxrULkQFGbGvhKURLvS9WnzD/m1K2zzwxzkPTzT9/Yf06O6Mal5AdugPL
VrM0m72/jnpKo04=
=zNCn
-----END PGP PRIVATE KEY BLOCK-----`
const dsaKeyWithSHA512 = `9901a2044f04b07f110400db244efecc7316553ee08d179972aab87bb1214de7692593fcf5b6feb1c80fba268722dd464748539b85b81d574cd2d7ad0ca2444de4d849b8756bad7768c486c83a824f9bba4af773d11742bdfb4ac3b89ef8cc9452d4aad31a37e4b630d33927bff68e879284a1672659b8b298222fc68f370f3e24dccacc4a862442b9438b00a0ea444a24088dc23e26df7daf8f43cba3bffc4fe703fe3d6cd7fdca199d54ed8ae501c30e3ec7871ea9cdd4cf63cfe6fc82281d70a5b8bb493f922cd99fba5f088935596af087c8d818d5ec4d0b9afa7f070b3d7c1dd32a84fca08d8280b4890c8da1dde334de8e3cad8450eed2a4a4fcc2db7b8e5528b869a74a7f0189e11ef097ef1253582348de072bb07a9fa8ab838e993cef0ee203ff49298723e2d1f549b00559f886cd417a41692ce58d0ac1307dc71d85a8af21b0cf6eaa14baf2922d3a70389bedf17cc514ba0febbd107675a372fe84b90162a9e88b14d4b1c6be855b96b33fb198c46f058568817780435b6936167ebb3724b680f32bf27382ada2e37a879b3d9de2abe0c3f399350afd1ad438883f4791e2e3b4184453412068617368207472756e636174696f6e207465737488620413110a002205024f04b07f021b03060b090807030206150802090a0b0416020301021e01021780000a0910ef20e0cefca131581318009e2bf3bf047a44d75a9bacd00161ee04d435522397009a03a60d51bd8a568c6c021c8d7cf1be8d990d6417b0020003`
......@@ -8,7 +8,7 @@ package s2k
import (
"crypto"
error_ "crypto/openpgp/error"
"crypto/openpgp/errors"
"hash"
"io"
"strconv"
......@@ -89,11 +89,11 @@ func Parse(r io.Reader) (f func(out, in []byte), err error) {
hash, ok := HashIdToHash(buf[1])
if !ok {
return nil, error_.UnsupportedError("hash for S2K function: " + strconv.Itoa(int(buf[1])))
return nil, errors.UnsupportedError("hash for S2K function: " + strconv.Itoa(int(buf[1])))
}
h := hash.New()
if h == nil {
return nil, error_.UnsupportedError("hash not available: " + strconv.Itoa(int(hash)))
return nil, errors.UnsupportedError("hash not available: " + strconv.Itoa(int(hash)))
}
switch buf[0] {
......@@ -123,7 +123,7 @@ func Parse(r io.Reader) (f func(out, in []byte), err error) {
return f, nil
}
return nil, error_.UnsupportedError("S2K function")
return nil, errors.UnsupportedError("S2K function")
}
// Serialize salts and stretches the given passphrase and writes the resulting
......
......@@ -7,7 +7,7 @@ package openpgp
import (
"crypto"
"crypto/openpgp/armor"
error_ "crypto/openpgp/error"
"crypto/openpgp/errors"
"crypto/openpgp/packet"
"crypto/openpgp/s2k"
"crypto/rand"
......@@ -58,10 +58,10 @@ func armoredDetachSign(w io.Writer, signer *Entity, message io.Reader, sigType p
func detachSign(w io.Writer, signer *Entity, message io.Reader, sigType packet.SignatureType) (err error) {
if signer.PrivateKey == nil {
return error_.InvalidArgumentError("signing key doesn't have a private key")
return errors.InvalidArgumentError("signing key doesn't have a private key")
}
if signer.PrivateKey.Encrypted {
return error_.InvalidArgumentError("signing key is encrypted")
return errors.InvalidArgumentError("signing key is encrypted")
}
sig := new(packet.Signature)
......@@ -77,7 +77,7 @@ func detachSign(w io.Writer, signer *Entity, message io.Reader, sigType packet.S
}
io.Copy(wrappedHash, message)
err = sig.Sign(h, signer.PrivateKey)
err = sig.Sign(rand.Reader, h, signer.PrivateKey)
if err != nil {
return
}
......@@ -111,7 +111,7 @@ func SymmetricallyEncrypt(ciphertext io.Writer, passphrase []byte, hints *FileHi
if err != nil {
return
}
w, err := packet.SerializeSymmetricallyEncrypted(ciphertext, packet.CipherAES128, key)
w, err := packet.SerializeSymmetricallyEncrypted(ciphertext, rand.Reader, packet.CipherAES128, key)
if err != nil {
return
}
......@@ -156,7 +156,7 @@ func Encrypt(ciphertext io.Writer, to []*Entity, signed *Entity, hints *FileHint
if signed != nil {
signer = signed.signingKey().PrivateKey
if signer == nil || signer.Encrypted {
return nil, error_.InvalidArgumentError("signing key must be decrypted")
return nil, errors.InvalidArgumentError("signing key must be decrypted")
}
}
......@@ -183,7 +183,7 @@ func Encrypt(ciphertext io.Writer, to []*Entity, signed *Entity, hints *FileHint
for i := range to {
encryptKeys[i] = to[i].encryptionKey()
if encryptKeys[i].PublicKey == nil {
return nil, error_.InvalidArgumentError("cannot encrypt a message to key id " + strconv.FormatUint(to[i].PrimaryKey.KeyId, 16) + " because it has no encryption keys")
return nil, errors.InvalidArgumentError("cannot encrypt a message to key id " + strconv.FormatUint(to[i].PrimaryKey.KeyId, 16) + " because it has no encryption keys")
}
sig := to[i].primaryIdentity().SelfSignature
......@@ -201,7 +201,7 @@ func Encrypt(ciphertext io.Writer, to []*Entity, signed *Entity, hints *FileHint
}
if len(candidateCiphers) == 0 || len(candidateHashes) == 0 {
return nil, error_.InvalidArgumentError("cannot encrypt because recipient set shares no common algorithms")
return nil, errors.InvalidArgumentError("cannot encrypt because recipient set shares no common algorithms")
}
cipher := packet.CipherFunction(candidateCiphers[0])
......@@ -217,7 +217,7 @@ func Encrypt(ciphertext io.Writer, to []*Entity, signed *Entity, hints *FileHint
}
}
encryptedData, err := packet.SerializeSymmetricallyEncrypted(ciphertext, cipher, symKey)
encryptedData, err := packet.SerializeSymmetricallyEncrypted(ciphertext, rand.Reader, cipher, symKey)
if err != nil {
return
}
......@@ -287,7 +287,7 @@ func (s signatureWriter) Close() error {
IssuerKeyId: &s.signer.KeyId,
}
if err := sig.Sign(s.h, s.signer); err != nil {
if err := sig.Sign(rand.Reader, s.h, s.signer); err != nil {
return err
}
if err := s.literalData.Close(); err != nil {
......
......@@ -222,7 +222,7 @@ func TestEncryption(t *testing.T) {
if test.isSigned {
if md.SignatureError != nil {
t.Errorf("#%d: signature error: %s", i, err)
t.Errorf("#%d: signature error: %s", i, md.SignatureError)
}
if md.Signature == nil {
t.Error("signature missing")
......
......@@ -111,6 +111,18 @@ type ConnectionState struct {
VerifiedChains [][]*x509.Certificate
}
// ClientAuthType declares the policy the server will follow for
// TLS Client Authentication.
type ClientAuthType int
const (
NoClientCert ClientAuthType = iota
RequestClientCert
RequireAnyClientCert
VerifyClientCertIfGiven
RequireAndVerifyClientCert
)
// A Config structure is used to configure a TLS client or server. After one
// has been passed to a TLS function it must not be modified.
type Config struct {
......@@ -120,7 +132,7 @@ type Config struct {
Rand io.Reader
// Time returns the current time as the number of seconds since the epoch.
// If Time is nil, TLS uses the system time.Seconds.
// If Time is nil, TLS uses time.Now.
Time func() time.Time
// Certificates contains one or more certificate chains
......@@ -148,11 +160,14 @@ type Config struct {
// hosting.
ServerName string
// AuthenticateClient controls whether a server will request a certificate
// from the client. It does not require that the client send a
// certificate nor does it require that the certificate sent be
// anything more than self-signed.
AuthenticateClient bool
// ClientAuth determines the server's policy for
// TLS Client Authentication. The default is NoClientCert.
ClientAuth ClientAuthType
// ClientCAs defines the set of root certificate authorities
// that servers use if required to verify a client certificate
// by the policy in ClientAuth.
ClientCAs *x509.CertPool
// InsecureSkipVerify controls whether a client verifies the
// server's certificate chain and host name.
......@@ -259,6 +274,11 @@ type Certificate struct {
// OCSPStaple contains an optional OCSP response which will be served
// to clients that request it.
OCSPStaple []byte
// Leaf is the parsed form of the leaf certificate, which may be
// initialized using x509.ParseCertificate to reduce per-handshake
// processing for TLS clients doing client authentication. If nil, the
// leaf certificate will be parsed as needed.
Leaf *x509.Certificate
}
// A TLS record.
......
......@@ -31,7 +31,7 @@ func main() {
return
}
now := time.Seconds()
now := time.Now()
template := x509.Certificate{
SerialNumber: new(big.Int).SetInt64(0),
......@@ -39,8 +39,8 @@ func main() {
CommonName: *hostName,
Organization: []string{"Acme Co"},
},
NotBefore: time.SecondsToUTC(now - 300),
NotAfter: time.SecondsToUTC(now + 60*60*24*365), // valid for 1 year.
NotBefore: now.Add(-5 * time.Minute).UTC(),
NotAfter: now.AddDate(1, 0, 0).UTC(), // valid for 1 year.
SubjectKeyId: []byte{1, 2, 3, 4},
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
......
......@@ -5,12 +5,14 @@
package tls
import (
"bytes"
"crypto"
"crypto/rsa"
"crypto/subtle"
"crypto/x509"
"errors"
"io"
"strconv"
)
func (c *Conn) clientHandshake() error {
......@@ -162,10 +164,23 @@ func (c *Conn) clientHandshake() error {
}
}
transmitCert := false
var certToSend *Certificate
certReq, ok := msg.(*certificateRequestMsg)
if ok {
// We only accept certificates with RSA keys.
// RFC 4346 on the certificateAuthorities field:
// A list of the distinguished names of acceptable certificate
// authorities. These distinguished names may specify a desired
// distinguished name for a root CA or for a subordinate CA;
// thus, this message can be used to describe both known roots
// and a desired authorization space. If the
// certificate_authorities list is empty then the client MAY
// send any certificate of the appropriate
// ClientCertificateType, unless there is some external
// arrangement to the contrary.
finishedHash.Write(certReq.marshal())
// For now, we only know how to sign challenges with RSA
rsaAvail := false
for _, certType := range certReq.certificateTypes {
if certType == certTypeRSASign {
......@@ -174,23 +189,41 @@ func (c *Conn) clientHandshake() error {
}
}
// For now, only send a certificate back if the server gives us an
// empty list of certificateAuthorities.
//
// RFC 4346 on the certificateAuthorities field:
// A list of the distinguished names of acceptable certificate
// authorities. These distinguished names may specify a desired
// distinguished name for a root CA or for a subordinate CA; thus,
// this message can be used to describe both known roots and a
// desired authorization space. If the certificate_authorities
// list is empty then the client MAY send any certificate of the
// appropriate ClientCertificateType, unless there is some
// external arrangement to the contrary.
if rsaAvail && len(certReq.certificateAuthorities) == 0 {
transmitCert = true
}
// We need to search our list of client certs for one
// where SignatureAlgorithm is RSA and the Issuer is in
// certReq.certificateAuthorities
findCert:
for i, cert := range c.config.Certificates {
if !rsaAvail {
continue
}
finishedHash.Write(certReq.marshal())
leaf := cert.Leaf
if leaf == nil {
if leaf, err = x509.ParseCertificate(cert.Certificate[0]); err != nil {
c.sendAlert(alertInternalError)
return errors.New("tls: failed to parse client certificate #" + strconv.Itoa(i) + ": " + err.Error())
}
}
if leaf.PublicKeyAlgorithm != x509.RSA {
continue
}
if len(certReq.certificateAuthorities) == 0 {
// they gave us an empty list, so just take the
// first RSA cert from c.config.Certificates
certToSend = &cert
break
}
for _, ca := range certReq.certificateAuthorities {
if bytes.Equal(leaf.RawIssuer, ca) {
certToSend = &cert
break findCert
}
}
}
msg, err = c.readHandshake()
if err != nil {
......@@ -204,17 +237,9 @@ func (c *Conn) clientHandshake() error {
}
finishedHash.Write(shd.marshal())
var cert *x509.Certificate
if transmitCert {
if certToSend != nil {
certMsg = new(certificateMsg)
if len(c.config.Certificates) > 0 {
cert, err = x509.ParseCertificate(c.config.Certificates[0].Certificate[0])
if err == nil && cert.PublicKeyAlgorithm == x509.RSA {
certMsg.certificates = c.config.Certificates[0].Certificate
} else {
cert = nil
}
}
certMsg.certificates = certToSend.Certificate
finishedHash.Write(certMsg.marshal())
c.writeRecord(recordTypeHandshake, certMsg.marshal())
}
......@@ -229,7 +254,7 @@ func (c *Conn) clientHandshake() error {
c.writeRecord(recordTypeHandshake, ckx.marshal())
}
if cert != nil {
if certToSend != nil {
certVerify := new(certificateVerifyMsg)
digest := make([]byte, 0, 36)
digest = finishedHash.serverMD5.Sum(digest)
......
......@@ -881,9 +881,11 @@ func (m *certificateRequestMsg) marshal() (x []byte) {
// See http://tools.ietf.org/html/rfc4346#section-7.4.4
length := 1 + len(m.certificateTypes) + 2
casLength := 0
for _, ca := range m.certificateAuthorities {
length += 2 + len(ca)
casLength += 2 + len(ca)
}
length += casLength
x = make([]byte, 4+length)
x[0] = typeCertificateRequest
......@@ -895,10 +897,8 @@ func (m *certificateRequestMsg) marshal() (x []byte) {
copy(x[5:], m.certificateTypes)
y := x[5+len(m.certificateTypes):]
numCA := len(m.certificateAuthorities)
y[0] = uint8(numCA >> 8)
y[1] = uint8(numCA)
y[0] = uint8(casLength >> 8)
y[1] = uint8(casLength)
y = y[2:]
for _, ca := range m.certificateAuthorities {
y[0] = uint8(len(ca) >> 8)
......@@ -909,7 +909,6 @@ func (m *certificateRequestMsg) marshal() (x []byte) {
}
m.raw = x
return
}
......@@ -937,31 +936,34 @@ func (m *certificateRequestMsg) unmarshal(data []byte) bool {
}
data = data[numCertTypes:]
if len(data) < 2 {
return false
}
numCAs := uint16(data[0])<<16 | uint16(data[1])
casLength := uint16(data[0])<<8 | uint16(data[1])
data = data[2:]
if len(data) < int(casLength) {
return false
}
cas := make([]byte, casLength)
copy(cas, data)
data = data[casLength:]
m.certificateAuthorities = make([][]byte, numCAs)
for i := uint16(0); i < numCAs; i++ {
if len(data) < 2 {
m.certificateAuthorities = nil
for len(cas) > 0 {
if len(cas) < 2 {
return false
}
caLen := uint16(data[0])<<16 | uint16(data[1])
caLen := uint16(cas[0])<<8 | uint16(cas[1])
cas = cas[2:]
data = data[2:]
if len(data) < int(caLen) {
if len(cas) < int(caLen) {
return false
}
ca := make([]byte, caLen)
copy(ca, data)
m.certificateAuthorities[i] = ca
data = data[caLen:]
m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
cas = cas[caLen:]
}
if len(data) > 0 {
return false
}
......
......@@ -150,14 +150,19 @@ FindCipherSuite:
c.writeRecord(recordTypeHandshake, skx.marshal())
}
if config.AuthenticateClient {
if config.ClientAuth >= RequestClientCert {
// Request a client certificate
certReq := new(certificateRequestMsg)
certReq.certificateTypes = []byte{certTypeRSASign}
// An empty list of certificateAuthorities signals to
// the client that it may send any certificate in response
// to our request.
// to our request. When we know the CAs we trust, then
// we can send them down, so that the client can choose
// an appropriate certificate to give to us.
if config.ClientCAs != nil {
certReq.certificateAuthorities = config.ClientCAs.Subjects()
}
finishedHash.Write(certReq.marshal())
c.writeRecord(recordTypeHandshake, certReq.marshal())
}
......@@ -166,52 +171,87 @@ FindCipherSuite:
finishedHash.Write(helloDone.marshal())
c.writeRecord(recordTypeHandshake, helloDone.marshal())
var pub *rsa.PublicKey
if config.AuthenticateClient {
// Get client certificate
msg, err = c.readHandshake()
if err != nil {
return err
}
certMsg, ok = msg.(*certificateMsg)
if !ok {
return c.sendAlert(alertUnexpectedMessage)
var pub *rsa.PublicKey // public key for client auth, if any
msg, err = c.readHandshake()
if err != nil {
return err
}
// If we requested a client certificate, then the client must send a
// certificate message, even if it's empty.
if config.ClientAuth >= RequestClientCert {
if certMsg, ok = msg.(*certificateMsg); !ok {
return c.sendAlert(alertHandshakeFailure)
}
finishedHash.Write(certMsg.marshal())
if len(certMsg.certificates) == 0 {
// The client didn't actually send a certificate
switch config.ClientAuth {
case RequireAnyClientCert, RequireAndVerifyClientCert:
c.sendAlert(alertBadCertificate)
return errors.New("tls: client didn't provide a certificate")
}
}
certs := make([]*x509.Certificate, len(certMsg.certificates))
for i, asn1Data := range certMsg.certificates {
cert, err := x509.ParseCertificate(asn1Data)
if err != nil {
if certs[i], err = x509.ParseCertificate(asn1Data); err != nil {
c.sendAlert(alertBadCertificate)
return errors.New("could not parse client's certificate: " + err.Error())
return errors.New("tls: failed to parse client certificate: " + err.Error())
}
certs[i] = cert
}
// TODO(agl): do better validation of certs: max path length, name restrictions etc.
for i := 1; i < len(certs); i++ {
if err := certs[i-1].CheckSignatureFrom(certs[i]); err != nil {
if c.config.ClientAuth >= VerifyClientCertIfGiven && len(certs) > 0 {
opts := x509.VerifyOptions{
Roots: c.config.ClientCAs,
CurrentTime: c.config.time(),
Intermediates: x509.NewCertPool(),
}
for i, cert := range certs {
if i == 0 {
continue
}
opts.Intermediates.AddCert(cert)
}
chains, err := certs[0].Verify(opts)
if err != nil {
c.sendAlert(alertBadCertificate)
return errors.New("could not validate certificate signature: " + err.Error())
return errors.New("tls: failed to verify client's certificate: " + err.Error())
}
ok := false
for _, ku := range certs[0].ExtKeyUsage {
if ku == x509.ExtKeyUsageClientAuth {
ok = true
break
}
}
if !ok {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: client's certificate's extended key usage doesn't permit it to be used for client authentication")
}
c.verifiedChains = chains
}
if len(certs) > 0 {
key, ok := certs[0].PublicKey.(*rsa.PublicKey)
if !ok {
if pub, ok = certs[0].PublicKey.(*rsa.PublicKey); !ok {
return c.sendAlert(alertUnsupportedCertificate)
}
pub = key
c.peerCertificates = certs
}
msg, err = c.readHandshake()
if err != nil {
return err
}
}
// Get client key exchange
msg, err = c.readHandshake()
if err != nil {
return err
}
ckx, ok := msg.(*clientKeyExchangeMsg)
if !ok {
return c.sendAlert(alertUnexpectedMessage)
......
......@@ -120,7 +120,7 @@ func Dial(network, addr string, config *Config) (*Conn, error) {
// LoadX509KeyPair reads and parses a public/private key pair from a pair of
// files. The files must contain PEM encoded data.
func LoadX509KeyPair(certFile string, keyFile string) (cert Certificate, err error) {
func LoadX509KeyPair(certFile, keyFile string) (cert Certificate, err error) {
certPEMBlock, err := ioutil.ReadFile(certFile)
if err != nil {
return
......
......@@ -101,3 +101,13 @@ func (s *CertPool) AppendCertsFromPEM(pemCerts []byte) (ok bool) {
return
}
// Subjects returns a list of the DER-encoded subjects of
// all of the certificates in the pool.
func (s *CertPool) Subjects() (res [][]byte) {
res = make([][]byte, len(s.certs))
for i, c := range s.certs {
res[i] = c.RawSubject
}
return
}
......@@ -7,14 +7,14 @@ package gosym
import (
"debug/elf"
"os"
"syscall"
"runtime"
"testing"
)
func dotest() bool {
// For now, only works on ELF platforms.
// TODO: convert to work with new go tool
return false && syscall.OS == "linux" && os.Getenv("GOARCH") == "amd64"
return false && runtime.GOOS == "linux" && runtime.GOARCH == "amd64"
}
func getTable(t *testing.T) *Table {
......
......@@ -786,7 +786,8 @@ func setDefaultValue(v reflect.Value, params fieldParameters) (ok bool) {
// Because Unmarshal uses the reflect package, the structs
// being written to must use upper case field names.
//
// An ASN.1 INTEGER can be written to an int, int32 or int64.
// An ASN.1 INTEGER can be written to an int, int32, int64,
// or *big.Int (from the math/big package).
// If the encoded value does not fit in the Go type,
// Unmarshal returns a parse error.
//
......
......@@ -6,6 +6,7 @@ package asn1
import (
"bytes"
"math/big"
"reflect"
"testing"
"time"
......@@ -351,6 +352,10 @@ type TestElementsAfterString struct {
A, B int
}
type TestBigInt struct {
X *big.Int
}
var unmarshalTestData = []struct {
in []byte
out interface{}
......@@ -369,6 +374,7 @@ var unmarshalTestData = []struct {
{[]byte{0x01, 0x01, 0x00}, newBool(false)},
{[]byte{0x01, 0x01, 0x01}, newBool(true)},
{[]byte{0x30, 0x0b, 0x13, 0x03, 0x66, 0x6f, 0x6f, 0x02, 0x01, 0x22, 0x02, 0x01, 0x33}, &TestElementsAfterString{"foo", 0x22, 0x33}},
{[]byte{0x30, 0x05, 0x02, 0x03, 0x12, 0x34, 0x56}, &TestBigInt{big.NewInt(0x123456)}},
}
func TestUnmarshal(t *testing.T) {
......
......@@ -7,6 +7,7 @@ package asn1
import (
"bytes"
"encoding/hex"
"math/big"
"testing"
"time"
)
......@@ -20,6 +21,10 @@ type twoIntStruct struct {
B int
}
type bigIntStruct struct {
A *big.Int
}
type nestedStruct struct {
A intStruct
}
......@@ -65,6 +70,7 @@ var marshalTests = []marshalTest{
{-128, "020180"},
{-129, "0202ff7f"},
{intStruct{64}, "3003020140"},
{bigIntStruct{big.NewInt(0x123456)}, "30050203123456"},
{twoIntStruct{64, 65}, "3006020140020141"},
{nestedStruct{intStruct{127}}, "3005300302017f"},
{[]byte{1, 2, 3}, "0403010203"},
......
......@@ -1039,9 +1039,9 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[re
// Extract and compare element types.
var sw *sliceType
if tt, ok := builtinIdToType[fw]; ok {
sw = tt.(*sliceType)
} else {
sw = dec.wireType[fw].SliceT
sw, _ = tt.(*sliceType)
} else if wire != nil {
sw = wire.SliceT
}
elem := userType(t.Elem()).base
return sw != nil && dec.compatibleType(elem, sw.Elem, inProgress)
......
......@@ -678,3 +678,11 @@ func TestUnexportedChan(t *testing.T) {
t.Fatalf("error encoding unexported channel: %s", err)
}
}
func TestSliceIncompatibility(t *testing.T) {
var in = []byte{1, 2, 3}
var out []int
if err := encAndDec(in, &out); err == nil {
t.Error("expected compatibility error")
}
}
......@@ -10,6 +10,7 @@ package json
import (
"encoding/base64"
"errors"
"fmt"
"reflect"
"runtime"
"strconv"
......@@ -538,7 +539,7 @@ func (d *decodeState) object(v reflect.Value) {
// Read value.
if destring {
d.value(reflect.ValueOf(&d.tempstr))
d.literalStore([]byte(d.tempstr), subv)
d.literalStore([]byte(d.tempstr), subv, true)
} else {
d.value(subv)
}
......@@ -571,11 +572,15 @@ func (d *decodeState) literal(v reflect.Value) {
d.off--
d.scan.undo(op)
d.literalStore(d.data[start:d.off], v)
d.literalStore(d.data[start:d.off], v, false)
}
// literalStore decodes a literal stored in item into v.
func (d *decodeState) literalStore(item []byte, v reflect.Value) {
//
// fromQuoted indicates whether this literal came from unwrapping a
// string from the ",string" struct tag option. this is used only to
// produce more helpful error messages.
func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool) {
// Check for unmarshaler.
wantptr := item[0] == 'n' // null
unmarshaler, pv := d.indirect(v, wantptr)
......@@ -601,7 +606,11 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value) {
value := c == 't'
switch v.Kind() {
default:
d.saveError(&UnmarshalTypeError{"bool", v.Type()})
if fromQuoted {
d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
} else {
d.saveError(&UnmarshalTypeError{"bool", v.Type()})
}
case reflect.Bool:
v.SetBool(value)
case reflect.Interface:
......@@ -611,7 +620,11 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value) {
case '"': // string
s, ok := unquoteBytes(item)
if !ok {
d.error(errPhase)
if fromQuoted {
d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
} else {
d.error(errPhase)
}
}
switch v.Kind() {
default:
......@@ -636,12 +649,20 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value) {
default: // number
if c != '-' && (c < '0' || c > '9') {
d.error(errPhase)
if fromQuoted {
d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
} else {
d.error(errPhase)
}
}
s := string(item)
switch v.Kind() {
default:
d.error(&UnmarshalTypeError{"number", v.Type()})
if fromQuoted {
d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
} else {
d.error(&UnmarshalTypeError{"number", v.Type()})
}
case reflect.Interface:
n, err := strconv.ParseFloat(s, 64)
if err != nil {
......
......@@ -258,13 +258,10 @@ type wrongStringTest struct {
in, err string
}
// TODO(bradfitz): as part of Issue 2331, fix these tests' expected
// error values to be helpful, rather than the confusing messages they
// are now.
var wrongStringTests = []wrongStringTest{
{`{"result":"x"}`, "JSON decoder out of sync - data changing underfoot?"},
{`{"result":"foo"}`, "json: cannot unmarshal bool into Go value of type string"},
{`{"result":"123"}`, "json: cannot unmarshal number into Go value of type string"},
{`{"result":"x"}`, `json: invalid use of ,string struct tag, trying to unmarshal "x" into string`},
{`{"result":"foo"}`, `json: invalid use of ,string struct tag, trying to unmarshal "foo" into string`},
{`{"result":"123"}`, `json: invalid use of ,string struct tag, trying to unmarshal "123" into string`},
}
// If people misuse the ,string modifier, the error message should be
......
......@@ -12,6 +12,7 @@ package json
import (
"bytes"
"encoding/base64"
"math"
"reflect"
"runtime"
"sort"
......@@ -170,6 +171,15 @@ func (e *UnsupportedTypeError) Error() string {
return "json: unsupported type: " + e.Type.String()
}
type UnsupportedValueError struct {
Value reflect.Value
Str string
}
func (e *UnsupportedValueError) Error() string {
return "json: unsupported value: " + e.Str
}
type InvalidUTF8Error struct {
S string
}
......@@ -290,7 +300,11 @@ func (e *encodeState) reflectValueQuoted(v reflect.Value, quoted bool) {
e.Write(b)
}
case reflect.Float32, reflect.Float64:
b := strconv.AppendFloat(e.scratch[:0], v.Float(), 'g', -1, v.Type().Bits())
f := v.Float()
if math.IsInf(f, 0) || math.IsNaN(f) {
e.error(&UnsupportedValueError{v, strconv.FormatFloat(f, 'g', -1, v.Type().Bits())})
}
b := strconv.AppendFloat(e.scratch[:0], f, 'g', -1, v.Type().Bits())
if quoted {
writeString(e, string(b))
} else {
......
......@@ -6,6 +6,7 @@ package json
import (
"bytes"
"math"
"reflect"
"testing"
)
......@@ -107,3 +108,21 @@ func TestEncodeRenamedByteSlice(t *testing.T) {
t.Errorf(" got %s want %s", result, expect)
}
}
var unsupportedValues = []interface{}{
math.NaN(),
math.Inf(-1),
math.Inf(1),
}
func TestUnsupportedValues(t *testing.T) {
for _, v := range unsupportedValues {
if _, err := Marshal(v); err != nil {
if _, ok := err.(*UnsupportedValueError); !ok {
t.Errorf("for %v, got %T want UnsupportedValueError", v, err)
}
} else {
t.Errorf("for %v, expected error", v)
}
}
}
......@@ -5,6 +5,7 @@
package xml
var atomValue = &Feed{
XMLName: Name{"http://www.w3.org/2005/Atom", "feed"},
Title: "Example Feed",
Link: []Link{{Href: "http://example.org/"}},
Updated: ParseTime("2003-12-13T18:30:02Z"),
......@@ -24,19 +25,19 @@ var atomValue = &Feed{
var atomXml = `` +
`<feed xmlns="http://www.w3.org/2005/Atom">` +
`<Title>Example Feed</Title>` +
`<Id>urn:uuid:60a76c80-d399-11d9-b93C-0003939e0af6</Id>` +
`<Link href="http://example.org/"></Link>` +
`<Updated>2003-12-13T18:30:02Z</Updated>` +
`<Author><Name>John Doe</Name><URI></URI><Email></Email></Author>` +
`<Entry>` +
`<Title>Atom-Powered Robots Run Amok</Title>` +
`<Id>urn:uuid:1225c695-cfb8-4ebb-aaaa-80da344efa6a</Id>` +
`<Link href="http://example.org/2003/12/13/atom03"></Link>` +
`<Updated>2003-12-13T18:30:02Z</Updated>` +
`<Author><Name></Name><URI></URI><Email></Email></Author>` +
`<Summary>Some text.</Summary>` +
`</Entry>` +
`<title>Example Feed</title>` +
`<id>urn:uuid:60a76c80-d399-11d9-b93C-0003939e0af6</id>` +
`<link href="http://example.org/"></link>` +
`<updated>2003-12-13T18:30:02Z</updated>` +
`<author><name>John Doe</name><uri></uri><email></email></author>` +
`<entry>` +
`<title>Atom-Powered Robots Run Amok</title>` +
`<id>urn:uuid:1225c695-cfb8-4ebb-aaaa-80da344efa6a</id>` +
`<link href="http://example.org/2003/12/13/atom03"></link>` +
`<updated>2003-12-13T18:30:02Z</updated>` +
`<author><name></name><uri></uri><email></email></author>` +
`<summary>Some text.</summary>` +
`</entry>` +
`</feed>`
func ParseTime(str string) Time {
......
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xml
import "testing"
type C struct {
Name string
Open bool
}
type A struct {
XMLName Name `xml:"http://domain a"`
C
B B
FieldA string
}
type B struct {
XMLName Name `xml:"b"`
C
FieldB string
}
const _1a = `
<?xml version="1.0" encoding="UTF-8"?>
<a xmlns="http://domain">
<name>KmlFile</name>
<open>1</open>
<b>
<name>Absolute</name>
<open>0</open>
<fieldb>bar</fieldb>
</b>
<fielda>foo</fielda>
</a>
`
// Tests that embedded structs are marshalled.
func TestEmbedded1(t *testing.T) {
var a A
if e := Unmarshal(StringReader(_1a), &a); e != nil {
t.Fatalf("Unmarshal: %s", e)
}
if a.FieldA != "foo" {
t.Fatalf("Unmarshal: expected 'foo' but found '%s'", a.FieldA)
}
if a.Name != "KmlFile" {
t.Fatalf("Unmarshal: expected 'KmlFile' but found '%s'", a.Name)
}
if !a.Open {
t.Fatal("Unmarshal: expected 'true' but found otherwise")
}
if a.B.FieldB != "bar" {
t.Fatalf("Unmarshal: expected 'bar' but found '%s'", a.B.FieldB)
}
if a.B.Name != "Absolute" {
t.Fatalf("Unmarshal: expected 'Absolute' but found '%s'", a.B.Name)
}
if a.B.Open {
t.Fatal("Unmarshal: expected 'false' but found otherwise")
}
}
type A2 struct {
XMLName Name `xml:"http://domain a"`
XY string
Xy string
}
const _2a = `
<?xml version="1.0" encoding="UTF-8"?>
<a xmlns="http://domain">
<xy>foo</xy>
</a>
`
// Tests that conflicting field names get excluded.
func TestEmbedded2(t *testing.T) {
var a A2
if e := Unmarshal(StringReader(_2a), &a); e != nil {
t.Fatalf("Unmarshal: %s", e)
}
if a.XY != "" {
t.Fatalf("Unmarshal: expected empty string but found '%s'", a.XY)
}
if a.Xy != "" {
t.Fatalf("Unmarshal: expected empty string but found '%s'", a.Xy)
}
}
type A3 struct {
XMLName Name `xml:"http://domain a"`
xy string
}
// Tests that private fields are not set.
func TestEmbedded3(t *testing.T) {
var a A3
if e := Unmarshal(StringReader(_2a), &a); e != nil {
t.Fatalf("Unmarshal: %s", e)
}
if a.xy != "" {
t.Fatalf("Unmarshal: expected empty string but found '%s'", a.xy)
}
}
type A4 struct {
XMLName Name `xml:"http://domain a"`
Any string
}
// Tests that private fields are not set.
func TestEmbedded4(t *testing.T) {
var a A4
if e := Unmarshal(StringReader(_2a), &a); e != nil {
t.Fatalf("Unmarshal: %s", e)
}
if a.Any != "foo" {
t.Fatalf("Unmarshal: expected 'foo' but found '%s'", a.Any)
}
}
......@@ -6,6 +6,8 @@ package xml
import (
"bufio"
"bytes"
"fmt"
"io"
"reflect"
"strconv"
......@@ -42,20 +44,26 @@ type printer struct {
// elements containing the data.
//
// The name for the XML elements is taken from, in order of preference:
// - the tag on an XMLName field, if the data is a struct
// - the value of an XMLName field of type xml.Name
// - the tag on the XMLName field, if the data is a struct
// - the value of the XMLName field of type xml.Name
// - the tag of the struct field used to obtain the data
// - the name of the struct field used to obtain the data
// - the name '???'.
// - the name of the marshalled type
//
// The XML element for a struct contains marshalled elements for each of the
// exported fields of the struct, with these exceptions:
// - the XMLName field, described above, is omitted.
// - a field with tag "attr" becomes an attribute in the XML element.
// - a field with tag "chardata" is written as character data,
// not as an XML element.
// - a field with tag "innerxml" is written verbatim,
// not subject to the usual marshalling procedure.
// - a field with tag "name,attr" becomes an attribute with
// the given name in the XML element.
// - a field with tag ",attr" becomes an attribute with the
// field name in the in the XML element.
// - a field with tag ",chardata" is written as character data,
// not as an XML element.
// - a field with tag ",innerxml" is written verbatim, not subject
// to the usual marshalling procedure.
// - a field with tag ",comment" is written as an XML comment, not
// subject to the usual marshalling procedure. It must not contain
// the "--" string within it.
//
// If a field uses a tag "a>b>c", then the element c will be nested inside
// parent elements a and b. Fields that appear next to each other that name
......@@ -63,17 +71,18 @@ type printer struct {
//
// type Result struct {
// XMLName xml.Name `xml:"result"`
// Id int `xml:"id,attr"`
// FirstName string `xml:"person>name>first"`
// LastName string `xml:"person>name>last"`
// Age int `xml:"person>age"`
// }
//
// xml.Marshal(w, &Result{FirstName: "John", LastName: "Doe", Age: 42})
// xml.Marshal(w, &Result{Id: 13, FirstName: "John", LastName: "Doe", Age: 42})
//
// would be marshalled as:
//
// <result>
// <person>
// <person id="13">
// <name>
// <first>John</first>
// <last>Doe</last>
......@@ -85,12 +94,12 @@ type printer struct {
// Marshal will return an error if asked to marshal a channel, function, or map.
func Marshal(w io.Writer, v interface{}) (err error) {
p := &printer{bufio.NewWriter(w)}
err = p.marshalValue(reflect.ValueOf(v), "???")
err = p.marshalValue(reflect.ValueOf(v), nil)
p.Flush()
return err
}
func (p *printer) marshalValue(val reflect.Value, name string) error {
func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo) error {
if !val.IsValid() {
return nil
}
......@@ -115,58 +124,75 @@ func (p *printer) marshalValue(val reflect.Value, name string) error {
if val.IsNil() {
return nil
}
return p.marshalValue(val.Elem(), name)
return p.marshalValue(val.Elem(), finfo)
}
// Slices and arrays iterate over the elements. They do not have an enclosing tag.
if (kind == reflect.Slice || kind == reflect.Array) && typ.Elem().Kind() != reflect.Uint8 {
for i, n := 0, val.Len(); i < n; i++ {
if err := p.marshalValue(val.Index(i), name); err != nil {
if err := p.marshalValue(val.Index(i), finfo); err != nil {
return err
}
}
return nil
}
// Find XML name
xmlns := ""
if kind == reflect.Struct {
if f, ok := typ.FieldByName("XMLName"); ok {
if tag := f.Tag.Get("xml"); tag != "" {
if i := strings.Index(tag, " "); i >= 0 {
xmlns, name = tag[:i], tag[i+1:]
} else {
name = tag
}
} else if v, ok := val.FieldByIndex(f.Index).Interface().(Name); ok && v.Local != "" {
xmlns, name = v.Space, v.Local
}
tinfo, err := getTypeInfo(typ)
if err != nil {
return err
}
// Precedence for the XML element name is:
// 1. XMLName field in underlying struct;
// 2. field name/tag in the struct field; and
// 3. type name
var xmlns, name string
if tinfo.xmlname != nil {
xmlname := tinfo.xmlname
if xmlname.name != "" {
xmlns, name = xmlname.xmlns, xmlname.name
} else if v, ok := val.FieldByIndex(xmlname.idx).Interface().(Name); ok && v.Local != "" {
xmlns, name = v.Space, v.Local
}
}
if name == "" && finfo != nil {
xmlns, name = finfo.xmlns, finfo.name
}
if name == "" {
name = typ.Name()
if name == "" {
return &UnsupportedTypeError{typ}
}
}
p.WriteByte('<')
p.WriteString(name)
if xmlns != "" {
p.WriteString(` xmlns="`)
// TODO: EscapeString, to avoid the allocation.
Escape(p, []byte(xmlns))
p.WriteByte('"')
}
// Attributes
if kind == reflect.Struct {
if len(xmlns) > 0 {
p.WriteString(` xmlns="`)
Escape(p, []byte(xmlns))
p.WriteByte('"')
for i := range tinfo.fields {
finfo := &tinfo.fields[i]
if finfo.flags&fAttr == 0 {
continue
}
for i, n := 0, typ.NumField(); i < n; i++ {
if f := typ.Field(i); f.PkgPath == "" && f.Tag.Get("xml") == "attr" {
if f.Type.Kind() == reflect.String {
if str := val.Field(i).String(); str != "" {
p.WriteByte(' ')
p.WriteString(strings.ToLower(f.Name))
p.WriteString(`="`)
Escape(p, []byte(str))
p.WriteByte('"')
}
}
}
var str string
if fv := val.FieldByIndex(finfo.idx); fv.Kind() == reflect.String {
str = fv.String()
} else {
str = fmt.Sprint(fv.Interface())
}
if str != "" {
p.WriteByte(' ')
p.WriteString(finfo.name)
p.WriteString(`="`)
Escape(p, []byte(str))
p.WriteByte('"')
}
}
p.WriteByte('>')
......@@ -194,58 +220,9 @@ func (p *printer) marshalValue(val reflect.Value, name string) error {
bytes := val.Interface().([]byte)
Escape(p, bytes)
case reflect.Struct:
s := parentStack{printer: p}
for i, n := 0, val.NumField(); i < n; i++ {
if f := typ.Field(i); f.Name != "XMLName" && f.PkgPath == "" {
name := f.Name
vf := val.Field(i)
switch tag := f.Tag.Get("xml"); tag {
case "":
s.trim(nil)
case "chardata":
if tk := f.Type.Kind(); tk == reflect.String {
Escape(p, []byte(vf.String()))
} else if tk == reflect.Slice {
if elem, ok := vf.Interface().([]byte); ok {
Escape(p, elem)
}
}
continue
case "innerxml":
iface := vf.Interface()
switch raw := iface.(type) {
case []byte:
p.Write(raw)
continue
case string:
p.WriteString(raw)
continue
}
case "attr":
continue
default:
parents := strings.Split(tag, ">")
if len(parents) == 1 {
parents, name = nil, tag
} else {
parents, name = parents[:len(parents)-1], parents[len(parents)-1]
if parents[0] == "" {
parents[0] = f.Name
}
}
s.trim(parents)
if !(vf.Kind() == reflect.Ptr || vf.Kind() == reflect.Interface) || !vf.IsNil() {
s.push(parents[len(s.stack):])
}
}
if err := p.marshalValue(vf, name); err != nil {
return err
}
}
if err := p.marshalStruct(tinfo, val); err != nil {
return err
}
s.trim(nil)
default:
return &UnsupportedTypeError{typ}
}
......@@ -258,6 +235,94 @@ func (p *printer) marshalValue(val reflect.Value, name string) error {
return nil
}
var ddBytes = []byte("--")
func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
s := parentStack{printer: p}
for i := range tinfo.fields {
finfo := &tinfo.fields[i]
if finfo.flags&(fAttr|fAny) != 0 {
continue
}
vf := val.FieldByIndex(finfo.idx)
switch finfo.flags & fMode {
case fCharData:
switch vf.Kind() {
case reflect.String:
Escape(p, []byte(vf.String()))
case reflect.Slice:
if elem, ok := vf.Interface().([]byte); ok {
Escape(p, elem)
}
}
continue
case fComment:
k := vf.Kind()
if !(k == reflect.String || k == reflect.Slice && vf.Type().Elem().Kind() == reflect.Uint8) {
return fmt.Errorf("xml: bad type for comment field of %s", val.Type())
}
if vf.Len() == 0 {
continue
}
p.WriteString("<!--")
dashDash := false
dashLast := false
switch k {
case reflect.String:
s := vf.String()
dashDash = strings.Index(s, "--") >= 0
dashLast = s[len(s)-1] == '-'
if !dashDash {
p.WriteString(s)
}
case reflect.Slice:
b := vf.Bytes()
dashDash = bytes.Index(b, ddBytes) >= 0
dashLast = b[len(b)-1] == '-'
if !dashDash {
p.Write(b)
}
default:
panic("can't happen")
}
if dashDash {
return fmt.Errorf(`xml: comments must not contain "--"`)
}
if dashLast {
// "--->" is invalid grammar. Make it "- -->"
p.WriteByte(' ')
}
p.WriteString("-->")
continue
case fInnerXml:
iface := vf.Interface()
switch raw := iface.(type) {
case []byte:
p.Write(raw)
continue
case string:
p.WriteString(raw)
continue
}
case fElement:
s.trim(finfo.parents)
if len(finfo.parents) > len(s.stack) {
if vf.Kind() != reflect.Ptr && vf.Kind() != reflect.Interface || !vf.IsNil() {
s.push(finfo.parents[len(s.stack):])
}
}
}
if err := p.marshalValue(vf, finfo); err != nil {
return err
}
}
s.trim(nil)
return nil
}
type parentStack struct {
*printer
stack []string
......
......@@ -6,6 +6,7 @@ package xml
import (
"reflect"
"strings"
"testing"
)
......@@ -13,7 +14,7 @@ import (
func TestUnmarshalFeed(t *testing.T) {
var f Feed
if err := Unmarshal(StringReader(atomFeedString), &f); err != nil {
if err := Unmarshal(strings.NewReader(atomFeedString), &f); err != nil {
t.Fatalf("Unmarshal: %s", err)
}
if !reflect.DeepEqual(f, atomFeed) {
......@@ -24,8 +25,8 @@ func TestUnmarshalFeed(t *testing.T) {
// hget http://codereview.appspot.com/rss/mine/rsc
const atomFeedString = `
<?xml version="1.0" encoding="utf-8"?>
<feed xmlns="http://www.w3.org/2005/Atom" xml:lang="en-us"><title>Code Review - My issues</title><link href="http://codereview.appspot.com/" rel="alternate"></link><li-nk href="http://codereview.appspot.com/rss/mine/rsc" rel="self"></li-nk><id>http://codereview.appspot.com/</id><updated>2009-10-04T01:35:58+00:00</updated><author><name>rietveld&lt;&gt;</name></author><entry><title>rietveld: an attempt at pubsubhubbub
</title><link hre-f="http://codereview.appspot.com/126085" rel="alternate"></link><updated>2009-10-04T01:35:58+00:00</updated><author><name>email-address-removed</name></author><id>urn:md5:134d9179c41f806be79b3a5f7877d19a</id><summary type="html">
<feed xmlns="http://www.w3.org/2005/Atom" xml:lang="en-us"><title>Code Review - My issues</title><link href="http://codereview.appspot.com/" rel="alternate"></link><link href="http://codereview.appspot.com/rss/mine/rsc" rel="self"></link><id>http://codereview.appspot.com/</id><updated>2009-10-04T01:35:58+00:00</updated><author><name>rietveld&lt;&gt;</name></author><entry><title>rietveld: an attempt at pubsubhubbub
</title><link href="http://codereview.appspot.com/126085" rel="alternate"></link><updated>2009-10-04T01:35:58+00:00</updated><author><name>email-address-removed</name></author><id>urn:md5:134d9179c41f806be79b3a5f7877d19a</id><summary type="html">
An attempt at adding pubsubhubbub support to Rietveld.
http://code.google.com/p/pubsubhubbub
http://code.google.com/p/rietveld/issues/detail?id=155
......@@ -78,39 +79,39 @@ not being used from outside intra_region_diff.py.
</summary></entry></feed> `
type Feed struct {
XMLName Name `xml:"http://www.w3.org/2005/Atom feed"`
Title string
Id string
Link []Link
Updated Time
Author Person
Entry []Entry
XMLName Name `xml:"http://www.w3.org/2005/Atom feed"`
Title string `xml:"title"`
Id string `xml:"id"`
Link []Link `xml:"link"`
Updated Time `xml:"updated"`
Author Person `xml:"author"`
Entry []Entry `xml:"entry"`
}
type Entry struct {
Title string
Id string
Link []Link
Updated Time
Author Person
Summary Text
Title string `xml:"title"`
Id string `xml:"id"`
Link []Link `xml:"link"`
Updated Time `xml:"updated"`
Author Person `xml:"author"`
Summary Text `xml:"summary"`
}
type Link struct {
Rel string `xml:"attr"`
Href string `xml:"attr"`
Rel string `xml:"rel,attr"`
Href string `xml:"href,attr"`
}
type Person struct {
Name string
URI string
Email string
InnerXML string `xml:"innerxml"`
Name string `xml:"name"`
URI string `xml:"uri"`
Email string `xml:"email"`
InnerXML string `xml:",innerxml"`
}
type Text struct {
Type string `xml:"attr"`
Body string `xml:"chardata"`
Type string `xml:"type,attr"`
Body string `xml:",chardata"`
}
type Time string
......@@ -213,44 +214,26 @@ not being used from outside intra_region_diff.py.
},
}
type FieldNameTest struct {
in, out string
}
var FieldNameTests = []FieldNameTest{
{"Profile-Image", "profileimage"},
{"_score", "score"},
}
func TestFieldName(t *testing.T) {
for _, tt := range FieldNameTests {
a := fieldName(tt.in)
if a != tt.out {
t.Fatalf("have %#v\nwant %#v\n\n", a, tt.out)
}
}
}
const pathTestString = `
<result>
<before>1</before>
<items>
<item1>
<value>A</value>
</item1>
<item2>
<value>B</value>
</item2>
<Result>
<Before>1</Before>
<Items>
<Item1>
<Value>A</Value>
</Item1>
<Item2>
<Value>B</Value>
</Item2>
<Item1>
<Value>C</Value>
<Value>D</Value>
</Item1>
<_>
<value>E</value>
<Value>E</Value>
</_>
</items>
<after>2</after>
</result>
</Items>
<After>2</After>
</Result>
`
type PathTestItem struct {
......@@ -258,18 +241,18 @@ type PathTestItem struct {
}
type PathTestA struct {
Items []PathTestItem `xml:">item1"`
Items []PathTestItem `xml:">Item1"`
Before, After string
}
type PathTestB struct {
Other []PathTestItem `xml:"items>Item1"`
Other []PathTestItem `xml:"Items>Item1"`
Before, After string
}
type PathTestC struct {
Values1 []string `xml:"items>item1>value"`
Values2 []string `xml:"items>item2>value"`
Values1 []string `xml:"Items>Item1>Value"`
Values2 []string `xml:"Items>Item2>Value"`
Before, After string
}
......@@ -278,12 +261,12 @@ type PathTestSet struct {
}
type PathTestD struct {
Other PathTestSet `xml:"items>"`
Other PathTestSet `xml:"Items"`
Before, After string
}
type PathTestE struct {
Underline string `xml:"items>_>value"`
Underline string `xml:"Items>_>Value"`
Before, After string
}
......@@ -298,7 +281,7 @@ var pathTests = []interface{}{
func TestUnmarshalPaths(t *testing.T) {
for _, pt := range pathTests {
v := reflect.New(reflect.TypeOf(pt).Elem()).Interface()
if err := Unmarshal(StringReader(pathTestString), v); err != nil {
if err := Unmarshal(strings.NewReader(pathTestString), v); err != nil {
t.Fatalf("Unmarshal: %s", err)
}
if !reflect.DeepEqual(v, pt) {
......@@ -310,7 +293,7 @@ func TestUnmarshalPaths(t *testing.T) {
type BadPathTestA struct {
First string `xml:"items>item1"`
Other string `xml:"items>item2"`
Second string `xml:"items>"`
Second string `xml:"items"`
}
type BadPathTestB struct {
......@@ -319,81 +302,55 @@ type BadPathTestB struct {
Second string `xml:"items>item1>value"`
}
type BadPathTestC struct {
First string
Second string `xml:"First"`
}
type BadPathTestD struct {
BadPathEmbeddedA
BadPathEmbeddedB
}
type BadPathEmbeddedA struct {
First string
}
type BadPathEmbeddedB struct {
Second string `xml:"First"`
}
var badPathTests = []struct {
v, e interface{}
}{
{&BadPathTestA{}, &TagPathError{reflect.TypeOf(BadPathTestA{}), "First", "items>item1", "Second", "items>"}},
{&BadPathTestA{}, &TagPathError{reflect.TypeOf(BadPathTestA{}), "First", "items>item1", "Second", "items"}},
{&BadPathTestB{}, &TagPathError{reflect.TypeOf(BadPathTestB{}), "First", "items>item1", "Second", "items>item1>value"}},
{&BadPathTestC{}, &TagPathError{reflect.TypeOf(BadPathTestC{}), "First", "", "Second", "First"}},
{&BadPathTestD{}, &TagPathError{reflect.TypeOf(BadPathTestD{}), "First", "", "Second", "First"}},
}
func TestUnmarshalBadPaths(t *testing.T) {
for _, tt := range badPathTests {
err := Unmarshal(StringReader(pathTestString), tt.v)
err := Unmarshal(strings.NewReader(pathTestString), tt.v)
if !reflect.DeepEqual(err, tt.e) {
t.Fatalf("Unmarshal with %#v didn't fail properly: %#v", tt.v, err)
t.Fatalf("Unmarshal with %#v didn't fail properly:\nhave %#v,\nwant %#v", tt.v, err, tt.e)
}
}
}
func TestUnmarshalAttrs(t *testing.T) {
var f AttrTest
if err := Unmarshal(StringReader(attrString), &f); err != nil {
t.Fatalf("Unmarshal: %s", err)
}
if !reflect.DeepEqual(f, attrStruct) {
t.Fatalf("have %#v\nwant %#v", f, attrStruct)
}
}
type AttrTest struct {
Test1 Test1
Test2 Test2
}
type Test1 struct {
Int int `xml:"attr"`
Float float64 `xml:"attr"`
Uint8 uint8 `xml:"attr"`
}
type Test2 struct {
Bool bool `xml:"attr"`
}
const attrString = `
<?xml version="1.0" charset="utf-8"?>
<attrtest>
<test1 int="8" float="23.5" uint8="255"/>
<test2 bool="true"/>
</attrtest>
`
var attrStruct = AttrTest{
Test1: Test1{
Int: 8,
Float: 23.5,
Uint8: 255,
},
Test2: Test2{
Bool: true,
},
}
// test data for TestUnmarshalWithoutNameType
const OK = "OK"
const withoutNameTypeData = `
<?xml version="1.0" charset="utf-8"?>
<Test3 attr="OK" />`
<Test3 Attr="OK" />`
type TestThree struct {
XMLName bool `xml:"Test3"` // XMLName field without an xml.Name type
Attr string `xml:"attr"`
XMLName Name `xml:"Test3"`
Attr string `xml:",attr"`
}
func TestUnmarshalWithoutNameType(t *testing.T) {
var x TestThree
if err := Unmarshal(StringReader(withoutNameTypeData), &x); err != nil {
if err := Unmarshal(strings.NewReader(withoutNameTypeData), &x); err != nil {
t.Fatalf("Unmarshal: %s", err)
}
if x.Attr != OK {
......
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xml
import (
"fmt"
"reflect"
"strings"
"sync"
)
// typeInfo holds details for the xml representation of a type.
type typeInfo struct {
xmlname *fieldInfo
fields []fieldInfo
}
// fieldInfo holds details for the xml representation of a single field.
type fieldInfo struct {
idx []int
name string
xmlns string
flags fieldFlags
parents []string
}
type fieldFlags int
const (
fElement fieldFlags = 1 << iota
fAttr
fCharData
fInnerXml
fComment
fAny
// TODO:
//fIgnore
//fOmitEmpty
fMode = fElement | fAttr | fCharData | fInnerXml | fComment | fAny
)
var tinfoMap = make(map[reflect.Type]*typeInfo)
var tinfoLock sync.RWMutex
// getTypeInfo returns the typeInfo structure with details necessary
// for marshalling and unmarshalling typ.
func getTypeInfo(typ reflect.Type) (*typeInfo, error) {
tinfoLock.RLock()
tinfo, ok := tinfoMap[typ]
tinfoLock.RUnlock()
if ok {
return tinfo, nil
}
tinfo = &typeInfo{}
if typ.Kind() == reflect.Struct {
n := typ.NumField()
for i := 0; i < n; i++ {
f := typ.Field(i)
if f.PkgPath != "" {
continue // Private field
}
// For embedded structs, embed its fields.
if f.Anonymous {
if f.Type.Kind() != reflect.Struct {
continue
}
inner, err := getTypeInfo(f.Type)
if err != nil {
return nil, err
}
for _, finfo := range inner.fields {
finfo.idx = append([]int{i}, finfo.idx...)
if err := addFieldInfo(typ, tinfo, &finfo); err != nil {
return nil, err
}
}
continue
}
finfo, err := structFieldInfo(typ, &f)
if err != nil {
return nil, err
}
if f.Name == "XMLName" {
tinfo.xmlname = finfo
continue
}
// Add the field if it doesn't conflict with other fields.
if err := addFieldInfo(typ, tinfo, finfo); err != nil {
return nil, err
}
}
}
tinfoLock.Lock()
tinfoMap[typ] = tinfo
tinfoLock.Unlock()
return tinfo, nil
}
// structFieldInfo builds and returns a fieldInfo for f.
func structFieldInfo(typ reflect.Type, f *reflect.StructField) (*fieldInfo, error) {
finfo := &fieldInfo{idx: f.Index}
// Split the tag from the xml namespace if necessary.
tag := f.Tag.Get("xml")
if i := strings.Index(tag, " "); i >= 0 {
finfo.xmlns, tag = tag[:i], tag[i+1:]
}
// Parse flags.
tokens := strings.Split(tag, ",")
if len(tokens) == 1 {
finfo.flags = fElement
} else {
tag = tokens[0]
for _, flag := range tokens[1:] {
switch flag {
case "attr":
finfo.flags |= fAttr
case "chardata":
finfo.flags |= fCharData
case "innerxml":
finfo.flags |= fInnerXml
case "comment":
finfo.flags |= fComment
case "any":
finfo.flags |= fAny
}
}
// Validate the flags used.
switch mode := finfo.flags & fMode; mode {
case 0:
finfo.flags |= fElement
case fAttr, fCharData, fInnerXml, fComment, fAny:
if f.Name != "XMLName" && (tag == "" || mode == fAttr) {
break
}
fallthrough
default:
// This will also catch multiple modes in a single field.
return nil, fmt.Errorf("xml: invalid tag in field %s of type %s: %q",
f.Name, typ, f.Tag.Get("xml"))
}
}
// Use of xmlns without a name is not allowed.
if finfo.xmlns != "" && tag == "" {
return nil, fmt.Errorf("xml: namespace without name in field %s of type %s: %q",
f.Name, typ, f.Tag.Get("xml"))
}
if f.Name == "XMLName" {
// The XMLName field records the XML element name. Don't
// process it as usual because its name should default to
// empty rather than to the field name.
finfo.name = tag
return finfo, nil
}
if tag == "" {
// If the name part of the tag is completely empty, get
// default from XMLName of underlying struct if feasible,
// or field name otherwise.
if xmlname := lookupXMLName(f.Type); xmlname != nil {
finfo.xmlns, finfo.name = xmlname.xmlns, xmlname.name
} else {
finfo.name = f.Name
}
return finfo, nil
}
// Prepare field name and parents.
tokens = strings.Split(tag, ">")
if tokens[0] == "" {
tokens[0] = f.Name
}
if tokens[len(tokens)-1] == "" {
return nil, fmt.Errorf("xml: trailing '>' in field %s of type %s", f.Name, typ)
}
finfo.name = tokens[len(tokens)-1]
if len(tokens) > 1 {
finfo.parents = tokens[:len(tokens)-1]
}
// If the field type has an XMLName field, the names must match
// so that the behavior of both marshalling and unmarshalling
// is straighforward and unambiguous.
if finfo.flags&fElement != 0 {
ftyp := f.Type
xmlname := lookupXMLName(ftyp)
if xmlname != nil && xmlname.name != finfo.name {
return nil, fmt.Errorf("xml: name %q in tag of %s.%s conflicts with name %q in %s.XMLName",
finfo.name, typ, f.Name, xmlname.name, ftyp)
}
}
return finfo, nil
}
// lookupXMLName returns the fieldInfo for typ's XMLName field
// in case it exists and has a valid xml field tag, otherwise
// it returns nil.
func lookupXMLName(typ reflect.Type) (xmlname *fieldInfo) {
for typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
if typ.Kind() != reflect.Struct {
return nil
}
for i, n := 0, typ.NumField(); i < n; i++ {
f := typ.Field(i)
if f.Name != "XMLName" {
continue
}
finfo, err := structFieldInfo(typ, &f)
if finfo.name != "" && err == nil {
return finfo
}
// Also consider errors as a non-existent field tag
// and let getTypeInfo itself report the error.
break
}
return nil
}
func min(a, b int) int {
if a <= b {
return a
}
return b
}
// addFieldInfo adds finfo to tinfo.fields if there are no
// conflicts, or if conflicts arise from previous fields that were
// obtained from deeper embedded structures than finfo. In the latter
// case, the conflicting entries are dropped.
// A conflict occurs when the path (parent + name) to a field is
// itself a prefix of another path, or when two paths match exactly.
// It is okay for field paths to share a common, shorter prefix.
func addFieldInfo(typ reflect.Type, tinfo *typeInfo, newf *fieldInfo) error {
var conflicts []int
Loop:
// First, figure all conflicts. Most working code will have none.
for i := range tinfo.fields {
oldf := &tinfo.fields[i]
if oldf.flags&fMode != newf.flags&fMode {
continue
}
minl := min(len(newf.parents), len(oldf.parents))
for p := 0; p < minl; p++ {
if oldf.parents[p] != newf.parents[p] {
continue Loop
}
}
if len(oldf.parents) > len(newf.parents) {
if oldf.parents[len(newf.parents)] == newf.name {
conflicts = append(conflicts, i)
}
} else if len(oldf.parents) < len(newf.parents) {
if newf.parents[len(oldf.parents)] == oldf.name {
conflicts = append(conflicts, i)
}
} else {
if newf.name == oldf.name {
conflicts = append(conflicts, i)
}
}
}
// Without conflicts, add the new field and return.
if conflicts == nil {
tinfo.fields = append(tinfo.fields, *newf)
return nil
}
// If any conflict is shallower, ignore the new field.
// This matches the Go field resolution on embedding.
for _, i := range conflicts {
if len(tinfo.fields[i].idx) < len(newf.idx) {
return nil
}
}
// Otherwise, if any of them is at the same depth level, it's an error.
for _, i := range conflicts {
oldf := &tinfo.fields[i]
if len(oldf.idx) == len(newf.idx) {
f1 := typ.FieldByIndex(oldf.idx)
f2 := typ.FieldByIndex(newf.idx)
return &TagPathError{typ, f1.Name, f1.Tag.Get("xml"), f2.Name, f2.Tag.Get("xml")}
}
}
// Otherwise, the new field is shallower, and thus takes precedence,
// so drop the conflicting fields from tinfo and append the new one.
for c := len(conflicts) - 1; c >= 0; c-- {
i := conflicts[c]
copy(tinfo.fields[i:], tinfo.fields[i+1:])
tinfo.fields = tinfo.fields[:len(tinfo.fields)-1]
}
tinfo.fields = append(tinfo.fields, *newf)
return nil
}
// A TagPathError represents an error in the unmarshalling process
// caused by the use of field tags with conflicting paths.
type TagPathError struct {
Struct reflect.Type
Field1, Tag1 string
Field2, Tag2 string
}
func (e *TagPathError) Error() string {
return fmt.Sprintf("%s field %q with tag %q conflicts with field %q with tag %q", e.Struct, e.Field1, e.Tag1, e.Field2, e.Tag2)
}
......@@ -154,36 +154,8 @@ var xmlInput = []string{
"<t>cdata]]></t>",
}
type stringReader struct {
s string
off int
}
func (r *stringReader) Read(b []byte) (n int, err error) {
if r.off >= len(r.s) {
return 0, io.EOF
}
for r.off < len(r.s) && n < len(b) {
b[n] = r.s[r.off]
n++
r.off++
}
return
}
func (r *stringReader) ReadByte() (b byte, err error) {
if r.off >= len(r.s) {
return 0, io.EOF
}
b = r.s[r.off]
r.off++
return
}
func StringReader(s string) io.Reader { return &stringReader{s, 0} }
func TestRawToken(t *testing.T) {
p := NewParser(StringReader(testInput))
p := NewParser(strings.NewReader(testInput))
testRawToken(t, p, rawTokens)
}
......@@ -207,7 +179,7 @@ func (d *downCaser) Read(p []byte) (int, error) {
func TestRawTokenAltEncoding(t *testing.T) {
sawEncoding := ""
p := NewParser(StringReader(testInputAltEncoding))
p := NewParser(strings.NewReader(testInputAltEncoding))
p.CharsetReader = func(charset string, input io.Reader) (io.Reader, error) {
sawEncoding = charset
if charset != "x-testing-uppercase" {
......@@ -219,7 +191,7 @@ func TestRawTokenAltEncoding(t *testing.T) {
}
func TestRawTokenAltEncodingNoConverter(t *testing.T) {
p := NewParser(StringReader(testInputAltEncoding))
p := NewParser(strings.NewReader(testInputAltEncoding))
token, err := p.RawToken()
if token == nil {
t.Fatalf("expected a token on first RawToken call")
......@@ -286,7 +258,7 @@ var nestedDirectivesTokens = []Token{
}
func TestNestedDirectives(t *testing.T) {
p := NewParser(StringReader(nestedDirectivesInput))
p := NewParser(strings.NewReader(nestedDirectivesInput))
for i, want := range nestedDirectivesTokens {
have, err := p.Token()
......@@ -300,7 +272,7 @@ func TestNestedDirectives(t *testing.T) {
}
func TestToken(t *testing.T) {
p := NewParser(StringReader(testInput))
p := NewParser(strings.NewReader(testInput))
for i, want := range cookedTokens {
have, err := p.Token()
......@@ -315,7 +287,7 @@ func TestToken(t *testing.T) {
func TestSyntax(t *testing.T) {
for i := range xmlInput {
p := NewParser(StringReader(xmlInput[i]))
p := NewParser(strings.NewReader(xmlInput[i]))
var err error
for _, err = p.Token(); err == nil; _, err = p.Token() {
}
......@@ -372,26 +344,26 @@ var all = allScalars{
var sixteen = "16"
const testScalarsInput = `<allscalars>
<true1>true</true1>
<true2>1</true2>
<false1>false</false1>
<false2>0</false2>
<int>1</int>
<int8>-2</int8>
<int16>3</int16>
<int32>-4</int32>
<int64>5</int64>
<uint>6</uint>
<uint8>7</uint8>
<uint16>8</uint16>
<uint32>9</uint32>
<uint64>10</uint64>
<uintptr>11</uintptr>
<float>12.0</float>
<float32>13.0</float32>
<float64>14.0</float64>
<string>15</string>
<ptrstring>16</ptrstring>
<True1>true</True1>
<True2>1</True2>
<False1>false</False1>
<False2>0</False2>
<Int>1</Int>
<Int8>-2</Int8>
<Int16>3</Int16>
<Int32>-4</Int32>
<Int64>5</Int64>
<Uint>6</Uint>
<Uint8>7</Uint8>
<Uint16>8</Uint16>
<Uint32>9</Uint32>
<Uint64>10</Uint64>
<Uintptr>11</Uintptr>
<Float>12.0</Float>
<Float32>13.0</Float32>
<Float64>14.0</Float64>
<String>15</String>
<PtrString>16</PtrString>
</allscalars>`
func TestAllScalars(t *testing.T) {
......@@ -412,7 +384,7 @@ type item struct {
}
func TestIssue569(t *testing.T) {
data := `<item><field_a>abcd</field_a></item>`
data := `<item><Field_a>abcd</Field_a></item>`
var i item
buf := bytes.NewBufferString(data)
err := Unmarshal(buf, &i)
......@@ -424,7 +396,7 @@ func TestIssue569(t *testing.T) {
func TestUnquotedAttrs(t *testing.T) {
data := "<tag attr=azAZ09:-_\t>"
p := NewParser(StringReader(data))
p := NewParser(strings.NewReader(data))
p.Strict = false
token, err := p.Token()
if _, ok := err.(*SyntaxError); ok {
......@@ -450,7 +422,7 @@ func TestValuelessAttrs(t *testing.T) {
{"<input checked />", "input", "checked"},
}
for _, test := range tests {
p := NewParser(StringReader(test[0]))
p := NewParser(strings.NewReader(test[0]))
p.Strict = false
token, err := p.Token()
if _, ok := err.(*SyntaxError); ok {
......@@ -500,7 +472,7 @@ func TestCopyTokenStartElement(t *testing.T) {
func TestSyntaxErrorLineNum(t *testing.T) {
testInput := "<P>Foo<P>\n\n<P>Bar</>\n"
p := NewParser(StringReader(testInput))
p := NewParser(strings.NewReader(testInput))
var err error
for _, err = p.Token(); err == nil; _, err = p.Token() {
}
......@@ -515,7 +487,7 @@ func TestSyntaxErrorLineNum(t *testing.T) {
func TestTrailingRawToken(t *testing.T) {
input := `<FOO></FOO> `
p := NewParser(StringReader(input))
p := NewParser(strings.NewReader(input))
var err error
for _, err = p.RawToken(); err == nil; _, err = p.RawToken() {
}
......@@ -526,7 +498,7 @@ func TestTrailingRawToken(t *testing.T) {
func TestTrailingToken(t *testing.T) {
input := `<FOO></FOO> `
p := NewParser(StringReader(input))
p := NewParser(strings.NewReader(input))
var err error
for _, err = p.Token(); err == nil; _, err = p.Token() {
}
......@@ -537,7 +509,7 @@ func TestTrailingToken(t *testing.T) {
func TestEntityInsideCDATA(t *testing.T) {
input := `<test><![CDATA[ &val=foo ]]></test>`
p := NewParser(StringReader(input))
p := NewParser(strings.NewReader(input))
var err error
for _, err = p.Token(); err == nil; _, err = p.Token() {
}
......@@ -569,7 +541,7 @@ var characterTests = []struct {
func TestDisallowedCharacters(t *testing.T) {
for i, tt := range characterTests {
p := NewParser(StringReader(tt.in))
p := NewParser(strings.NewReader(tt.in))
var err error
for err == nil {
......
......@@ -8,7 +8,7 @@ import "unicode/utf8"
type input interface {
skipASCII(p int) int
skipNonStarter() int
skipNonStarter(p int) int
appendSlice(buf []byte, s, e int) []byte
copySlice(buf []byte, s, e int)
charinfo(p int) (uint16, int)
......@@ -25,8 +25,7 @@ func (s inputString) skipASCII(p int) int {
return p
}
func (s inputString) skipNonStarter() int {
p := 0
func (s inputString) skipNonStarter(p int) int {
for ; p < len(s) && !utf8.RuneStart(s[p]); p++ {
}
return p
......@@ -71,8 +70,7 @@ func (s inputBytes) skipASCII(p int) int {
return p
}
func (s inputBytes) skipNonStarter() int {
p := 0
func (s inputBytes) skipNonStarter(p int) int {
for ; p < len(s) && !utf8.RuneStart(s[p]); p++ {
}
return p
......
......@@ -34,24 +34,28 @@ const (
// Bytes returns f(b). May return b if f(b) = b.
func (f Form) Bytes(b []byte) []byte {
n := f.QuickSpan(b)
rb := reorderBuffer{}
rb.init(f, b)
n := quickSpan(&rb, 0)
if n == len(b) {
return b
}
out := make([]byte, n, len(b))
copy(out, b[0:n])
return f.Append(out, b[n:]...)
return doAppend(&rb, out, n)
}
// String returns f(s).
func (f Form) String(s string) string {
n := f.QuickSpanString(s)
rb := reorderBuffer{}
rb.initString(f, s)
n := quickSpan(&rb, 0)
if n == len(s) {
return s
}
out := make([]byte, 0, len(s))
out := make([]byte, n, len(s))
copy(out, s[0:n])
return string(f.AppendString(out, s[n:]))
return string(doAppend(&rb, out, n))
}
// IsNormal returns true if b == f(b).
......@@ -122,23 +126,27 @@ func (f Form) IsNormalString(s string) bool {
// patchTail fixes a case where a rune may be incorrectly normalized
// if it is followed by illegal continuation bytes. It returns the
// patched buffer and the number of trailing continuation bytes that
// have been dropped.
func patchTail(rb *reorderBuffer, buf []byte) ([]byte, int) {
// patched buffer and whether there were trailing continuation bytes.
func patchTail(rb *reorderBuffer, buf []byte) ([]byte, bool) {
info, p := lastRuneStart(&rb.f, buf)
if p == -1 || info.size == 0 {
return buf, 0
return buf, false
}
end := p + int(info.size)
extra := len(buf) - end
if extra > 0 {
// Potentially allocating memory. However, this only
// happens with ill-formed UTF-8.
x := make([]byte, 0)
x = append(x, buf[len(buf)-extra:]...)
buf = decomposeToLastBoundary(rb, buf[:end])
if rb.f.composing {
rb.compose()
}
return rb.flush(buf), extra
buf = rb.flush(buf)
return append(buf, x...), true
}
return buf, 0
return buf, false
}
func appendQuick(rb *reorderBuffer, dst []byte, i int) ([]byte, int) {
......@@ -157,23 +165,23 @@ func (f Form) Append(out []byte, src ...byte) []byte {
}
rb := reorderBuffer{}
rb.init(f, src)
return doAppend(&rb, out)
return doAppend(&rb, out, 0)
}
func doAppend(rb *reorderBuffer, out []byte) []byte {
func doAppend(rb *reorderBuffer, out []byte, p int) []byte {
src, n := rb.src, rb.nsrc
doMerge := len(out) > 0
p := 0
if p = src.skipNonStarter(); p > 0 {
if q := src.skipNonStarter(p); q > p {
// Move leading non-starters to destination.
out = src.appendSlice(out, 0, p)
buf, ndropped := patchTail(rb, out)
if ndropped > 0 {
out = src.appendSlice(buf, p-ndropped, p)
out = src.appendSlice(out, p, q)
buf, endsInError := patchTail(rb, out)
if endsInError {
out = buf
doMerge = false // no need to merge, ends with illegal UTF-8
} else {
out = decomposeToLastBoundary(rb, buf) // force decomposition
}
p = q
}
fd := &rb.f
if doMerge {
......@@ -217,7 +225,7 @@ func (f Form) AppendString(out []byte, src string) []byte {
}
rb := reorderBuffer{}
rb.initString(f, src)
return doAppend(&rb, out)
return doAppend(&rb, out, 0)
}
// QuickSpan returns a boundary n such that b[0:n] == f(b[0:n]).
......@@ -225,7 +233,8 @@ func (f Form) AppendString(out []byte, src string) []byte {
func (f Form) QuickSpan(b []byte) int {
rb := reorderBuffer{}
rb.init(f, b)
return quickSpan(&rb, 0)
n := quickSpan(&rb, 0)
return n
}
func quickSpan(rb *reorderBuffer, i int) int {
......@@ -301,7 +310,7 @@ func (f Form) FirstBoundary(b []byte) int {
func firstBoundary(rb *reorderBuffer) int {
src, nsrc := rb.src, rb.nsrc
i := src.skipNonStarter()
i := src.skipNonStarter(0)
if i >= nsrc {
return -1
}
......
......@@ -253,7 +253,7 @@ var quickSpanNFDTests = []PositionTest{
{"\u0316\u0300cd", 6, ""},
{"\u043E\u0308b", 5, ""},
// incorrectly ordered combining characters
{"ab\u0300\u0316", 1, ""}, // TODO(mpvl): we could skip 'b' as well.
{"ab\u0300\u0316", 1, ""}, // TODO: we could skip 'b' as well.
{"ab\u0300\u0316cd", 1, ""},
// Hangul
{"같은", 0, ""},
......@@ -465,6 +465,7 @@ var appendTests = []AppendTest{
{"\u0300", "\xFC\x80\x80\x80\x80\x80\u0300", "\u0300\xFC\x80\x80\x80\x80\x80\u0300"},
{"\xF8\x80\x80\x80\x80\u0300", "\u0300", "\xF8\x80\x80\x80\x80\u0300\u0300"},
{"\xFC\x80\x80\x80\x80\x80\u0300", "\u0300", "\xFC\x80\x80\x80\x80\x80\u0300\u0300"},
{"\xF8\x80\x80\x80", "\x80\u0300\u0300", "\xF8\x80\x80\x80\x80\u0300\u0300"},
}
func appendF(f Form, out []byte, s string) []byte {
......@@ -475,9 +476,23 @@ func appendStringF(f Form, out []byte, s string) []byte {
return f.AppendString(out, s)
}
func bytesF(f Form, out []byte, s string) []byte {
buf := []byte{}
buf = append(buf, out...)
buf = append(buf, s...)
return f.Bytes(buf)
}
func stringF(f Form, out []byte, s string) []byte {
outs := string(out) + s
return []byte(f.String(outs))
}
func TestAppend(t *testing.T) {
runAppendTests(t, "TestAppend", NFKC, appendF, appendTests)
runAppendTests(t, "TestAppendString", NFKC, appendStringF, appendTests)
runAppendTests(t, "TestBytes", NFKC, bytesF, appendTests)
runAppendTests(t, "TestString", NFKC, stringF, appendTests)
}
func doFormBenchmark(b *testing.B, f Form, s string) {
......
......@@ -27,7 +27,7 @@ func (w *normWriter) Write(data []byte) (n int, err error) {
}
w.rb.src = inputBytes(data[:m])
w.rb.nsrc = m
w.buf = doAppend(&w.rb, w.buf)
w.buf = doAppend(&w.rb, w.buf, 0)
data = data[m:]
n += m
......@@ -101,7 +101,7 @@ func (r *normReader) Read(p []byte) (int, error) {
r.rb.src = inputBytes(r.inbuf[0:n])
r.rb.nsrc, r.err = n, err
if n > 0 {
r.outbuf = doAppend(&r.rb, r.outbuf)
r.outbuf = doAppend(&r.rb, r.outbuf, 0)
}
if err == io.EOF {
r.lastBoundary = len(r.outbuf)
......
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package proxy
import (
"net"
)
type direct struct{}
// Direct is a direct proxy: one that makes network connections directly.
var Direct = direct{}
func (direct) Dial(network, addr string) (net.Conn, error) {
return net.Dial(network, addr)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package proxy
import (
"net"
"strings"
)
// A PerHost directs connections to a default Dailer unless the hostname
// requested matches one of a number of exceptions.
type PerHost struct {
def, bypass Dialer
bypassNetworks []*net.IPNet
bypassIPs []net.IP
bypassZones []string
bypassHosts []string
}
// NewPerHost returns a PerHost Dialer that directs connections to either
// defaultDialer or bypass, depending on whether the connection matches one of
// the configured rules.
func NewPerHost(defaultDialer, bypass Dialer) *PerHost {
return &PerHost{
def: defaultDialer,
bypass: bypass,
}
}
// Dial connects to the address addr on the network net through either
// defaultDialer or bypass.
func (p *PerHost) Dial(network, addr string) (c net.Conn, err error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
return p.dialerForRequest(host).Dial(network, addr)
}
func (p *PerHost) dialerForRequest(host string) Dialer {
if ip := net.ParseIP(host); ip != nil {
for _, net := range p.bypassNetworks {
if net.Contains(ip) {
return p.bypass
}
}
for _, bypassIP := range p.bypassIPs {
if bypassIP.Equal(ip) {
return p.bypass
}
}
return p.def
}
for _, zone := range p.bypassZones {
if strings.HasSuffix(host, zone) {
return p.bypass
}
if host == zone[1:] {
// For a zone "example.com", we match "example.com"
// too.
return p.bypass
}
}
for _, bypassHost := range p.bypassHosts {
if bypassHost == host {
return p.bypass
}
}
return p.def
}
// AddFromString parses a string that contains comma-separated values
// specifing hosts that should use the bypass proxy. Each value is either an
// IP address, a CIDR range, a zone (*.example.com) or a hostname
// (localhost). A best effort is made to parse the string and errors are
// ignored.
func (p *PerHost) AddFromString(s string) {
hosts := strings.Split(s, ",")
for _, host := range hosts {
host = strings.TrimSpace(host)
if len(host) == 0 {
continue
}
if strings.Contains(host, "/") {
// We assume that it's a CIDR address like 127.0.0.0/8
if _, net, err := net.ParseCIDR(host); err == nil {
p.AddNetwork(net)
}
continue
}
if ip := net.ParseIP(host); ip != nil {
p.AddIP(ip)
continue
}
if strings.HasPrefix(host, "*.") {
p.AddZone(host[1:])
continue
}
p.AddHost(host)
}
}
// AddIP specifies an IP address that will use the bypass proxy. Note that
// this will only take effect if a literal IP address is dialed. A connection
// to a named host will never match an IP.
func (p *PerHost) AddIP(ip net.IP) {
p.bypassIPs = append(p.bypassIPs, ip)
}
// AddIP specifies an IP range that will use the bypass proxy. Note that this
// will only take effect if a literal IP address is dialed. A connection to a
// named host will never match.
func (p *PerHost) AddNetwork(net *net.IPNet) {
p.bypassNetworks = append(p.bypassNetworks, net)
}
// AddZone specifies a DNS suffix that will use the bypass proxy. A zone of
// "example.com" matches "example.com" and all of its subdomains.
func (p *PerHost) AddZone(zone string) {
if strings.HasSuffix(zone, ".") {
zone = zone[:len(zone)-1]
}
if !strings.HasPrefix(zone, ".") {
zone = "." + zone
}
p.bypassZones = append(p.bypassZones, zone)
}
// AddHost specifies a hostname that will use the bypass proxy.
func (p *PerHost) AddHost(host string) {
if strings.HasSuffix(host, ".") {
host = host[:len(host)-1]
}
p.bypassHosts = append(p.bypassHosts, host)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package proxy
import (
"errors"
"net"
"reflect"
"testing"
)
type recordingProxy struct {
addrs []string
}
func (r *recordingProxy) Dial(network, addr string) (net.Conn, error) {
r.addrs = append(r.addrs, addr)
return nil, errors.New("recordingProxy")
}
func TestPerHost(t *testing.T) {
var def, bypass recordingProxy
perHost := NewPerHost(&def, &bypass)
perHost.AddFromString("localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16")
expectedDef := []string{
"example.com:123",
"1.2.3.4:123",
"[1001::]:123",
}
expectedBypass := []string{
"localhost:123",
"zone:123",
"foo.zone:123",
"127.0.0.1:123",
"10.1.2.3:123",
"[1000::]:123",
}
for _, addr := range expectedDef {
perHost.Dial("tcp", addr)
}
for _, addr := range expectedBypass {
perHost.Dial("tcp", addr)
}
if !reflect.DeepEqual(expectedDef, def.addrs) {
t.Errorf("Hosts which went to the default proxy didn't match. Got %v, want %v", def.addrs, expectedDef)
}
if !reflect.DeepEqual(expectedBypass, bypass.addrs) {
t.Errorf("Hosts which went to the bypass proxy didn't match. Got %v, want %v", bypass.addrs, expectedBypass)
}
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package proxy provides support for a variety of protocols to proxy network
// data.
package proxy
import (
"errors"
"net"
"net/url"
"os"
"strings"
)
// A Dialer is a means to establish a connection.
type Dialer interface {
// Dial connects to the given address via the proxy.
Dial(network, addr string) (c net.Conn, err error)
}
// Auth contains authentication parameters that specific Dialers may require.
type Auth struct {
User, Password string
}
// DefaultDialer returns the dialer specified by the proxy related variables in
// the environment.
func FromEnvironment() Dialer {
allProxy := os.Getenv("all_proxy")
if len(allProxy) == 0 {
return Direct
}
proxyURL, err := url.Parse(allProxy)
if err != nil {
return Direct
}
proxy, err := FromURL(proxyURL, Direct)
if err != nil {
return Direct
}
noProxy := os.Getenv("no_proxy")
if len(noProxy) == 0 {
return proxy
}
perHost := NewPerHost(proxy, Direct)
perHost.AddFromString(noProxy)
return perHost
}
// proxySchemes is a map from URL schemes to a function that creates a Dialer
// from a URL with such a scheme.
var proxySchemes map[string]func(*url.URL, Dialer) (Dialer, error)
// RegisterDialerType takes a URL scheme and a function to generate Dialers from
// a URL with that scheme and a forwarding Dialer. Registered schemes are used
// by FromURL.
func RegisterDialerType(scheme string, f func(*url.URL, Dialer) (Dialer, error)) {
if proxySchemes == nil {
proxySchemes = make(map[string]func(*url.URL, Dialer) (Dialer, error))
}
proxySchemes[scheme] = f
}
// FromURL returns a Dialer given a URL specification and an underlying
// Dialer for it to make network requests.
func FromURL(u *url.URL, forward Dialer) (Dialer, error) {
var auth *Auth
if len(u.RawUserinfo) > 0 {
auth = new(Auth)
parts := strings.SplitN(u.RawUserinfo, ":", 1)
if len(parts) == 1 {
auth.User = parts[0]
} else if len(parts) >= 2 {
auth.User = parts[0]
auth.Password = parts[1]
}
}
switch u.Scheme {
case "socks5":
return SOCKS5("tcp", u.Host, auth, forward)
}
// If the scheme doesn't match any of the built-in schemes, see if it
// was registered by another package.
if proxySchemes != nil {
if f, ok := proxySchemes[u.Scheme]; ok {
return f(u, forward)
}
}
return nil, errors.New("proxy: unknown scheme: " + u.Scheme)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package proxy
import (
"net"
"net/url"
"testing"
)
type testDialer struct {
network, addr string
}
func (t *testDialer) Dial(network, addr string) (net.Conn, error) {
t.network = network
t.addr = addr
return nil, t
}
func (t *testDialer) Error() string {
return "testDialer " + t.network + " " + t.addr
}
func TestFromURL(t *testing.T) {
u, err := url.Parse("socks5://user:password@1.2.3.4:5678")
if err != nil {
t.Fatalf("failed to parse URL: %s", err)
}
tp := &testDialer{}
proxy, err := FromURL(u, tp)
if err != nil {
t.Fatalf("FromURL failed: %s", err)
}
conn, err := proxy.Dial("tcp", "example.com:80")
if conn != nil {
t.Error("Dial unexpected didn't return an error")
}
if tp, ok := err.(*testDialer); ok {
if tp.network != "tcp" || tp.addr != "1.2.3.4:5678" {
t.Errorf("Dialer connected to wrong host. Wanted 1.2.3.4:5678, got: %v", tp)
}
} else {
t.Errorf("Unexpected error from Dial: %s", err)
}
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package proxy
import (
"errors"
"io"
"net"
"strconv"
)
// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given address
// with an optional username and password. See RFC 1928.
func SOCKS5(network, addr string, auth *Auth, forward Dialer) (Dialer, error) {
s := &socks5{
network: network,
addr: addr,
forward: forward,
}
if auth != nil {
s.user = auth.User
s.password = auth.Password
}
return s, nil
}
type socks5 struct {
user, password string
network, addr string
forward Dialer
}
const socks5Version = 5
const (
socks5AuthNone = 0
socks5AuthPassword = 2
)
const socks5Connect = 1
const (
socks5IP4 = 1
socks5Domain = 3
socks5IP6 = 4
)
var socks5Errors = []string{
"",
"general failure",
"connection forbidden",
"network unreachable",
"host unreachable",
"connection refused",
"TTL expired",
"command not supported",
"address type not supported",
}
// Dial connects to the address addr on the network net via the SOCKS5 proxy.
func (s *socks5) Dial(network, addr string) (net.Conn, error) {
switch network {
case "tcp", "tcp6", "tcp4":
break
default:
return nil, errors.New("proxy: no support for SOCKS5 proxy connections of type " + network)
}
conn, err := s.forward.Dial(s.network, s.addr)
if err != nil {
return nil, err
}
closeConn := &conn
defer func() {
if closeConn != nil {
(*closeConn).Close()
}
}()
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
port, err := strconv.Atoi(portStr)
if err != nil {
return nil, errors.New("proxy: failed to parse port number: " + portStr)
}
if port < 1 || port > 0xffff {
return nil, errors.New("proxy: port number out of range: " + portStr)
}
// the size here is just an estimate
buf := make([]byte, 0, 6+len(host))
buf = append(buf, socks5Version)
if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 {
buf = append(buf, 2, /* num auth methods */ socks5AuthNone, socks5AuthPassword)
} else {
buf = append(buf, 1, /* num auth methods */ socks5AuthNone)
}
if _, err = conn.Write(buf); err != nil {
return nil, errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err = io.ReadFull(conn, buf[:2]); err != nil {
return nil, errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if buf[0] != 5 {
return nil, errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0])))
}
if buf[1] == 0xff {
return nil, errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication")
}
if buf[1] == socks5AuthPassword {
buf = buf[:0]
buf = append(buf, socks5Version)
buf = append(buf, uint8(len(s.user)))
buf = append(buf, s.user...)
buf = append(buf, uint8(len(s.password)))
buf = append(buf, s.password...)
if _, err = conn.Write(buf); err != nil {
return nil, errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err = io.ReadFull(conn, buf[:2]); err != nil {
return nil, errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if buf[1] != 0 {
return nil, errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password")
}
}
buf = buf[:0]
buf = append(buf, socks5Version, socks5Connect, 0 /* reserved */ )
if ip := net.ParseIP(host); ip != nil {
if len(ip) == 4 {
buf = append(buf, socks5IP4)
} else {
buf = append(buf, socks5IP6)
}
buf = append(buf, []byte(ip)...)
} else {
buf = append(buf, socks5Domain)
buf = append(buf, byte(len(host)))
buf = append(buf, host...)
}
buf = append(buf, byte(port>>8), byte(port))
if _, err = conn.Write(buf); err != nil {
return nil, errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err = io.ReadFull(conn, buf[:4]); err != nil {
return nil, errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
failure := "unknown error"
if int(buf[1]) < len(socks5Errors) {
failure = socks5Errors[buf[1]]
}
if len(failure) > 0 {
return nil, errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure)
}
bytesToDiscard := 0
switch buf[3] {
case socks5IP4:
bytesToDiscard = 4
case socks5IP6:
bytesToDiscard = 16
case socks5Domain:
_, err := io.ReadFull(conn, buf[:1])
if err != nil {
return nil, errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
bytesToDiscard = int(buf[0])
default:
return nil, errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr)
}
if cap(buf) < bytesToDiscard {
buf = make([]byte, bytesToDiscard)
} else {
buf = buf[:bytesToDiscard]
}
if _, err = io.ReadFull(conn, buf); err != nil {
return nil, errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
// Also need to discard the port number
if _, err = io.ReadFull(conn, buf[:2]); err != nil {
return nil, errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
closeConn = nil
return conn, nil
}
......@@ -8,8 +8,11 @@ import (
"fmt"
"reflect"
"testing"
"time"
)
var someTime = time.Unix(123, 0)
type conversionTest struct {
s, d interface{} // source and destination
......@@ -19,6 +22,7 @@ type conversionTest struct {
wantstr string
wantf32 float32
wantf64 float64
wanttime time.Time
wantbool bool // used if d is of type *bool
wanterr string
}
......@@ -35,12 +39,14 @@ var (
scanbool bool
scanf32 float32
scanf64 float64
scantime time.Time
)
var conversionTests = []conversionTest{
// Exact conversions (destination pointer type matches source type)
{s: "foo", d: &scanstr, wantstr: "foo"},
{s: 123, d: &scanint, wantint: 123},
{s: someTime, d: &scantime, wanttime: someTime},
// To strings
{s: []byte("byteslice"), d: &scanstr, wantstr: "byteslice"},
......@@ -106,6 +112,10 @@ func float32Value(ptr interface{}) float32 {
return *(ptr.(*float32))
}
func timeValue(ptr interface{}) time.Time {
return *(ptr.(*time.Time))
}
func TestConversions(t *testing.T) {
for n, ct := range conversionTests {
err := convertAssign(ct.d, ct.s)
......@@ -138,6 +148,9 @@ func TestConversions(t *testing.T) {
if bp, boolTest := ct.d.(*bool); boolTest && *bp != ct.wantbool && ct.wanterr == "" {
errf("want bool %v, got %v", ct.wantbool, *bp)
}
if !ct.wanttime.IsZero() && !ct.wanttime.Equal(timeValue(ct.d)) {
errf("want time %v, got %v", ct.wanttime, timeValue(ct.d))
}
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment