Commit df1304ee by Ian Lance Taylor

libgo: Update to weekly.2012-01-15.

From-SVN: r183539
parent 3be18e47
4a8268927758 354b17404643
The first line of this file holds the Mercurial revision number of the The first line of this file holds the Mercurial revision number of the
last merge done from the master library sources. last merge done from the master library sources.
...@@ -188,7 +188,7 @@ toolexeclibgocryptoopenpgpdir = $(toolexeclibgocryptodir)/openpgp ...@@ -188,7 +188,7 @@ toolexeclibgocryptoopenpgpdir = $(toolexeclibgocryptodir)/openpgp
toolexeclibgocryptoopenpgp_DATA = \ toolexeclibgocryptoopenpgp_DATA = \
crypto/openpgp/armor.gox \ crypto/openpgp/armor.gox \
crypto/openpgp/elgamal.gox \ crypto/openpgp/elgamal.gox \
crypto/openpgp/error.gox \ crypto/openpgp/errors.gox \
crypto/openpgp/packet.gox \ crypto/openpgp/packet.gox \
crypto/openpgp/s2k.gox crypto/openpgp/s2k.gox
...@@ -235,6 +235,7 @@ toolexeclibgoexp_DATA = \ ...@@ -235,6 +235,7 @@ toolexeclibgoexp_DATA = \
exp/ebnf.gox \ exp/ebnf.gox \
$(exp_inotify_gox) \ $(exp_inotify_gox) \
exp/norm.gox \ exp/norm.gox \
exp/proxy.gox \
exp/spdy.gox \ exp/spdy.gox \
exp/sql.gox \ exp/sql.gox \
exp/ssh.gox \ exp/ssh.gox \
...@@ -669,17 +670,25 @@ endif # !LIBGO_IS_RTEMS ...@@ -669,17 +670,25 @@ endif # !LIBGO_IS_RTEMS
if LIBGO_IS_LINUX if LIBGO_IS_LINUX
go_net_cgo_file = go/net/cgo_linux.go go_net_cgo_file = go/net/cgo_linux.go
go_net_sock_file = go/net/sock_linux.go go_net_sock_file = go/net/sock_linux.go
go_net_sockopt_file = go/net/sockopt_linux.go
go_net_sockoptip_file = go/net/sockoptip_linux.go
else else
if LIBGO_IS_IRIX if LIBGO_IS_IRIX
go_net_cgo_file = go/net/cgo_linux.go go_net_cgo_file = go/net/cgo_linux.go
go_net_sock_file = go/net/sock_linux.go go_net_sock_file = go/net/sock_linux.go
go_net_sockopt_file = go/net/sockopt_linux.go
go_net_sockoptip_file = go/net/sockoptip_linux.go
else else
if LIBGO_IS_SOLARIS if LIBGO_IS_SOLARIS
go_net_cgo_file = go/net/cgo_linux.go go_net_cgo_file = go/net/cgo_linux.go
go_net_sock_file = go/net/sock_linux.go go_net_sock_file = go/net/sock_linux.go
go_net_sockopt_file = go/net/sockopt_linux.go
go_net_sockoptip_file = go/net/sockoptip_linux.go
else else
go_net_cgo_file = go/net/cgo_bsd.go go_net_cgo_file = go/net/cgo_bsd.go
go_net_sock_file = go/net/sock_bsd.go go_net_sock_file = go/net/sock_bsd.go
go_net_sockopt_file = go/net/sockopt_bsd.go
go_net_sockoptip_file = go/net/sockoptip_bsd.go
endif endif
endif endif
endif endif
...@@ -728,6 +737,10 @@ go_net_files = \ ...@@ -728,6 +737,10 @@ go_net_files = \
$(go_net_sendfile_file) \ $(go_net_sendfile_file) \
go/net/sock.go \ go/net/sock.go \
$(go_net_sock_file) \ $(go_net_sock_file) \
go/net/sockopt.go \
$(go_net_sockopt_file) \
go/net/sockoptip.go \
$(go_net_sockoptip_file) \
go/net/tcpsock.go \ go/net/tcpsock.go \
go/net/tcpsock_posix.go \ go/net/tcpsock_posix.go \
go/net/udpsock.go \ go/net/udpsock.go \
...@@ -890,8 +903,7 @@ go_syslog_c_files = \ ...@@ -890,8 +903,7 @@ go_syslog_c_files = \
go_testing_files = \ go_testing_files = \
go/testing/benchmark.go \ go/testing/benchmark.go \
go/testing/example.go \ go/testing/example.go \
go/testing/testing.go \ go/testing/testing.go
go/testing/wrapper.go
go_time_files = \ go_time_files = \
go/time/format.go \ go/time/format.go \
...@@ -1061,8 +1073,8 @@ go_crypto_openpgp_armor_files = \ ...@@ -1061,8 +1073,8 @@ go_crypto_openpgp_armor_files = \
go/crypto/openpgp/armor/encode.go go/crypto/openpgp/armor/encode.go
go_crypto_openpgp_elgamal_files = \ go_crypto_openpgp_elgamal_files = \
go/crypto/openpgp/elgamal/elgamal.go go/crypto/openpgp/elgamal/elgamal.go
go_crypto_openpgp_error_files = \ go_crypto_openpgp_errors_files = \
go/crypto/openpgp/error/error.go go/crypto/openpgp/errors/errors.go
go_crypto_openpgp_packet_files = \ go_crypto_openpgp_packet_files = \
go/crypto/openpgp/packet/compressed.go \ go/crypto/openpgp/packet/compressed.go \
go/crypto/openpgp/packet/encrypted_key.go \ go/crypto/openpgp/packet/encrypted_key.go \
...@@ -1142,6 +1154,7 @@ go_encoding_pem_files = \ ...@@ -1142,6 +1154,7 @@ go_encoding_pem_files = \
go_encoding_xml_files = \ go_encoding_xml_files = \
go/encoding/xml/marshal.go \ go/encoding/xml/marshal.go \
go/encoding/xml/read.go \ go/encoding/xml/read.go \
go/encoding/xml/typeinfo.go \
go/encoding/xml/xml.go go/encoding/xml/xml.go
go_exp_ebnf_files = \ go_exp_ebnf_files = \
...@@ -1157,6 +1170,11 @@ go_exp_norm_files = \ ...@@ -1157,6 +1170,11 @@ go_exp_norm_files = \
go/exp/norm/readwriter.go \ go/exp/norm/readwriter.go \
go/exp/norm/tables.go \ go/exp/norm/tables.go \
go/exp/norm/trie.go go/exp/norm/trie.go
go_exp_proxy_files = \
go/exp/proxy/direct.go \
go/exp/proxy/per_host.go \
go/exp/proxy/proxy.go \
go/exp/proxy/socks5.go
go_exp_spdy_files = \ go_exp_spdy_files = \
go/exp/spdy/read.go \ go/exp/spdy/read.go \
go/exp/spdy/types.go \ go/exp/spdy/types.go \
...@@ -1173,7 +1191,7 @@ go_exp_ssh_files = \ ...@@ -1173,7 +1191,7 @@ go_exp_ssh_files = \
go/exp/ssh/doc.go \ go/exp/ssh/doc.go \
go/exp/ssh/messages.go \ go/exp/ssh/messages.go \
go/exp/ssh/server.go \ go/exp/ssh/server.go \
go/exp/ssh/server_shell.go \ go/exp/ssh/server_terminal.go \
go/exp/ssh/session.go \ go/exp/ssh/session.go \
go/exp/ssh/tcpip.go \ go/exp/ssh/tcpip.go \
go/exp/ssh/transport.go go/exp/ssh/transport.go
...@@ -1210,7 +1228,8 @@ go_go_doc_files = \ ...@@ -1210,7 +1228,8 @@ go_go_doc_files = \
go/go/doc/doc.go \ go/go/doc/doc.go \
go/go/doc/example.go \ go/go/doc/example.go \
go/go/doc/exports.go \ go/go/doc/exports.go \
go/go/doc/filter.go go/go/doc/filter.go \
go/go/doc/reader.go
go_go_parser_files = \ go_go_parser_files = \
go/go/parser/interface.go \ go/go/parser/interface.go \
go/go/parser/parser.go go/go/parser/parser.go
...@@ -1461,8 +1480,15 @@ endif ...@@ -1461,8 +1480,15 @@ endif
# Define ForkExec and Exec. # Define ForkExec and Exec.
if LIBGO_IS_RTEMS if LIBGO_IS_RTEMS
syscall_exec_file = go/syscall/exec_stubs.go syscall_exec_file = go/syscall/exec_stubs.go
syscall_exec_os_file =
else
if LIBGO_IS_LINUX
syscall_exec_file = go/syscall/exec_unix.go
syscall_exec_os_file = go/syscall/exec_linux.go
else else
syscall_exec_file = go/syscall/exec_unix.go syscall_exec_file = go/syscall/exec_unix.go
syscall_exec_os_file = go/syscall/exec_bsd.go
endif
endif endif
# Define Wait4. # Define Wait4.
...@@ -1573,6 +1599,7 @@ go_base_syscall_files = \ ...@@ -1573,6 +1599,7 @@ go_base_syscall_files = \
go/syscall/syscall.go \ go/syscall/syscall.go \
$(syscall_syscall_file) \ $(syscall_syscall_file) \
$(syscall_exec_file) \ $(syscall_exec_file) \
$(syscall_exec_os_file) \
$(syscall_wait_file) \ $(syscall_wait_file) \
$(syscall_sleep_file) \ $(syscall_sleep_file) \
$(syscall_errstr_file) \ $(syscall_errstr_file) \
...@@ -1720,7 +1747,7 @@ libgo_go_objs = \ ...@@ -1720,7 +1747,7 @@ libgo_go_objs = \
crypto/xtea.lo \ crypto/xtea.lo \
crypto/openpgp/armor.lo \ crypto/openpgp/armor.lo \
crypto/openpgp/elgamal.lo \ crypto/openpgp/elgamal.lo \
crypto/openpgp/error.lo \ crypto/openpgp/errors.lo \
crypto/openpgp/packet.lo \ crypto/openpgp/packet.lo \
crypto/openpgp/s2k.lo \ crypto/openpgp/s2k.lo \
crypto/x509/pkix.lo \ crypto/x509/pkix.lo \
...@@ -1743,6 +1770,7 @@ libgo_go_objs = \ ...@@ -1743,6 +1770,7 @@ libgo_go_objs = \
encoding/xml.lo \ encoding/xml.lo \
exp/ebnf.lo \ exp/ebnf.lo \
exp/norm.lo \ exp/norm.lo \
exp/proxy.lo \
exp/spdy.lo \ exp/spdy.lo \
exp/sql.lo \ exp/sql.lo \
exp/ssh.lo \ exp/ssh.lo \
...@@ -2578,15 +2606,15 @@ crypto/openpgp/elgamal/check: $(CHECK_DEPS) ...@@ -2578,15 +2606,15 @@ crypto/openpgp/elgamal/check: $(CHECK_DEPS)
@$(CHECK) @$(CHECK)
.PHONY: crypto/openpgp/elgamal/check .PHONY: crypto/openpgp/elgamal/check
@go_include@ crypto/openpgp/error.lo.dep @go_include@ crypto/openpgp/errors.lo.dep
crypto/openpgp/error.lo.dep: $(go_crypto_openpgp_error_files) crypto/openpgp/errors.lo.dep: $(go_crypto_openpgp_errors_files)
$(BUILDDEPS) $(BUILDDEPS)
crypto/openpgp/error.lo: $(go_crypto_openpgp_error_files) crypto/openpgp/errors.lo: $(go_crypto_openpgp_errors_files)
$(BUILDPACKAGE) $(BUILDPACKAGE)
crypto/openpgp/error/check: $(CHECK_DEPS) crypto/openpgp/errors/check: $(CHECK_DEPS)
@$(MKDIR_P) crypto/openpgp/error @$(MKDIR_P) crypto/openpgp/errors
@$(CHECK) @$(CHECK)
.PHONY: crypto/openpgp/error/check .PHONY: crypto/openpgp/errors/check
@go_include@ crypto/openpgp/packet.lo.dep @go_include@ crypto/openpgp/packet.lo.dep
crypto/openpgp/packet.lo.dep: $(go_crypto_openpgp_packet_files) crypto/openpgp/packet.lo.dep: $(go_crypto_openpgp_packet_files)
...@@ -2808,6 +2836,16 @@ exp/norm/check: $(CHECK_DEPS) ...@@ -2808,6 +2836,16 @@ exp/norm/check: $(CHECK_DEPS)
@$(CHECK) @$(CHECK)
.PHONY: exp/norm/check .PHONY: exp/norm/check
@go_include@ exp/proxy.lo.dep
exp/proxy.lo.dep: $(go_exp_proxy_files)
$(BUILDDEPS)
exp/proxy.lo: $(go_exp_proxy_files)
$(BUILDPACKAGE)
exp/proxy/check: $(CHECK_DEPS)
@$(MKDIR_P) exp/proxy
@$(CHECK)
.PHONY: exp/proxy/check
@go_include@ exp/spdy.lo.dep @go_include@ exp/spdy.lo.dep
exp/spdy.lo.dep: $(go_exp_spdy_files) exp/spdy.lo.dep: $(go_exp_spdy_files)
$(BUILDDEPS) $(BUILDDEPS)
...@@ -3622,7 +3660,7 @@ crypto/openpgp/armor.gox: crypto/openpgp/armor.lo ...@@ -3622,7 +3660,7 @@ crypto/openpgp/armor.gox: crypto/openpgp/armor.lo
$(BUILDGOX) $(BUILDGOX)
crypto/openpgp/elgamal.gox: crypto/openpgp/elgamal.lo crypto/openpgp/elgamal.gox: crypto/openpgp/elgamal.lo
$(BUILDGOX) $(BUILDGOX)
crypto/openpgp/error.gox: crypto/openpgp/error.lo crypto/openpgp/errors.gox: crypto/openpgp/errors.lo
$(BUILDGOX) $(BUILDGOX)
crypto/openpgp/packet.gox: crypto/openpgp/packet.lo crypto/openpgp/packet.gox: crypto/openpgp/packet.lo
$(BUILDGOX) $(BUILDGOX)
...@@ -3674,6 +3712,8 @@ exp/inotify.gox: exp/inotify.lo ...@@ -3674,6 +3712,8 @@ exp/inotify.gox: exp/inotify.lo
$(BUILDGOX) $(BUILDGOX)
exp/norm.gox: exp/norm.lo exp/norm.gox: exp/norm.lo
$(BUILDGOX) $(BUILDGOX)
exp/proxy.gox: exp/proxy.lo
$(BUILDGOX)
exp/spdy.gox: exp/spdy.lo exp/spdy.gox: exp/spdy.lo
$(BUILDGOX) $(BUILDGOX)
exp/sql.gox: exp/sql.lo exp/sql.gox: exp/sql.lo
...@@ -3920,6 +3960,7 @@ TEST_PACKAGES = \ ...@@ -3920,6 +3960,7 @@ TEST_PACKAGES = \
exp/ebnf/check \ exp/ebnf/check \
$(exp_inotify_check) \ $(exp_inotify_check) \
exp/norm/check \ exp/norm/check \
exp/proxy/check \
exp/spdy/check \ exp/spdy/check \
exp/sql/check \ exp/sql/check \
exp/ssh/check \ exp/ssh/check \
......
...@@ -74,6 +74,9 @@ ...@@ -74,6 +74,9 @@
/* Define to 1 if you have the <sys/mman.h> header file. */ /* Define to 1 if you have the <sys/mman.h> header file. */
#undef HAVE_SYS_MMAN_H #undef HAVE_SYS_MMAN_H
/* Define to 1 if you have the <sys/prctl.h> header file. */
#undef HAVE_SYS_PRCTL_H
/* Define to 1 if you have the <sys/ptrace.h> header file. */ /* Define to 1 if you have the <sys/ptrace.h> header file. */
#undef HAVE_SYS_PTRACE_H #undef HAVE_SYS_PTRACE_H
......
...@@ -14505,7 +14505,7 @@ no) ...@@ -14505,7 +14505,7 @@ no)
;; ;;
esac esac
for ac_header in sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h sys/socket.h net/if.h for ac_header in sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h sys/socket.h net/if.h sys/prctl.h
do : do :
as_ac_Header=`$as_echo "ac_cv_header_$ac_header" | $as_tr_sh` as_ac_Header=`$as_echo "ac_cv_header_$ac_header" | $as_tr_sh`
ac_fn_c_check_header_mongrel "$LINENO" "$ac_header" "$as_ac_Header" "$ac_includes_default" ac_fn_c_check_header_mongrel "$LINENO" "$ac_header" "$as_ac_Header" "$ac_includes_default"
......
...@@ -451,7 +451,7 @@ no) ...@@ -451,7 +451,7 @@ no)
;; ;;
esac esac
AC_CHECK_HEADERS(sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h sys/socket.h net/if.h) AC_CHECK_HEADERS(sys/mman.h syscall.h sys/epoll.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h sys/socket.h net/if.h sys/prctl.h)
AC_CHECK_HEADERS([linux/filter.h linux/netlink.h linux/rtnetlink.h], [], [], AC_CHECK_HEADERS([linux/filter.h linux/netlink.h linux/rtnetlink.h], [], [],
[#ifdef HAVE_SYS_SOCKET_H [#ifdef HAVE_SYS_SOCKET_H
......
...@@ -97,8 +97,7 @@ func (b *Buffer) grow(n int) int { ...@@ -97,8 +97,7 @@ func (b *Buffer) grow(n int) int {
func (b *Buffer) Write(p []byte) (n int, err error) { func (b *Buffer) Write(p []byte) (n int, err error) {
b.lastRead = opInvalid b.lastRead = opInvalid
m := b.grow(len(p)) m := b.grow(len(p))
copy(b.buf[m:], p) return copy(b.buf[m:], p), nil
return len(p), nil
} }
// WriteString appends the contents of s to the buffer. The return // WriteString appends the contents of s to the buffer. The return
...@@ -200,13 +199,16 @@ func (b *Buffer) WriteRune(r rune) (n int, err error) { ...@@ -200,13 +199,16 @@ func (b *Buffer) WriteRune(r rune) (n int, err error) {
// Read reads the next len(p) bytes from the buffer or until the buffer // Read reads the next len(p) bytes from the buffer or until the buffer
// is drained. The return value n is the number of bytes read. If the // is drained. The return value n is the number of bytes read. If the
// buffer has no data to return, err is io.EOF even if len(p) is zero; // buffer has no data to return, err is io.EOF (unless len(p) is zero);
// otherwise it is nil. // otherwise it is nil.
func (b *Buffer) Read(p []byte) (n int, err error) { func (b *Buffer) Read(p []byte) (n int, err error) {
b.lastRead = opInvalid b.lastRead = opInvalid
if b.off >= len(b.buf) { if b.off >= len(b.buf) {
// Buffer is empty, reset to recover space. // Buffer is empty, reset to recover space.
b.Truncate(0) b.Truncate(0)
if len(p) == 0 {
return
}
return 0, io.EOF return 0, io.EOF
} }
n = copy(p, b.buf[b.off:]) n = copy(p, b.buf[b.off:])
......
...@@ -373,3 +373,16 @@ func TestReadBytes(t *testing.T) { ...@@ -373,3 +373,16 @@ func TestReadBytes(t *testing.T) {
} }
} }
} }
// Was a bug: used to give EOF reading empty slice at EOF.
func TestReadEmptyAtEOF(t *testing.T) {
b := new(Buffer)
slice := make([]byte, 0)
n, err := b.Read(slice)
if err != nil {
t.Errorf("read error: %v", err)
}
if n != 0 {
t.Errorf("wrong count; got %d want 0", n)
}
}
...@@ -9,7 +9,7 @@ package armor ...@@ -9,7 +9,7 @@ package armor
import ( import (
"bufio" "bufio"
"bytes" "bytes"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"encoding/base64" "encoding/base64"
"io" "io"
) )
...@@ -35,7 +35,7 @@ type Block struct { ...@@ -35,7 +35,7 @@ type Block struct {
oReader openpgpReader oReader openpgpReader
} }
var ArmorCorrupt error = error_.StructuralError("armor invalid") var ArmorCorrupt error = errors.StructuralError("armor invalid")
const crc24Init = 0xb704ce const crc24Init = 0xb704ce
const crc24Poly = 0x1864cfb const crc24Poly = 0x1864cfb
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Package error contains common error types for the OpenPGP packages. // Package errors contains common error types for the OpenPGP packages.
package error package errors
import ( import (
"strconv" "strconv"
......
...@@ -7,8 +7,9 @@ package openpgp ...@@ -7,8 +7,9 @@ package openpgp
import ( import (
"crypto" "crypto"
"crypto/openpgp/armor" "crypto/openpgp/armor"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"crypto/openpgp/packet" "crypto/openpgp/packet"
"crypto/rand"
"crypto/rsa" "crypto/rsa"
"io" "io"
"time" "time"
...@@ -181,13 +182,13 @@ func (el EntityList) DecryptionKeys() (keys []Key) { ...@@ -181,13 +182,13 @@ func (el EntityList) DecryptionKeys() (keys []Key) {
func ReadArmoredKeyRing(r io.Reader) (EntityList, error) { func ReadArmoredKeyRing(r io.Reader) (EntityList, error) {
block, err := armor.Decode(r) block, err := armor.Decode(r)
if err == io.EOF { if err == io.EOF {
return nil, error_.InvalidArgumentError("no armored data found") return nil, errors.InvalidArgumentError("no armored data found")
} }
if err != nil { if err != nil {
return nil, err return nil, err
} }
if block.Type != PublicKeyType && block.Type != PrivateKeyType { if block.Type != PublicKeyType && block.Type != PrivateKeyType {
return nil, error_.InvalidArgumentError("expected public or private key block, got: " + block.Type) return nil, errors.InvalidArgumentError("expected public or private key block, got: " + block.Type)
} }
return ReadKeyRing(block.Body) return ReadKeyRing(block.Body)
...@@ -203,7 +204,7 @@ func ReadKeyRing(r io.Reader) (el EntityList, err error) { ...@@ -203,7 +204,7 @@ func ReadKeyRing(r io.Reader) (el EntityList, err error) {
var e *Entity var e *Entity
e, err = readEntity(packets) e, err = readEntity(packets)
if err != nil { if err != nil {
if _, ok := err.(error_.UnsupportedError); ok { if _, ok := err.(errors.UnsupportedError); ok {
lastUnsupportedError = err lastUnsupportedError = err
err = readToNextPublicKey(packets) err = readToNextPublicKey(packets)
} }
...@@ -235,7 +236,7 @@ func readToNextPublicKey(packets *packet.Reader) (err error) { ...@@ -235,7 +236,7 @@ func readToNextPublicKey(packets *packet.Reader) (err error) {
if err == io.EOF { if err == io.EOF {
return return
} else if err != nil { } else if err != nil {
if _, ok := err.(error_.UnsupportedError); ok { if _, ok := err.(errors.UnsupportedError); ok {
err = nil err = nil
continue continue
} }
...@@ -266,14 +267,14 @@ func readEntity(packets *packet.Reader) (*Entity, error) { ...@@ -266,14 +267,14 @@ func readEntity(packets *packet.Reader) (*Entity, error) {
if e.PrimaryKey, ok = p.(*packet.PublicKey); !ok { if e.PrimaryKey, ok = p.(*packet.PublicKey); !ok {
if e.PrivateKey, ok = p.(*packet.PrivateKey); !ok { if e.PrivateKey, ok = p.(*packet.PrivateKey); !ok {
packets.Unread(p) packets.Unread(p)
return nil, error_.StructuralError("first packet was not a public/private key") return nil, errors.StructuralError("first packet was not a public/private key")
} else { } else {
e.PrimaryKey = &e.PrivateKey.PublicKey e.PrimaryKey = &e.PrivateKey.PublicKey
} }
} }
if !e.PrimaryKey.PubKeyAlgo.CanSign() { if !e.PrimaryKey.PubKeyAlgo.CanSign() {
return nil, error_.StructuralError("primary key cannot be used for signatures") return nil, errors.StructuralError("primary key cannot be used for signatures")
} }
var current *Identity var current *Identity
...@@ -303,12 +304,12 @@ EachPacket: ...@@ -303,12 +304,12 @@ EachPacket:
sig, ok := p.(*packet.Signature) sig, ok := p.(*packet.Signature)
if !ok { if !ok {
return nil, error_.StructuralError("user ID packet not followed by self-signature") return nil, errors.StructuralError("user ID packet not followed by self-signature")
} }
if (sig.SigType == packet.SigTypePositiveCert || sig.SigType == packet.SigTypeGenericCert) && sig.IssuerKeyId != nil && *sig.IssuerKeyId == e.PrimaryKey.KeyId { if (sig.SigType == packet.SigTypePositiveCert || sig.SigType == packet.SigTypeGenericCert) && sig.IssuerKeyId != nil && *sig.IssuerKeyId == e.PrimaryKey.KeyId {
if err = e.PrimaryKey.VerifyUserIdSignature(pkt.Id, sig); err != nil { if err = e.PrimaryKey.VerifyUserIdSignature(pkt.Id, sig); err != nil {
return nil, error_.StructuralError("user ID self-signature invalid: " + err.Error()) return nil, errors.StructuralError("user ID self-signature invalid: " + err.Error())
} }
current.SelfSignature = sig current.SelfSignature = sig
break break
...@@ -317,7 +318,7 @@ EachPacket: ...@@ -317,7 +318,7 @@ EachPacket:
} }
case *packet.Signature: case *packet.Signature:
if current == nil { if current == nil {
return nil, error_.StructuralError("signature packet found before user id packet") return nil, errors.StructuralError("signature packet found before user id packet")
} }
current.Signatures = append(current.Signatures, pkt) current.Signatures = append(current.Signatures, pkt)
case *packet.PrivateKey: case *packet.PrivateKey:
...@@ -344,7 +345,7 @@ EachPacket: ...@@ -344,7 +345,7 @@ EachPacket:
} }
if len(e.Identities) == 0 { if len(e.Identities) == 0 {
return nil, error_.StructuralError("entity without any identities") return nil, errors.StructuralError("entity without any identities")
} }
return e, nil return e, nil
...@@ -359,19 +360,19 @@ func addSubkey(e *Entity, packets *packet.Reader, pub *packet.PublicKey, priv *p ...@@ -359,19 +360,19 @@ func addSubkey(e *Entity, packets *packet.Reader, pub *packet.PublicKey, priv *p
return io.ErrUnexpectedEOF return io.ErrUnexpectedEOF
} }
if err != nil { if err != nil {
return error_.StructuralError("subkey signature invalid: " + err.Error()) return errors.StructuralError("subkey signature invalid: " + err.Error())
} }
var ok bool var ok bool
subKey.Sig, ok = p.(*packet.Signature) subKey.Sig, ok = p.(*packet.Signature)
if !ok { if !ok {
return error_.StructuralError("subkey packet not followed by signature") return errors.StructuralError("subkey packet not followed by signature")
} }
if subKey.Sig.SigType != packet.SigTypeSubkeyBinding { if subKey.Sig.SigType != packet.SigTypeSubkeyBinding {
return error_.StructuralError("subkey signature with wrong type") return errors.StructuralError("subkey signature with wrong type")
} }
err = e.PrimaryKey.VerifyKeySignature(subKey.PublicKey, subKey.Sig) err = e.PrimaryKey.VerifyKeySignature(subKey.PublicKey, subKey.Sig)
if err != nil { if err != nil {
return error_.StructuralError("subkey signature invalid: " + err.Error()) return errors.StructuralError("subkey signature invalid: " + err.Error())
} }
e.Subkeys = append(e.Subkeys, subKey) e.Subkeys = append(e.Subkeys, subKey)
return nil return nil
...@@ -385,7 +386,7 @@ const defaultRSAKeyBits = 2048 ...@@ -385,7 +386,7 @@ const defaultRSAKeyBits = 2048
func NewEntity(rand io.Reader, currentTime time.Time, name, comment, email string) (*Entity, error) { func NewEntity(rand io.Reader, currentTime time.Time, name, comment, email string) (*Entity, error) {
uid := packet.NewUserId(name, comment, email) uid := packet.NewUserId(name, comment, email)
if uid == nil { if uid == nil {
return nil, error_.InvalidArgumentError("user id field contained invalid characters") return nil, errors.InvalidArgumentError("user id field contained invalid characters")
} }
signingPriv, err := rsa.GenerateKey(rand, defaultRSAKeyBits) signingPriv, err := rsa.GenerateKey(rand, defaultRSAKeyBits)
if err != nil { if err != nil {
...@@ -397,8 +398,8 @@ func NewEntity(rand io.Reader, currentTime time.Time, name, comment, email strin ...@@ -397,8 +398,8 @@ func NewEntity(rand io.Reader, currentTime time.Time, name, comment, email strin
} }
e := &Entity{ e := &Entity{
PrimaryKey: packet.NewRSAPublicKey(currentTime, &signingPriv.PublicKey, false /* not a subkey */ ), PrimaryKey: packet.NewRSAPublicKey(currentTime, &signingPriv.PublicKey),
PrivateKey: packet.NewRSAPrivateKey(currentTime, signingPriv, false /* not a subkey */ ), PrivateKey: packet.NewRSAPrivateKey(currentTime, signingPriv),
Identities: make(map[string]*Identity), Identities: make(map[string]*Identity),
} }
isPrimaryId := true isPrimaryId := true
...@@ -420,8 +421,8 @@ func NewEntity(rand io.Reader, currentTime time.Time, name, comment, email strin ...@@ -420,8 +421,8 @@ func NewEntity(rand io.Reader, currentTime time.Time, name, comment, email strin
e.Subkeys = make([]Subkey, 1) e.Subkeys = make([]Subkey, 1)
e.Subkeys[0] = Subkey{ e.Subkeys[0] = Subkey{
PublicKey: packet.NewRSAPublicKey(currentTime, &encryptingPriv.PublicKey, true /* is a subkey */ ), PublicKey: packet.NewRSAPublicKey(currentTime, &encryptingPriv.PublicKey),
PrivateKey: packet.NewRSAPrivateKey(currentTime, encryptingPriv, true /* is a subkey */ ), PrivateKey: packet.NewRSAPrivateKey(currentTime, encryptingPriv),
Sig: &packet.Signature{ Sig: &packet.Signature{
CreationTime: currentTime, CreationTime: currentTime,
SigType: packet.SigTypeSubkeyBinding, SigType: packet.SigTypeSubkeyBinding,
...@@ -433,6 +434,8 @@ func NewEntity(rand io.Reader, currentTime time.Time, name, comment, email strin ...@@ -433,6 +434,8 @@ func NewEntity(rand io.Reader, currentTime time.Time, name, comment, email strin
IssuerKeyId: &e.PrimaryKey.KeyId, IssuerKeyId: &e.PrimaryKey.KeyId,
}, },
} }
e.Subkeys[0].PublicKey.IsSubkey = true
e.Subkeys[0].PrivateKey.IsSubkey = true
return e, nil return e, nil
} }
...@@ -450,7 +453,7 @@ func (e *Entity) SerializePrivate(w io.Writer) (err error) { ...@@ -450,7 +453,7 @@ func (e *Entity) SerializePrivate(w io.Writer) (err error) {
if err != nil { if err != nil {
return return
} }
err = ident.SelfSignature.SignUserId(ident.UserId.Id, e.PrimaryKey, e.PrivateKey) err = ident.SelfSignature.SignUserId(rand.Reader, ident.UserId.Id, e.PrimaryKey, e.PrivateKey)
if err != nil { if err != nil {
return return
} }
...@@ -464,7 +467,7 @@ func (e *Entity) SerializePrivate(w io.Writer) (err error) { ...@@ -464,7 +467,7 @@ func (e *Entity) SerializePrivate(w io.Writer) (err error) {
if err != nil { if err != nil {
return return
} }
err = subkey.Sig.SignKey(subkey.PublicKey, e.PrivateKey) err = subkey.Sig.SignKey(rand.Reader, subkey.PublicKey, e.PrivateKey)
if err != nil { if err != nil {
return return
} }
...@@ -518,14 +521,14 @@ func (e *Entity) Serialize(w io.Writer) error { ...@@ -518,14 +521,14 @@ func (e *Entity) Serialize(w io.Writer) error {
// necessary. // necessary.
func (e *Entity) SignIdentity(identity string, signer *Entity) error { func (e *Entity) SignIdentity(identity string, signer *Entity) error {
if signer.PrivateKey == nil { if signer.PrivateKey == nil {
return error_.InvalidArgumentError("signing Entity must have a private key") return errors.InvalidArgumentError("signing Entity must have a private key")
} }
if signer.PrivateKey.Encrypted { if signer.PrivateKey.Encrypted {
return error_.InvalidArgumentError("signing Entity's private key must be decrypted") return errors.InvalidArgumentError("signing Entity's private key must be decrypted")
} }
ident, ok := e.Identities[identity] ident, ok := e.Identities[identity]
if !ok { if !ok {
return error_.InvalidArgumentError("given identity string not found in Entity") return errors.InvalidArgumentError("given identity string not found in Entity")
} }
sig := &packet.Signature{ sig := &packet.Signature{
...@@ -535,7 +538,7 @@ func (e *Entity) SignIdentity(identity string, signer *Entity) error { ...@@ -535,7 +538,7 @@ func (e *Entity) SignIdentity(identity string, signer *Entity) error {
CreationTime: time.Now(), CreationTime: time.Now(),
IssuerKeyId: &signer.PrivateKey.KeyId, IssuerKeyId: &signer.PrivateKey.KeyId,
} }
if err := sig.SignKey(e.PrimaryKey, signer.PrivateKey); err != nil { if err := sig.SignKey(rand.Reader, e.PrimaryKey, signer.PrivateKey); err != nil {
return err return err
} }
ident.Signatures = append(ident.Signatures, sig) ident.Signatures = append(ident.Signatures, sig)
......
...@@ -7,7 +7,7 @@ package packet ...@@ -7,7 +7,7 @@ package packet
import ( import (
"compress/flate" "compress/flate"
"compress/zlib" "compress/zlib"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"io" "io"
"strconv" "strconv"
) )
...@@ -31,7 +31,7 @@ func (c *Compressed) parse(r io.Reader) error { ...@@ -31,7 +31,7 @@ func (c *Compressed) parse(r io.Reader) error {
case 2: case 2:
c.Body, err = zlib.NewReader(r) c.Body, err = zlib.NewReader(r)
default: default:
err = error_.UnsupportedError("unknown compression algorithm: " + strconv.Itoa(int(buf[0]))) err = errors.UnsupportedError("unknown compression algorithm: " + strconv.Itoa(int(buf[0])))
} }
return err return err
......
...@@ -6,7 +6,7 @@ package packet ...@@ -6,7 +6,7 @@ package packet
import ( import (
"crypto/openpgp/elgamal" "crypto/openpgp/elgamal"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"encoding/binary" "encoding/binary"
...@@ -35,7 +35,7 @@ func (e *EncryptedKey) parse(r io.Reader) (err error) { ...@@ -35,7 +35,7 @@ func (e *EncryptedKey) parse(r io.Reader) (err error) {
return return
} }
if buf[0] != encryptedKeyVersion { if buf[0] != encryptedKeyVersion {
return error_.UnsupportedError("unknown EncryptedKey version " + strconv.Itoa(int(buf[0]))) return errors.UnsupportedError("unknown EncryptedKey version " + strconv.Itoa(int(buf[0])))
} }
e.KeyId = binary.BigEndian.Uint64(buf[1:9]) e.KeyId = binary.BigEndian.Uint64(buf[1:9])
e.Algo = PublicKeyAlgorithm(buf[9]) e.Algo = PublicKeyAlgorithm(buf[9])
...@@ -77,7 +77,7 @@ func (e *EncryptedKey) Decrypt(priv *PrivateKey) error { ...@@ -77,7 +77,7 @@ func (e *EncryptedKey) Decrypt(priv *PrivateKey) error {
c2 := new(big.Int).SetBytes(e.encryptedMPI2) c2 := new(big.Int).SetBytes(e.encryptedMPI2)
b, err = elgamal.Decrypt(priv.PrivateKey.(*elgamal.PrivateKey), c1, c2) b, err = elgamal.Decrypt(priv.PrivateKey.(*elgamal.PrivateKey), c1, c2)
default: default:
err = error_.InvalidArgumentError("cannot decrypted encrypted session key with private key of type " + strconv.Itoa(int(priv.PubKeyAlgo))) err = errors.InvalidArgumentError("cannot decrypted encrypted session key with private key of type " + strconv.Itoa(int(priv.PubKeyAlgo)))
} }
if err != nil { if err != nil {
...@@ -89,7 +89,7 @@ func (e *EncryptedKey) Decrypt(priv *PrivateKey) error { ...@@ -89,7 +89,7 @@ func (e *EncryptedKey) Decrypt(priv *PrivateKey) error {
expectedChecksum := uint16(b[len(b)-2])<<8 | uint16(b[len(b)-1]) expectedChecksum := uint16(b[len(b)-2])<<8 | uint16(b[len(b)-1])
checksum := checksumKeyMaterial(e.Key) checksum := checksumKeyMaterial(e.Key)
if checksum != expectedChecksum { if checksum != expectedChecksum {
return error_.StructuralError("EncryptedKey checksum incorrect") return errors.StructuralError("EncryptedKey checksum incorrect")
} }
return nil return nil
...@@ -116,16 +116,16 @@ func SerializeEncryptedKey(w io.Writer, rand io.Reader, pub *PublicKey, cipherFu ...@@ -116,16 +116,16 @@ func SerializeEncryptedKey(w io.Writer, rand io.Reader, pub *PublicKey, cipherFu
case PubKeyAlgoElGamal: case PubKeyAlgoElGamal:
return serializeEncryptedKeyElGamal(w, rand, buf, pub.PublicKey.(*elgamal.PublicKey), keyBlock) return serializeEncryptedKeyElGamal(w, rand, buf, pub.PublicKey.(*elgamal.PublicKey), keyBlock)
case PubKeyAlgoDSA, PubKeyAlgoRSASignOnly: case PubKeyAlgoDSA, PubKeyAlgoRSASignOnly:
return error_.InvalidArgumentError("cannot encrypt to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo))) return errors.InvalidArgumentError("cannot encrypt to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo)))
} }
return error_.UnsupportedError("encrypting a key to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo))) return errors.UnsupportedError("encrypting a key to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo)))
} }
func serializeEncryptedKeyRSA(w io.Writer, rand io.Reader, header [10]byte, pub *rsa.PublicKey, keyBlock []byte) error { func serializeEncryptedKeyRSA(w io.Writer, rand io.Reader, header [10]byte, pub *rsa.PublicKey, keyBlock []byte) error {
cipherText, err := rsa.EncryptPKCS1v15(rand, pub, keyBlock) cipherText, err := rsa.EncryptPKCS1v15(rand, pub, keyBlock)
if err != nil { if err != nil {
return error_.InvalidArgumentError("RSA encryption failed: " + err.Error()) return errors.InvalidArgumentError("RSA encryption failed: " + err.Error())
} }
packetLen := 10 /* header length */ + 2 /* mpi size */ + len(cipherText) packetLen := 10 /* header length */ + 2 /* mpi size */ + len(cipherText)
...@@ -144,7 +144,7 @@ func serializeEncryptedKeyRSA(w io.Writer, rand io.Reader, header [10]byte, pub ...@@ -144,7 +144,7 @@ func serializeEncryptedKeyRSA(w io.Writer, rand io.Reader, header [10]byte, pub
func serializeEncryptedKeyElGamal(w io.Writer, rand io.Reader, header [10]byte, pub *elgamal.PublicKey, keyBlock []byte) error { func serializeEncryptedKeyElGamal(w io.Writer, rand io.Reader, header [10]byte, pub *elgamal.PublicKey, keyBlock []byte) error {
c1, c2, err := elgamal.Encrypt(rand, pub, keyBlock) c1, c2, err := elgamal.Encrypt(rand, pub, keyBlock)
if err != nil { if err != nil {
return error_.InvalidArgumentError("ElGamal encryption failed: " + err.Error()) return errors.InvalidArgumentError("ElGamal encryption failed: " + err.Error())
} }
packetLen := 10 /* header length */ packetLen := 10 /* header length */
......
...@@ -6,7 +6,7 @@ package packet ...@@ -6,7 +6,7 @@ package packet
import ( import (
"crypto" "crypto"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"crypto/openpgp/s2k" "crypto/openpgp/s2k"
"encoding/binary" "encoding/binary"
"io" "io"
...@@ -33,13 +33,13 @@ func (ops *OnePassSignature) parse(r io.Reader) (err error) { ...@@ -33,13 +33,13 @@ func (ops *OnePassSignature) parse(r io.Reader) (err error) {
return return
} }
if buf[0] != onePassSignatureVersion { if buf[0] != onePassSignatureVersion {
err = error_.UnsupportedError("one-pass-signature packet version " + strconv.Itoa(int(buf[0]))) err = errors.UnsupportedError("one-pass-signature packet version " + strconv.Itoa(int(buf[0])))
} }
var ok bool var ok bool
ops.Hash, ok = s2k.HashIdToHash(buf[2]) ops.Hash, ok = s2k.HashIdToHash(buf[2])
if !ok { if !ok {
return error_.UnsupportedError("hash function: " + strconv.Itoa(int(buf[2]))) return errors.UnsupportedError("hash function: " + strconv.Itoa(int(buf[2])))
} }
ops.SigType = SignatureType(buf[1]) ops.SigType = SignatureType(buf[1])
...@@ -57,7 +57,7 @@ func (ops *OnePassSignature) Serialize(w io.Writer) error { ...@@ -57,7 +57,7 @@ func (ops *OnePassSignature) Serialize(w io.Writer) error {
var ok bool var ok bool
buf[2], ok = s2k.HashToHashId(ops.Hash) buf[2], ok = s2k.HashToHashId(ops.Hash)
if !ok { if !ok {
return error_.UnsupportedError("hash type: " + strconv.Itoa(int(ops.Hash))) return errors.UnsupportedError("hash type: " + strconv.Itoa(int(ops.Hash)))
} }
buf[3] = uint8(ops.PubKeyAlgo) buf[3] = uint8(ops.PubKeyAlgo)
binary.BigEndian.PutUint64(buf[4:12], ops.KeyId) binary.BigEndian.PutUint64(buf[4:12], ops.KeyId)
......
...@@ -10,7 +10,7 @@ import ( ...@@ -10,7 +10,7 @@ import (
"crypto/aes" "crypto/aes"
"crypto/cast5" "crypto/cast5"
"crypto/cipher" "crypto/cipher"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"io" "io"
"math/big" "math/big"
) )
...@@ -162,7 +162,7 @@ func readHeader(r io.Reader) (tag packetType, length int64, contents io.Reader, ...@@ -162,7 +162,7 @@ func readHeader(r io.Reader) (tag packetType, length int64, contents io.Reader,
return return
} }
if buf[0]&0x80 == 0 { if buf[0]&0x80 == 0 {
err = error_.StructuralError("tag byte does not have MSB set") err = errors.StructuralError("tag byte does not have MSB set")
return return
} }
if buf[0]&0x40 == 0 { if buf[0]&0x40 == 0 {
...@@ -337,7 +337,7 @@ func Read(r io.Reader) (p Packet, err error) { ...@@ -337,7 +337,7 @@ func Read(r io.Reader) (p Packet, err error) {
se.MDC = true se.MDC = true
p = se p = se
default: default:
err = error_.UnknownPacketTypeError(tag) err = errors.UnknownPacketTypeError(tag)
} }
if p != nil { if p != nil {
err = p.parse(contents) err = p.parse(contents)
......
...@@ -6,7 +6,7 @@ package packet ...@@ -6,7 +6,7 @@ package packet
import ( import (
"bytes" "bytes"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"io" "io"
...@@ -152,7 +152,7 @@ func TestReadHeader(t *testing.T) { ...@@ -152,7 +152,7 @@ func TestReadHeader(t *testing.T) {
for i, test := range readHeaderTests { for i, test := range readHeaderTests {
tag, length, contents, err := readHeader(readerFromHex(test.hexInput)) tag, length, contents, err := readHeader(readerFromHex(test.hexInput))
if test.structuralError { if test.structuralError {
if _, ok := err.(error_.StructuralError); ok { if _, ok := err.(errors.StructuralError); ok {
continue continue
} }
t.Errorf("%d: expected StructuralError, got:%s", i, err) t.Errorf("%d: expected StructuralError, got:%s", i, err)
......
...@@ -9,7 +9,7 @@ import ( ...@@ -9,7 +9,7 @@ import (
"crypto/cipher" "crypto/cipher"
"crypto/dsa" "crypto/dsa"
"crypto/openpgp/elgamal" "crypto/openpgp/elgamal"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"crypto/openpgp/s2k" "crypto/openpgp/s2k"
"crypto/rsa" "crypto/rsa"
"crypto/sha1" "crypto/sha1"
...@@ -28,14 +28,21 @@ type PrivateKey struct { ...@@ -28,14 +28,21 @@ type PrivateKey struct {
encryptedData []byte encryptedData []byte
cipher CipherFunction cipher CipherFunction
s2k func(out, in []byte) s2k func(out, in []byte)
PrivateKey interface{} // An *rsa.PrivateKey. PrivateKey interface{} // An *rsa.PrivateKey or *dsa.PrivateKey.
sha1Checksum bool sha1Checksum bool
iv []byte iv []byte
} }
func NewRSAPrivateKey(currentTime time.Time, priv *rsa.PrivateKey, isSubkey bool) *PrivateKey { func NewRSAPrivateKey(currentTime time.Time, priv *rsa.PrivateKey) *PrivateKey {
pk := new(PrivateKey) pk := new(PrivateKey)
pk.PublicKey = *NewRSAPublicKey(currentTime, &priv.PublicKey, isSubkey) pk.PublicKey = *NewRSAPublicKey(currentTime, &priv.PublicKey)
pk.PrivateKey = priv
return pk
}
func NewDSAPrivateKey(currentTime time.Time, priv *dsa.PrivateKey) *PrivateKey {
pk := new(PrivateKey)
pk.PublicKey = *NewDSAPublicKey(currentTime, &priv.PublicKey)
pk.PrivateKey = priv pk.PrivateKey = priv
return pk return pk
} }
...@@ -72,13 +79,13 @@ func (pk *PrivateKey) parse(r io.Reader) (err error) { ...@@ -72,13 +79,13 @@ func (pk *PrivateKey) parse(r io.Reader) (err error) {
pk.sha1Checksum = true pk.sha1Checksum = true
} }
default: default:
return error_.UnsupportedError("deprecated s2k function in private key") return errors.UnsupportedError("deprecated s2k function in private key")
} }
if pk.Encrypted { if pk.Encrypted {
blockSize := pk.cipher.blockSize() blockSize := pk.cipher.blockSize()
if blockSize == 0 { if blockSize == 0 {
return error_.UnsupportedError("unsupported cipher in private key: " + strconv.Itoa(int(pk.cipher))) return errors.UnsupportedError("unsupported cipher in private key: " + strconv.Itoa(int(pk.cipher)))
} }
pk.iv = make([]byte, blockSize) pk.iv = make([]byte, blockSize)
_, err = readFull(r, pk.iv) _, err = readFull(r, pk.iv)
...@@ -121,8 +128,10 @@ func (pk *PrivateKey) Serialize(w io.Writer) (err error) { ...@@ -121,8 +128,10 @@ func (pk *PrivateKey) Serialize(w io.Writer) (err error) {
switch priv := pk.PrivateKey.(type) { switch priv := pk.PrivateKey.(type) {
case *rsa.PrivateKey: case *rsa.PrivateKey:
err = serializeRSAPrivateKey(privateKeyBuf, priv) err = serializeRSAPrivateKey(privateKeyBuf, priv)
case *dsa.PrivateKey:
err = serializeDSAPrivateKey(privateKeyBuf, priv)
default: default:
err = error_.InvalidArgumentError("non-RSA private key") err = errors.InvalidArgumentError("unknown private key type")
} }
if err != nil { if err != nil {
return return
...@@ -172,6 +181,10 @@ func serializeRSAPrivateKey(w io.Writer, priv *rsa.PrivateKey) error { ...@@ -172,6 +181,10 @@ func serializeRSAPrivateKey(w io.Writer, priv *rsa.PrivateKey) error {
return writeBig(w, priv.Precomputed.Qinv) return writeBig(w, priv.Precomputed.Qinv)
} }
func serializeDSAPrivateKey(w io.Writer, priv *dsa.PrivateKey) error {
return writeBig(w, priv.X)
}
// Decrypt decrypts an encrypted private key using a passphrase. // Decrypt decrypts an encrypted private key using a passphrase.
func (pk *PrivateKey) Decrypt(passphrase []byte) error { func (pk *PrivateKey) Decrypt(passphrase []byte) error {
if !pk.Encrypted { if !pk.Encrypted {
...@@ -188,18 +201,18 @@ func (pk *PrivateKey) Decrypt(passphrase []byte) error { ...@@ -188,18 +201,18 @@ func (pk *PrivateKey) Decrypt(passphrase []byte) error {
if pk.sha1Checksum { if pk.sha1Checksum {
if len(data) < sha1.Size { if len(data) < sha1.Size {
return error_.StructuralError("truncated private key data") return errors.StructuralError("truncated private key data")
} }
h := sha1.New() h := sha1.New()
h.Write(data[:len(data)-sha1.Size]) h.Write(data[:len(data)-sha1.Size])
sum := h.Sum(nil) sum := h.Sum(nil)
if !bytes.Equal(sum, data[len(data)-sha1.Size:]) { if !bytes.Equal(sum, data[len(data)-sha1.Size:]) {
return error_.StructuralError("private key checksum failure") return errors.StructuralError("private key checksum failure")
} }
data = data[:len(data)-sha1.Size] data = data[:len(data)-sha1.Size]
} else { } else {
if len(data) < 2 { if len(data) < 2 {
return error_.StructuralError("truncated private key data") return errors.StructuralError("truncated private key data")
} }
var sum uint16 var sum uint16
for i := 0; i < len(data)-2; i++ { for i := 0; i < len(data)-2; i++ {
...@@ -207,7 +220,7 @@ func (pk *PrivateKey) Decrypt(passphrase []byte) error { ...@@ -207,7 +220,7 @@ func (pk *PrivateKey) Decrypt(passphrase []byte) error {
} }
if data[len(data)-2] != uint8(sum>>8) || if data[len(data)-2] != uint8(sum>>8) ||
data[len(data)-1] != uint8(sum) { data[len(data)-1] != uint8(sum) {
return error_.StructuralError("private key checksum failure") return errors.StructuralError("private key checksum failure")
} }
data = data[:len(data)-2] data = data[:len(data)-2]
} }
......
...@@ -7,7 +7,7 @@ package packet ...@@ -7,7 +7,7 @@ package packet
import ( import (
"crypto/dsa" "crypto/dsa"
"crypto/openpgp/elgamal" "crypto/openpgp/elgamal"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"crypto/rsa" "crypto/rsa"
"crypto/sha1" "crypto/sha1"
"encoding/binary" "encoding/binary"
...@@ -39,12 +39,11 @@ func fromBig(n *big.Int) parsedMPI { ...@@ -39,12 +39,11 @@ func fromBig(n *big.Int) parsedMPI {
} }
// NewRSAPublicKey returns a PublicKey that wraps the given rsa.PublicKey. // NewRSAPublicKey returns a PublicKey that wraps the given rsa.PublicKey.
func NewRSAPublicKey(creationTime time.Time, pub *rsa.PublicKey, isSubkey bool) *PublicKey { func NewRSAPublicKey(creationTime time.Time, pub *rsa.PublicKey) *PublicKey {
pk := &PublicKey{ pk := &PublicKey{
CreationTime: creationTime, CreationTime: creationTime,
PubKeyAlgo: PubKeyAlgoRSA, PubKeyAlgo: PubKeyAlgoRSA,
PublicKey: pub, PublicKey: pub,
IsSubkey: isSubkey,
n: fromBig(pub.N), n: fromBig(pub.N),
e: fromBig(big.NewInt(int64(pub.E))), e: fromBig(big.NewInt(int64(pub.E))),
} }
...@@ -53,6 +52,22 @@ func NewRSAPublicKey(creationTime time.Time, pub *rsa.PublicKey, isSubkey bool) ...@@ -53,6 +52,22 @@ func NewRSAPublicKey(creationTime time.Time, pub *rsa.PublicKey, isSubkey bool)
return pk return pk
} }
// NewDSAPublicKey returns a PublicKey that wraps the given rsa.PublicKey.
func NewDSAPublicKey(creationTime time.Time, pub *dsa.PublicKey) *PublicKey {
pk := &PublicKey{
CreationTime: creationTime,
PubKeyAlgo: PubKeyAlgoDSA,
PublicKey: pub,
p: fromBig(pub.P),
q: fromBig(pub.Q),
g: fromBig(pub.G),
y: fromBig(pub.Y),
}
pk.setFingerPrintAndKeyId()
return pk
}
func (pk *PublicKey) parse(r io.Reader) (err error) { func (pk *PublicKey) parse(r io.Reader) (err error) {
// RFC 4880, section 5.5.2 // RFC 4880, section 5.5.2
var buf [6]byte var buf [6]byte
...@@ -61,7 +76,7 @@ func (pk *PublicKey) parse(r io.Reader) (err error) { ...@@ -61,7 +76,7 @@ func (pk *PublicKey) parse(r io.Reader) (err error) {
return return
} }
if buf[0] != 4 { if buf[0] != 4 {
return error_.UnsupportedError("public key version") return errors.UnsupportedError("public key version")
} }
pk.CreationTime = time.Unix(int64(uint32(buf[1])<<24|uint32(buf[2])<<16|uint32(buf[3])<<8|uint32(buf[4])), 0) pk.CreationTime = time.Unix(int64(uint32(buf[1])<<24|uint32(buf[2])<<16|uint32(buf[3])<<8|uint32(buf[4])), 0)
pk.PubKeyAlgo = PublicKeyAlgorithm(buf[5]) pk.PubKeyAlgo = PublicKeyAlgorithm(buf[5])
...@@ -73,7 +88,7 @@ func (pk *PublicKey) parse(r io.Reader) (err error) { ...@@ -73,7 +88,7 @@ func (pk *PublicKey) parse(r io.Reader) (err error) {
case PubKeyAlgoElGamal: case PubKeyAlgoElGamal:
err = pk.parseElGamal(r) err = pk.parseElGamal(r)
default: default:
err = error_.UnsupportedError("public key type: " + strconv.Itoa(int(pk.PubKeyAlgo))) err = errors.UnsupportedError("public key type: " + strconv.Itoa(int(pk.PubKeyAlgo)))
} }
if err != nil { if err != nil {
return return
...@@ -105,7 +120,7 @@ func (pk *PublicKey) parseRSA(r io.Reader) (err error) { ...@@ -105,7 +120,7 @@ func (pk *PublicKey) parseRSA(r io.Reader) (err error) {
} }
if len(pk.e.bytes) > 3 { if len(pk.e.bytes) > 3 {
err = error_.UnsupportedError("large public exponent") err = errors.UnsupportedError("large public exponent")
return return
} }
rsa := &rsa.PublicKey{ rsa := &rsa.PublicKey{
...@@ -255,7 +270,7 @@ func (pk *PublicKey) serializeWithoutHeaders(w io.Writer) (err error) { ...@@ -255,7 +270,7 @@ func (pk *PublicKey) serializeWithoutHeaders(w io.Writer) (err error) {
case PubKeyAlgoElGamal: case PubKeyAlgoElGamal:
return writeMPIs(w, pk.p, pk.g, pk.y) return writeMPIs(w, pk.p, pk.g, pk.y)
} }
return error_.InvalidArgumentError("bad public-key algorithm") return errors.InvalidArgumentError("bad public-key algorithm")
} }
// CanSign returns true iff this public key can generate signatures // CanSign returns true iff this public key can generate signatures
...@@ -267,18 +282,18 @@ func (pk *PublicKey) CanSign() bool { ...@@ -267,18 +282,18 @@ func (pk *PublicKey) CanSign() bool {
// public key, of the data hashed into signed. signed is mutated by this call. // public key, of the data hashed into signed. signed is mutated by this call.
func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err error) { func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err error) {
if !pk.CanSign() { if !pk.CanSign() {
return error_.InvalidArgumentError("public key cannot generate signatures") return errors.InvalidArgumentError("public key cannot generate signatures")
} }
signed.Write(sig.HashSuffix) signed.Write(sig.HashSuffix)
hashBytes := signed.Sum(nil) hashBytes := signed.Sum(nil)
if hashBytes[0] != sig.HashTag[0] || hashBytes[1] != sig.HashTag[1] { if hashBytes[0] != sig.HashTag[0] || hashBytes[1] != sig.HashTag[1] {
return error_.SignatureError("hash tag doesn't match") return errors.SignatureError("hash tag doesn't match")
} }
if pk.PubKeyAlgo != sig.PubKeyAlgo { if pk.PubKeyAlgo != sig.PubKeyAlgo {
return error_.InvalidArgumentError("public key and signature use different algorithms") return errors.InvalidArgumentError("public key and signature use different algorithms")
} }
switch pk.PubKeyAlgo { switch pk.PubKeyAlgo {
...@@ -286,13 +301,18 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err erro ...@@ -286,13 +301,18 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err erro
rsaPublicKey, _ := pk.PublicKey.(*rsa.PublicKey) rsaPublicKey, _ := pk.PublicKey.(*rsa.PublicKey)
err = rsa.VerifyPKCS1v15(rsaPublicKey, sig.Hash, hashBytes, sig.RSASignature.bytes) err = rsa.VerifyPKCS1v15(rsaPublicKey, sig.Hash, hashBytes, sig.RSASignature.bytes)
if err != nil { if err != nil {
return error_.SignatureError("RSA verification failure") return errors.SignatureError("RSA verification failure")
} }
return nil return nil
case PubKeyAlgoDSA: case PubKeyAlgoDSA:
dsaPublicKey, _ := pk.PublicKey.(*dsa.PublicKey) dsaPublicKey, _ := pk.PublicKey.(*dsa.PublicKey)
// Need to truncate hashBytes to match FIPS 186-3 section 4.6.
subgroupSize := (dsaPublicKey.Q.BitLen() + 7) / 8
if len(hashBytes) > subgroupSize {
hashBytes = hashBytes[:subgroupSize]
}
if !dsa.Verify(dsaPublicKey, hashBytes, new(big.Int).SetBytes(sig.DSASigR.bytes), new(big.Int).SetBytes(sig.DSASigS.bytes)) { if !dsa.Verify(dsaPublicKey, hashBytes, new(big.Int).SetBytes(sig.DSASigR.bytes), new(big.Int).SetBytes(sig.DSASigS.bytes)) {
return error_.SignatureError("DSA verification failure") return errors.SignatureError("DSA verification failure")
} }
return nil return nil
default: default:
...@@ -306,7 +326,7 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err erro ...@@ -306,7 +326,7 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err erro
func keySignatureHash(pk, signed *PublicKey, sig *Signature) (h hash.Hash, err error) { func keySignatureHash(pk, signed *PublicKey, sig *Signature) (h hash.Hash, err error) {
h = sig.Hash.New() h = sig.Hash.New()
if h == nil { if h == nil {
return nil, error_.UnsupportedError("hash function") return nil, errors.UnsupportedError("hash function")
} }
// RFC 4880, section 5.2.4 // RFC 4880, section 5.2.4
...@@ -332,7 +352,7 @@ func (pk *PublicKey) VerifyKeySignature(signed *PublicKey, sig *Signature) (err ...@@ -332,7 +352,7 @@ func (pk *PublicKey) VerifyKeySignature(signed *PublicKey, sig *Signature) (err
func userIdSignatureHash(id string, pk *PublicKey, sig *Signature) (h hash.Hash, err error) { func userIdSignatureHash(id string, pk *PublicKey, sig *Signature) (h hash.Hash, err error) {
h = sig.Hash.New() h = sig.Hash.New()
if h == nil { if h == nil {
return nil, error_.UnsupportedError("hash function") return nil, errors.UnsupportedError("hash function")
} }
// RFC 4880, section 5.2.4 // RFC 4880, section 5.2.4
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
package packet package packet
import ( import (
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"io" "io"
) )
...@@ -34,7 +34,7 @@ func (r *Reader) Next() (p Packet, err error) { ...@@ -34,7 +34,7 @@ func (r *Reader) Next() (p Packet, err error) {
r.readers = r.readers[:len(r.readers)-1] r.readers = r.readers[:len(r.readers)-1]
continue continue
} }
if _, ok := err.(error_.UnknownPacketTypeError); !ok { if _, ok := err.(errors.UnknownPacketTypeError); !ok {
return nil, err return nil, err
} }
} }
......
...@@ -7,7 +7,7 @@ package packet ...@@ -7,7 +7,7 @@ package packet
import ( import (
"bytes" "bytes"
"crypto/cipher" "crypto/cipher"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"crypto/openpgp/s2k" "crypto/openpgp/s2k"
"io" "io"
"strconv" "strconv"
...@@ -37,12 +37,12 @@ func (ske *SymmetricKeyEncrypted) parse(r io.Reader) (err error) { ...@@ -37,12 +37,12 @@ func (ske *SymmetricKeyEncrypted) parse(r io.Reader) (err error) {
return return
} }
if buf[0] != symmetricKeyEncryptedVersion { if buf[0] != symmetricKeyEncryptedVersion {
return error_.UnsupportedError("SymmetricKeyEncrypted version") return errors.UnsupportedError("SymmetricKeyEncrypted version")
} }
ske.CipherFunc = CipherFunction(buf[1]) ske.CipherFunc = CipherFunction(buf[1])
if ske.CipherFunc.KeySize() == 0 { if ske.CipherFunc.KeySize() == 0 {
return error_.UnsupportedError("unknown cipher: " + strconv.Itoa(int(buf[1]))) return errors.UnsupportedError("unknown cipher: " + strconv.Itoa(int(buf[1])))
} }
ske.s2k, err = s2k.Parse(r) ske.s2k, err = s2k.Parse(r)
...@@ -60,7 +60,7 @@ func (ske *SymmetricKeyEncrypted) parse(r io.Reader) (err error) { ...@@ -60,7 +60,7 @@ func (ske *SymmetricKeyEncrypted) parse(r io.Reader) (err error) {
err = nil err = nil
if n != 0 { if n != 0 {
if n == maxSessionKeySizeInBytes { if n == maxSessionKeySizeInBytes {
return error_.UnsupportedError("oversized encrypted session key") return errors.UnsupportedError("oversized encrypted session key")
} }
ske.encryptedKey = encryptedKey[:n] ske.encryptedKey = encryptedKey[:n]
} }
...@@ -89,13 +89,13 @@ func (ske *SymmetricKeyEncrypted) Decrypt(passphrase []byte) error { ...@@ -89,13 +89,13 @@ func (ske *SymmetricKeyEncrypted) Decrypt(passphrase []byte) error {
c.XORKeyStream(ske.encryptedKey, ske.encryptedKey) c.XORKeyStream(ske.encryptedKey, ske.encryptedKey)
ske.CipherFunc = CipherFunction(ske.encryptedKey[0]) ske.CipherFunc = CipherFunction(ske.encryptedKey[0])
if ske.CipherFunc.blockSize() == 0 { if ske.CipherFunc.blockSize() == 0 {
return error_.UnsupportedError("unknown cipher: " + strconv.Itoa(int(ske.CipherFunc))) return errors.UnsupportedError("unknown cipher: " + strconv.Itoa(int(ske.CipherFunc)))
} }
ske.CipherFunc = CipherFunction(ske.encryptedKey[0]) ske.CipherFunc = CipherFunction(ske.encryptedKey[0])
ske.Key = ske.encryptedKey[1:] ske.Key = ske.encryptedKey[1:]
if len(ske.Key)%ske.CipherFunc.blockSize() != 0 { if len(ske.Key)%ske.CipherFunc.blockSize() != 0 {
ske.Key = nil ske.Key = nil
return error_.StructuralError("length of decrypted key not a multiple of block size") return errors.StructuralError("length of decrypted key not a multiple of block size")
} }
} }
...@@ -110,7 +110,7 @@ func (ske *SymmetricKeyEncrypted) Decrypt(passphrase []byte) error { ...@@ -110,7 +110,7 @@ func (ske *SymmetricKeyEncrypted) Decrypt(passphrase []byte) error {
func SerializeSymmetricKeyEncrypted(w io.Writer, rand io.Reader, passphrase []byte, cipherFunc CipherFunction) (key []byte, err error) { func SerializeSymmetricKeyEncrypted(w io.Writer, rand io.Reader, passphrase []byte, cipherFunc CipherFunction) (key []byte, err error) {
keySize := cipherFunc.KeySize() keySize := cipherFunc.KeySize()
if keySize == 0 { if keySize == 0 {
return nil, error_.UnsupportedError("unknown cipher: " + strconv.Itoa(int(cipherFunc))) return nil, errors.UnsupportedError("unknown cipher: " + strconv.Itoa(int(cipherFunc)))
} }
s2kBuf := new(bytes.Buffer) s2kBuf := new(bytes.Buffer)
......
...@@ -6,8 +6,7 @@ package packet ...@@ -6,8 +6,7 @@ package packet
import ( import (
"crypto/cipher" "crypto/cipher"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"crypto/rand"
"crypto/sha1" "crypto/sha1"
"crypto/subtle" "crypto/subtle"
"hash" "hash"
...@@ -35,7 +34,7 @@ func (se *SymmetricallyEncrypted) parse(r io.Reader) error { ...@@ -35,7 +34,7 @@ func (se *SymmetricallyEncrypted) parse(r io.Reader) error {
return err return err
} }
if buf[0] != symmetricallyEncryptedVersion { if buf[0] != symmetricallyEncryptedVersion {
return error_.UnsupportedError("unknown SymmetricallyEncrypted version") return errors.UnsupportedError("unknown SymmetricallyEncrypted version")
} }
} }
se.contents = r se.contents = r
...@@ -48,10 +47,10 @@ func (se *SymmetricallyEncrypted) parse(r io.Reader) error { ...@@ -48,10 +47,10 @@ func (se *SymmetricallyEncrypted) parse(r io.Reader) error {
func (se *SymmetricallyEncrypted) Decrypt(c CipherFunction, key []byte) (io.ReadCloser, error) { func (se *SymmetricallyEncrypted) Decrypt(c CipherFunction, key []byte) (io.ReadCloser, error) {
keySize := c.KeySize() keySize := c.KeySize()
if keySize == 0 { if keySize == 0 {
return nil, error_.UnsupportedError("unknown cipher: " + strconv.Itoa(int(c))) return nil, errors.UnsupportedError("unknown cipher: " + strconv.Itoa(int(c)))
} }
if len(key) != keySize { if len(key) != keySize {
return nil, error_.InvalidArgumentError("SymmetricallyEncrypted: incorrect key length") return nil, errors.InvalidArgumentError("SymmetricallyEncrypted: incorrect key length")
} }
if se.prefix == nil { if se.prefix == nil {
...@@ -61,7 +60,7 @@ func (se *SymmetricallyEncrypted) Decrypt(c CipherFunction, key []byte) (io.Read ...@@ -61,7 +60,7 @@ func (se *SymmetricallyEncrypted) Decrypt(c CipherFunction, key []byte) (io.Read
return nil, err return nil, err
} }
} else if len(se.prefix) != c.blockSize()+2 { } else if len(se.prefix) != c.blockSize()+2 {
return nil, error_.InvalidArgumentError("can't try ciphers with different block lengths") return nil, errors.InvalidArgumentError("can't try ciphers with different block lengths")
} }
ocfbResync := cipher.OCFBResync ocfbResync := cipher.OCFBResync
...@@ -72,7 +71,7 @@ func (se *SymmetricallyEncrypted) Decrypt(c CipherFunction, key []byte) (io.Read ...@@ -72,7 +71,7 @@ func (se *SymmetricallyEncrypted) Decrypt(c CipherFunction, key []byte) (io.Read
s := cipher.NewOCFBDecrypter(c.new(key), se.prefix, ocfbResync) s := cipher.NewOCFBDecrypter(c.new(key), se.prefix, ocfbResync)
if s == nil { if s == nil {
return nil, error_.KeyIncorrectError return nil, errors.KeyIncorrectError
} }
plaintext := cipher.StreamReader{S: s, R: se.contents} plaintext := cipher.StreamReader{S: s, R: se.contents}
...@@ -181,7 +180,7 @@ const mdcPacketTagByte = byte(0x80) | 0x40 | 19 ...@@ -181,7 +180,7 @@ const mdcPacketTagByte = byte(0x80) | 0x40 | 19
func (ser *seMDCReader) Close() error { func (ser *seMDCReader) Close() error {
if ser.error { if ser.error {
return error_.SignatureError("error during reading") return errors.SignatureError("error during reading")
} }
for !ser.eof { for !ser.eof {
...@@ -192,18 +191,18 @@ func (ser *seMDCReader) Close() error { ...@@ -192,18 +191,18 @@ func (ser *seMDCReader) Close() error {
break break
} }
if err != nil { if err != nil {
return error_.SignatureError("error during reading") return errors.SignatureError("error during reading")
} }
} }
if ser.trailer[0] != mdcPacketTagByte || ser.trailer[1] != sha1.Size { if ser.trailer[0] != mdcPacketTagByte || ser.trailer[1] != sha1.Size {
return error_.SignatureError("MDC packet not found") return errors.SignatureError("MDC packet not found")
} }
ser.h.Write(ser.trailer[:2]) ser.h.Write(ser.trailer[:2])
final := ser.h.Sum(nil) final := ser.h.Sum(nil)
if subtle.ConstantTimeCompare(final, ser.trailer[2:]) != 1 { if subtle.ConstantTimeCompare(final, ser.trailer[2:]) != 1 {
return error_.SignatureError("hash mismatch") return errors.SignatureError("hash mismatch")
} }
return nil return nil
} }
...@@ -253,9 +252,9 @@ func (c noOpCloser) Close() error { ...@@ -253,9 +252,9 @@ func (c noOpCloser) Close() error {
// SerializeSymmetricallyEncrypted serializes a symmetrically encrypted packet // SerializeSymmetricallyEncrypted serializes a symmetrically encrypted packet
// to w and returns a WriteCloser to which the to-be-encrypted packets can be // to w and returns a WriteCloser to which the to-be-encrypted packets can be
// written. // written.
func SerializeSymmetricallyEncrypted(w io.Writer, c CipherFunction, key []byte) (contents io.WriteCloser, err error) { func SerializeSymmetricallyEncrypted(w io.Writer, rand io.Reader, c CipherFunction, key []byte) (contents io.WriteCloser, err error) {
if c.KeySize() != len(key) { if c.KeySize() != len(key) {
return nil, error_.InvalidArgumentError("SymmetricallyEncrypted.Serialize: bad key length") return nil, errors.InvalidArgumentError("SymmetricallyEncrypted.Serialize: bad key length")
} }
writeCloser := noOpCloser{w} writeCloser := noOpCloser{w}
ciphertext, err := serializeStreamHeader(writeCloser, packetTypeSymmetricallyEncryptedMDC) ciphertext, err := serializeStreamHeader(writeCloser, packetTypeSymmetricallyEncryptedMDC)
...@@ -271,7 +270,7 @@ func SerializeSymmetricallyEncrypted(w io.Writer, c CipherFunction, key []byte) ...@@ -271,7 +270,7 @@ func SerializeSymmetricallyEncrypted(w io.Writer, c CipherFunction, key []byte)
block := c.new(key) block := c.new(key)
blockSize := block.BlockSize() blockSize := block.BlockSize()
iv := make([]byte, blockSize) iv := make([]byte, blockSize)
_, err = rand.Reader.Read(iv) _, err = rand.Read(iv)
if err != nil { if err != nil {
return return
} }
......
...@@ -6,7 +6,8 @@ package packet ...@@ -6,7 +6,8 @@ package packet
import ( import (
"bytes" "bytes"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"crypto/rand"
"crypto/sha1" "crypto/sha1"
"encoding/hex" "encoding/hex"
"io" "io"
...@@ -70,7 +71,7 @@ func testMDCReader(t *testing.T) { ...@@ -70,7 +71,7 @@ func testMDCReader(t *testing.T) {
err = mdcReader.Close() err = mdcReader.Close()
if err == nil { if err == nil {
t.Error("corruption: no error") t.Error("corruption: no error")
} else if _, ok := err.(*error_.SignatureError); !ok { } else if _, ok := err.(*errors.SignatureError); !ok {
t.Errorf("corruption: expected SignatureError, got: %s", err) t.Errorf("corruption: expected SignatureError, got: %s", err)
} }
} }
...@@ -82,7 +83,7 @@ func TestSerialize(t *testing.T) { ...@@ -82,7 +83,7 @@ func TestSerialize(t *testing.T) {
c := CipherAES128 c := CipherAES128
key := make([]byte, c.KeySize()) key := make([]byte, c.KeySize())
w, err := SerializeSymmetricallyEncrypted(buf, c, key) w, err := SerializeSymmetricallyEncrypted(buf, rand.Reader, c, key)
if err != nil { if err != nil {
t.Errorf("error from SerializeSymmetricallyEncrypted: %s", err) t.Errorf("error from SerializeSymmetricallyEncrypted: %s", err)
return return
......
...@@ -8,7 +8,7 @@ package openpgp ...@@ -8,7 +8,7 @@ package openpgp
import ( import (
"crypto" "crypto"
"crypto/openpgp/armor" "crypto/openpgp/armor"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"crypto/openpgp/packet" "crypto/openpgp/packet"
_ "crypto/sha256" _ "crypto/sha256"
"hash" "hash"
...@@ -27,7 +27,7 @@ func readArmored(r io.Reader, expectedType string) (body io.Reader, err error) { ...@@ -27,7 +27,7 @@ func readArmored(r io.Reader, expectedType string) (body io.Reader, err error) {
} }
if block.Type != expectedType { if block.Type != expectedType {
return nil, error_.InvalidArgumentError("expected '" + expectedType + "', got: " + block.Type) return nil, errors.InvalidArgumentError("expected '" + expectedType + "', got: " + block.Type)
} }
return block.Body, nil return block.Body, nil
...@@ -130,7 +130,7 @@ ParsePackets: ...@@ -130,7 +130,7 @@ ParsePackets:
case *packet.Compressed, *packet.LiteralData, *packet.OnePassSignature: case *packet.Compressed, *packet.LiteralData, *packet.OnePassSignature:
// This message isn't encrypted. // This message isn't encrypted.
if len(symKeys) != 0 || len(pubKeys) != 0 { if len(symKeys) != 0 || len(pubKeys) != 0 {
return nil, error_.StructuralError("key material not followed by encrypted message") return nil, errors.StructuralError("key material not followed by encrypted message")
} }
packets.Unread(p) packets.Unread(p)
return readSignedMessage(packets, nil, keyring) return readSignedMessage(packets, nil, keyring)
...@@ -161,7 +161,7 @@ FindKey: ...@@ -161,7 +161,7 @@ FindKey:
continue continue
} }
decrypted, err = se.Decrypt(pk.encryptedKey.CipherFunc, pk.encryptedKey.Key) decrypted, err = se.Decrypt(pk.encryptedKey.CipherFunc, pk.encryptedKey.Key)
if err != nil && err != error_.KeyIncorrectError { if err != nil && err != errors.KeyIncorrectError {
return nil, err return nil, err
} }
if decrypted != nil { if decrypted != nil {
...@@ -179,11 +179,11 @@ FindKey: ...@@ -179,11 +179,11 @@ FindKey:
} }
if len(candidates) == 0 && len(symKeys) == 0 { if len(candidates) == 0 && len(symKeys) == 0 {
return nil, error_.KeyIncorrectError return nil, errors.KeyIncorrectError
} }
if prompt == nil { if prompt == nil {
return nil, error_.KeyIncorrectError return nil, errors.KeyIncorrectError
} }
passphrase, err := prompt(candidates, len(symKeys) != 0) passphrase, err := prompt(candidates, len(symKeys) != 0)
...@@ -197,7 +197,7 @@ FindKey: ...@@ -197,7 +197,7 @@ FindKey:
err = s.Decrypt(passphrase) err = s.Decrypt(passphrase)
if err == nil && !s.Encrypted { if err == nil && !s.Encrypted {
decrypted, err = se.Decrypt(s.CipherFunc, s.Key) decrypted, err = se.Decrypt(s.CipherFunc, s.Key)
if err != nil && err != error_.KeyIncorrectError { if err != nil && err != errors.KeyIncorrectError {
return nil, err return nil, err
} }
if decrypted != nil { if decrypted != nil {
...@@ -237,7 +237,7 @@ FindLiteralData: ...@@ -237,7 +237,7 @@ FindLiteralData:
packets.Push(p.Body) packets.Push(p.Body)
case *packet.OnePassSignature: case *packet.OnePassSignature:
if !p.IsLast { if !p.IsLast {
return nil, error_.UnsupportedError("nested signatures") return nil, errors.UnsupportedError("nested signatures")
} }
h, wrappedHash, err = hashForSignature(p.Hash, p.SigType) h, wrappedHash, err = hashForSignature(p.Hash, p.SigType)
...@@ -281,7 +281,7 @@ FindLiteralData: ...@@ -281,7 +281,7 @@ FindLiteralData:
func hashForSignature(hashId crypto.Hash, sigType packet.SignatureType) (hash.Hash, hash.Hash, error) { func hashForSignature(hashId crypto.Hash, sigType packet.SignatureType) (hash.Hash, hash.Hash, error) {
h := hashId.New() h := hashId.New()
if h == nil { if h == nil {
return nil, nil, error_.UnsupportedError("hash not available: " + strconv.Itoa(int(hashId))) return nil, nil, errors.UnsupportedError("hash not available: " + strconv.Itoa(int(hashId)))
} }
switch sigType { switch sigType {
...@@ -291,7 +291,7 @@ func hashForSignature(hashId crypto.Hash, sigType packet.SignatureType) (hash.Ha ...@@ -291,7 +291,7 @@ func hashForSignature(hashId crypto.Hash, sigType packet.SignatureType) (hash.Ha
return h, NewCanonicalTextHash(h), nil return h, NewCanonicalTextHash(h), nil
} }
return nil, nil, error_.UnsupportedError("unsupported signature type: " + strconv.Itoa(int(sigType))) return nil, nil, errors.UnsupportedError("unsupported signature type: " + strconv.Itoa(int(sigType)))
} }
// checkReader wraps an io.Reader from a LiteralData packet. When it sees EOF // checkReader wraps an io.Reader from a LiteralData packet. When it sees EOF
...@@ -333,7 +333,7 @@ func (scr *signatureCheckReader) Read(buf []byte) (n int, err error) { ...@@ -333,7 +333,7 @@ func (scr *signatureCheckReader) Read(buf []byte) (n int, err error) {
var ok bool var ok bool
if scr.md.Signature, ok = p.(*packet.Signature); !ok { if scr.md.Signature, ok = p.(*packet.Signature); !ok {
scr.md.SignatureError = error_.StructuralError("LiteralData not followed by Signature") scr.md.SignatureError = errors.StructuralError("LiteralData not followed by Signature")
return return
} }
...@@ -363,16 +363,16 @@ func CheckDetachedSignature(keyring KeyRing, signed, signature io.Reader) (signe ...@@ -363,16 +363,16 @@ func CheckDetachedSignature(keyring KeyRing, signed, signature io.Reader) (signe
sig, ok := p.(*packet.Signature) sig, ok := p.(*packet.Signature)
if !ok { if !ok {
return nil, error_.StructuralError("non signature packet found") return nil, errors.StructuralError("non signature packet found")
} }
if sig.IssuerKeyId == nil { if sig.IssuerKeyId == nil {
return nil, error_.StructuralError("signature doesn't have an issuer") return nil, errors.StructuralError("signature doesn't have an issuer")
} }
keys := keyring.KeysById(*sig.IssuerKeyId) keys := keyring.KeysById(*sig.IssuerKeyId)
if len(keys) == 0 { if len(keys) == 0 {
return nil, error_.UnknownIssuerError return nil, errors.UnknownIssuerError
} }
h, wrappedHash, err := hashForSignature(sig.Hash, sig.SigType) h, wrappedHash, err := hashForSignature(sig.Hash, sig.SigType)
...@@ -399,7 +399,7 @@ func CheckDetachedSignature(keyring KeyRing, signed, signature io.Reader) (signe ...@@ -399,7 +399,7 @@ func CheckDetachedSignature(keyring KeyRing, signed, signature io.Reader) (signe
return return
} }
return nil, error_.UnknownIssuerError return nil, errors.UnknownIssuerError
} }
// CheckArmoredDetachedSignature performs the same actions as // CheckArmoredDetachedSignature performs the same actions as
......
...@@ -6,7 +6,8 @@ package openpgp ...@@ -6,7 +6,8 @@ package openpgp
import ( import (
"bytes" "bytes"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
_ "crypto/sha512"
"encoding/hex" "encoding/hex"
"io" "io"
"io/ioutil" "io/ioutil"
...@@ -77,6 +78,15 @@ func TestReadDSAKey(t *testing.T) { ...@@ -77,6 +78,15 @@ func TestReadDSAKey(t *testing.T) {
} }
} }
func TestDSAHashTruncatation(t *testing.T) {
// dsaKeyWithSHA512 was generated with GnuPG and --cert-digest-algo
// SHA512 in order to require DSA hash truncation to verify correctly.
_, err := ReadKeyRing(readerFromHex(dsaKeyWithSHA512))
if err != nil {
t.Error(err)
}
}
func TestGetKeyById(t *testing.T) { func TestGetKeyById(t *testing.T) {
kring, _ := ReadKeyRing(readerFromHex(testKeys1And2Hex)) kring, _ := ReadKeyRing(readerFromHex(testKeys1And2Hex))
...@@ -151,18 +161,18 @@ func TestSignedEncryptedMessage(t *testing.T) { ...@@ -151,18 +161,18 @@ func TestSignedEncryptedMessage(t *testing.T) {
prompt := func(keys []Key, symmetric bool) ([]byte, error) { prompt := func(keys []Key, symmetric bool) ([]byte, error) {
if symmetric { if symmetric {
t.Errorf("prompt: message was marked as symmetrically encrypted") t.Errorf("prompt: message was marked as symmetrically encrypted")
return nil, error_.KeyIncorrectError return nil, errors.KeyIncorrectError
} }
if len(keys) == 0 { if len(keys) == 0 {
t.Error("prompt: no keys requested") t.Error("prompt: no keys requested")
return nil, error_.KeyIncorrectError return nil, errors.KeyIncorrectError
} }
err := keys[0].PrivateKey.Decrypt([]byte("passphrase")) err := keys[0].PrivateKey.Decrypt([]byte("passphrase"))
if err != nil { if err != nil {
t.Errorf("prompt: error decrypting key: %s", err) t.Errorf("prompt: error decrypting key: %s", err)
return nil, error_.KeyIncorrectError return nil, errors.KeyIncorrectError
} }
return nil, nil return nil, nil
...@@ -286,7 +296,7 @@ func TestReadingArmoredPrivateKey(t *testing.T) { ...@@ -286,7 +296,7 @@ func TestReadingArmoredPrivateKey(t *testing.T) {
func TestNoArmoredData(t *testing.T) { func TestNoArmoredData(t *testing.T) {
_, err := ReadArmoredKeyRing(bytes.NewBufferString("foo")) _, err := ReadArmoredKeyRing(bytes.NewBufferString("foo"))
if _, ok := err.(error_.InvalidArgumentError); !ok { if _, ok := err.(errors.InvalidArgumentError); !ok {
t.Errorf("error was not an InvalidArgumentError: %s", err) t.Errorf("error was not an InvalidArgumentError: %s", err)
} }
} }
...@@ -358,3 +368,5 @@ AHcVnXjtxrULkQFGbGvhKURLvS9WnzD/m1K2zzwxzkPTzT9/Yf06O6Mal5AdugPL ...@@ -358,3 +368,5 @@ AHcVnXjtxrULkQFGbGvhKURLvS9WnzD/m1K2zzwxzkPTzT9/Yf06O6Mal5AdugPL
VrM0m72/jnpKo04= VrM0m72/jnpKo04=
=zNCn =zNCn
-----END PGP PRIVATE KEY BLOCK-----` -----END PGP PRIVATE KEY BLOCK-----`
const dsaKeyWithSHA512 = `9901a2044f04b07f110400db244efecc7316553ee08d179972aab87bb1214de7692593fcf5b6feb1c80fba268722dd464748539b85b81d574cd2d7ad0ca2444de4d849b8756bad7768c486c83a824f9bba4af773d11742bdfb4ac3b89ef8cc9452d4aad31a37e4b630d33927bff68e879284a1672659b8b298222fc68f370f3e24dccacc4a862442b9438b00a0ea444a24088dc23e26df7daf8f43cba3bffc4fe703fe3d6cd7fdca199d54ed8ae501c30e3ec7871ea9cdd4cf63cfe6fc82281d70a5b8bb493f922cd99fba5f088935596af087c8d818d5ec4d0b9afa7f070b3d7c1dd32a84fca08d8280b4890c8da1dde334de8e3cad8450eed2a4a4fcc2db7b8e5528b869a74a7f0189e11ef097ef1253582348de072bb07a9fa8ab838e993cef0ee203ff49298723e2d1f549b00559f886cd417a41692ce58d0ac1307dc71d85a8af21b0cf6eaa14baf2922d3a70389bedf17cc514ba0febbd107675a372fe84b90162a9e88b14d4b1c6be855b96b33fb198c46f058568817780435b6936167ebb3724b680f32bf27382ada2e37a879b3d9de2abe0c3f399350afd1ad438883f4791e2e3b4184453412068617368207472756e636174696f6e207465737488620413110a002205024f04b07f021b03060b090807030206150802090a0b0416020301021e01021780000a0910ef20e0cefca131581318009e2bf3bf047a44d75a9bacd00161ee04d435522397009a03a60d51bd8a568c6c021c8d7cf1be8d990d6417b0020003`
...@@ -8,7 +8,7 @@ package s2k ...@@ -8,7 +8,7 @@ package s2k
import ( import (
"crypto" "crypto"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"hash" "hash"
"io" "io"
"strconv" "strconv"
...@@ -89,11 +89,11 @@ func Parse(r io.Reader) (f func(out, in []byte), err error) { ...@@ -89,11 +89,11 @@ func Parse(r io.Reader) (f func(out, in []byte), err error) {
hash, ok := HashIdToHash(buf[1]) hash, ok := HashIdToHash(buf[1])
if !ok { if !ok {
return nil, error_.UnsupportedError("hash for S2K function: " + strconv.Itoa(int(buf[1]))) return nil, errors.UnsupportedError("hash for S2K function: " + strconv.Itoa(int(buf[1])))
} }
h := hash.New() h := hash.New()
if h == nil { if h == nil {
return nil, error_.UnsupportedError("hash not available: " + strconv.Itoa(int(hash))) return nil, errors.UnsupportedError("hash not available: " + strconv.Itoa(int(hash)))
} }
switch buf[0] { switch buf[0] {
...@@ -123,7 +123,7 @@ func Parse(r io.Reader) (f func(out, in []byte), err error) { ...@@ -123,7 +123,7 @@ func Parse(r io.Reader) (f func(out, in []byte), err error) {
return f, nil return f, nil
} }
return nil, error_.UnsupportedError("S2K function") return nil, errors.UnsupportedError("S2K function")
} }
// Serialize salts and stretches the given passphrase and writes the resulting // Serialize salts and stretches the given passphrase and writes the resulting
......
...@@ -7,7 +7,7 @@ package openpgp ...@@ -7,7 +7,7 @@ package openpgp
import ( import (
"crypto" "crypto"
"crypto/openpgp/armor" "crypto/openpgp/armor"
error_ "crypto/openpgp/error" "crypto/openpgp/errors"
"crypto/openpgp/packet" "crypto/openpgp/packet"
"crypto/openpgp/s2k" "crypto/openpgp/s2k"
"crypto/rand" "crypto/rand"
...@@ -58,10 +58,10 @@ func armoredDetachSign(w io.Writer, signer *Entity, message io.Reader, sigType p ...@@ -58,10 +58,10 @@ func armoredDetachSign(w io.Writer, signer *Entity, message io.Reader, sigType p
func detachSign(w io.Writer, signer *Entity, message io.Reader, sigType packet.SignatureType) (err error) { func detachSign(w io.Writer, signer *Entity, message io.Reader, sigType packet.SignatureType) (err error) {
if signer.PrivateKey == nil { if signer.PrivateKey == nil {
return error_.InvalidArgumentError("signing key doesn't have a private key") return errors.InvalidArgumentError("signing key doesn't have a private key")
} }
if signer.PrivateKey.Encrypted { if signer.PrivateKey.Encrypted {
return error_.InvalidArgumentError("signing key is encrypted") return errors.InvalidArgumentError("signing key is encrypted")
} }
sig := new(packet.Signature) sig := new(packet.Signature)
...@@ -77,7 +77,7 @@ func detachSign(w io.Writer, signer *Entity, message io.Reader, sigType packet.S ...@@ -77,7 +77,7 @@ func detachSign(w io.Writer, signer *Entity, message io.Reader, sigType packet.S
} }
io.Copy(wrappedHash, message) io.Copy(wrappedHash, message)
err = sig.Sign(h, signer.PrivateKey) err = sig.Sign(rand.Reader, h, signer.PrivateKey)
if err != nil { if err != nil {
return return
} }
...@@ -111,7 +111,7 @@ func SymmetricallyEncrypt(ciphertext io.Writer, passphrase []byte, hints *FileHi ...@@ -111,7 +111,7 @@ func SymmetricallyEncrypt(ciphertext io.Writer, passphrase []byte, hints *FileHi
if err != nil { if err != nil {
return return
} }
w, err := packet.SerializeSymmetricallyEncrypted(ciphertext, packet.CipherAES128, key) w, err := packet.SerializeSymmetricallyEncrypted(ciphertext, rand.Reader, packet.CipherAES128, key)
if err != nil { if err != nil {
return return
} }
...@@ -156,7 +156,7 @@ func Encrypt(ciphertext io.Writer, to []*Entity, signed *Entity, hints *FileHint ...@@ -156,7 +156,7 @@ func Encrypt(ciphertext io.Writer, to []*Entity, signed *Entity, hints *FileHint
if signed != nil { if signed != nil {
signer = signed.signingKey().PrivateKey signer = signed.signingKey().PrivateKey
if signer == nil || signer.Encrypted { if signer == nil || signer.Encrypted {
return nil, error_.InvalidArgumentError("signing key must be decrypted") return nil, errors.InvalidArgumentError("signing key must be decrypted")
} }
} }
...@@ -183,7 +183,7 @@ func Encrypt(ciphertext io.Writer, to []*Entity, signed *Entity, hints *FileHint ...@@ -183,7 +183,7 @@ func Encrypt(ciphertext io.Writer, to []*Entity, signed *Entity, hints *FileHint
for i := range to { for i := range to {
encryptKeys[i] = to[i].encryptionKey() encryptKeys[i] = to[i].encryptionKey()
if encryptKeys[i].PublicKey == nil { if encryptKeys[i].PublicKey == nil {
return nil, error_.InvalidArgumentError("cannot encrypt a message to key id " + strconv.FormatUint(to[i].PrimaryKey.KeyId, 16) + " because it has no encryption keys") return nil, errors.InvalidArgumentError("cannot encrypt a message to key id " + strconv.FormatUint(to[i].PrimaryKey.KeyId, 16) + " because it has no encryption keys")
} }
sig := to[i].primaryIdentity().SelfSignature sig := to[i].primaryIdentity().SelfSignature
...@@ -201,7 +201,7 @@ func Encrypt(ciphertext io.Writer, to []*Entity, signed *Entity, hints *FileHint ...@@ -201,7 +201,7 @@ func Encrypt(ciphertext io.Writer, to []*Entity, signed *Entity, hints *FileHint
} }
if len(candidateCiphers) == 0 || len(candidateHashes) == 0 { if len(candidateCiphers) == 0 || len(candidateHashes) == 0 {
return nil, error_.InvalidArgumentError("cannot encrypt because recipient set shares no common algorithms") return nil, errors.InvalidArgumentError("cannot encrypt because recipient set shares no common algorithms")
} }
cipher := packet.CipherFunction(candidateCiphers[0]) cipher := packet.CipherFunction(candidateCiphers[0])
...@@ -217,7 +217,7 @@ func Encrypt(ciphertext io.Writer, to []*Entity, signed *Entity, hints *FileHint ...@@ -217,7 +217,7 @@ func Encrypt(ciphertext io.Writer, to []*Entity, signed *Entity, hints *FileHint
} }
} }
encryptedData, err := packet.SerializeSymmetricallyEncrypted(ciphertext, cipher, symKey) encryptedData, err := packet.SerializeSymmetricallyEncrypted(ciphertext, rand.Reader, cipher, symKey)
if err != nil { if err != nil {
return return
} }
...@@ -287,7 +287,7 @@ func (s signatureWriter) Close() error { ...@@ -287,7 +287,7 @@ func (s signatureWriter) Close() error {
IssuerKeyId: &s.signer.KeyId, IssuerKeyId: &s.signer.KeyId,
} }
if err := sig.Sign(s.h, s.signer); err != nil { if err := sig.Sign(rand.Reader, s.h, s.signer); err != nil {
return err return err
} }
if err := s.literalData.Close(); err != nil { if err := s.literalData.Close(); err != nil {
......
...@@ -222,7 +222,7 @@ func TestEncryption(t *testing.T) { ...@@ -222,7 +222,7 @@ func TestEncryption(t *testing.T) {
if test.isSigned { if test.isSigned {
if md.SignatureError != nil { if md.SignatureError != nil {
t.Errorf("#%d: signature error: %s", i, err) t.Errorf("#%d: signature error: %s", i, md.SignatureError)
} }
if md.Signature == nil { if md.Signature == nil {
t.Error("signature missing") t.Error("signature missing")
......
...@@ -111,6 +111,18 @@ type ConnectionState struct { ...@@ -111,6 +111,18 @@ type ConnectionState struct {
VerifiedChains [][]*x509.Certificate VerifiedChains [][]*x509.Certificate
} }
// ClientAuthType declares the policy the server will follow for
// TLS Client Authentication.
type ClientAuthType int
const (
NoClientCert ClientAuthType = iota
RequestClientCert
RequireAnyClientCert
VerifyClientCertIfGiven
RequireAndVerifyClientCert
)
// A Config structure is used to configure a TLS client or server. After one // A Config structure is used to configure a TLS client or server. After one
// has been passed to a TLS function it must not be modified. // has been passed to a TLS function it must not be modified.
type Config struct { type Config struct {
...@@ -120,7 +132,7 @@ type Config struct { ...@@ -120,7 +132,7 @@ type Config struct {
Rand io.Reader Rand io.Reader
// Time returns the current time as the number of seconds since the epoch. // Time returns the current time as the number of seconds since the epoch.
// If Time is nil, TLS uses the system time.Seconds. // If Time is nil, TLS uses time.Now.
Time func() time.Time Time func() time.Time
// Certificates contains one or more certificate chains // Certificates contains one or more certificate chains
...@@ -148,11 +160,14 @@ type Config struct { ...@@ -148,11 +160,14 @@ type Config struct {
// hosting. // hosting.
ServerName string ServerName string
// AuthenticateClient controls whether a server will request a certificate // ClientAuth determines the server's policy for
// from the client. It does not require that the client send a // TLS Client Authentication. The default is NoClientCert.
// certificate nor does it require that the certificate sent be ClientAuth ClientAuthType
// anything more than self-signed.
AuthenticateClient bool // ClientCAs defines the set of root certificate authorities
// that servers use if required to verify a client certificate
// by the policy in ClientAuth.
ClientCAs *x509.CertPool
// InsecureSkipVerify controls whether a client verifies the // InsecureSkipVerify controls whether a client verifies the
// server's certificate chain and host name. // server's certificate chain and host name.
...@@ -259,6 +274,11 @@ type Certificate struct { ...@@ -259,6 +274,11 @@ type Certificate struct {
// OCSPStaple contains an optional OCSP response which will be served // OCSPStaple contains an optional OCSP response which will be served
// to clients that request it. // to clients that request it.
OCSPStaple []byte OCSPStaple []byte
// Leaf is the parsed form of the leaf certificate, which may be
// initialized using x509.ParseCertificate to reduce per-handshake
// processing for TLS clients doing client authentication. If nil, the
// leaf certificate will be parsed as needed.
Leaf *x509.Certificate
} }
// A TLS record. // A TLS record.
......
...@@ -31,7 +31,7 @@ func main() { ...@@ -31,7 +31,7 @@ func main() {
return return
} }
now := time.Seconds() now := time.Now()
template := x509.Certificate{ template := x509.Certificate{
SerialNumber: new(big.Int).SetInt64(0), SerialNumber: new(big.Int).SetInt64(0),
...@@ -39,8 +39,8 @@ func main() { ...@@ -39,8 +39,8 @@ func main() {
CommonName: *hostName, CommonName: *hostName,
Organization: []string{"Acme Co"}, Organization: []string{"Acme Co"},
}, },
NotBefore: time.SecondsToUTC(now - 300), NotBefore: now.Add(-5 * time.Minute).UTC(),
NotAfter: time.SecondsToUTC(now + 60*60*24*365), // valid for 1 year. NotAfter: now.AddDate(1, 0, 0).UTC(), // valid for 1 year.
SubjectKeyId: []byte{1, 2, 3, 4}, SubjectKeyId: []byte{1, 2, 3, 4},
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
......
...@@ -5,12 +5,14 @@ ...@@ -5,12 +5,14 @@
package tls package tls
import ( import (
"bytes"
"crypto" "crypto"
"crypto/rsa" "crypto/rsa"
"crypto/subtle" "crypto/subtle"
"crypto/x509" "crypto/x509"
"errors" "errors"
"io" "io"
"strconv"
) )
func (c *Conn) clientHandshake() error { func (c *Conn) clientHandshake() error {
...@@ -162,10 +164,23 @@ func (c *Conn) clientHandshake() error { ...@@ -162,10 +164,23 @@ func (c *Conn) clientHandshake() error {
} }
} }
transmitCert := false var certToSend *Certificate
certReq, ok := msg.(*certificateRequestMsg) certReq, ok := msg.(*certificateRequestMsg)
if ok { if ok {
// We only accept certificates with RSA keys. // RFC 4346 on the certificateAuthorities field:
// A list of the distinguished names of acceptable certificate
// authorities. These distinguished names may specify a desired
// distinguished name for a root CA or for a subordinate CA;
// thus, this message can be used to describe both known roots
// and a desired authorization space. If the
// certificate_authorities list is empty then the client MAY
// send any certificate of the appropriate
// ClientCertificateType, unless there is some external
// arrangement to the contrary.
finishedHash.Write(certReq.marshal())
// For now, we only know how to sign challenges with RSA
rsaAvail := false rsaAvail := false
for _, certType := range certReq.certificateTypes { for _, certType := range certReq.certificateTypes {
if certType == certTypeRSASign { if certType == certTypeRSASign {
...@@ -174,23 +189,41 @@ func (c *Conn) clientHandshake() error { ...@@ -174,23 +189,41 @@ func (c *Conn) clientHandshake() error {
} }
} }
// For now, only send a certificate back if the server gives us an // We need to search our list of client certs for one
// empty list of certificateAuthorities. // where SignatureAlgorithm is RSA and the Issuer is in
// // certReq.certificateAuthorities
// RFC 4346 on the certificateAuthorities field: findCert:
// A list of the distinguished names of acceptable certificate for i, cert := range c.config.Certificates {
// authorities. These distinguished names may specify a desired if !rsaAvail {
// distinguished name for a root CA or for a subordinate CA; thus, continue
// this message can be used to describe both known roots and a }
// desired authorization space. If the certificate_authorities
// list is empty then the client MAY send any certificate of the
// appropriate ClientCertificateType, unless there is some
// external arrangement to the contrary.
if rsaAvail && len(certReq.certificateAuthorities) == 0 {
transmitCert = true
}
finishedHash.Write(certReq.marshal()) leaf := cert.Leaf
if leaf == nil {
if leaf, err = x509.ParseCertificate(cert.Certificate[0]); err != nil {
c.sendAlert(alertInternalError)
return errors.New("tls: failed to parse client certificate #" + strconv.Itoa(i) + ": " + err.Error())
}
}
if leaf.PublicKeyAlgorithm != x509.RSA {
continue
}
if len(certReq.certificateAuthorities) == 0 {
// they gave us an empty list, so just take the
// first RSA cert from c.config.Certificates
certToSend = &cert
break
}
for _, ca := range certReq.certificateAuthorities {
if bytes.Equal(leaf.RawIssuer, ca) {
certToSend = &cert
break findCert
}
}
}
msg, err = c.readHandshake() msg, err = c.readHandshake()
if err != nil { if err != nil {
...@@ -204,17 +237,9 @@ func (c *Conn) clientHandshake() error { ...@@ -204,17 +237,9 @@ func (c *Conn) clientHandshake() error {
} }
finishedHash.Write(shd.marshal()) finishedHash.Write(shd.marshal())
var cert *x509.Certificate if certToSend != nil {
if transmitCert {
certMsg = new(certificateMsg) certMsg = new(certificateMsg)
if len(c.config.Certificates) > 0 { certMsg.certificates = certToSend.Certificate
cert, err = x509.ParseCertificate(c.config.Certificates[0].Certificate[0])
if err == nil && cert.PublicKeyAlgorithm == x509.RSA {
certMsg.certificates = c.config.Certificates[0].Certificate
} else {
cert = nil
}
}
finishedHash.Write(certMsg.marshal()) finishedHash.Write(certMsg.marshal())
c.writeRecord(recordTypeHandshake, certMsg.marshal()) c.writeRecord(recordTypeHandshake, certMsg.marshal())
} }
...@@ -229,7 +254,7 @@ func (c *Conn) clientHandshake() error { ...@@ -229,7 +254,7 @@ func (c *Conn) clientHandshake() error {
c.writeRecord(recordTypeHandshake, ckx.marshal()) c.writeRecord(recordTypeHandshake, ckx.marshal())
} }
if cert != nil { if certToSend != nil {
certVerify := new(certificateVerifyMsg) certVerify := new(certificateVerifyMsg)
digest := make([]byte, 0, 36) digest := make([]byte, 0, 36)
digest = finishedHash.serverMD5.Sum(digest) digest = finishedHash.serverMD5.Sum(digest)
......
...@@ -881,9 +881,11 @@ func (m *certificateRequestMsg) marshal() (x []byte) { ...@@ -881,9 +881,11 @@ func (m *certificateRequestMsg) marshal() (x []byte) {
// See http://tools.ietf.org/html/rfc4346#section-7.4.4 // See http://tools.ietf.org/html/rfc4346#section-7.4.4
length := 1 + len(m.certificateTypes) + 2 length := 1 + len(m.certificateTypes) + 2
casLength := 0
for _, ca := range m.certificateAuthorities { for _, ca := range m.certificateAuthorities {
length += 2 + len(ca) casLength += 2 + len(ca)
} }
length += casLength
x = make([]byte, 4+length) x = make([]byte, 4+length)
x[0] = typeCertificateRequest x[0] = typeCertificateRequest
...@@ -895,10 +897,8 @@ func (m *certificateRequestMsg) marshal() (x []byte) { ...@@ -895,10 +897,8 @@ func (m *certificateRequestMsg) marshal() (x []byte) {
copy(x[5:], m.certificateTypes) copy(x[5:], m.certificateTypes)
y := x[5+len(m.certificateTypes):] y := x[5+len(m.certificateTypes):]
y[0] = uint8(casLength >> 8)
numCA := len(m.certificateAuthorities) y[1] = uint8(casLength)
y[0] = uint8(numCA >> 8)
y[1] = uint8(numCA)
y = y[2:] y = y[2:]
for _, ca := range m.certificateAuthorities { for _, ca := range m.certificateAuthorities {
y[0] = uint8(len(ca) >> 8) y[0] = uint8(len(ca) >> 8)
...@@ -909,7 +909,6 @@ func (m *certificateRequestMsg) marshal() (x []byte) { ...@@ -909,7 +909,6 @@ func (m *certificateRequestMsg) marshal() (x []byte) {
} }
m.raw = x m.raw = x
return return
} }
...@@ -937,31 +936,34 @@ func (m *certificateRequestMsg) unmarshal(data []byte) bool { ...@@ -937,31 +936,34 @@ func (m *certificateRequestMsg) unmarshal(data []byte) bool {
} }
data = data[numCertTypes:] data = data[numCertTypes:]
if len(data) < 2 { if len(data) < 2 {
return false return false
} }
casLength := uint16(data[0])<<8 | uint16(data[1])
numCAs := uint16(data[0])<<16 | uint16(data[1])
data = data[2:] data = data[2:]
if len(data) < int(casLength) {
return false
}
cas := make([]byte, casLength)
copy(cas, data)
data = data[casLength:]
m.certificateAuthorities = make([][]byte, numCAs) m.certificateAuthorities = nil
for i := uint16(0); i < numCAs; i++ { for len(cas) > 0 {
if len(data) < 2 { if len(cas) < 2 {
return false return false
} }
caLen := uint16(data[0])<<16 | uint16(data[1]) caLen := uint16(cas[0])<<8 | uint16(cas[1])
cas = cas[2:]
data = data[2:] if len(cas) < int(caLen) {
if len(data) < int(caLen) {
return false return false
} }
ca := make([]byte, caLen) m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
copy(ca, data) cas = cas[caLen:]
m.certificateAuthorities[i] = ca
data = data[caLen:]
} }
if len(data) > 0 { if len(data) > 0 {
return false return false
} }
......
...@@ -150,14 +150,19 @@ FindCipherSuite: ...@@ -150,14 +150,19 @@ FindCipherSuite:
c.writeRecord(recordTypeHandshake, skx.marshal()) c.writeRecord(recordTypeHandshake, skx.marshal())
} }
if config.AuthenticateClient { if config.ClientAuth >= RequestClientCert {
// Request a client certificate // Request a client certificate
certReq := new(certificateRequestMsg) certReq := new(certificateRequestMsg)
certReq.certificateTypes = []byte{certTypeRSASign} certReq.certificateTypes = []byte{certTypeRSASign}
// An empty list of certificateAuthorities signals to // An empty list of certificateAuthorities signals to
// the client that it may send any certificate in response // the client that it may send any certificate in response
// to our request. // to our request. When we know the CAs we trust, then
// we can send them down, so that the client can choose
// an appropriate certificate to give to us.
if config.ClientCAs != nil {
certReq.certificateAuthorities = config.ClientCAs.Subjects()
}
finishedHash.Write(certReq.marshal()) finishedHash.Write(certReq.marshal())
c.writeRecord(recordTypeHandshake, certReq.marshal()) c.writeRecord(recordTypeHandshake, certReq.marshal())
} }
...@@ -166,52 +171,87 @@ FindCipherSuite: ...@@ -166,52 +171,87 @@ FindCipherSuite:
finishedHash.Write(helloDone.marshal()) finishedHash.Write(helloDone.marshal())
c.writeRecord(recordTypeHandshake, helloDone.marshal()) c.writeRecord(recordTypeHandshake, helloDone.marshal())
var pub *rsa.PublicKey var pub *rsa.PublicKey // public key for client auth, if any
if config.AuthenticateClient {
// Get client certificate msg, err = c.readHandshake()
msg, err = c.readHandshake() if err != nil {
if err != nil { return err
return err }
}
certMsg, ok = msg.(*certificateMsg) // If we requested a client certificate, then the client must send a
if !ok { // certificate message, even if it's empty.
return c.sendAlert(alertUnexpectedMessage) if config.ClientAuth >= RequestClientCert {
if certMsg, ok = msg.(*certificateMsg); !ok {
return c.sendAlert(alertHandshakeFailure)
} }
finishedHash.Write(certMsg.marshal()) finishedHash.Write(certMsg.marshal())
if len(certMsg.certificates) == 0 {
// The client didn't actually send a certificate
switch config.ClientAuth {
case RequireAnyClientCert, RequireAndVerifyClientCert:
c.sendAlert(alertBadCertificate)
return errors.New("tls: client didn't provide a certificate")
}
}
certs := make([]*x509.Certificate, len(certMsg.certificates)) certs := make([]*x509.Certificate, len(certMsg.certificates))
for i, asn1Data := range certMsg.certificates { for i, asn1Data := range certMsg.certificates {
cert, err := x509.ParseCertificate(asn1Data) if certs[i], err = x509.ParseCertificate(asn1Data); err != nil {
if err != nil {
c.sendAlert(alertBadCertificate) c.sendAlert(alertBadCertificate)
return errors.New("could not parse client's certificate: " + err.Error()) return errors.New("tls: failed to parse client certificate: " + err.Error())
} }
certs[i] = cert
} }
// TODO(agl): do better validation of certs: max path length, name restrictions etc. if c.config.ClientAuth >= VerifyClientCertIfGiven && len(certs) > 0 {
for i := 1; i < len(certs); i++ { opts := x509.VerifyOptions{
if err := certs[i-1].CheckSignatureFrom(certs[i]); err != nil { Roots: c.config.ClientCAs,
CurrentTime: c.config.time(),
Intermediates: x509.NewCertPool(),
}
for i, cert := range certs {
if i == 0 {
continue
}
opts.Intermediates.AddCert(cert)
}
chains, err := certs[0].Verify(opts)
if err != nil {
c.sendAlert(alertBadCertificate) c.sendAlert(alertBadCertificate)
return errors.New("could not validate certificate signature: " + err.Error()) return errors.New("tls: failed to verify client's certificate: " + err.Error())
} }
ok := false
for _, ku := range certs[0].ExtKeyUsage {
if ku == x509.ExtKeyUsageClientAuth {
ok = true
break
}
}
if !ok {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: client's certificate's extended key usage doesn't permit it to be used for client authentication")
}
c.verifiedChains = chains
} }
if len(certs) > 0 { if len(certs) > 0 {
key, ok := certs[0].PublicKey.(*rsa.PublicKey) if pub, ok = certs[0].PublicKey.(*rsa.PublicKey); !ok {
if !ok {
return c.sendAlert(alertUnsupportedCertificate) return c.sendAlert(alertUnsupportedCertificate)
} }
pub = key
c.peerCertificates = certs c.peerCertificates = certs
} }
msg, err = c.readHandshake()
if err != nil {
return err
}
} }
// Get client key exchange // Get client key exchange
msg, err = c.readHandshake()
if err != nil {
return err
}
ckx, ok := msg.(*clientKeyExchangeMsg) ckx, ok := msg.(*clientKeyExchangeMsg)
if !ok { if !ok {
return c.sendAlert(alertUnexpectedMessage) return c.sendAlert(alertUnexpectedMessage)
......
...@@ -120,7 +120,7 @@ func Dial(network, addr string, config *Config) (*Conn, error) { ...@@ -120,7 +120,7 @@ func Dial(network, addr string, config *Config) (*Conn, error) {
// LoadX509KeyPair reads and parses a public/private key pair from a pair of // LoadX509KeyPair reads and parses a public/private key pair from a pair of
// files. The files must contain PEM encoded data. // files. The files must contain PEM encoded data.
func LoadX509KeyPair(certFile string, keyFile string) (cert Certificate, err error) { func LoadX509KeyPair(certFile, keyFile string) (cert Certificate, err error) {
certPEMBlock, err := ioutil.ReadFile(certFile) certPEMBlock, err := ioutil.ReadFile(certFile)
if err != nil { if err != nil {
return return
......
...@@ -101,3 +101,13 @@ func (s *CertPool) AppendCertsFromPEM(pemCerts []byte) (ok bool) { ...@@ -101,3 +101,13 @@ func (s *CertPool) AppendCertsFromPEM(pemCerts []byte) (ok bool) {
return return
} }
// Subjects returns a list of the DER-encoded subjects of
// all of the certificates in the pool.
func (s *CertPool) Subjects() (res [][]byte) {
res = make([][]byte, len(s.certs))
for i, c := range s.certs {
res[i] = c.RawSubject
}
return
}
...@@ -7,14 +7,14 @@ package gosym ...@@ -7,14 +7,14 @@ package gosym
import ( import (
"debug/elf" "debug/elf"
"os" "os"
"syscall" "runtime"
"testing" "testing"
) )
func dotest() bool { func dotest() bool {
// For now, only works on ELF platforms. // For now, only works on ELF platforms.
// TODO: convert to work with new go tool // TODO: convert to work with new go tool
return false && syscall.OS == "linux" && os.Getenv("GOARCH") == "amd64" return false && runtime.GOOS == "linux" && runtime.GOARCH == "amd64"
} }
func getTable(t *testing.T) *Table { func getTable(t *testing.T) *Table {
......
...@@ -786,7 +786,8 @@ func setDefaultValue(v reflect.Value, params fieldParameters) (ok bool) { ...@@ -786,7 +786,8 @@ func setDefaultValue(v reflect.Value, params fieldParameters) (ok bool) {
// Because Unmarshal uses the reflect package, the structs // Because Unmarshal uses the reflect package, the structs
// being written to must use upper case field names. // being written to must use upper case field names.
// //
// An ASN.1 INTEGER can be written to an int, int32 or int64. // An ASN.1 INTEGER can be written to an int, int32, int64,
// or *big.Int (from the math/big package).
// If the encoded value does not fit in the Go type, // If the encoded value does not fit in the Go type,
// Unmarshal returns a parse error. // Unmarshal returns a parse error.
// //
......
...@@ -6,6 +6,7 @@ package asn1 ...@@ -6,6 +6,7 @@ package asn1
import ( import (
"bytes" "bytes"
"math/big"
"reflect" "reflect"
"testing" "testing"
"time" "time"
...@@ -351,6 +352,10 @@ type TestElementsAfterString struct { ...@@ -351,6 +352,10 @@ type TestElementsAfterString struct {
A, B int A, B int
} }
type TestBigInt struct {
X *big.Int
}
var unmarshalTestData = []struct { var unmarshalTestData = []struct {
in []byte in []byte
out interface{} out interface{}
...@@ -369,6 +374,7 @@ var unmarshalTestData = []struct { ...@@ -369,6 +374,7 @@ var unmarshalTestData = []struct {
{[]byte{0x01, 0x01, 0x00}, newBool(false)}, {[]byte{0x01, 0x01, 0x00}, newBool(false)},
{[]byte{0x01, 0x01, 0x01}, newBool(true)}, {[]byte{0x01, 0x01, 0x01}, newBool(true)},
{[]byte{0x30, 0x0b, 0x13, 0x03, 0x66, 0x6f, 0x6f, 0x02, 0x01, 0x22, 0x02, 0x01, 0x33}, &TestElementsAfterString{"foo", 0x22, 0x33}}, {[]byte{0x30, 0x0b, 0x13, 0x03, 0x66, 0x6f, 0x6f, 0x02, 0x01, 0x22, 0x02, 0x01, 0x33}, &TestElementsAfterString{"foo", 0x22, 0x33}},
{[]byte{0x30, 0x05, 0x02, 0x03, 0x12, 0x34, 0x56}, &TestBigInt{big.NewInt(0x123456)}},
} }
func TestUnmarshal(t *testing.T) { func TestUnmarshal(t *testing.T) {
......
...@@ -7,6 +7,7 @@ package asn1 ...@@ -7,6 +7,7 @@ package asn1
import ( import (
"bytes" "bytes"
"encoding/hex" "encoding/hex"
"math/big"
"testing" "testing"
"time" "time"
) )
...@@ -20,6 +21,10 @@ type twoIntStruct struct { ...@@ -20,6 +21,10 @@ type twoIntStruct struct {
B int B int
} }
type bigIntStruct struct {
A *big.Int
}
type nestedStruct struct { type nestedStruct struct {
A intStruct A intStruct
} }
...@@ -65,6 +70,7 @@ var marshalTests = []marshalTest{ ...@@ -65,6 +70,7 @@ var marshalTests = []marshalTest{
{-128, "020180"}, {-128, "020180"},
{-129, "0202ff7f"}, {-129, "0202ff7f"},
{intStruct{64}, "3003020140"}, {intStruct{64}, "3003020140"},
{bigIntStruct{big.NewInt(0x123456)}, "30050203123456"},
{twoIntStruct{64, 65}, "3006020140020141"}, {twoIntStruct{64, 65}, "3006020140020141"},
{nestedStruct{intStruct{127}}, "3005300302017f"}, {nestedStruct{intStruct{127}}, "3005300302017f"},
{[]byte{1, 2, 3}, "0403010203"}, {[]byte{1, 2, 3}, "0403010203"},
......
...@@ -1039,9 +1039,9 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[re ...@@ -1039,9 +1039,9 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[re
// Extract and compare element types. // Extract and compare element types.
var sw *sliceType var sw *sliceType
if tt, ok := builtinIdToType[fw]; ok { if tt, ok := builtinIdToType[fw]; ok {
sw = tt.(*sliceType) sw, _ = tt.(*sliceType)
} else { } else if wire != nil {
sw = dec.wireType[fw].SliceT sw = wire.SliceT
} }
elem := userType(t.Elem()).base elem := userType(t.Elem()).base
return sw != nil && dec.compatibleType(elem, sw.Elem, inProgress) return sw != nil && dec.compatibleType(elem, sw.Elem, inProgress)
......
...@@ -678,3 +678,11 @@ func TestUnexportedChan(t *testing.T) { ...@@ -678,3 +678,11 @@ func TestUnexportedChan(t *testing.T) {
t.Fatalf("error encoding unexported channel: %s", err) t.Fatalf("error encoding unexported channel: %s", err)
} }
} }
func TestSliceIncompatibility(t *testing.T) {
var in = []byte{1, 2, 3}
var out []int
if err := encAndDec(in, &out); err == nil {
t.Error("expected compatibility error")
}
}
...@@ -10,6 +10,7 @@ package json ...@@ -10,6 +10,7 @@ package json
import ( import (
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt"
"reflect" "reflect"
"runtime" "runtime"
"strconv" "strconv"
...@@ -538,7 +539,7 @@ func (d *decodeState) object(v reflect.Value) { ...@@ -538,7 +539,7 @@ func (d *decodeState) object(v reflect.Value) {
// Read value. // Read value.
if destring { if destring {
d.value(reflect.ValueOf(&d.tempstr)) d.value(reflect.ValueOf(&d.tempstr))
d.literalStore([]byte(d.tempstr), subv) d.literalStore([]byte(d.tempstr), subv, true)
} else { } else {
d.value(subv) d.value(subv)
} }
...@@ -571,11 +572,15 @@ func (d *decodeState) literal(v reflect.Value) { ...@@ -571,11 +572,15 @@ func (d *decodeState) literal(v reflect.Value) {
d.off-- d.off--
d.scan.undo(op) d.scan.undo(op)
d.literalStore(d.data[start:d.off], v) d.literalStore(d.data[start:d.off], v, false)
} }
// literalStore decodes a literal stored in item into v. // literalStore decodes a literal stored in item into v.
func (d *decodeState) literalStore(item []byte, v reflect.Value) { //
// fromQuoted indicates whether this literal came from unwrapping a
// string from the ",string" struct tag option. this is used only to
// produce more helpful error messages.
func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool) {
// Check for unmarshaler. // Check for unmarshaler.
wantptr := item[0] == 'n' // null wantptr := item[0] == 'n' // null
unmarshaler, pv := d.indirect(v, wantptr) unmarshaler, pv := d.indirect(v, wantptr)
...@@ -601,7 +606,11 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value) { ...@@ -601,7 +606,11 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value) {
value := c == 't' value := c == 't'
switch v.Kind() { switch v.Kind() {
default: default:
d.saveError(&UnmarshalTypeError{"bool", v.Type()}) if fromQuoted {
d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
} else {
d.saveError(&UnmarshalTypeError{"bool", v.Type()})
}
case reflect.Bool: case reflect.Bool:
v.SetBool(value) v.SetBool(value)
case reflect.Interface: case reflect.Interface:
...@@ -611,7 +620,11 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value) { ...@@ -611,7 +620,11 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value) {
case '"': // string case '"': // string
s, ok := unquoteBytes(item) s, ok := unquoteBytes(item)
if !ok { if !ok {
d.error(errPhase) if fromQuoted {
d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
} else {
d.error(errPhase)
}
} }
switch v.Kind() { switch v.Kind() {
default: default:
...@@ -636,12 +649,20 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value) { ...@@ -636,12 +649,20 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value) {
default: // number default: // number
if c != '-' && (c < '0' || c > '9') { if c != '-' && (c < '0' || c > '9') {
d.error(errPhase) if fromQuoted {
d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
} else {
d.error(errPhase)
}
} }
s := string(item) s := string(item)
switch v.Kind() { switch v.Kind() {
default: default:
d.error(&UnmarshalTypeError{"number", v.Type()}) if fromQuoted {
d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
} else {
d.error(&UnmarshalTypeError{"number", v.Type()})
}
case reflect.Interface: case reflect.Interface:
n, err := strconv.ParseFloat(s, 64) n, err := strconv.ParseFloat(s, 64)
if err != nil { if err != nil {
......
...@@ -258,13 +258,10 @@ type wrongStringTest struct { ...@@ -258,13 +258,10 @@ type wrongStringTest struct {
in, err string in, err string
} }
// TODO(bradfitz): as part of Issue 2331, fix these tests' expected
// error values to be helpful, rather than the confusing messages they
// are now.
var wrongStringTests = []wrongStringTest{ var wrongStringTests = []wrongStringTest{
{`{"result":"x"}`, "JSON decoder out of sync - data changing underfoot?"}, {`{"result":"x"}`, `json: invalid use of ,string struct tag, trying to unmarshal "x" into string`},
{`{"result":"foo"}`, "json: cannot unmarshal bool into Go value of type string"}, {`{"result":"foo"}`, `json: invalid use of ,string struct tag, trying to unmarshal "foo" into string`},
{`{"result":"123"}`, "json: cannot unmarshal number into Go value of type string"}, {`{"result":"123"}`, `json: invalid use of ,string struct tag, trying to unmarshal "123" into string`},
} }
// If people misuse the ,string modifier, the error message should be // If people misuse the ,string modifier, the error message should be
......
...@@ -12,6 +12,7 @@ package json ...@@ -12,6 +12,7 @@ package json
import ( import (
"bytes" "bytes"
"encoding/base64" "encoding/base64"
"math"
"reflect" "reflect"
"runtime" "runtime"
"sort" "sort"
...@@ -170,6 +171,15 @@ func (e *UnsupportedTypeError) Error() string { ...@@ -170,6 +171,15 @@ func (e *UnsupportedTypeError) Error() string {
return "json: unsupported type: " + e.Type.String() return "json: unsupported type: " + e.Type.String()
} }
type UnsupportedValueError struct {
Value reflect.Value
Str string
}
func (e *UnsupportedValueError) Error() string {
return "json: unsupported value: " + e.Str
}
type InvalidUTF8Error struct { type InvalidUTF8Error struct {
S string S string
} }
...@@ -290,7 +300,11 @@ func (e *encodeState) reflectValueQuoted(v reflect.Value, quoted bool) { ...@@ -290,7 +300,11 @@ func (e *encodeState) reflectValueQuoted(v reflect.Value, quoted bool) {
e.Write(b) e.Write(b)
} }
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
b := strconv.AppendFloat(e.scratch[:0], v.Float(), 'g', -1, v.Type().Bits()) f := v.Float()
if math.IsInf(f, 0) || math.IsNaN(f) {
e.error(&UnsupportedValueError{v, strconv.FormatFloat(f, 'g', -1, v.Type().Bits())})
}
b := strconv.AppendFloat(e.scratch[:0], f, 'g', -1, v.Type().Bits())
if quoted { if quoted {
writeString(e, string(b)) writeString(e, string(b))
} else { } else {
......
...@@ -6,6 +6,7 @@ package json ...@@ -6,6 +6,7 @@ package json
import ( import (
"bytes" "bytes"
"math"
"reflect" "reflect"
"testing" "testing"
) )
...@@ -107,3 +108,21 @@ func TestEncodeRenamedByteSlice(t *testing.T) { ...@@ -107,3 +108,21 @@ func TestEncodeRenamedByteSlice(t *testing.T) {
t.Errorf(" got %s want %s", result, expect) t.Errorf(" got %s want %s", result, expect)
} }
} }
var unsupportedValues = []interface{}{
math.NaN(),
math.Inf(-1),
math.Inf(1),
}
func TestUnsupportedValues(t *testing.T) {
for _, v := range unsupportedValues {
if _, err := Marshal(v); err != nil {
if _, ok := err.(*UnsupportedValueError); !ok {
t.Errorf("for %v, got %T want UnsupportedValueError", v, err)
}
} else {
t.Errorf("for %v, expected error", v)
}
}
}
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
package xml package xml
var atomValue = &Feed{ var atomValue = &Feed{
XMLName: Name{"http://www.w3.org/2005/Atom", "feed"},
Title: "Example Feed", Title: "Example Feed",
Link: []Link{{Href: "http://example.org/"}}, Link: []Link{{Href: "http://example.org/"}},
Updated: ParseTime("2003-12-13T18:30:02Z"), Updated: ParseTime("2003-12-13T18:30:02Z"),
...@@ -24,19 +25,19 @@ var atomValue = &Feed{ ...@@ -24,19 +25,19 @@ var atomValue = &Feed{
var atomXml = `` + var atomXml = `` +
`<feed xmlns="http://www.w3.org/2005/Atom">` + `<feed xmlns="http://www.w3.org/2005/Atom">` +
`<Title>Example Feed</Title>` + `<title>Example Feed</title>` +
`<Id>urn:uuid:60a76c80-d399-11d9-b93C-0003939e0af6</Id>` + `<id>urn:uuid:60a76c80-d399-11d9-b93C-0003939e0af6</id>` +
`<Link href="http://example.org/"></Link>` + `<link href="http://example.org/"></link>` +
`<Updated>2003-12-13T18:30:02Z</Updated>` + `<updated>2003-12-13T18:30:02Z</updated>` +
`<Author><Name>John Doe</Name><URI></URI><Email></Email></Author>` + `<author><name>John Doe</name><uri></uri><email></email></author>` +
`<Entry>` + `<entry>` +
`<Title>Atom-Powered Robots Run Amok</Title>` + `<title>Atom-Powered Robots Run Amok</title>` +
`<Id>urn:uuid:1225c695-cfb8-4ebb-aaaa-80da344efa6a</Id>` + `<id>urn:uuid:1225c695-cfb8-4ebb-aaaa-80da344efa6a</id>` +
`<Link href="http://example.org/2003/12/13/atom03"></Link>` + `<link href="http://example.org/2003/12/13/atom03"></link>` +
`<Updated>2003-12-13T18:30:02Z</Updated>` + `<updated>2003-12-13T18:30:02Z</updated>` +
`<Author><Name></Name><URI></URI><Email></Email></Author>` + `<author><name></name><uri></uri><email></email></author>` +
`<Summary>Some text.</Summary>` + `<summary>Some text.</summary>` +
`</Entry>` + `</entry>` +
`</feed>` `</feed>`
func ParseTime(str string) Time { func ParseTime(str string) Time {
......
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xml
import "testing"
type C struct {
Name string
Open bool
}
type A struct {
XMLName Name `xml:"http://domain a"`
C
B B
FieldA string
}
type B struct {
XMLName Name `xml:"b"`
C
FieldB string
}
const _1a = `
<?xml version="1.0" encoding="UTF-8"?>
<a xmlns="http://domain">
<name>KmlFile</name>
<open>1</open>
<b>
<name>Absolute</name>
<open>0</open>
<fieldb>bar</fieldb>
</b>
<fielda>foo</fielda>
</a>
`
// Tests that embedded structs are marshalled.
func TestEmbedded1(t *testing.T) {
var a A
if e := Unmarshal(StringReader(_1a), &a); e != nil {
t.Fatalf("Unmarshal: %s", e)
}
if a.FieldA != "foo" {
t.Fatalf("Unmarshal: expected 'foo' but found '%s'", a.FieldA)
}
if a.Name != "KmlFile" {
t.Fatalf("Unmarshal: expected 'KmlFile' but found '%s'", a.Name)
}
if !a.Open {
t.Fatal("Unmarshal: expected 'true' but found otherwise")
}
if a.B.FieldB != "bar" {
t.Fatalf("Unmarshal: expected 'bar' but found '%s'", a.B.FieldB)
}
if a.B.Name != "Absolute" {
t.Fatalf("Unmarshal: expected 'Absolute' but found '%s'", a.B.Name)
}
if a.B.Open {
t.Fatal("Unmarshal: expected 'false' but found otherwise")
}
}
type A2 struct {
XMLName Name `xml:"http://domain a"`
XY string
Xy string
}
const _2a = `
<?xml version="1.0" encoding="UTF-8"?>
<a xmlns="http://domain">
<xy>foo</xy>
</a>
`
// Tests that conflicting field names get excluded.
func TestEmbedded2(t *testing.T) {
var a A2
if e := Unmarshal(StringReader(_2a), &a); e != nil {
t.Fatalf("Unmarshal: %s", e)
}
if a.XY != "" {
t.Fatalf("Unmarshal: expected empty string but found '%s'", a.XY)
}
if a.Xy != "" {
t.Fatalf("Unmarshal: expected empty string but found '%s'", a.Xy)
}
}
type A3 struct {
XMLName Name `xml:"http://domain a"`
xy string
}
// Tests that private fields are not set.
func TestEmbedded3(t *testing.T) {
var a A3
if e := Unmarshal(StringReader(_2a), &a); e != nil {
t.Fatalf("Unmarshal: %s", e)
}
if a.xy != "" {
t.Fatalf("Unmarshal: expected empty string but found '%s'", a.xy)
}
}
type A4 struct {
XMLName Name `xml:"http://domain a"`
Any string
}
// Tests that private fields are not set.
func TestEmbedded4(t *testing.T) {
var a A4
if e := Unmarshal(StringReader(_2a), &a); e != nil {
t.Fatalf("Unmarshal: %s", e)
}
if a.Any != "foo" {
t.Fatalf("Unmarshal: expected 'foo' but found '%s'", a.Any)
}
}
...@@ -6,6 +6,8 @@ package xml ...@@ -6,6 +6,8 @@ package xml
import ( import (
"bufio" "bufio"
"bytes"
"fmt"
"io" "io"
"reflect" "reflect"
"strconv" "strconv"
...@@ -42,20 +44,26 @@ type printer struct { ...@@ -42,20 +44,26 @@ type printer struct {
// elements containing the data. // elements containing the data.
// //
// The name for the XML elements is taken from, in order of preference: // The name for the XML elements is taken from, in order of preference:
// - the tag on an XMLName field, if the data is a struct // - the tag on the XMLName field, if the data is a struct
// - the value of an XMLName field of type xml.Name // - the value of the XMLName field of type xml.Name
// - the tag of the struct field used to obtain the data // - the tag of the struct field used to obtain the data
// - the name of the struct field used to obtain the data // - the name of the struct field used to obtain the data
// - the name '???'. // - the name of the marshalled type
// //
// The XML element for a struct contains marshalled elements for each of the // The XML element for a struct contains marshalled elements for each of the
// exported fields of the struct, with these exceptions: // exported fields of the struct, with these exceptions:
// - the XMLName field, described above, is omitted. // - the XMLName field, described above, is omitted.
// - a field with tag "attr" becomes an attribute in the XML element. // - a field with tag "name,attr" becomes an attribute with
// - a field with tag "chardata" is written as character data, // the given name in the XML element.
// not as an XML element. // - a field with tag ",attr" becomes an attribute with the
// - a field with tag "innerxml" is written verbatim, // field name in the in the XML element.
// not subject to the usual marshalling procedure. // - a field with tag ",chardata" is written as character data,
// not as an XML element.
// - a field with tag ",innerxml" is written verbatim, not subject
// to the usual marshalling procedure.
// - a field with tag ",comment" is written as an XML comment, not
// subject to the usual marshalling procedure. It must not contain
// the "--" string within it.
// //
// If a field uses a tag "a>b>c", then the element c will be nested inside // If a field uses a tag "a>b>c", then the element c will be nested inside
// parent elements a and b. Fields that appear next to each other that name // parent elements a and b. Fields that appear next to each other that name
...@@ -63,17 +71,18 @@ type printer struct { ...@@ -63,17 +71,18 @@ type printer struct {
// //
// type Result struct { // type Result struct {
// XMLName xml.Name `xml:"result"` // XMLName xml.Name `xml:"result"`
// Id int `xml:"id,attr"`
// FirstName string `xml:"person>name>first"` // FirstName string `xml:"person>name>first"`
// LastName string `xml:"person>name>last"` // LastName string `xml:"person>name>last"`
// Age int `xml:"person>age"` // Age int `xml:"person>age"`
// } // }
// //
// xml.Marshal(w, &Result{FirstName: "John", LastName: "Doe", Age: 42}) // xml.Marshal(w, &Result{Id: 13, FirstName: "John", LastName: "Doe", Age: 42})
// //
// would be marshalled as: // would be marshalled as:
// //
// <result> // <result>
// <person> // <person id="13">
// <name> // <name>
// <first>John</first> // <first>John</first>
// <last>Doe</last> // <last>Doe</last>
...@@ -85,12 +94,12 @@ type printer struct { ...@@ -85,12 +94,12 @@ type printer struct {
// Marshal will return an error if asked to marshal a channel, function, or map. // Marshal will return an error if asked to marshal a channel, function, or map.
func Marshal(w io.Writer, v interface{}) (err error) { func Marshal(w io.Writer, v interface{}) (err error) {
p := &printer{bufio.NewWriter(w)} p := &printer{bufio.NewWriter(w)}
err = p.marshalValue(reflect.ValueOf(v), "???") err = p.marshalValue(reflect.ValueOf(v), nil)
p.Flush() p.Flush()
return err return err
} }
func (p *printer) marshalValue(val reflect.Value, name string) error { func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo) error {
if !val.IsValid() { if !val.IsValid() {
return nil return nil
} }
...@@ -115,58 +124,75 @@ func (p *printer) marshalValue(val reflect.Value, name string) error { ...@@ -115,58 +124,75 @@ func (p *printer) marshalValue(val reflect.Value, name string) error {
if val.IsNil() { if val.IsNil() {
return nil return nil
} }
return p.marshalValue(val.Elem(), name) return p.marshalValue(val.Elem(), finfo)
} }
// Slices and arrays iterate over the elements. They do not have an enclosing tag. // Slices and arrays iterate over the elements. They do not have an enclosing tag.
if (kind == reflect.Slice || kind == reflect.Array) && typ.Elem().Kind() != reflect.Uint8 { if (kind == reflect.Slice || kind == reflect.Array) && typ.Elem().Kind() != reflect.Uint8 {
for i, n := 0, val.Len(); i < n; i++ { for i, n := 0, val.Len(); i < n; i++ {
if err := p.marshalValue(val.Index(i), name); err != nil { if err := p.marshalValue(val.Index(i), finfo); err != nil {
return err return err
} }
} }
return nil return nil
} }
// Find XML name tinfo, err := getTypeInfo(typ)
xmlns := "" if err != nil {
if kind == reflect.Struct { return err
if f, ok := typ.FieldByName("XMLName"); ok { }
if tag := f.Tag.Get("xml"); tag != "" {
if i := strings.Index(tag, " "); i >= 0 { // Precedence for the XML element name is:
xmlns, name = tag[:i], tag[i+1:] // 1. XMLName field in underlying struct;
} else { // 2. field name/tag in the struct field; and
name = tag // 3. type name
} var xmlns, name string
} else if v, ok := val.FieldByIndex(f.Index).Interface().(Name); ok && v.Local != "" { if tinfo.xmlname != nil {
xmlns, name = v.Space, v.Local xmlname := tinfo.xmlname
} if xmlname.name != "" {
xmlns, name = xmlname.xmlns, xmlname.name
} else if v, ok := val.FieldByIndex(xmlname.idx).Interface().(Name); ok && v.Local != "" {
xmlns, name = v.Space, v.Local
}
}
if name == "" && finfo != nil {
xmlns, name = finfo.xmlns, finfo.name
}
if name == "" {
name = typ.Name()
if name == "" {
return &UnsupportedTypeError{typ}
} }
} }
p.WriteByte('<') p.WriteByte('<')
p.WriteString(name) p.WriteString(name)
if xmlns != "" {
p.WriteString(` xmlns="`)
// TODO: EscapeString, to avoid the allocation.
Escape(p, []byte(xmlns))
p.WriteByte('"')
}
// Attributes // Attributes
if kind == reflect.Struct { for i := range tinfo.fields {
if len(xmlns) > 0 { finfo := &tinfo.fields[i]
p.WriteString(` xmlns="`) if finfo.flags&fAttr == 0 {
Escape(p, []byte(xmlns)) continue
p.WriteByte('"')
} }
var str string
for i, n := 0, typ.NumField(); i < n; i++ { if fv := val.FieldByIndex(finfo.idx); fv.Kind() == reflect.String {
if f := typ.Field(i); f.PkgPath == "" && f.Tag.Get("xml") == "attr" { str = fv.String()
if f.Type.Kind() == reflect.String { } else {
if str := val.Field(i).String(); str != "" { str = fmt.Sprint(fv.Interface())
p.WriteByte(' ') }
p.WriteString(strings.ToLower(f.Name)) if str != "" {
p.WriteString(`="`) p.WriteByte(' ')
Escape(p, []byte(str)) p.WriteString(finfo.name)
p.WriteByte('"') p.WriteString(`="`)
} Escape(p, []byte(str))
} p.WriteByte('"')
}
} }
} }
p.WriteByte('>') p.WriteByte('>')
...@@ -194,58 +220,9 @@ func (p *printer) marshalValue(val reflect.Value, name string) error { ...@@ -194,58 +220,9 @@ func (p *printer) marshalValue(val reflect.Value, name string) error {
bytes := val.Interface().([]byte) bytes := val.Interface().([]byte)
Escape(p, bytes) Escape(p, bytes)
case reflect.Struct: case reflect.Struct:
s := parentStack{printer: p} if err := p.marshalStruct(tinfo, val); err != nil {
for i, n := 0, val.NumField(); i < n; i++ { return err
if f := typ.Field(i); f.Name != "XMLName" && f.PkgPath == "" {
name := f.Name
vf := val.Field(i)
switch tag := f.Tag.Get("xml"); tag {
case "":
s.trim(nil)
case "chardata":
if tk := f.Type.Kind(); tk == reflect.String {
Escape(p, []byte(vf.String()))
} else if tk == reflect.Slice {
if elem, ok := vf.Interface().([]byte); ok {
Escape(p, elem)
}
}
continue
case "innerxml":
iface := vf.Interface()
switch raw := iface.(type) {
case []byte:
p.Write(raw)
continue
case string:
p.WriteString(raw)
continue
}
case "attr":
continue
default:
parents := strings.Split(tag, ">")
if len(parents) == 1 {
parents, name = nil, tag
} else {
parents, name = parents[:len(parents)-1], parents[len(parents)-1]
if parents[0] == "" {
parents[0] = f.Name
}
}
s.trim(parents)
if !(vf.Kind() == reflect.Ptr || vf.Kind() == reflect.Interface) || !vf.IsNil() {
s.push(parents[len(s.stack):])
}
}
if err := p.marshalValue(vf, name); err != nil {
return err
}
}
} }
s.trim(nil)
default: default:
return &UnsupportedTypeError{typ} return &UnsupportedTypeError{typ}
} }
...@@ -258,6 +235,94 @@ func (p *printer) marshalValue(val reflect.Value, name string) error { ...@@ -258,6 +235,94 @@ func (p *printer) marshalValue(val reflect.Value, name string) error {
return nil return nil
} }
var ddBytes = []byte("--")
func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
s := parentStack{printer: p}
for i := range tinfo.fields {
finfo := &tinfo.fields[i]
if finfo.flags&(fAttr|fAny) != 0 {
continue
}
vf := val.FieldByIndex(finfo.idx)
switch finfo.flags & fMode {
case fCharData:
switch vf.Kind() {
case reflect.String:
Escape(p, []byte(vf.String()))
case reflect.Slice:
if elem, ok := vf.Interface().([]byte); ok {
Escape(p, elem)
}
}
continue
case fComment:
k := vf.Kind()
if !(k == reflect.String || k == reflect.Slice && vf.Type().Elem().Kind() == reflect.Uint8) {
return fmt.Errorf("xml: bad type for comment field of %s", val.Type())
}
if vf.Len() == 0 {
continue
}
p.WriteString("<!--")
dashDash := false
dashLast := false
switch k {
case reflect.String:
s := vf.String()
dashDash = strings.Index(s, "--") >= 0
dashLast = s[len(s)-1] == '-'
if !dashDash {
p.WriteString(s)
}
case reflect.Slice:
b := vf.Bytes()
dashDash = bytes.Index(b, ddBytes) >= 0
dashLast = b[len(b)-1] == '-'
if !dashDash {
p.Write(b)
}
default:
panic("can't happen")
}
if dashDash {
return fmt.Errorf(`xml: comments must not contain "--"`)
}
if dashLast {
// "--->" is invalid grammar. Make it "- -->"
p.WriteByte(' ')
}
p.WriteString("-->")
continue
case fInnerXml:
iface := vf.Interface()
switch raw := iface.(type) {
case []byte:
p.Write(raw)
continue
case string:
p.WriteString(raw)
continue
}
case fElement:
s.trim(finfo.parents)
if len(finfo.parents) > len(s.stack) {
if vf.Kind() != reflect.Ptr && vf.Kind() != reflect.Interface || !vf.IsNil() {
s.push(finfo.parents[len(s.stack):])
}
}
}
if err := p.marshalValue(vf, finfo); err != nil {
return err
}
}
s.trim(nil)
return nil
}
type parentStack struct { type parentStack struct {
*printer *printer
stack []string stack []string
......
...@@ -6,6 +6,7 @@ package xml ...@@ -6,6 +6,7 @@ package xml
import ( import (
"reflect" "reflect"
"strings"
"testing" "testing"
) )
...@@ -13,7 +14,7 @@ import ( ...@@ -13,7 +14,7 @@ import (
func TestUnmarshalFeed(t *testing.T) { func TestUnmarshalFeed(t *testing.T) {
var f Feed var f Feed
if err := Unmarshal(StringReader(atomFeedString), &f); err != nil { if err := Unmarshal(strings.NewReader(atomFeedString), &f); err != nil {
t.Fatalf("Unmarshal: %s", err) t.Fatalf("Unmarshal: %s", err)
} }
if !reflect.DeepEqual(f, atomFeed) { if !reflect.DeepEqual(f, atomFeed) {
...@@ -24,8 +25,8 @@ func TestUnmarshalFeed(t *testing.T) { ...@@ -24,8 +25,8 @@ func TestUnmarshalFeed(t *testing.T) {
// hget http://codereview.appspot.com/rss/mine/rsc // hget http://codereview.appspot.com/rss/mine/rsc
const atomFeedString = ` const atomFeedString = `
<?xml version="1.0" encoding="utf-8"?> <?xml version="1.0" encoding="utf-8"?>
<feed xmlns="http://www.w3.org/2005/Atom" xml:lang="en-us"><title>Code Review - My issues</title><link href="http://codereview.appspot.com/" rel="alternate"></link><li-nk href="http://codereview.appspot.com/rss/mine/rsc" rel="self"></li-nk><id>http://codereview.appspot.com/</id><updated>2009-10-04T01:35:58+00:00</updated><author><name>rietveld&lt;&gt;</name></author><entry><title>rietveld: an attempt at pubsubhubbub <feed xmlns="http://www.w3.org/2005/Atom" xml:lang="en-us"><title>Code Review - My issues</title><link href="http://codereview.appspot.com/" rel="alternate"></link><link href="http://codereview.appspot.com/rss/mine/rsc" rel="self"></link><id>http://codereview.appspot.com/</id><updated>2009-10-04T01:35:58+00:00</updated><author><name>rietveld&lt;&gt;</name></author><entry><title>rietveld: an attempt at pubsubhubbub
</title><link hre-f="http://codereview.appspot.com/126085" rel="alternate"></link><updated>2009-10-04T01:35:58+00:00</updated><author><name>email-address-removed</name></author><id>urn:md5:134d9179c41f806be79b3a5f7877d19a</id><summary type="html"> </title><link href="http://codereview.appspot.com/126085" rel="alternate"></link><updated>2009-10-04T01:35:58+00:00</updated><author><name>email-address-removed</name></author><id>urn:md5:134d9179c41f806be79b3a5f7877d19a</id><summary type="html">
An attempt at adding pubsubhubbub support to Rietveld. An attempt at adding pubsubhubbub support to Rietveld.
http://code.google.com/p/pubsubhubbub http://code.google.com/p/pubsubhubbub
http://code.google.com/p/rietveld/issues/detail?id=155 http://code.google.com/p/rietveld/issues/detail?id=155
...@@ -78,39 +79,39 @@ not being used from outside intra_region_diff.py. ...@@ -78,39 +79,39 @@ not being used from outside intra_region_diff.py.
</summary></entry></feed> ` </summary></entry></feed> `
type Feed struct { type Feed struct {
XMLName Name `xml:"http://www.w3.org/2005/Atom feed"` XMLName Name `xml:"http://www.w3.org/2005/Atom feed"`
Title string Title string `xml:"title"`
Id string Id string `xml:"id"`
Link []Link Link []Link `xml:"link"`
Updated Time Updated Time `xml:"updated"`
Author Person Author Person `xml:"author"`
Entry []Entry Entry []Entry `xml:"entry"`
} }
type Entry struct { type Entry struct {
Title string Title string `xml:"title"`
Id string Id string `xml:"id"`
Link []Link Link []Link `xml:"link"`
Updated Time Updated Time `xml:"updated"`
Author Person Author Person `xml:"author"`
Summary Text Summary Text `xml:"summary"`
} }
type Link struct { type Link struct {
Rel string `xml:"attr"` Rel string `xml:"rel,attr"`
Href string `xml:"attr"` Href string `xml:"href,attr"`
} }
type Person struct { type Person struct {
Name string Name string `xml:"name"`
URI string URI string `xml:"uri"`
Email string Email string `xml:"email"`
InnerXML string `xml:"innerxml"` InnerXML string `xml:",innerxml"`
} }
type Text struct { type Text struct {
Type string `xml:"attr"` Type string `xml:"type,attr"`
Body string `xml:"chardata"` Body string `xml:",chardata"`
} }
type Time string type Time string
...@@ -213,44 +214,26 @@ not being used from outside intra_region_diff.py. ...@@ -213,44 +214,26 @@ not being used from outside intra_region_diff.py.
}, },
} }
type FieldNameTest struct {
in, out string
}
var FieldNameTests = []FieldNameTest{
{"Profile-Image", "profileimage"},
{"_score", "score"},
}
func TestFieldName(t *testing.T) {
for _, tt := range FieldNameTests {
a := fieldName(tt.in)
if a != tt.out {
t.Fatalf("have %#v\nwant %#v\n\n", a, tt.out)
}
}
}
const pathTestString = ` const pathTestString = `
<result> <Result>
<before>1</before> <Before>1</Before>
<items> <Items>
<item1> <Item1>
<value>A</value> <Value>A</Value>
</item1> </Item1>
<item2> <Item2>
<value>B</value> <Value>B</Value>
</item2> </Item2>
<Item1> <Item1>
<Value>C</Value> <Value>C</Value>
<Value>D</Value> <Value>D</Value>
</Item1> </Item1>
<_> <_>
<value>E</value> <Value>E</Value>
</_> </_>
</items> </Items>
<after>2</after> <After>2</After>
</result> </Result>
` `
type PathTestItem struct { type PathTestItem struct {
...@@ -258,18 +241,18 @@ type PathTestItem struct { ...@@ -258,18 +241,18 @@ type PathTestItem struct {
} }
type PathTestA struct { type PathTestA struct {
Items []PathTestItem `xml:">item1"` Items []PathTestItem `xml:">Item1"`
Before, After string Before, After string
} }
type PathTestB struct { type PathTestB struct {
Other []PathTestItem `xml:"items>Item1"` Other []PathTestItem `xml:"Items>Item1"`
Before, After string Before, After string
} }
type PathTestC struct { type PathTestC struct {
Values1 []string `xml:"items>item1>value"` Values1 []string `xml:"Items>Item1>Value"`
Values2 []string `xml:"items>item2>value"` Values2 []string `xml:"Items>Item2>Value"`
Before, After string Before, After string
} }
...@@ -278,12 +261,12 @@ type PathTestSet struct { ...@@ -278,12 +261,12 @@ type PathTestSet struct {
} }
type PathTestD struct { type PathTestD struct {
Other PathTestSet `xml:"items>"` Other PathTestSet `xml:"Items"`
Before, After string Before, After string
} }
type PathTestE struct { type PathTestE struct {
Underline string `xml:"items>_>value"` Underline string `xml:"Items>_>Value"`
Before, After string Before, After string
} }
...@@ -298,7 +281,7 @@ var pathTests = []interface{}{ ...@@ -298,7 +281,7 @@ var pathTests = []interface{}{
func TestUnmarshalPaths(t *testing.T) { func TestUnmarshalPaths(t *testing.T) {
for _, pt := range pathTests { for _, pt := range pathTests {
v := reflect.New(reflect.TypeOf(pt).Elem()).Interface() v := reflect.New(reflect.TypeOf(pt).Elem()).Interface()
if err := Unmarshal(StringReader(pathTestString), v); err != nil { if err := Unmarshal(strings.NewReader(pathTestString), v); err != nil {
t.Fatalf("Unmarshal: %s", err) t.Fatalf("Unmarshal: %s", err)
} }
if !reflect.DeepEqual(v, pt) { if !reflect.DeepEqual(v, pt) {
...@@ -310,7 +293,7 @@ func TestUnmarshalPaths(t *testing.T) { ...@@ -310,7 +293,7 @@ func TestUnmarshalPaths(t *testing.T) {
type BadPathTestA struct { type BadPathTestA struct {
First string `xml:"items>item1"` First string `xml:"items>item1"`
Other string `xml:"items>item2"` Other string `xml:"items>item2"`
Second string `xml:"items>"` Second string `xml:"items"`
} }
type BadPathTestB struct { type BadPathTestB struct {
...@@ -319,81 +302,55 @@ type BadPathTestB struct { ...@@ -319,81 +302,55 @@ type BadPathTestB struct {
Second string `xml:"items>item1>value"` Second string `xml:"items>item1>value"`
} }
type BadPathTestC struct {
First string
Second string `xml:"First"`
}
type BadPathTestD struct {
BadPathEmbeddedA
BadPathEmbeddedB
}
type BadPathEmbeddedA struct {
First string
}
type BadPathEmbeddedB struct {
Second string `xml:"First"`
}
var badPathTests = []struct { var badPathTests = []struct {
v, e interface{} v, e interface{}
}{ }{
{&BadPathTestA{}, &TagPathError{reflect.TypeOf(BadPathTestA{}), "First", "items>item1", "Second", "items>"}}, {&BadPathTestA{}, &TagPathError{reflect.TypeOf(BadPathTestA{}), "First", "items>item1", "Second", "items"}},
{&BadPathTestB{}, &TagPathError{reflect.TypeOf(BadPathTestB{}), "First", "items>item1", "Second", "items>item1>value"}}, {&BadPathTestB{}, &TagPathError{reflect.TypeOf(BadPathTestB{}), "First", "items>item1", "Second", "items>item1>value"}},
{&BadPathTestC{}, &TagPathError{reflect.TypeOf(BadPathTestC{}), "First", "", "Second", "First"}},
{&BadPathTestD{}, &TagPathError{reflect.TypeOf(BadPathTestD{}), "First", "", "Second", "First"}},
} }
func TestUnmarshalBadPaths(t *testing.T) { func TestUnmarshalBadPaths(t *testing.T) {
for _, tt := range badPathTests { for _, tt := range badPathTests {
err := Unmarshal(StringReader(pathTestString), tt.v) err := Unmarshal(strings.NewReader(pathTestString), tt.v)
if !reflect.DeepEqual(err, tt.e) { if !reflect.DeepEqual(err, tt.e) {
t.Fatalf("Unmarshal with %#v didn't fail properly: %#v", tt.v, err) t.Fatalf("Unmarshal with %#v didn't fail properly:\nhave %#v,\nwant %#v", tt.v, err, tt.e)
} }
} }
} }
func TestUnmarshalAttrs(t *testing.T) {
var f AttrTest
if err := Unmarshal(StringReader(attrString), &f); err != nil {
t.Fatalf("Unmarshal: %s", err)
}
if !reflect.DeepEqual(f, attrStruct) {
t.Fatalf("have %#v\nwant %#v", f, attrStruct)
}
}
type AttrTest struct {
Test1 Test1
Test2 Test2
}
type Test1 struct {
Int int `xml:"attr"`
Float float64 `xml:"attr"`
Uint8 uint8 `xml:"attr"`
}
type Test2 struct {
Bool bool `xml:"attr"`
}
const attrString = `
<?xml version="1.0" charset="utf-8"?>
<attrtest>
<test1 int="8" float="23.5" uint8="255"/>
<test2 bool="true"/>
</attrtest>
`
var attrStruct = AttrTest{
Test1: Test1{
Int: 8,
Float: 23.5,
Uint8: 255,
},
Test2: Test2{
Bool: true,
},
}
// test data for TestUnmarshalWithoutNameType
const OK = "OK" const OK = "OK"
const withoutNameTypeData = ` const withoutNameTypeData = `
<?xml version="1.0" charset="utf-8"?> <?xml version="1.0" charset="utf-8"?>
<Test3 attr="OK" />` <Test3 Attr="OK" />`
type TestThree struct { type TestThree struct {
XMLName bool `xml:"Test3"` // XMLName field without an xml.Name type XMLName Name `xml:"Test3"`
Attr string `xml:"attr"` Attr string `xml:",attr"`
} }
func TestUnmarshalWithoutNameType(t *testing.T) { func TestUnmarshalWithoutNameType(t *testing.T) {
var x TestThree var x TestThree
if err := Unmarshal(StringReader(withoutNameTypeData), &x); err != nil { if err := Unmarshal(strings.NewReader(withoutNameTypeData), &x); err != nil {
t.Fatalf("Unmarshal: %s", err) t.Fatalf("Unmarshal: %s", err)
} }
if x.Attr != OK { if x.Attr != OK {
......
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xml
import (
"fmt"
"reflect"
"strings"
"sync"
)
// typeInfo holds details for the xml representation of a type.
type typeInfo struct {
xmlname *fieldInfo
fields []fieldInfo
}
// fieldInfo holds details for the xml representation of a single field.
type fieldInfo struct {
idx []int
name string
xmlns string
flags fieldFlags
parents []string
}
type fieldFlags int
const (
fElement fieldFlags = 1 << iota
fAttr
fCharData
fInnerXml
fComment
fAny
// TODO:
//fIgnore
//fOmitEmpty
fMode = fElement | fAttr | fCharData | fInnerXml | fComment | fAny
)
var tinfoMap = make(map[reflect.Type]*typeInfo)
var tinfoLock sync.RWMutex
// getTypeInfo returns the typeInfo structure with details necessary
// for marshalling and unmarshalling typ.
func getTypeInfo(typ reflect.Type) (*typeInfo, error) {
tinfoLock.RLock()
tinfo, ok := tinfoMap[typ]
tinfoLock.RUnlock()
if ok {
return tinfo, nil
}
tinfo = &typeInfo{}
if typ.Kind() == reflect.Struct {
n := typ.NumField()
for i := 0; i < n; i++ {
f := typ.Field(i)
if f.PkgPath != "" {
continue // Private field
}
// For embedded structs, embed its fields.
if f.Anonymous {
if f.Type.Kind() != reflect.Struct {
continue
}
inner, err := getTypeInfo(f.Type)
if err != nil {
return nil, err
}
for _, finfo := range inner.fields {
finfo.idx = append([]int{i}, finfo.idx...)
if err := addFieldInfo(typ, tinfo, &finfo); err != nil {
return nil, err
}
}
continue
}
finfo, err := structFieldInfo(typ, &f)
if err != nil {
return nil, err
}
if f.Name == "XMLName" {
tinfo.xmlname = finfo
continue
}
// Add the field if it doesn't conflict with other fields.
if err := addFieldInfo(typ, tinfo, finfo); err != nil {
return nil, err
}
}
}
tinfoLock.Lock()
tinfoMap[typ] = tinfo
tinfoLock.Unlock()
return tinfo, nil
}
// structFieldInfo builds and returns a fieldInfo for f.
func structFieldInfo(typ reflect.Type, f *reflect.StructField) (*fieldInfo, error) {
finfo := &fieldInfo{idx: f.Index}
// Split the tag from the xml namespace if necessary.
tag := f.Tag.Get("xml")
if i := strings.Index(tag, " "); i >= 0 {
finfo.xmlns, tag = tag[:i], tag[i+1:]
}
// Parse flags.
tokens := strings.Split(tag, ",")
if len(tokens) == 1 {
finfo.flags = fElement
} else {
tag = tokens[0]
for _, flag := range tokens[1:] {
switch flag {
case "attr":
finfo.flags |= fAttr
case "chardata":
finfo.flags |= fCharData
case "innerxml":
finfo.flags |= fInnerXml
case "comment":
finfo.flags |= fComment
case "any":
finfo.flags |= fAny
}
}
// Validate the flags used.
switch mode := finfo.flags & fMode; mode {
case 0:
finfo.flags |= fElement
case fAttr, fCharData, fInnerXml, fComment, fAny:
if f.Name != "XMLName" && (tag == "" || mode == fAttr) {
break
}
fallthrough
default:
// This will also catch multiple modes in a single field.
return nil, fmt.Errorf("xml: invalid tag in field %s of type %s: %q",
f.Name, typ, f.Tag.Get("xml"))
}
}
// Use of xmlns without a name is not allowed.
if finfo.xmlns != "" && tag == "" {
return nil, fmt.Errorf("xml: namespace without name in field %s of type %s: %q",
f.Name, typ, f.Tag.Get("xml"))
}
if f.Name == "XMLName" {
// The XMLName field records the XML element name. Don't
// process it as usual because its name should default to
// empty rather than to the field name.
finfo.name = tag
return finfo, nil
}
if tag == "" {
// If the name part of the tag is completely empty, get
// default from XMLName of underlying struct if feasible,
// or field name otherwise.
if xmlname := lookupXMLName(f.Type); xmlname != nil {
finfo.xmlns, finfo.name = xmlname.xmlns, xmlname.name
} else {
finfo.name = f.Name
}
return finfo, nil
}
// Prepare field name and parents.
tokens = strings.Split(tag, ">")
if tokens[0] == "" {
tokens[0] = f.Name
}
if tokens[len(tokens)-1] == "" {
return nil, fmt.Errorf("xml: trailing '>' in field %s of type %s", f.Name, typ)
}
finfo.name = tokens[len(tokens)-1]
if len(tokens) > 1 {
finfo.parents = tokens[:len(tokens)-1]
}
// If the field type has an XMLName field, the names must match
// so that the behavior of both marshalling and unmarshalling
// is straighforward and unambiguous.
if finfo.flags&fElement != 0 {
ftyp := f.Type
xmlname := lookupXMLName(ftyp)
if xmlname != nil && xmlname.name != finfo.name {
return nil, fmt.Errorf("xml: name %q in tag of %s.%s conflicts with name %q in %s.XMLName",
finfo.name, typ, f.Name, xmlname.name, ftyp)
}
}
return finfo, nil
}
// lookupXMLName returns the fieldInfo for typ's XMLName field
// in case it exists and has a valid xml field tag, otherwise
// it returns nil.
func lookupXMLName(typ reflect.Type) (xmlname *fieldInfo) {
for typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
if typ.Kind() != reflect.Struct {
return nil
}
for i, n := 0, typ.NumField(); i < n; i++ {
f := typ.Field(i)
if f.Name != "XMLName" {
continue
}
finfo, err := structFieldInfo(typ, &f)
if finfo.name != "" && err == nil {
return finfo
}
// Also consider errors as a non-existent field tag
// and let getTypeInfo itself report the error.
break
}
return nil
}
func min(a, b int) int {
if a <= b {
return a
}
return b
}
// addFieldInfo adds finfo to tinfo.fields if there are no
// conflicts, or if conflicts arise from previous fields that were
// obtained from deeper embedded structures than finfo. In the latter
// case, the conflicting entries are dropped.
// A conflict occurs when the path (parent + name) to a field is
// itself a prefix of another path, or when two paths match exactly.
// It is okay for field paths to share a common, shorter prefix.
func addFieldInfo(typ reflect.Type, tinfo *typeInfo, newf *fieldInfo) error {
var conflicts []int
Loop:
// First, figure all conflicts. Most working code will have none.
for i := range tinfo.fields {
oldf := &tinfo.fields[i]
if oldf.flags&fMode != newf.flags&fMode {
continue
}
minl := min(len(newf.parents), len(oldf.parents))
for p := 0; p < minl; p++ {
if oldf.parents[p] != newf.parents[p] {
continue Loop
}
}
if len(oldf.parents) > len(newf.parents) {
if oldf.parents[len(newf.parents)] == newf.name {
conflicts = append(conflicts, i)
}
} else if len(oldf.parents) < len(newf.parents) {
if newf.parents[len(oldf.parents)] == oldf.name {
conflicts = append(conflicts, i)
}
} else {
if newf.name == oldf.name {
conflicts = append(conflicts, i)
}
}
}
// Without conflicts, add the new field and return.
if conflicts == nil {
tinfo.fields = append(tinfo.fields, *newf)
return nil
}
// If any conflict is shallower, ignore the new field.
// This matches the Go field resolution on embedding.
for _, i := range conflicts {
if len(tinfo.fields[i].idx) < len(newf.idx) {
return nil
}
}
// Otherwise, if any of them is at the same depth level, it's an error.
for _, i := range conflicts {
oldf := &tinfo.fields[i]
if len(oldf.idx) == len(newf.idx) {
f1 := typ.FieldByIndex(oldf.idx)
f2 := typ.FieldByIndex(newf.idx)
return &TagPathError{typ, f1.Name, f1.Tag.Get("xml"), f2.Name, f2.Tag.Get("xml")}
}
}
// Otherwise, the new field is shallower, and thus takes precedence,
// so drop the conflicting fields from tinfo and append the new one.
for c := len(conflicts) - 1; c >= 0; c-- {
i := conflicts[c]
copy(tinfo.fields[i:], tinfo.fields[i+1:])
tinfo.fields = tinfo.fields[:len(tinfo.fields)-1]
}
tinfo.fields = append(tinfo.fields, *newf)
return nil
}
// A TagPathError represents an error in the unmarshalling process
// caused by the use of field tags with conflicting paths.
type TagPathError struct {
Struct reflect.Type
Field1, Tag1 string
Field2, Tag2 string
}
func (e *TagPathError) Error() string {
return fmt.Sprintf("%s field %q with tag %q conflicts with field %q with tag %q", e.Struct, e.Field1, e.Tag1, e.Field2, e.Tag2)
}
...@@ -154,36 +154,8 @@ var xmlInput = []string{ ...@@ -154,36 +154,8 @@ var xmlInput = []string{
"<t>cdata]]></t>", "<t>cdata]]></t>",
} }
type stringReader struct {
s string
off int
}
func (r *stringReader) Read(b []byte) (n int, err error) {
if r.off >= len(r.s) {
return 0, io.EOF
}
for r.off < len(r.s) && n < len(b) {
b[n] = r.s[r.off]
n++
r.off++
}
return
}
func (r *stringReader) ReadByte() (b byte, err error) {
if r.off >= len(r.s) {
return 0, io.EOF
}
b = r.s[r.off]
r.off++
return
}
func StringReader(s string) io.Reader { return &stringReader{s, 0} }
func TestRawToken(t *testing.T) { func TestRawToken(t *testing.T) {
p := NewParser(StringReader(testInput)) p := NewParser(strings.NewReader(testInput))
testRawToken(t, p, rawTokens) testRawToken(t, p, rawTokens)
} }
...@@ -207,7 +179,7 @@ func (d *downCaser) Read(p []byte) (int, error) { ...@@ -207,7 +179,7 @@ func (d *downCaser) Read(p []byte) (int, error) {
func TestRawTokenAltEncoding(t *testing.T) { func TestRawTokenAltEncoding(t *testing.T) {
sawEncoding := "" sawEncoding := ""
p := NewParser(StringReader(testInputAltEncoding)) p := NewParser(strings.NewReader(testInputAltEncoding))
p.CharsetReader = func(charset string, input io.Reader) (io.Reader, error) { p.CharsetReader = func(charset string, input io.Reader) (io.Reader, error) {
sawEncoding = charset sawEncoding = charset
if charset != "x-testing-uppercase" { if charset != "x-testing-uppercase" {
...@@ -219,7 +191,7 @@ func TestRawTokenAltEncoding(t *testing.T) { ...@@ -219,7 +191,7 @@ func TestRawTokenAltEncoding(t *testing.T) {
} }
func TestRawTokenAltEncodingNoConverter(t *testing.T) { func TestRawTokenAltEncodingNoConverter(t *testing.T) {
p := NewParser(StringReader(testInputAltEncoding)) p := NewParser(strings.NewReader(testInputAltEncoding))
token, err := p.RawToken() token, err := p.RawToken()
if token == nil { if token == nil {
t.Fatalf("expected a token on first RawToken call") t.Fatalf("expected a token on first RawToken call")
...@@ -286,7 +258,7 @@ var nestedDirectivesTokens = []Token{ ...@@ -286,7 +258,7 @@ var nestedDirectivesTokens = []Token{
} }
func TestNestedDirectives(t *testing.T) { func TestNestedDirectives(t *testing.T) {
p := NewParser(StringReader(nestedDirectivesInput)) p := NewParser(strings.NewReader(nestedDirectivesInput))
for i, want := range nestedDirectivesTokens { for i, want := range nestedDirectivesTokens {
have, err := p.Token() have, err := p.Token()
...@@ -300,7 +272,7 @@ func TestNestedDirectives(t *testing.T) { ...@@ -300,7 +272,7 @@ func TestNestedDirectives(t *testing.T) {
} }
func TestToken(t *testing.T) { func TestToken(t *testing.T) {
p := NewParser(StringReader(testInput)) p := NewParser(strings.NewReader(testInput))
for i, want := range cookedTokens { for i, want := range cookedTokens {
have, err := p.Token() have, err := p.Token()
...@@ -315,7 +287,7 @@ func TestToken(t *testing.T) { ...@@ -315,7 +287,7 @@ func TestToken(t *testing.T) {
func TestSyntax(t *testing.T) { func TestSyntax(t *testing.T) {
for i := range xmlInput { for i := range xmlInput {
p := NewParser(StringReader(xmlInput[i])) p := NewParser(strings.NewReader(xmlInput[i]))
var err error var err error
for _, err = p.Token(); err == nil; _, err = p.Token() { for _, err = p.Token(); err == nil; _, err = p.Token() {
} }
...@@ -372,26 +344,26 @@ var all = allScalars{ ...@@ -372,26 +344,26 @@ var all = allScalars{
var sixteen = "16" var sixteen = "16"
const testScalarsInput = `<allscalars> const testScalarsInput = `<allscalars>
<true1>true</true1> <True1>true</True1>
<true2>1</true2> <True2>1</True2>
<false1>false</false1> <False1>false</False1>
<false2>0</false2> <False2>0</False2>
<int>1</int> <Int>1</Int>
<int8>-2</int8> <Int8>-2</Int8>
<int16>3</int16> <Int16>3</Int16>
<int32>-4</int32> <Int32>-4</Int32>
<int64>5</int64> <Int64>5</Int64>
<uint>6</uint> <Uint>6</Uint>
<uint8>7</uint8> <Uint8>7</Uint8>
<uint16>8</uint16> <Uint16>8</Uint16>
<uint32>9</uint32> <Uint32>9</Uint32>
<uint64>10</uint64> <Uint64>10</Uint64>
<uintptr>11</uintptr> <Uintptr>11</Uintptr>
<float>12.0</float> <Float>12.0</Float>
<float32>13.0</float32> <Float32>13.0</Float32>
<float64>14.0</float64> <Float64>14.0</Float64>
<string>15</string> <String>15</String>
<ptrstring>16</ptrstring> <PtrString>16</PtrString>
</allscalars>` </allscalars>`
func TestAllScalars(t *testing.T) { func TestAllScalars(t *testing.T) {
...@@ -412,7 +384,7 @@ type item struct { ...@@ -412,7 +384,7 @@ type item struct {
} }
func TestIssue569(t *testing.T) { func TestIssue569(t *testing.T) {
data := `<item><field_a>abcd</field_a></item>` data := `<item><Field_a>abcd</Field_a></item>`
var i item var i item
buf := bytes.NewBufferString(data) buf := bytes.NewBufferString(data)
err := Unmarshal(buf, &i) err := Unmarshal(buf, &i)
...@@ -424,7 +396,7 @@ func TestIssue569(t *testing.T) { ...@@ -424,7 +396,7 @@ func TestIssue569(t *testing.T) {
func TestUnquotedAttrs(t *testing.T) { func TestUnquotedAttrs(t *testing.T) {
data := "<tag attr=azAZ09:-_\t>" data := "<tag attr=azAZ09:-_\t>"
p := NewParser(StringReader(data)) p := NewParser(strings.NewReader(data))
p.Strict = false p.Strict = false
token, err := p.Token() token, err := p.Token()
if _, ok := err.(*SyntaxError); ok { if _, ok := err.(*SyntaxError); ok {
...@@ -450,7 +422,7 @@ func TestValuelessAttrs(t *testing.T) { ...@@ -450,7 +422,7 @@ func TestValuelessAttrs(t *testing.T) {
{"<input checked />", "input", "checked"}, {"<input checked />", "input", "checked"},
} }
for _, test := range tests { for _, test := range tests {
p := NewParser(StringReader(test[0])) p := NewParser(strings.NewReader(test[0]))
p.Strict = false p.Strict = false
token, err := p.Token() token, err := p.Token()
if _, ok := err.(*SyntaxError); ok { if _, ok := err.(*SyntaxError); ok {
...@@ -500,7 +472,7 @@ func TestCopyTokenStartElement(t *testing.T) { ...@@ -500,7 +472,7 @@ func TestCopyTokenStartElement(t *testing.T) {
func TestSyntaxErrorLineNum(t *testing.T) { func TestSyntaxErrorLineNum(t *testing.T) {
testInput := "<P>Foo<P>\n\n<P>Bar</>\n" testInput := "<P>Foo<P>\n\n<P>Bar</>\n"
p := NewParser(StringReader(testInput)) p := NewParser(strings.NewReader(testInput))
var err error var err error
for _, err = p.Token(); err == nil; _, err = p.Token() { for _, err = p.Token(); err == nil; _, err = p.Token() {
} }
...@@ -515,7 +487,7 @@ func TestSyntaxErrorLineNum(t *testing.T) { ...@@ -515,7 +487,7 @@ func TestSyntaxErrorLineNum(t *testing.T) {
func TestTrailingRawToken(t *testing.T) { func TestTrailingRawToken(t *testing.T) {
input := `<FOO></FOO> ` input := `<FOO></FOO> `
p := NewParser(StringReader(input)) p := NewParser(strings.NewReader(input))
var err error var err error
for _, err = p.RawToken(); err == nil; _, err = p.RawToken() { for _, err = p.RawToken(); err == nil; _, err = p.RawToken() {
} }
...@@ -526,7 +498,7 @@ func TestTrailingRawToken(t *testing.T) { ...@@ -526,7 +498,7 @@ func TestTrailingRawToken(t *testing.T) {
func TestTrailingToken(t *testing.T) { func TestTrailingToken(t *testing.T) {
input := `<FOO></FOO> ` input := `<FOO></FOO> `
p := NewParser(StringReader(input)) p := NewParser(strings.NewReader(input))
var err error var err error
for _, err = p.Token(); err == nil; _, err = p.Token() { for _, err = p.Token(); err == nil; _, err = p.Token() {
} }
...@@ -537,7 +509,7 @@ func TestTrailingToken(t *testing.T) { ...@@ -537,7 +509,7 @@ func TestTrailingToken(t *testing.T) {
func TestEntityInsideCDATA(t *testing.T) { func TestEntityInsideCDATA(t *testing.T) {
input := `<test><![CDATA[ &val=foo ]]></test>` input := `<test><![CDATA[ &val=foo ]]></test>`
p := NewParser(StringReader(input)) p := NewParser(strings.NewReader(input))
var err error var err error
for _, err = p.Token(); err == nil; _, err = p.Token() { for _, err = p.Token(); err == nil; _, err = p.Token() {
} }
...@@ -569,7 +541,7 @@ var characterTests = []struct { ...@@ -569,7 +541,7 @@ var characterTests = []struct {
func TestDisallowedCharacters(t *testing.T) { func TestDisallowedCharacters(t *testing.T) {
for i, tt := range characterTests { for i, tt := range characterTests {
p := NewParser(StringReader(tt.in)) p := NewParser(strings.NewReader(tt.in))
var err error var err error
for err == nil { for err == nil {
......
...@@ -8,7 +8,7 @@ import "unicode/utf8" ...@@ -8,7 +8,7 @@ import "unicode/utf8"
type input interface { type input interface {
skipASCII(p int) int skipASCII(p int) int
skipNonStarter() int skipNonStarter(p int) int
appendSlice(buf []byte, s, e int) []byte appendSlice(buf []byte, s, e int) []byte
copySlice(buf []byte, s, e int) copySlice(buf []byte, s, e int)
charinfo(p int) (uint16, int) charinfo(p int) (uint16, int)
...@@ -25,8 +25,7 @@ func (s inputString) skipASCII(p int) int { ...@@ -25,8 +25,7 @@ func (s inputString) skipASCII(p int) int {
return p return p
} }
func (s inputString) skipNonStarter() int { func (s inputString) skipNonStarter(p int) int {
p := 0
for ; p < len(s) && !utf8.RuneStart(s[p]); p++ { for ; p < len(s) && !utf8.RuneStart(s[p]); p++ {
} }
return p return p
...@@ -71,8 +70,7 @@ func (s inputBytes) skipASCII(p int) int { ...@@ -71,8 +70,7 @@ func (s inputBytes) skipASCII(p int) int {
return p return p
} }
func (s inputBytes) skipNonStarter() int { func (s inputBytes) skipNonStarter(p int) int {
p := 0
for ; p < len(s) && !utf8.RuneStart(s[p]); p++ { for ; p < len(s) && !utf8.RuneStart(s[p]); p++ {
} }
return p return p
......
...@@ -34,24 +34,28 @@ const ( ...@@ -34,24 +34,28 @@ const (
// Bytes returns f(b). May return b if f(b) = b. // Bytes returns f(b). May return b if f(b) = b.
func (f Form) Bytes(b []byte) []byte { func (f Form) Bytes(b []byte) []byte {
n := f.QuickSpan(b) rb := reorderBuffer{}
rb.init(f, b)
n := quickSpan(&rb, 0)
if n == len(b) { if n == len(b) {
return b return b
} }
out := make([]byte, n, len(b)) out := make([]byte, n, len(b))
copy(out, b[0:n]) copy(out, b[0:n])
return f.Append(out, b[n:]...) return doAppend(&rb, out, n)
} }
// String returns f(s). // String returns f(s).
func (f Form) String(s string) string { func (f Form) String(s string) string {
n := f.QuickSpanString(s) rb := reorderBuffer{}
rb.initString(f, s)
n := quickSpan(&rb, 0)
if n == len(s) { if n == len(s) {
return s return s
} }
out := make([]byte, 0, len(s)) out := make([]byte, n, len(s))
copy(out, s[0:n]) copy(out, s[0:n])
return string(f.AppendString(out, s[n:])) return string(doAppend(&rb, out, n))
} }
// IsNormal returns true if b == f(b). // IsNormal returns true if b == f(b).
...@@ -122,23 +126,27 @@ func (f Form) IsNormalString(s string) bool { ...@@ -122,23 +126,27 @@ func (f Form) IsNormalString(s string) bool {
// patchTail fixes a case where a rune may be incorrectly normalized // patchTail fixes a case where a rune may be incorrectly normalized
// if it is followed by illegal continuation bytes. It returns the // if it is followed by illegal continuation bytes. It returns the
// patched buffer and the number of trailing continuation bytes that // patched buffer and whether there were trailing continuation bytes.
// have been dropped. func patchTail(rb *reorderBuffer, buf []byte) ([]byte, bool) {
func patchTail(rb *reorderBuffer, buf []byte) ([]byte, int) {
info, p := lastRuneStart(&rb.f, buf) info, p := lastRuneStart(&rb.f, buf)
if p == -1 || info.size == 0 { if p == -1 || info.size == 0 {
return buf, 0 return buf, false
} }
end := p + int(info.size) end := p + int(info.size)
extra := len(buf) - end extra := len(buf) - end
if extra > 0 { if extra > 0 {
// Potentially allocating memory. However, this only
// happens with ill-formed UTF-8.
x := make([]byte, 0)
x = append(x, buf[len(buf)-extra:]...)
buf = decomposeToLastBoundary(rb, buf[:end]) buf = decomposeToLastBoundary(rb, buf[:end])
if rb.f.composing { if rb.f.composing {
rb.compose() rb.compose()
} }
return rb.flush(buf), extra buf = rb.flush(buf)
return append(buf, x...), true
} }
return buf, 0 return buf, false
} }
func appendQuick(rb *reorderBuffer, dst []byte, i int) ([]byte, int) { func appendQuick(rb *reorderBuffer, dst []byte, i int) ([]byte, int) {
...@@ -157,23 +165,23 @@ func (f Form) Append(out []byte, src ...byte) []byte { ...@@ -157,23 +165,23 @@ func (f Form) Append(out []byte, src ...byte) []byte {
} }
rb := reorderBuffer{} rb := reorderBuffer{}
rb.init(f, src) rb.init(f, src)
return doAppend(&rb, out) return doAppend(&rb, out, 0)
} }
func doAppend(rb *reorderBuffer, out []byte) []byte { func doAppend(rb *reorderBuffer, out []byte, p int) []byte {
src, n := rb.src, rb.nsrc src, n := rb.src, rb.nsrc
doMerge := len(out) > 0 doMerge := len(out) > 0
p := 0 if q := src.skipNonStarter(p); q > p {
if p = src.skipNonStarter(); p > 0 {
// Move leading non-starters to destination. // Move leading non-starters to destination.
out = src.appendSlice(out, 0, p) out = src.appendSlice(out, p, q)
buf, ndropped := patchTail(rb, out) buf, endsInError := patchTail(rb, out)
if ndropped > 0 { if endsInError {
out = src.appendSlice(buf, p-ndropped, p) out = buf
doMerge = false // no need to merge, ends with illegal UTF-8 doMerge = false // no need to merge, ends with illegal UTF-8
} else { } else {
out = decomposeToLastBoundary(rb, buf) // force decomposition out = decomposeToLastBoundary(rb, buf) // force decomposition
} }
p = q
} }
fd := &rb.f fd := &rb.f
if doMerge { if doMerge {
...@@ -217,7 +225,7 @@ func (f Form) AppendString(out []byte, src string) []byte { ...@@ -217,7 +225,7 @@ func (f Form) AppendString(out []byte, src string) []byte {
} }
rb := reorderBuffer{} rb := reorderBuffer{}
rb.initString(f, src) rb.initString(f, src)
return doAppend(&rb, out) return doAppend(&rb, out, 0)
} }
// QuickSpan returns a boundary n such that b[0:n] == f(b[0:n]). // QuickSpan returns a boundary n such that b[0:n] == f(b[0:n]).
...@@ -225,7 +233,8 @@ func (f Form) AppendString(out []byte, src string) []byte { ...@@ -225,7 +233,8 @@ func (f Form) AppendString(out []byte, src string) []byte {
func (f Form) QuickSpan(b []byte) int { func (f Form) QuickSpan(b []byte) int {
rb := reorderBuffer{} rb := reorderBuffer{}
rb.init(f, b) rb.init(f, b)
return quickSpan(&rb, 0) n := quickSpan(&rb, 0)
return n
} }
func quickSpan(rb *reorderBuffer, i int) int { func quickSpan(rb *reorderBuffer, i int) int {
...@@ -301,7 +310,7 @@ func (f Form) FirstBoundary(b []byte) int { ...@@ -301,7 +310,7 @@ func (f Form) FirstBoundary(b []byte) int {
func firstBoundary(rb *reorderBuffer) int { func firstBoundary(rb *reorderBuffer) int {
src, nsrc := rb.src, rb.nsrc src, nsrc := rb.src, rb.nsrc
i := src.skipNonStarter() i := src.skipNonStarter(0)
if i >= nsrc { if i >= nsrc {
return -1 return -1
} }
......
...@@ -253,7 +253,7 @@ var quickSpanNFDTests = []PositionTest{ ...@@ -253,7 +253,7 @@ var quickSpanNFDTests = []PositionTest{
{"\u0316\u0300cd", 6, ""}, {"\u0316\u0300cd", 6, ""},
{"\u043E\u0308b", 5, ""}, {"\u043E\u0308b", 5, ""},
// incorrectly ordered combining characters // incorrectly ordered combining characters
{"ab\u0300\u0316", 1, ""}, // TODO(mpvl): we could skip 'b' as well. {"ab\u0300\u0316", 1, ""}, // TODO: we could skip 'b' as well.
{"ab\u0300\u0316cd", 1, ""}, {"ab\u0300\u0316cd", 1, ""},
// Hangul // Hangul
{"같은", 0, ""}, {"같은", 0, ""},
...@@ -465,6 +465,7 @@ var appendTests = []AppendTest{ ...@@ -465,6 +465,7 @@ var appendTests = []AppendTest{
{"\u0300", "\xFC\x80\x80\x80\x80\x80\u0300", "\u0300\xFC\x80\x80\x80\x80\x80\u0300"}, {"\u0300", "\xFC\x80\x80\x80\x80\x80\u0300", "\u0300\xFC\x80\x80\x80\x80\x80\u0300"},
{"\xF8\x80\x80\x80\x80\u0300", "\u0300", "\xF8\x80\x80\x80\x80\u0300\u0300"}, {"\xF8\x80\x80\x80\x80\u0300", "\u0300", "\xF8\x80\x80\x80\x80\u0300\u0300"},
{"\xFC\x80\x80\x80\x80\x80\u0300", "\u0300", "\xFC\x80\x80\x80\x80\x80\u0300\u0300"}, {"\xFC\x80\x80\x80\x80\x80\u0300", "\u0300", "\xFC\x80\x80\x80\x80\x80\u0300\u0300"},
{"\xF8\x80\x80\x80", "\x80\u0300\u0300", "\xF8\x80\x80\x80\x80\u0300\u0300"},
} }
func appendF(f Form, out []byte, s string) []byte { func appendF(f Form, out []byte, s string) []byte {
...@@ -475,9 +476,23 @@ func appendStringF(f Form, out []byte, s string) []byte { ...@@ -475,9 +476,23 @@ func appendStringF(f Form, out []byte, s string) []byte {
return f.AppendString(out, s) return f.AppendString(out, s)
} }
func bytesF(f Form, out []byte, s string) []byte {
buf := []byte{}
buf = append(buf, out...)
buf = append(buf, s...)
return f.Bytes(buf)
}
func stringF(f Form, out []byte, s string) []byte {
outs := string(out) + s
return []byte(f.String(outs))
}
func TestAppend(t *testing.T) { func TestAppend(t *testing.T) {
runAppendTests(t, "TestAppend", NFKC, appendF, appendTests) runAppendTests(t, "TestAppend", NFKC, appendF, appendTests)
runAppendTests(t, "TestAppendString", NFKC, appendStringF, appendTests) runAppendTests(t, "TestAppendString", NFKC, appendStringF, appendTests)
runAppendTests(t, "TestBytes", NFKC, bytesF, appendTests)
runAppendTests(t, "TestString", NFKC, stringF, appendTests)
} }
func doFormBenchmark(b *testing.B, f Form, s string) { func doFormBenchmark(b *testing.B, f Form, s string) {
......
...@@ -27,7 +27,7 @@ func (w *normWriter) Write(data []byte) (n int, err error) { ...@@ -27,7 +27,7 @@ func (w *normWriter) Write(data []byte) (n int, err error) {
} }
w.rb.src = inputBytes(data[:m]) w.rb.src = inputBytes(data[:m])
w.rb.nsrc = m w.rb.nsrc = m
w.buf = doAppend(&w.rb, w.buf) w.buf = doAppend(&w.rb, w.buf, 0)
data = data[m:] data = data[m:]
n += m n += m
...@@ -101,7 +101,7 @@ func (r *normReader) Read(p []byte) (int, error) { ...@@ -101,7 +101,7 @@ func (r *normReader) Read(p []byte) (int, error) {
r.rb.src = inputBytes(r.inbuf[0:n]) r.rb.src = inputBytes(r.inbuf[0:n])
r.rb.nsrc, r.err = n, err r.rb.nsrc, r.err = n, err
if n > 0 { if n > 0 {
r.outbuf = doAppend(&r.rb, r.outbuf) r.outbuf = doAppend(&r.rb, r.outbuf, 0)
} }
if err == io.EOF { if err == io.EOF {
r.lastBoundary = len(r.outbuf) r.lastBoundary = len(r.outbuf)
......
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package proxy
import (
"net"
)
type direct struct{}
// Direct is a direct proxy: one that makes network connections directly.
var Direct = direct{}
func (direct) Dial(network, addr string) (net.Conn, error) {
return net.Dial(network, addr)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package proxy
import (
"net"
"strings"
)
// A PerHost directs connections to a default Dailer unless the hostname
// requested matches one of a number of exceptions.
type PerHost struct {
def, bypass Dialer
bypassNetworks []*net.IPNet
bypassIPs []net.IP
bypassZones []string
bypassHosts []string
}
// NewPerHost returns a PerHost Dialer that directs connections to either
// defaultDialer or bypass, depending on whether the connection matches one of
// the configured rules.
func NewPerHost(defaultDialer, bypass Dialer) *PerHost {
return &PerHost{
def: defaultDialer,
bypass: bypass,
}
}
// Dial connects to the address addr on the network net through either
// defaultDialer or bypass.
func (p *PerHost) Dial(network, addr string) (c net.Conn, err error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
return p.dialerForRequest(host).Dial(network, addr)
}
func (p *PerHost) dialerForRequest(host string) Dialer {
if ip := net.ParseIP(host); ip != nil {
for _, net := range p.bypassNetworks {
if net.Contains(ip) {
return p.bypass
}
}
for _, bypassIP := range p.bypassIPs {
if bypassIP.Equal(ip) {
return p.bypass
}
}
return p.def
}
for _, zone := range p.bypassZones {
if strings.HasSuffix(host, zone) {
return p.bypass
}
if host == zone[1:] {
// For a zone "example.com", we match "example.com"
// too.
return p.bypass
}
}
for _, bypassHost := range p.bypassHosts {
if bypassHost == host {
return p.bypass
}
}
return p.def
}
// AddFromString parses a string that contains comma-separated values
// specifing hosts that should use the bypass proxy. Each value is either an
// IP address, a CIDR range, a zone (*.example.com) or a hostname
// (localhost). A best effort is made to parse the string and errors are
// ignored.
func (p *PerHost) AddFromString(s string) {
hosts := strings.Split(s, ",")
for _, host := range hosts {
host = strings.TrimSpace(host)
if len(host) == 0 {
continue
}
if strings.Contains(host, "/") {
// We assume that it's a CIDR address like 127.0.0.0/8
if _, net, err := net.ParseCIDR(host); err == nil {
p.AddNetwork(net)
}
continue
}
if ip := net.ParseIP(host); ip != nil {
p.AddIP(ip)
continue
}
if strings.HasPrefix(host, "*.") {
p.AddZone(host[1:])
continue
}
p.AddHost(host)
}
}
// AddIP specifies an IP address that will use the bypass proxy. Note that
// this will only take effect if a literal IP address is dialed. A connection
// to a named host will never match an IP.
func (p *PerHost) AddIP(ip net.IP) {
p.bypassIPs = append(p.bypassIPs, ip)
}
// AddIP specifies an IP range that will use the bypass proxy. Note that this
// will only take effect if a literal IP address is dialed. A connection to a
// named host will never match.
func (p *PerHost) AddNetwork(net *net.IPNet) {
p.bypassNetworks = append(p.bypassNetworks, net)
}
// AddZone specifies a DNS suffix that will use the bypass proxy. A zone of
// "example.com" matches "example.com" and all of its subdomains.
func (p *PerHost) AddZone(zone string) {
if strings.HasSuffix(zone, ".") {
zone = zone[:len(zone)-1]
}
if !strings.HasPrefix(zone, ".") {
zone = "." + zone
}
p.bypassZones = append(p.bypassZones, zone)
}
// AddHost specifies a hostname that will use the bypass proxy.
func (p *PerHost) AddHost(host string) {
if strings.HasSuffix(host, ".") {
host = host[:len(host)-1]
}
p.bypassHosts = append(p.bypassHosts, host)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package proxy
import (
"errors"
"net"
"reflect"
"testing"
)
type recordingProxy struct {
addrs []string
}
func (r *recordingProxy) Dial(network, addr string) (net.Conn, error) {
r.addrs = append(r.addrs, addr)
return nil, errors.New("recordingProxy")
}
func TestPerHost(t *testing.T) {
var def, bypass recordingProxy
perHost := NewPerHost(&def, &bypass)
perHost.AddFromString("localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16")
expectedDef := []string{
"example.com:123",
"1.2.3.4:123",
"[1001::]:123",
}
expectedBypass := []string{
"localhost:123",
"zone:123",
"foo.zone:123",
"127.0.0.1:123",
"10.1.2.3:123",
"[1000::]:123",
}
for _, addr := range expectedDef {
perHost.Dial("tcp", addr)
}
for _, addr := range expectedBypass {
perHost.Dial("tcp", addr)
}
if !reflect.DeepEqual(expectedDef, def.addrs) {
t.Errorf("Hosts which went to the default proxy didn't match. Got %v, want %v", def.addrs, expectedDef)
}
if !reflect.DeepEqual(expectedBypass, bypass.addrs) {
t.Errorf("Hosts which went to the bypass proxy didn't match. Got %v, want %v", bypass.addrs, expectedBypass)
}
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package proxy provides support for a variety of protocols to proxy network
// data.
package proxy
import (
"errors"
"net"
"net/url"
"os"
"strings"
)
// A Dialer is a means to establish a connection.
type Dialer interface {
// Dial connects to the given address via the proxy.
Dial(network, addr string) (c net.Conn, err error)
}
// Auth contains authentication parameters that specific Dialers may require.
type Auth struct {
User, Password string
}
// DefaultDialer returns the dialer specified by the proxy related variables in
// the environment.
func FromEnvironment() Dialer {
allProxy := os.Getenv("all_proxy")
if len(allProxy) == 0 {
return Direct
}
proxyURL, err := url.Parse(allProxy)
if err != nil {
return Direct
}
proxy, err := FromURL(proxyURL, Direct)
if err != nil {
return Direct
}
noProxy := os.Getenv("no_proxy")
if len(noProxy) == 0 {
return proxy
}
perHost := NewPerHost(proxy, Direct)
perHost.AddFromString(noProxy)
return perHost
}
// proxySchemes is a map from URL schemes to a function that creates a Dialer
// from a URL with such a scheme.
var proxySchemes map[string]func(*url.URL, Dialer) (Dialer, error)
// RegisterDialerType takes a URL scheme and a function to generate Dialers from
// a URL with that scheme and a forwarding Dialer. Registered schemes are used
// by FromURL.
func RegisterDialerType(scheme string, f func(*url.URL, Dialer) (Dialer, error)) {
if proxySchemes == nil {
proxySchemes = make(map[string]func(*url.URL, Dialer) (Dialer, error))
}
proxySchemes[scheme] = f
}
// FromURL returns a Dialer given a URL specification and an underlying
// Dialer for it to make network requests.
func FromURL(u *url.URL, forward Dialer) (Dialer, error) {
var auth *Auth
if len(u.RawUserinfo) > 0 {
auth = new(Auth)
parts := strings.SplitN(u.RawUserinfo, ":", 1)
if len(parts) == 1 {
auth.User = parts[0]
} else if len(parts) >= 2 {
auth.User = parts[0]
auth.Password = parts[1]
}
}
switch u.Scheme {
case "socks5":
return SOCKS5("tcp", u.Host, auth, forward)
}
// If the scheme doesn't match any of the built-in schemes, see if it
// was registered by another package.
if proxySchemes != nil {
if f, ok := proxySchemes[u.Scheme]; ok {
return f(u, forward)
}
}
return nil, errors.New("proxy: unknown scheme: " + u.Scheme)
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package proxy
import (
"net"
"net/url"
"testing"
)
type testDialer struct {
network, addr string
}
func (t *testDialer) Dial(network, addr string) (net.Conn, error) {
t.network = network
t.addr = addr
return nil, t
}
func (t *testDialer) Error() string {
return "testDialer " + t.network + " " + t.addr
}
func TestFromURL(t *testing.T) {
u, err := url.Parse("socks5://user:password@1.2.3.4:5678")
if err != nil {
t.Fatalf("failed to parse URL: %s", err)
}
tp := &testDialer{}
proxy, err := FromURL(u, tp)
if err != nil {
t.Fatalf("FromURL failed: %s", err)
}
conn, err := proxy.Dial("tcp", "example.com:80")
if conn != nil {
t.Error("Dial unexpected didn't return an error")
}
if tp, ok := err.(*testDialer); ok {
if tp.network != "tcp" || tp.addr != "1.2.3.4:5678" {
t.Errorf("Dialer connected to wrong host. Wanted 1.2.3.4:5678, got: %v", tp)
}
} else {
t.Errorf("Unexpected error from Dial: %s", err)
}
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package proxy
import (
"errors"
"io"
"net"
"strconv"
)
// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given address
// with an optional username and password. See RFC 1928.
func SOCKS5(network, addr string, auth *Auth, forward Dialer) (Dialer, error) {
s := &socks5{
network: network,
addr: addr,
forward: forward,
}
if auth != nil {
s.user = auth.User
s.password = auth.Password
}
return s, nil
}
type socks5 struct {
user, password string
network, addr string
forward Dialer
}
const socks5Version = 5
const (
socks5AuthNone = 0
socks5AuthPassword = 2
)
const socks5Connect = 1
const (
socks5IP4 = 1
socks5Domain = 3
socks5IP6 = 4
)
var socks5Errors = []string{
"",
"general failure",
"connection forbidden",
"network unreachable",
"host unreachable",
"connection refused",
"TTL expired",
"command not supported",
"address type not supported",
}
// Dial connects to the address addr on the network net via the SOCKS5 proxy.
func (s *socks5) Dial(network, addr string) (net.Conn, error) {
switch network {
case "tcp", "tcp6", "tcp4":
break
default:
return nil, errors.New("proxy: no support for SOCKS5 proxy connections of type " + network)
}
conn, err := s.forward.Dial(s.network, s.addr)
if err != nil {
return nil, err
}
closeConn := &conn
defer func() {
if closeConn != nil {
(*closeConn).Close()
}
}()
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
port, err := strconv.Atoi(portStr)
if err != nil {
return nil, errors.New("proxy: failed to parse port number: " + portStr)
}
if port < 1 || port > 0xffff {
return nil, errors.New("proxy: port number out of range: " + portStr)
}
// the size here is just an estimate
buf := make([]byte, 0, 6+len(host))
buf = append(buf, socks5Version)
if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 {
buf = append(buf, 2, /* num auth methods */ socks5AuthNone, socks5AuthPassword)
} else {
buf = append(buf, 1, /* num auth methods */ socks5AuthNone)
}
if _, err = conn.Write(buf); err != nil {
return nil, errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err = io.ReadFull(conn, buf[:2]); err != nil {
return nil, errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if buf[0] != 5 {
return nil, errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0])))
}
if buf[1] == 0xff {
return nil, errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication")
}
if buf[1] == socks5AuthPassword {
buf = buf[:0]
buf = append(buf, socks5Version)
buf = append(buf, uint8(len(s.user)))
buf = append(buf, s.user...)
buf = append(buf, uint8(len(s.password)))
buf = append(buf, s.password...)
if _, err = conn.Write(buf); err != nil {
return nil, errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err = io.ReadFull(conn, buf[:2]); err != nil {
return nil, errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if buf[1] != 0 {
return nil, errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password")
}
}
buf = buf[:0]
buf = append(buf, socks5Version, socks5Connect, 0 /* reserved */ )
if ip := net.ParseIP(host); ip != nil {
if len(ip) == 4 {
buf = append(buf, socks5IP4)
} else {
buf = append(buf, socks5IP6)
}
buf = append(buf, []byte(ip)...)
} else {
buf = append(buf, socks5Domain)
buf = append(buf, byte(len(host)))
buf = append(buf, host...)
}
buf = append(buf, byte(port>>8), byte(port))
if _, err = conn.Write(buf); err != nil {
return nil, errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err = io.ReadFull(conn, buf[:4]); err != nil {
return nil, errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
failure := "unknown error"
if int(buf[1]) < len(socks5Errors) {
failure = socks5Errors[buf[1]]
}
if len(failure) > 0 {
return nil, errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure)
}
bytesToDiscard := 0
switch buf[3] {
case socks5IP4:
bytesToDiscard = 4
case socks5IP6:
bytesToDiscard = 16
case socks5Domain:
_, err := io.ReadFull(conn, buf[:1])
if err != nil {
return nil, errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
bytesToDiscard = int(buf[0])
default:
return nil, errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr)
}
if cap(buf) < bytesToDiscard {
buf = make([]byte, bytesToDiscard)
} else {
buf = buf[:bytesToDiscard]
}
if _, err = io.ReadFull(conn, buf); err != nil {
return nil, errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
// Also need to discard the port number
if _, err = io.ReadFull(conn, buf[:2]); err != nil {
return nil, errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
closeConn = nil
return conn, nil
}
...@@ -8,8 +8,11 @@ import ( ...@@ -8,8 +8,11 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"testing" "testing"
"time"
) )
var someTime = time.Unix(123, 0)
type conversionTest struct { type conversionTest struct {
s, d interface{} // source and destination s, d interface{} // source and destination
...@@ -19,6 +22,7 @@ type conversionTest struct { ...@@ -19,6 +22,7 @@ type conversionTest struct {
wantstr string wantstr string
wantf32 float32 wantf32 float32
wantf64 float64 wantf64 float64
wanttime time.Time
wantbool bool // used if d is of type *bool wantbool bool // used if d is of type *bool
wanterr string wanterr string
} }
...@@ -35,12 +39,14 @@ var ( ...@@ -35,12 +39,14 @@ var (
scanbool bool scanbool bool
scanf32 float32 scanf32 float32
scanf64 float64 scanf64 float64
scantime time.Time
) )
var conversionTests = []conversionTest{ var conversionTests = []conversionTest{
// Exact conversions (destination pointer type matches source type) // Exact conversions (destination pointer type matches source type)
{s: "foo", d: &scanstr, wantstr: "foo"}, {s: "foo", d: &scanstr, wantstr: "foo"},
{s: 123, d: &scanint, wantint: 123}, {s: 123, d: &scanint, wantint: 123},
{s: someTime, d: &scantime, wanttime: someTime},
// To strings // To strings
{s: []byte("byteslice"), d: &scanstr, wantstr: "byteslice"}, {s: []byte("byteslice"), d: &scanstr, wantstr: "byteslice"},
...@@ -106,6 +112,10 @@ func float32Value(ptr interface{}) float32 { ...@@ -106,6 +112,10 @@ func float32Value(ptr interface{}) float32 {
return *(ptr.(*float32)) return *(ptr.(*float32))
} }
func timeValue(ptr interface{}) time.Time {
return *(ptr.(*time.Time))
}
func TestConversions(t *testing.T) { func TestConversions(t *testing.T) {
for n, ct := range conversionTests { for n, ct := range conversionTests {
err := convertAssign(ct.d, ct.s) err := convertAssign(ct.d, ct.s)
...@@ -138,6 +148,9 @@ func TestConversions(t *testing.T) { ...@@ -138,6 +148,9 @@ func TestConversions(t *testing.T) {
if bp, boolTest := ct.d.(*bool); boolTest && *bp != ct.wantbool && ct.wanterr == "" { if bp, boolTest := ct.d.(*bool); boolTest && *bp != ct.wantbool && ct.wanterr == "" {
errf("want bool %v, got %v", ct.wantbool, *bp) errf("want bool %v, got %v", ct.wantbool, *bp)
} }
if !ct.wanttime.IsZero() && !ct.wanttime.Equal(timeValue(ct.d)) {
errf("want time %v, got %v", ct.wanttime, timeValue(ct.d))
}
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment