Commit f038dae6 by Ian Lance Taylor

libgo: Update to October 24 version of master library.

From-SVN: r204466
parent f20f2613
a7bd9a33067b 7ebbddd21330
The first line of this file holds the Mercurial revision number of the The first line of this file holds the Mercurial revision number of the
last merge done from the master library sources. last merge done from the master library sources.
...@@ -37,7 +37,8 @@ AM_CPPFLAGS = -I $(srcdir)/runtime $(LIBFFIINCS) $(PTHREAD_CFLAGS) ...@@ -37,7 +37,8 @@ AM_CPPFLAGS = -I $(srcdir)/runtime $(LIBFFIINCS) $(PTHREAD_CFLAGS)
ACLOCAL_AMFLAGS = -I ./config -I ../config ACLOCAL_AMFLAGS = -I ./config -I ../config
AM_CFLAGS = -fexceptions -fplan9-extensions $(SPLIT_STACK) $(WARN_CFLAGS) \ AM_CFLAGS = -fexceptions -fnon-call-exceptions -fplan9-extensions \
$(SPLIT_STACK) $(WARN_CFLAGS) \
$(STRINGOPS_FLAG) $(OSCFLAGS) \ $(STRINGOPS_FLAG) $(OSCFLAGS) \
-I $(srcdir)/../libgcc -I $(srcdir)/../libbacktrace \ -I $(srcdir)/../libgcc -I $(srcdir)/../libbacktrace \
-I $(MULTIBUILDTOP)../../gcc/include -I $(MULTIBUILDTOP)../../gcc/include
...@@ -103,6 +104,7 @@ toolexeclibgo_DATA = \ ...@@ -103,6 +104,7 @@ toolexeclibgo_DATA = \
bufio.gox \ bufio.gox \
bytes.gox \ bytes.gox \
crypto.gox \ crypto.gox \
encoding.gox \
errors.gox \ errors.gox \
expvar.gox \ expvar.gox \
flag.gox \ flag.gox \
...@@ -251,6 +253,11 @@ toolexeclibgoimage_DATA = \ ...@@ -251,6 +253,11 @@ toolexeclibgoimage_DATA = \
image/jpeg.gox \ image/jpeg.gox \
image/png.gox image/png.gox
toolexeclibgoimagecolordir = $(toolexeclibgoimagedir)/color
toolexeclibgoimagecolor_DATA = \
image/color/palette.gox
toolexeclibgoindexdir = $(toolexeclibgodir)/index toolexeclibgoindexdir = $(toolexeclibgodir)/index
toolexeclibgoindex_DATA = \ toolexeclibgoindex_DATA = \
...@@ -573,6 +580,9 @@ go_bytes_c_files = \ ...@@ -573,6 +580,9 @@ go_bytes_c_files = \
go_crypto_files = \ go_crypto_files = \
go/crypto/crypto.go go/crypto/crypto.go
go_encoding_files = \
go/encoding/encoding.go
go_errors_files = \ go_errors_files = \
go/errors/errors.go go/errors/errors.go
...@@ -669,7 +679,7 @@ go_net_fd_os_file = ...@@ -669,7 +679,7 @@ go_net_fd_os_file =
go_net_newpollserver_file = go_net_newpollserver_file =
else # !LIBGO_IS_LINUX && !LIBGO_IS_RTEMS else # !LIBGO_IS_LINUX && !LIBGO_IS_RTEMS
if LIBGO_IS_NETBSD if LIBGO_IS_NETBSD
go_net_fd_os_file = go/net/fd_bsd.go go_net_fd_os_file =
go_net_newpollserver_file = go_net_newpollserver_file =
else # !LIBGO_IS_NETBSD && !LIBGO_IS_LINUX && !LIBGO_IS_RTEMS else # !LIBGO_IS_NETBSD && !LIBGO_IS_LINUX && !LIBGO_IS_RTEMS
# By default use select with pipes. Most systems should have # By default use select with pipes. Most systems should have
...@@ -726,9 +736,13 @@ else ...@@ -726,9 +736,13 @@ else
if LIBGO_IS_FREEBSD if LIBGO_IS_FREEBSD
go_net_sendfile_file = go/net/sendfile_freebsd.go go_net_sendfile_file = go/net/sendfile_freebsd.go
else else
if LIBGO_IS_DRAGONFLY
go_net_sendfile_file = go/net/sendfile_dragonfly.go
else
go_net_sendfile_file = go/net/sendfile_stub.go go_net_sendfile_file = go/net/sendfile_stub.go
endif endif
endif endif
endif
if LIBGO_IS_LINUX if LIBGO_IS_LINUX
go_net_interface_file = go/net/interface_linux.go go_net_interface_file = go/net/interface_linux.go
...@@ -736,9 +750,13 @@ else ...@@ -736,9 +750,13 @@ else
if LIBGO_IS_NETBSD if LIBGO_IS_NETBSD
go_net_interface_file = go/net/interface_netbsd.go go_net_interface_file = go/net/interface_netbsd.go
else else
if LIBGO_IS_DRAGONFLY
go_net_interface_file = go/net/interface_dragonfly.go
else
go_net_interface_file = go/net/interface_stub.go go_net_interface_file = go/net/interface_stub.go
endif endif
endif endif
endif
if LIBGO_IS_LINUX if LIBGO_IS_LINUX
go_net_cloexec_file = go/net/sock_cloexec.go go_net_cloexec_file = go/net/sock_cloexec.go
...@@ -746,13 +764,13 @@ else ...@@ -746,13 +764,13 @@ else
go_net_cloexec_file = go/net/sys_cloexec.go go_net_cloexec_file = go/net/sys_cloexec.go
endif endif
if LIBGO_IS_LINUX if LIBGO_IS_OPENBSD
go_net_poll_file = go/net/fd_poll_runtime.go go_net_tcpsockopt_file = go/net/tcpsockopt_openbsd.go
else else
if LIBGO_IS_DARWIN if LIBGO_IS_DARWIN
go_net_poll_file = go/net/fd_poll_runtime.go go_net_tcpsockopt_file = go/net/tcpsockopt_darwin.go
else else
go_net_poll_file = go/net/fd_poll_unix.go go_net_tcpsockopt_file = go/net/tcpsockopt_unix.go
endif endif
endif endif
...@@ -766,6 +784,7 @@ go_net_files = \ ...@@ -766,6 +784,7 @@ go_net_files = \
go/net/dnsconfig_unix.go \ go/net/dnsconfig_unix.go \
go/net/dnsmsg.go \ go/net/dnsmsg.go \
$(go_net_newpollserver_file) \ $(go_net_newpollserver_file) \
go/net/fd_mutex.go \
go/net/fd_unix.go \ go/net/fd_unix.go \
$(go_net_fd_os_file) \ $(go_net_fd_os_file) \
go/net/file_unix.go \ go/net/file_unix.go \
...@@ -783,18 +802,21 @@ go_net_files = \ ...@@ -783,18 +802,21 @@ go_net_files = \
go/net/net.go \ go/net/net.go \
go/net/parse.go \ go/net/parse.go \
go/net/pipe.go \ go/net/pipe.go \
$(go_net_poll_file) \ go/net/fd_poll_runtime.go \
go/net/port.go \ go/net/port.go \
go/net/port_unix.go \ go/net/port_unix.go \
go/net/race0.go \
$(go_net_sendfile_file) \ $(go_net_sendfile_file) \
go/net/singleflight.go \
go/net/sock_posix.go \ go/net/sock_posix.go \
go/net/sock_unix.go \
$(go_net_sock_file) \ $(go_net_sock_file) \
go/net/sockopt_posix.go \ go/net/sockopt_posix.go \
$(go_net_sockopt_file) \ $(go_net_sockopt_file) \
$(go_net_sockoptip_file) \ $(go_net_sockoptip_file) \
go/net/tcpsock.go \ go/net/tcpsock.go \
go/net/tcpsock_posix.go \ go/net/tcpsock_posix.go \
go/net/tcpsockopt_posix.go \
$(go_net_tcpsockopt_file) \
go/net/udpsock.go \ go/net/udpsock.go \
go/net/udpsock_posix.go \ go/net/udpsock_posix.go \
go/net/unixsock.go \ go/net/unixsock.go \
...@@ -818,6 +840,12 @@ go_os_dir_file = go/os/dir_regfile.go ...@@ -818,6 +840,12 @@ go_os_dir_file = go/os/dir_regfile.go
endif endif
endif endif
if LIBGO_IS_DARWIN
go_os_getwd_file = go/os/getwd_darwin.go
else
go_os_getwd_file =
endif
if LIBGO_IS_LINUX if LIBGO_IS_LINUX
go_os_sys_file = go/os/sys_linux.go go_os_sys_file = go/os/sys_linux.go
else else
...@@ -854,6 +882,9 @@ else ...@@ -854,6 +882,9 @@ else
if LIBGO_IS_NETBSD if LIBGO_IS_NETBSD
go_os_stat_file = go/os/stat_atimespec.go go_os_stat_file = go/os/stat_atimespec.go
else else
if LIBGO_IS_DRAGONFLY
go_os_stat_file = go/os/stat_dragonfly.go
else
go_os_stat_file = go/os/stat.go go_os_stat_file = go/os/stat.go
endif endif
endif endif
...@@ -861,6 +892,7 @@ endif ...@@ -861,6 +892,7 @@ endif
endif endif
endif endif
endif endif
endif
if LIBGO_IS_LINUX if LIBGO_IS_LINUX
go_os_pipe_file = go/os/pipe_linux.go go_os_pipe_file = go/os/pipe_linux.go
...@@ -874,7 +906,7 @@ go_os_files = \ ...@@ -874,7 +906,7 @@ go_os_files = \
go/os/doc.go \ go/os/doc.go \
go/os/env.go \ go/os/env.go \
go/os/error.go \ go/os/error.go \
go/os/error_posix.go \ go/os/error_unix.go \
go/os/exec.go \ go/os/exec.go \
go/os/exec_posix.go \ go/os/exec_posix.go \
go/os/exec_unix.go \ go/os/exec_unix.go \
...@@ -882,6 +914,7 @@ go_os_files = \ ...@@ -882,6 +914,7 @@ go_os_files = \
go/os/file_posix.go \ go/os/file_posix.go \
go/os/file_unix.go \ go/os/file_unix.go \
go/os/getwd.go \ go/os/getwd.go \
$(go_os_getwd_file) \
go/os/path.go \ go/os/path.go \
go/os/path_unix.go \ go/os/path_unix.go \
$(go_os_pipe_file) \ $(go_os_pipe_file) \
...@@ -970,7 +1003,10 @@ go_strings_files = \ ...@@ -970,7 +1003,10 @@ go_strings_files = \
go/strings/reader.go \ go/strings/reader.go \
go/strings/replace.go \ go/strings/replace.go \
go/strings/search.go \ go/strings/search.go \
go/strings/strings.go go/strings/strings.go \
go/strings/strings_decl.go
go_strings_c_files = \
go/strings/indexbyte.c
go_sync_files = \ go_sync_files = \
go/sync/cond.go \ go/sync/cond.go \
...@@ -1000,6 +1036,7 @@ go_syslog_c_files = \ ...@@ -1000,6 +1036,7 @@ go_syslog_c_files = \
go_testing_files = \ go_testing_files = \
go/testing/allocs.go \ go/testing/allocs.go \
go/testing/benchmark.go \ go/testing/benchmark.go \
go/testing/cover.go \
go/testing/example.go \ go/testing/example.go \
go/testing/testing.go go/testing/testing.go
...@@ -1048,6 +1085,7 @@ go_archive_tar_files = \ ...@@ -1048,6 +1085,7 @@ go_archive_tar_files = \
go_archive_zip_files = \ go_archive_zip_files = \
go/archive/zip/reader.go \ go/archive/zip/reader.go \
go/archive/zip/register.go \
go/archive/zip/struct.go \ go/archive/zip/struct.go \
go/archive/zip/writer.go go/archive/zip/writer.go
...@@ -1098,6 +1136,7 @@ go_crypto_cipher_files = \ ...@@ -1098,6 +1136,7 @@ go_crypto_cipher_files = \
go/crypto/cipher/cfb.go \ go/crypto/cipher/cfb.go \
go/crypto/cipher/cipher.go \ go/crypto/cipher/cipher.go \
go/crypto/cipher/ctr.go \ go/crypto/cipher/ctr.go \
go/crypto/cipher/gcm.go \
go/crypto/cipher/io.go \ go/crypto/cipher/io.go \
go/crypto/cipher/ofb.go go/crypto/cipher/ofb.go
go_crypto_des_files = \ go_crypto_des_files = \
...@@ -1110,7 +1149,8 @@ go_crypto_ecdsa_files = \ ...@@ -1110,7 +1149,8 @@ go_crypto_ecdsa_files = \
go/crypto/ecdsa/ecdsa.go go/crypto/ecdsa/ecdsa.go
go_crypto_elliptic_files = \ go_crypto_elliptic_files = \
go/crypto/elliptic/elliptic.go \ go/crypto/elliptic/elliptic.go \
go/crypto/elliptic/p224.go go/crypto/elliptic/p224.go \
go/crypto/elliptic/p256.go
go_crypto_hmac_files = \ go_crypto_hmac_files = \
go/crypto/hmac/hmac.go go/crypto/hmac/hmac.go
go_crypto_md5_files = \ go_crypto_md5_files = \
...@@ -1125,6 +1165,7 @@ go_crypto_rc4_files = \ ...@@ -1125,6 +1165,7 @@ go_crypto_rc4_files = \
go/crypto/rc4/rc4_ref.go go/crypto/rc4/rc4_ref.go
go_crypto_rsa_files = \ go_crypto_rsa_files = \
go/crypto/rsa/pkcs1v15.go \ go/crypto/rsa/pkcs1v15.go \
go/crypto/rsa/pss.go \
go/crypto/rsa/rsa.go go/crypto/rsa/rsa.go
go_crypto_sha1_files = \ go_crypto_sha1_files = \
go/crypto/sha1/sha1.go \ go/crypto/sha1/sha1.go \
...@@ -1308,11 +1349,15 @@ go_image_color_files = \ ...@@ -1308,11 +1349,15 @@ go_image_color_files = \
go/image/color/color.go \ go/image/color/color.go \
go/image/color/ycbcr.go go/image/color/ycbcr.go
go_image_color_palette_files = \
go/image/color/palette/palette.go
go_image_draw_files = \ go_image_draw_files = \
go/image/draw/draw.go go/image/draw/draw.go
go_image_gif_files = \ go_image_gif_files = \
go/image/gif/reader.go go/image/gif/reader.go \
go/image/gif/writer.go
go_image_jpeg_files = \ go_image_jpeg_files = \
go/image/jpeg/fdct.go \ go/image/jpeg/fdct.go \
...@@ -1766,6 +1811,7 @@ libgo_go_objs = \ ...@@ -1766,6 +1811,7 @@ libgo_go_objs = \
bytes.lo \ bytes.lo \
bytes/index.lo \ bytes/index.lo \
crypto.lo \ crypto.lo \
encoding.lo \
errors.lo \ errors.lo \
expvar.lo \ expvar.lo \
flag.lo \ flag.lo \
...@@ -1787,6 +1833,7 @@ libgo_go_objs = \ ...@@ -1787,6 +1833,7 @@ libgo_go_objs = \
sort.lo \ sort.lo \
strconv.lo \ strconv.lo \
strings.lo \ strings.lo \
strings/index.lo \
sync.lo \ sync.lo \
syscall.lo \ syscall.lo \
syscall/errno.lo \ syscall/errno.lo \
...@@ -1863,6 +1910,7 @@ libgo_go_objs = \ ...@@ -1863,6 +1910,7 @@ libgo_go_objs = \
net/http/httputil.lo \ net/http/httputil.lo \
net/http/pprof.lo \ net/http/pprof.lo \
image/color.lo \ image/color.lo \
image/color/palette.lo \
image/draw.lo \ image/draw.lo \
image/gif.lo \ image/gif.lo \
image/jpeg.lo \ image/jpeg.lo \
...@@ -2033,6 +2081,15 @@ crypto/check: $(CHECK_DEPS) ...@@ -2033,6 +2081,15 @@ crypto/check: $(CHECK_DEPS)
@$(CHECK) @$(CHECK)
.PHONY: crypto/check .PHONY: crypto/check
@go_include@ encoding.lo.dep
encoding.lo.dep: $(go_encoding_files)
$(BUILDDEPS)
encoding.lo: $(go_encoding_files)
$(BUILDPACKAGE)
encoding/check: $(CHECK_DEPS)
@$(CHECK)
.PHONY: encoding/check
@go_include@ errors.lo.dep @go_include@ errors.lo.dep
errors.lo.dep: $(go_errors_files) errors.lo.dep: $(go_errors_files)
$(BUILDDEPS) $(BUILDDEPS)
...@@ -2214,6 +2271,9 @@ strings.lo.dep: $(go_strings_files) ...@@ -2214,6 +2271,9 @@ strings.lo.dep: $(go_strings_files)
$(BUILDDEPS) $(BUILDDEPS)
strings.lo: $(go_strings_files) strings.lo: $(go_strings_files)
$(BUILDPACKAGE) $(BUILDPACKAGE)
strings/index.lo: $(go_strings_c_files)
@$(MKDIR_P) strings
$(LTCOMPILE) -c -o strings/index.lo $(srcdir)/go/strings/indexbyte.c
strings/check: $(CHECK_DEPS) strings/check: $(CHECK_DEPS)
@$(CHECK) @$(CHECK)
.PHONY: strings/check .PHONY: strings/check
...@@ -2821,6 +2881,15 @@ image/color/check: $(CHECK_DEPS) ...@@ -2821,6 +2881,15 @@ image/color/check: $(CHECK_DEPS)
@$(CHECK) @$(CHECK)
.PHONY: image/color/check .PHONY: image/color/check
@go_include@ image/color/palette.lo.dep
image/color/palette.lo.dep: $(go_image_color_palette_files)
$(BUILDDEPS)
image/color/palette.lo: $(go_image_color_palette_files)
$(BUILDPACKAGE)
image/color/palette/check: $(CHECK_DEPS)
@$(CHECK)
.PHONY: image/color/palette/check
@go_include@ image/draw.lo.dep @go_include@ image/draw.lo.dep
image/draw.lo.dep: $(go_image_draw_files) image/draw.lo.dep: $(go_image_draw_files)
$(BUILDDEPS) $(BUILDDEPS)
...@@ -3236,6 +3305,8 @@ bytes.gox: bytes.lo ...@@ -3236,6 +3305,8 @@ bytes.gox: bytes.lo
$(BUILDGOX) $(BUILDGOX)
crypto.gox: crypto.lo crypto.gox: crypto.lo
$(BUILDGOX) $(BUILDGOX)
encoding.gox: encoding.lo
$(BUILDGOX)
errors.gox: errors.lo errors.gox: errors.lo
$(BUILDGOX) $(BUILDGOX)
expvar.gox: expvar.lo expvar.gox: expvar.lo
...@@ -3433,6 +3504,9 @@ image/jpeg.gox: image/jpeg.lo ...@@ -3433,6 +3504,9 @@ image/jpeg.gox: image/jpeg.lo
image/png.gox: image/png.lo image/png.gox: image/png.lo
$(BUILDGOX) $(BUILDGOX)
image/color/palette.gox: image/color/palette.lo
$(BUILDGOX)
index/suffixarray.gox: index/suffixarray.lo index/suffixarray.gox: index/suffixarray.lo
$(BUILDGOX) $(BUILDGOX)
......
...@@ -147,6 +147,9 @@ ...@@ -147,6 +147,9 @@
/* Define to 1 if you have the `mknodat' function. */ /* Define to 1 if you have the `mknodat' function. */
#undef HAVE_MKNODAT #undef HAVE_MKNODAT
/* Define to 1 if you have the <netinet/icmp6.h> header file. */
#undef HAVE_NETINET_ICMP6_H
/* Define to 1 if you have the <netinet/if_ether.h> header file. */ /* Define to 1 if you have the <netinet/if_ether.h> header file. */
#undef HAVE_NETINET_IF_ETHER_H #undef HAVE_NETINET_IF_ETHER_H
......
...@@ -659,6 +659,8 @@ LIBGO_IS_SOLARIS_FALSE ...@@ -659,6 +659,8 @@ LIBGO_IS_SOLARIS_FALSE
LIBGO_IS_SOLARIS_TRUE LIBGO_IS_SOLARIS_TRUE
LIBGO_IS_RTEMS_FALSE LIBGO_IS_RTEMS_FALSE
LIBGO_IS_RTEMS_TRUE LIBGO_IS_RTEMS_TRUE
LIBGO_IS_DRAGONFLY_FALSE
LIBGO_IS_DRAGONFLY_TRUE
LIBGO_IS_OPENBSD_FALSE LIBGO_IS_OPENBSD_FALSE
LIBGO_IS_OPENBSD_TRUE LIBGO_IS_OPENBSD_TRUE
LIBGO_IS_NETBSD_FALSE LIBGO_IS_NETBSD_FALSE
...@@ -11111,7 +11113,7 @@ else ...@@ -11111,7 +11113,7 @@ else
lt_dlunknown=0; lt_dlno_uscore=1; lt_dlneed_uscore=2 lt_dlunknown=0; lt_dlno_uscore=1; lt_dlneed_uscore=2
lt_status=$lt_dlunknown lt_status=$lt_dlunknown
cat > conftest.$ac_ext <<_LT_EOF cat > conftest.$ac_ext <<_LT_EOF
#line 11114 "configure" #line 11116 "configure"
#include "confdefs.h" #include "confdefs.h"
#if HAVE_DLFCN_H #if HAVE_DLFCN_H
...@@ -11217,7 +11219,7 @@ else ...@@ -11217,7 +11219,7 @@ else
lt_dlunknown=0; lt_dlno_uscore=1; lt_dlneed_uscore=2 lt_dlunknown=0; lt_dlno_uscore=1; lt_dlneed_uscore=2
lt_status=$lt_dlunknown lt_status=$lt_dlunknown
cat > conftest.$ac_ext <<_LT_EOF cat > conftest.$ac_ext <<_LT_EOF
#line 11220 "configure" #line 11222 "configure"
#include "confdefs.h" #include "confdefs.h"
#if HAVE_DLFCN_H #if HAVE_DLFCN_H
...@@ -13490,6 +13492,7 @@ is_irix=no ...@@ -13490,6 +13492,7 @@ is_irix=no
is_linux=no is_linux=no
is_netbsd=no is_netbsd=no
is_openbsd=no is_openbsd=no
is_dragonfly=no
is_rtems=no is_rtems=no
is_solaris=no is_solaris=no
GOOS=unknown GOOS=unknown
...@@ -13500,6 +13503,7 @@ case ${host} in ...@@ -13500,6 +13503,7 @@ case ${host} in
*-*-linux*) is_linux=yes; GOOS=linux ;; *-*-linux*) is_linux=yes; GOOS=linux ;;
*-*-netbsd*) is_netbsd=yes; GOOS=netbsd ;; *-*-netbsd*) is_netbsd=yes; GOOS=netbsd ;;
*-*-openbsd*) is_openbsd=yes; GOOS=openbsd ;; *-*-openbsd*) is_openbsd=yes; GOOS=openbsd ;;
*-*-dragonfly*) is_dragonfly=yes; GOOS=dragonfly ;;
*-*-rtems*) is_rtems=yes; GOOS=rtems ;; *-*-rtems*) is_rtems=yes; GOOS=rtems ;;
*-*-solaris2*) is_solaris=yes; GOOS=solaris ;; *-*-solaris2*) is_solaris=yes; GOOS=solaris ;;
esac esac
...@@ -13551,6 +13555,14 @@ else ...@@ -13551,6 +13555,14 @@ else
LIBGO_IS_OPENBSD_FALSE= LIBGO_IS_OPENBSD_FALSE=
fi fi
if test $is_dragonly = yes; then
LIBGO_IS_DRAGONFLY_TRUE=
LIBGO_IS_DRAGONFLY_FALSE='#'
else
LIBGO_IS_DRAGONFLY_TRUE='#'
LIBGO_IS_DRAGONFLY_FALSE=
fi
if test $is_rtems = yes; then if test $is_rtems = yes; then
LIBGO_IS_RTEMS_TRUE= LIBGO_IS_RTEMS_TRUE=
LIBGO_IS_RTEMS_FALSE='#' LIBGO_IS_RTEMS_FALSE='#'
...@@ -14600,7 +14612,7 @@ no) ...@@ -14600,7 +14612,7 @@ no)
;; ;;
esac esac
for ac_header in sys/file.h sys/mman.h syscall.h sys/epoll.h sys/inotify.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h sys/socket.h net/if.h net/if_arp.h net/route.h netpacket/packet.h sys/prctl.h sys/mount.h sys/vfs.h sys/statfs.h sys/timex.h sys/sysinfo.h utime.h linux/ether.h linux/fs.h linux/reboot.h netinet/in_syst.h netinet/ip.h netinet/ip_mroute.h netinet/if_ether.h for ac_header in sys/file.h sys/mman.h syscall.h sys/epoll.h sys/inotify.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h sys/socket.h net/if.h net/if_arp.h net/route.h netpacket/packet.h sys/prctl.h sys/mount.h sys/vfs.h sys/statfs.h sys/timex.h sys/sysinfo.h utime.h linux/ether.h linux/fs.h linux/reboot.h netinet/icmp6.h netinet/in_syst.h netinet/ip.h netinet/ip_mroute.h netinet/if_ether.h
do : do :
as_ac_Header=`$as_echo "ac_cv_header_$ac_header" | $as_tr_sh` as_ac_Header=`$as_echo "ac_cv_header_$ac_header" | $as_tr_sh`
ac_fn_c_check_header_mongrel "$LINENO" "$ac_header" "$as_ac_Header" "$ac_includes_default" ac_fn_c_check_header_mongrel "$LINENO" "$ac_header" "$as_ac_Header" "$ac_includes_default"
...@@ -15502,6 +15514,10 @@ if test -z "${LIBGO_IS_OPENBSD_TRUE}" && test -z "${LIBGO_IS_OPENBSD_FALSE}"; th ...@@ -15502,6 +15514,10 @@ if test -z "${LIBGO_IS_OPENBSD_TRUE}" && test -z "${LIBGO_IS_OPENBSD_FALSE}"; th
as_fn_error "conditional \"LIBGO_IS_OPENBSD\" was never defined. as_fn_error "conditional \"LIBGO_IS_OPENBSD\" was never defined.
Usually this means the macro was only invoked conditionally." "$LINENO" 5 Usually this means the macro was only invoked conditionally." "$LINENO" 5
fi fi
if test -z "${LIBGO_IS_DRAGONFLY_TRUE}" && test -z "${LIBGO_IS_DRAGONFLY_FALSE}"; then
as_fn_error "conditional \"LIBGO_IS_DRAGONFLY\" was never defined.
Usually this means the macro was only invoked conditionally." "$LINENO" 5
fi
if test -z "${LIBGO_IS_RTEMS_TRUE}" && test -z "${LIBGO_IS_RTEMS_FALSE}"; then if test -z "${LIBGO_IS_RTEMS_TRUE}" && test -z "${LIBGO_IS_RTEMS_FALSE}"; then
as_fn_error "conditional \"LIBGO_IS_RTEMS\" was never defined. as_fn_error "conditional \"LIBGO_IS_RTEMS\" was never defined.
Usually this means the macro was only invoked conditionally." "$LINENO" 5 Usually this means the macro was only invoked conditionally." "$LINENO" 5
......
...@@ -133,6 +133,7 @@ is_irix=no ...@@ -133,6 +133,7 @@ is_irix=no
is_linux=no is_linux=no
is_netbsd=no is_netbsd=no
is_openbsd=no is_openbsd=no
is_dragonfly=no
is_rtems=no is_rtems=no
is_solaris=no is_solaris=no
GOOS=unknown GOOS=unknown
...@@ -143,6 +144,7 @@ case ${host} in ...@@ -143,6 +144,7 @@ case ${host} in
*-*-linux*) is_linux=yes; GOOS=linux ;; *-*-linux*) is_linux=yes; GOOS=linux ;;
*-*-netbsd*) is_netbsd=yes; GOOS=netbsd ;; *-*-netbsd*) is_netbsd=yes; GOOS=netbsd ;;
*-*-openbsd*) is_openbsd=yes; GOOS=openbsd ;; *-*-openbsd*) is_openbsd=yes; GOOS=openbsd ;;
*-*-dragonfly*) is_dragonfly=yes; GOOS=dragonfly ;;
*-*-rtems*) is_rtems=yes; GOOS=rtems ;; *-*-rtems*) is_rtems=yes; GOOS=rtems ;;
*-*-solaris2*) is_solaris=yes; GOOS=solaris ;; *-*-solaris2*) is_solaris=yes; GOOS=solaris ;;
esac esac
...@@ -152,6 +154,7 @@ AM_CONDITIONAL(LIBGO_IS_IRIX, test $is_irix = yes) ...@@ -152,6 +154,7 @@ AM_CONDITIONAL(LIBGO_IS_IRIX, test $is_irix = yes)
AM_CONDITIONAL(LIBGO_IS_LINUX, test $is_linux = yes) AM_CONDITIONAL(LIBGO_IS_LINUX, test $is_linux = yes)
AM_CONDITIONAL(LIBGO_IS_NETBSD, test $is_netbsd = yes) AM_CONDITIONAL(LIBGO_IS_NETBSD, test $is_netbsd = yes)
AM_CONDITIONAL(LIBGO_IS_OPENBSD, test $is_openbsd = yes) AM_CONDITIONAL(LIBGO_IS_OPENBSD, test $is_openbsd = yes)
AM_CONDITIONAL(LIBGO_IS_DRAGONFLY, test $is_dragonly = yes)
AM_CONDITIONAL(LIBGO_IS_RTEMS, test $is_rtems = yes) AM_CONDITIONAL(LIBGO_IS_RTEMS, test $is_rtems = yes)
AM_CONDITIONAL(LIBGO_IS_SOLARIS, test $is_solaris = yes) AM_CONDITIONAL(LIBGO_IS_SOLARIS, test $is_solaris = yes)
AC_SUBST(GOOS) AC_SUBST(GOOS)
...@@ -471,7 +474,7 @@ no) ...@@ -471,7 +474,7 @@ no)
;; ;;
esac esac
AC_CHECK_HEADERS(sys/file.h sys/mman.h syscall.h sys/epoll.h sys/inotify.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h sys/socket.h net/if.h net/if_arp.h net/route.h netpacket/packet.h sys/prctl.h sys/mount.h sys/vfs.h sys/statfs.h sys/timex.h sys/sysinfo.h utime.h linux/ether.h linux/fs.h linux/reboot.h netinet/in_syst.h netinet/ip.h netinet/ip_mroute.h netinet/if_ether.h) AC_CHECK_HEADERS(sys/file.h sys/mman.h syscall.h sys/epoll.h sys/inotify.h sys/ptrace.h sys/syscall.h sys/user.h sys/utsname.h sys/select.h sys/socket.h net/if.h net/if_arp.h net/route.h netpacket/packet.h sys/prctl.h sys/mount.h sys/vfs.h sys/statfs.h sys/timex.h sys/sysinfo.h utime.h linux/ether.h linux/fs.h linux/reboot.h netinet/icmp6.h netinet/in_syst.h netinet/ip.h netinet/ip_mroute.h netinet/if_ether.h)
AC_CHECK_HEADERS([linux/filter.h linux/if_addr.h linux/if_ether.h linux/if_tun.h linux/netlink.h linux/rtnetlink.h], [], [], AC_CHECK_HEADERS([linux/filter.h linux/if_addr.h linux/if_ether.h linux/if_tun.h linux/netlink.h linux/rtnetlink.h], [], [],
[#ifdef HAVE_SYS_SOCKET_H [#ifdef HAVE_SYS_SOCKET_H
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
package tar package tar
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"os" "os"
...@@ -82,9 +83,9 @@ func (fi headerFileInfo) Sys() interface{} { return fi.h } ...@@ -82,9 +83,9 @@ func (fi headerFileInfo) Sys() interface{} { return fi.h }
// Name returns the base name of the file. // Name returns the base name of the file.
func (fi headerFileInfo) Name() string { func (fi headerFileInfo) Name() string {
if fi.IsDir() { if fi.IsDir() {
return path.Clean(fi.h.Name) return path.Base(path.Clean(fi.h.Name))
} }
return fi.h.Name return path.Base(fi.h.Name)
} }
// Mode returns the permission and mode bits for the headerFileInfo. // Mode returns the permission and mode bits for the headerFileInfo.
...@@ -174,9 +175,29 @@ const ( ...@@ -174,9 +175,29 @@ const (
c_ISSOCK = 0140000 // Socket c_ISSOCK = 0140000 // Socket
) )
// Keywords for the PAX Extended Header
const (
paxAtime = "atime"
paxCharset = "charset"
paxComment = "comment"
paxCtime = "ctime" // please note that ctime is not a valid pax header.
paxGid = "gid"
paxGname = "gname"
paxLinkpath = "linkpath"
paxMtime = "mtime"
paxPath = "path"
paxSize = "size"
paxUid = "uid"
paxUname = "uname"
paxNone = ""
)
// FileInfoHeader creates a partially-populated Header from fi. // FileInfoHeader creates a partially-populated Header from fi.
// If fi describes a symlink, FileInfoHeader records link as the link target. // If fi describes a symlink, FileInfoHeader records link as the link target.
// If fi describes a directory, a slash is appended to the name. // If fi describes a directory, a slash is appended to the name.
// Because os.FileInfo's Name method returns only the base name of
// the file it describes, it may be necessary to modify the Name field
// of the returned header to provide the full path name of the file.
func FileInfoHeader(fi os.FileInfo, link string) (*Header, error) { func FileInfoHeader(fi os.FileInfo, link string) (*Header, error) {
if fi == nil { if fi == nil {
return nil, errors.New("tar: FileInfo is nil") return nil, errors.New("tar: FileInfo is nil")
...@@ -257,3 +278,25 @@ func (sp *slicer) next(n int) (b []byte) { ...@@ -257,3 +278,25 @@ func (sp *slicer) next(n int) (b []byte) {
b, *sp = s[0:n], s[n:] b, *sp = s[0:n], s[n:]
return return
} }
func isASCII(s string) bool {
for _, c := range s {
if c >= 0x80 {
return false
}
}
return true
}
func toASCII(s string) string {
if isASCII(s) {
return s
}
var buf bytes.Buffer
for _, c := range s {
if c < 0x80 {
buf.WriteByte(byte(c))
}
}
return buf.String()
}
...@@ -95,45 +95,45 @@ func (tr *Reader) Next() (*Header, error) { ...@@ -95,45 +95,45 @@ func (tr *Reader) Next() (*Header, error) {
func mergePAX(hdr *Header, headers map[string]string) error { func mergePAX(hdr *Header, headers map[string]string) error {
for k, v := range headers { for k, v := range headers {
switch k { switch k {
case "path": case paxPath:
hdr.Name = v hdr.Name = v
case "linkpath": case paxLinkpath:
hdr.Linkname = v hdr.Linkname = v
case "gname": case paxGname:
hdr.Gname = v hdr.Gname = v
case "uname": case paxUname:
hdr.Uname = v hdr.Uname = v
case "uid": case paxUid:
uid, err := strconv.ParseInt(v, 10, 0) uid, err := strconv.ParseInt(v, 10, 0)
if err != nil { if err != nil {
return err return err
} }
hdr.Uid = int(uid) hdr.Uid = int(uid)
case "gid": case paxGid:
gid, err := strconv.ParseInt(v, 10, 0) gid, err := strconv.ParseInt(v, 10, 0)
if err != nil { if err != nil {
return err return err
} }
hdr.Gid = int(gid) hdr.Gid = int(gid)
case "atime": case paxAtime:
t, err := parsePAXTime(v) t, err := parsePAXTime(v)
if err != nil { if err != nil {
return err return err
} }
hdr.AccessTime = t hdr.AccessTime = t
case "mtime": case paxMtime:
t, err := parsePAXTime(v) t, err := parsePAXTime(v)
if err != nil { if err != nil {
return err return err
} }
hdr.ModTime = t hdr.ModTime = t
case "ctime": case paxCtime:
t, err := parsePAXTime(v) t, err := parsePAXTime(v)
if err != nil { if err != nil {
return err return err
} }
hdr.ChangeTime = t hdr.ChangeTime = t
case "size": case paxSize:
size, err := strconv.ParseInt(v, 10, 0) size, err := strconv.ParseInt(v, 10, 0)
if err != nil { if err != nil {
return err return err
...@@ -243,13 +243,15 @@ func (tr *Reader) octal(b []byte) int64 { ...@@ -243,13 +243,15 @@ func (tr *Reader) octal(b []byte) int64 {
return x return x
} }
// Removing leading spaces. // Because unused fields are filled with NULs, we need
for len(b) > 0 && b[0] == ' ' { // to skip leading NULs. Fields may also be padded with
b = b[1:] // spaces or NULs.
} // So we remove leading and trailing NULs and spaces to
// Removing trailing NULs and spaces. // be sure.
for len(b) > 0 && (b[len(b)-1] == ' ' || b[len(b)-1] == '\x00') { b = bytes.Trim(b, " \x00")
b = b[0 : len(b)-1]
if len(b) == 0 {
return 0
} }
x, err := strconv.ParseUint(cString(b), 8, 64) x, err := strconv.ParseUint(cString(b), 8, 64)
if err != nil { if err != nil {
......
...@@ -142,6 +142,25 @@ var untarTests = []*untarTest{ ...@@ -142,6 +142,25 @@ var untarTests = []*untarTest{
}, },
}, },
}, },
{
file: "testdata/nil-uid.tar", // golang.org/issue/5290
headers: []*Header{
{
Name: "P1050238.JPG.log",
Mode: 0664,
Uid: 0,
Gid: 0,
Size: 14,
ModTime: time.Unix(1365454838, 0),
Typeflag: TypeReg,
Linkname: "",
Uname: "eyefi",
Gname: "eyefi",
Devmajor: 0,
Devminor: 0,
},
},
},
} }
func TestReader(t *testing.T) { func TestReader(t *testing.T) {
...@@ -152,6 +171,7 @@ testLoop: ...@@ -152,6 +171,7 @@ testLoop:
t.Errorf("test %d: Unexpected error: %v", i, err) t.Errorf("test %d: Unexpected error: %v", i, err)
continue continue
} }
defer f.Close()
tr := NewReader(f) tr := NewReader(f)
for j, header := range test.headers { for j, header := range test.headers {
hdr, err := tr.Next() hdr, err := tr.Next()
...@@ -172,7 +192,6 @@ testLoop: ...@@ -172,7 +192,6 @@ testLoop:
if hdr != nil || err != nil { if hdr != nil || err != nil {
t.Errorf("test %d: Unexpected entry or error: hdr=%v err=%v", i, hdr, err) t.Errorf("test %d: Unexpected entry or error: hdr=%v err=%v", i, hdr, err)
} }
f.Close()
} }
} }
......
...@@ -8,7 +8,9 @@ import ( ...@@ -8,7 +8,9 @@ import (
"bytes" "bytes"
"io/ioutil" "io/ioutil"
"os" "os"
"path"
"reflect" "reflect"
"strings"
"testing" "testing"
"time" "time"
) )
...@@ -249,7 +251,14 @@ func TestHeaderRoundTrip(t *testing.T) { ...@@ -249,7 +251,14 @@ func TestHeaderRoundTrip(t *testing.T) {
t.Error(err) t.Error(err)
continue continue
} }
if got, want := h2.Name, g.h.Name; got != want { if strings.Contains(fi.Name(), "/") {
t.Errorf("FileInfo of %q contains slash: %q", g.h.Name, fi.Name())
}
name := path.Base(g.h.Name)
if fi.IsDir() {
name += "/"
}
if got, want := h2.Name, name; got != want {
t.Errorf("i=%d: Name: got %v, want %v", i, got, want) t.Errorf("i=%d: Name: got %v, want %v", i, got, want)
} }
if got, want := h2.Size, g.h.Size; got != want { if got, want := h2.Size, g.h.Size; got != want {
......
...@@ -24,6 +24,7 @@ var ( ...@@ -24,6 +24,7 @@ var (
ErrFieldTooLong = errors.New("archive/tar: header field too long") ErrFieldTooLong = errors.New("archive/tar: header field too long")
ErrWriteAfterClose = errors.New("archive/tar: write after close") ErrWriteAfterClose = errors.New("archive/tar: write after close")
errNameTooLong = errors.New("archive/tar: name too long") errNameTooLong = errors.New("archive/tar: name too long")
errInvalidHeader = errors.New("archive/tar: header field too long or contains invalid values")
) )
// A Writer provides sequential writing of a tar archive in POSIX.1 format. // A Writer provides sequential writing of a tar archive in POSIX.1 format.
...@@ -37,6 +38,7 @@ type Writer struct { ...@@ -37,6 +38,7 @@ type Writer struct {
pad int64 // amount of padding to write after current file entry pad int64 // amount of padding to write after current file entry
closed bool closed bool
usedBinary bool // whether the binary numeric field extension was used usedBinary bool // whether the binary numeric field extension was used
preferPax bool // use pax header instead of binary numeric header
} }
// NewWriter creates a new Writer writing to w. // NewWriter creates a new Writer writing to w.
...@@ -65,16 +67,23 @@ func (tw *Writer) Flush() error { ...@@ -65,16 +67,23 @@ func (tw *Writer) Flush() error {
} }
// Write s into b, terminating it with a NUL if there is room. // Write s into b, terminating it with a NUL if there is room.
func (tw *Writer) cString(b []byte, s string) { // If the value is too long for the field and allowPax is true add a paxheader record instead
func (tw *Writer) cString(b []byte, s string, allowPax bool, paxKeyword string, paxHeaders map[string]string) {
needsPaxHeader := allowPax && len(s) > len(b) || !isASCII(s)
if needsPaxHeader {
paxHeaders[paxKeyword] = s
return
}
if len(s) > len(b) { if len(s) > len(b) {
if tw.err == nil { if tw.err == nil {
tw.err = ErrFieldTooLong tw.err = ErrFieldTooLong
} }
return return
} }
copy(b, s) ascii := toASCII(s)
if len(s) < len(b) { copy(b, ascii)
b[len(s)] = 0 if len(ascii) < len(b) {
b[len(ascii)] = 0
} }
} }
...@@ -85,17 +94,27 @@ func (tw *Writer) octal(b []byte, x int64) { ...@@ -85,17 +94,27 @@ func (tw *Writer) octal(b []byte, x int64) {
for len(s)+1 < len(b) { for len(s)+1 < len(b) {
s = "0" + s s = "0" + s
} }
tw.cString(b, s) tw.cString(b, s, false, paxNone, nil)
} }
// Write x into b, either as octal or as binary (GNUtar/star extension). // Write x into b, either as octal or as binary (GNUtar/star extension).
func (tw *Writer) numeric(b []byte, x int64) { // If the value is too long for the field and writingPax is enabled both for the field and the add a paxheader record instead
func (tw *Writer) numeric(b []byte, x int64, allowPax bool, paxKeyword string, paxHeaders map[string]string) {
// Try octal first. // Try octal first.
s := strconv.FormatInt(x, 8) s := strconv.FormatInt(x, 8)
if len(s) < len(b) { if len(s) < len(b) {
tw.octal(b, x) tw.octal(b, x)
return return
} }
// If it is too long for octal, and pax is preferred, use a pax header
if allowPax && tw.preferPax {
tw.octal(b, 0)
s := strconv.FormatInt(x, 10)
paxHeaders[paxKeyword] = s
return
}
// Too big: use binary (big-endian). // Too big: use binary (big-endian).
tw.usedBinary = true tw.usedBinary = true
for i := len(b) - 1; x > 0 && i >= 0; i-- { for i := len(b) - 1; x > 0 && i >= 0; i-- {
...@@ -115,6 +134,15 @@ var ( ...@@ -115,6 +134,15 @@ var (
// WriteHeader calls Flush if it is not the first header. // WriteHeader calls Flush if it is not the first header.
// Calling after a Close will return ErrWriteAfterClose. // Calling after a Close will return ErrWriteAfterClose.
func (tw *Writer) WriteHeader(hdr *Header) error { func (tw *Writer) WriteHeader(hdr *Header) error {
return tw.writeHeader(hdr, true)
}
// WriteHeader writes hdr and prepares to accept the file's contents.
// WriteHeader calls Flush if it is not the first header.
// Calling after a Close will return ErrWriteAfterClose.
// As this method is called internally by writePax header to allow it to
// suppress writing the pax header.
func (tw *Writer) writeHeader(hdr *Header, allowPax bool) error {
if tw.closed { if tw.closed {
return ErrWriteAfterClose return ErrWriteAfterClose
} }
...@@ -124,31 +152,21 @@ func (tw *Writer) WriteHeader(hdr *Header) error { ...@@ -124,31 +152,21 @@ func (tw *Writer) WriteHeader(hdr *Header) error {
if tw.err != nil { if tw.err != nil {
return tw.err return tw.err
} }
// Decide whether or not to use PAX extensions
// a map to hold pax header records, if any are needed
paxHeaders := make(map[string]string)
// TODO(shanemhansen): we might want to use PAX headers for // TODO(shanemhansen): we might want to use PAX headers for
// subsecond time resolution, but for now let's just capture // subsecond time resolution, but for now let's just capture
// the long name/long symlink use case. // too long fields or non ascii characters
suffix := hdr.Name
prefix := ""
if len(hdr.Name) > fileNameSize || len(hdr.Linkname) > fileNameSize {
var err error
prefix, suffix, err = tw.splitUSTARLongName(hdr.Name)
// Either we were unable to pack the long name into ustar format
// or the link name is too long; use PAX headers.
if err == errNameTooLong || len(hdr.Linkname) > fileNameSize {
if err := tw.writePAXHeader(hdr); err != nil {
return err
}
} else if err != nil {
return err
}
}
tw.nb = int64(hdr.Size)
tw.pad = -tw.nb & (blockSize - 1) // blockSize is a power of two
header := make([]byte, blockSize) header := make([]byte, blockSize)
s := slicer(header) s := slicer(header)
tw.cString(s.next(fileNameSize), suffix)
// keep a reference to the filename to allow to overwrite it later if we detect that we can use ustar longnames instead of pax
pathHeaderBytes := s.next(fileNameSize)
tw.cString(pathHeaderBytes, hdr.Name, true, paxPath, paxHeaders)
// Handle out of range ModTime carefully. // Handle out of range ModTime carefully.
var modTime int64 var modTime int64
...@@ -157,27 +175,55 @@ func (tw *Writer) WriteHeader(hdr *Header) error { ...@@ -157,27 +175,55 @@ func (tw *Writer) WriteHeader(hdr *Header) error {
} }
tw.octal(s.next(8), hdr.Mode) // 100:108 tw.octal(s.next(8), hdr.Mode) // 100:108
tw.numeric(s.next(8), int64(hdr.Uid)) // 108:116 tw.numeric(s.next(8), int64(hdr.Uid), true, paxUid, paxHeaders) // 108:116
tw.numeric(s.next(8), int64(hdr.Gid)) // 116:124 tw.numeric(s.next(8), int64(hdr.Gid), true, paxGid, paxHeaders) // 116:124
tw.numeric(s.next(12), hdr.Size) // 124:136 tw.numeric(s.next(12), hdr.Size, true, paxSize, paxHeaders) // 124:136
tw.numeric(s.next(12), modTime) // 136:148 tw.numeric(s.next(12), modTime, false, paxNone, nil) // 136:148 --- consider using pax for finer granularity
s.next(8) // chksum (148:156) s.next(8) // chksum (148:156)
s.next(1)[0] = hdr.Typeflag // 156:157 s.next(1)[0] = hdr.Typeflag // 156:157
tw.cString(s.next(100), hdr.Linkname) // linkname (157:257)
tw.cString(s.next(100), hdr.Linkname, true, paxLinkpath, paxHeaders)
copy(s.next(8), []byte("ustar\x0000")) // 257:265 copy(s.next(8), []byte("ustar\x0000")) // 257:265
tw.cString(s.next(32), hdr.Uname) // 265:297 tw.cString(s.next(32), hdr.Uname, true, paxUname, paxHeaders) // 265:297
tw.cString(s.next(32), hdr.Gname) // 297:329 tw.cString(s.next(32), hdr.Gname, true, paxGname, paxHeaders) // 297:329
tw.numeric(s.next(8), hdr.Devmajor) // 329:337 tw.numeric(s.next(8), hdr.Devmajor, false, paxNone, nil) // 329:337
tw.numeric(s.next(8), hdr.Devminor) // 337:345 tw.numeric(s.next(8), hdr.Devminor, false, paxNone, nil) // 337:345
tw.cString(s.next(155), prefix) // 345:500
// keep a reference to the prefix to allow to overwrite it later if we detect that we can use ustar longnames instead of pax
prefixHeaderBytes := s.next(155)
tw.cString(prefixHeaderBytes, "", false, paxNone, nil) // 345:500 prefix
// Use the GNU magic instead of POSIX magic if we used any GNU extensions. // Use the GNU magic instead of POSIX magic if we used any GNU extensions.
if tw.usedBinary { if tw.usedBinary {
copy(header[257:265], []byte("ustar \x00")) copy(header[257:265], []byte("ustar \x00"))
} }
_, paxPathUsed := paxHeaders[paxPath]
// try to use a ustar header when only the name is too long
if !tw.preferPax && len(paxHeaders) == 1 && paxPathUsed {
suffix := hdr.Name
prefix := ""
if len(hdr.Name) > fileNameSize && isASCII(hdr.Name) {
var err error
prefix, suffix, err = tw.splitUSTARLongName(hdr.Name)
if err == nil {
// ok we can use a ustar long name instead of pax, now correct the fields
// remove the path field from the pax header. this will suppress the pax header
delete(paxHeaders, paxPath)
// update the path fields
tw.cString(pathHeaderBytes, suffix, false, paxNone, nil)
tw.cString(prefixHeaderBytes, prefix, false, paxNone, nil)
// Use the ustar magic if we used ustar long names. // Use the ustar magic if we used ustar long names.
if len(prefix) > 0 { if len(prefix) > 0 {
copy(header[257:265], []byte("ustar\000")) copy(header[257:265], []byte("ustar\000"))
} }
}
}
}
// The chksum field is terminated by a NUL and a space. // The chksum field is terminated by a NUL and a space.
// This is different from the other octal fields. // This is different from the other octal fields.
...@@ -190,8 +236,18 @@ func (tw *Writer) WriteHeader(hdr *Header) error { ...@@ -190,8 +236,18 @@ func (tw *Writer) WriteHeader(hdr *Header) error {
return tw.err return tw.err
} }
_, tw.err = tw.w.Write(header) if len(paxHeaders) > 0 {
if !allowPax {
return errInvalidHeader
}
if err := tw.writePAXHeader(hdr, paxHeaders); err != nil {
return err
}
}
tw.nb = int64(hdr.Size)
tw.pad = (blockSize - (tw.nb % blockSize)) % blockSize
_, tw.err = tw.w.Write(header)
return tw.err return tw.err
} }
...@@ -207,8 +263,11 @@ func (tw *Writer) splitUSTARLongName(name string) (prefix, suffix string, err er ...@@ -207,8 +263,11 @@ func (tw *Writer) splitUSTARLongName(name string) (prefix, suffix string, err er
length-- length--
} }
i := strings.LastIndex(name[:length], "/") i := strings.LastIndex(name[:length], "/")
nlen := length - i - 1 // nlen contains the resulting length in the name field.
if i <= 0 || nlen > fileNameSize || nlen == 0 { // plen contains the resulting length in the prefix field.
nlen := len(name) - i - 1
plen := i
if i <= 0 || nlen > fileNameSize || nlen == 0 || plen > fileNamePrefixSize {
err = errNameTooLong err = errNameTooLong
return return
} }
...@@ -218,7 +277,7 @@ func (tw *Writer) splitUSTARLongName(name string) (prefix, suffix string, err er ...@@ -218,7 +277,7 @@ func (tw *Writer) splitUSTARLongName(name string) (prefix, suffix string, err er
// writePaxHeader writes an extended pax header to the // writePaxHeader writes an extended pax header to the
// archive. // archive.
func (tw *Writer) writePAXHeader(hdr *Header) error { func (tw *Writer) writePAXHeader(hdr *Header, paxHeaders map[string]string) error {
// Prepare extended header // Prepare extended header
ext := new(Header) ext := new(Header)
ext.Typeflag = TypeXHeader ext.Typeflag = TypeXHeader
...@@ -229,18 +288,23 @@ func (tw *Writer) writePAXHeader(hdr *Header) error { ...@@ -229,18 +288,23 @@ func (tw *Writer) writePAXHeader(hdr *Header) error {
// with the current pid. // with the current pid.
pid := os.Getpid() pid := os.Getpid()
dir, file := path.Split(hdr.Name) dir, file := path.Split(hdr.Name)
ext.Name = path.Join(dir, fullName := path.Join(dir,
fmt.Sprintf("PaxHeaders.%d", pid), file)[0:100] fmt.Sprintf("PaxHeaders.%d", pid), file)
ascii := toASCII(fullName)
if len(ascii) > 100 {
ascii = ascii[:100]
}
ext.Name = ascii
// Construct the body // Construct the body
var buf bytes.Buffer var buf bytes.Buffer
if len(hdr.Name) > fileNameSize {
fmt.Fprint(&buf, paxHeader("path="+hdr.Name)) for k, v := range paxHeaders {
} fmt.Fprint(&buf, paxHeader(k+"="+v))
if len(hdr.Linkname) > fileNameSize {
fmt.Fprint(&buf, paxHeader("linkpath="+hdr.Linkname))
} }
ext.Size = int64(len(buf.Bytes())) ext.Size = int64(len(buf.Bytes()))
if err := tw.WriteHeader(ext); err != nil { if err := tw.writeHeader(ext, false); err != nil {
return err return err
} }
if _, err := tw.Write(buf.Bytes()); err != nil { if _, err := tw.Write(buf.Bytes()); err != nil {
......
...@@ -243,15 +243,110 @@ func TestPax(t *testing.T) { ...@@ -243,15 +243,110 @@ func TestPax(t *testing.T) {
} }
} }
func TestPaxSymlink(t *testing.T) {
// Create an archive with a large linkname
fileinfo, err := os.Stat("testdata/small.txt")
if err != nil {
t.Fatal(err)
}
hdr, err := FileInfoHeader(fileinfo, "")
hdr.Typeflag = TypeSymlink
if err != nil {
t.Fatalf("os.Stat:1 %v", err)
}
// Force a PAX long linkname to be written
longLinkname := strings.Repeat("1234567890/1234567890", 10)
hdr.Linkname = longLinkname
hdr.Size = 0
var buf bytes.Buffer
writer := NewWriter(&buf)
if err := writer.WriteHeader(hdr); err != nil {
t.Fatal(err)
}
if err := writer.Close(); err != nil {
t.Fatal(err)
}
// Simple test to make sure PAX extensions are in effect
if !bytes.Contains(buf.Bytes(), []byte("PaxHeaders.")) {
t.Fatal("Expected at least one PAX header to be written.")
}
// Test that we can get a long name back out of the archive.
reader := NewReader(&buf)
hdr, err = reader.Next()
if err != nil {
t.Fatal(err)
}
if hdr.Linkname != longLinkname {
t.Fatal("Couldn't recover long link name")
}
}
func TestPaxNonAscii(t *testing.T) {
// Create an archive with non ascii. These should trigger a pax header
// because pax headers have a defined utf-8 encoding.
fileinfo, err := os.Stat("testdata/small.txt")
if err != nil {
t.Fatal(err)
}
hdr, err := FileInfoHeader(fileinfo, "")
if err != nil {
t.Fatalf("os.Stat:1 %v", err)
}
// some sample data
chineseFilename := "文件名"
chineseGroupname := "組"
chineseUsername := "用戶名"
hdr.Name = chineseFilename
hdr.Gname = chineseGroupname
hdr.Uname = chineseUsername
contents := strings.Repeat(" ", int(hdr.Size))
var buf bytes.Buffer
writer := NewWriter(&buf)
if err := writer.WriteHeader(hdr); err != nil {
t.Fatal(err)
}
if _, err = writer.Write([]byte(contents)); err != nil {
t.Fatal(err)
}
if err := writer.Close(); err != nil {
t.Fatal(err)
}
// Simple test to make sure PAX extensions are in effect
if !bytes.Contains(buf.Bytes(), []byte("PaxHeaders.")) {
t.Fatal("Expected at least one PAX header to be written.")
}
// Test that we can get a long name back out of the archive.
reader := NewReader(&buf)
hdr, err = reader.Next()
if err != nil {
t.Fatal(err)
}
if hdr.Name != chineseFilename {
t.Fatal("Couldn't recover unicode name")
}
if hdr.Gname != chineseGroupname {
t.Fatal("Couldn't recover unicode group")
}
if hdr.Uname != chineseUsername {
t.Fatal("Couldn't recover unicode user")
}
}
func TestPAXHeader(t *testing.T) { func TestPAXHeader(t *testing.T) {
medName := strings.Repeat("CD", 50) medName := strings.Repeat("CD", 50)
longName := strings.Repeat("AB", 100) longName := strings.Repeat("AB", 100)
paxTests := [][2]string{ paxTests := [][2]string{
{"name=/etc/hosts", "19 name=/etc/hosts\n"}, {paxPath + "=/etc/hosts", "19 path=/etc/hosts\n"},
{"a=b", "6 a=b\n"}, // Single digit length {"a=b", "6 a=b\n"}, // Single digit length
{"a=names", "11 a=names\n"}, // Test case involving carries {"a=names", "11 a=names\n"}, // Test case involving carries
{"name=" + longName, fmt.Sprintf("210 name=%s\n", longName)}, {paxPath + "=" + longName, fmt.Sprintf("210 path=%s\n", longName)},
{"name=" + medName, fmt.Sprintf("110 name=%s\n", medName)}} {paxPath + "=" + medName, fmt.Sprintf("110 path=%s\n", medName)}}
for _, test := range paxTests { for _, test := range paxTests {
key, expected := test[0], test[1] key, expected := test[0], test[1]
...@@ -260,3 +355,39 @@ func TestPAXHeader(t *testing.T) { ...@@ -260,3 +355,39 @@ func TestPAXHeader(t *testing.T) {
} }
} }
} }
func TestUSTARLongName(t *testing.T) {
// Create an archive with a path that failed to split with USTAR extension in previous versions.
fileinfo, err := os.Stat("testdata/small.txt")
if err != nil {
t.Fatal(err)
}
hdr, err := FileInfoHeader(fileinfo, "")
hdr.Typeflag = TypeDir
if err != nil {
t.Fatalf("os.Stat:1 %v", err)
}
// Force a PAX long name to be written. The name was taken from a practical example
// that fails and replaced ever char through numbers to anonymize the sample.
longName := "/0000_0000000/00000-000000000/0000_0000000/00000-0000000000000/0000_0000000/00000-0000000-00000000/0000_0000000/00000000/0000_0000000/000/0000_0000000/00000000v00/0000_0000000/000000/0000_0000000/0000000/0000_0000000/00000y-00/0000/0000/00000000/0x000000/"
hdr.Name = longName
hdr.Size = 0
var buf bytes.Buffer
writer := NewWriter(&buf)
if err := writer.WriteHeader(hdr); err != nil {
t.Fatal(err)
}
if err := writer.Close(); err != nil {
t.Fatal(err)
}
// Test that we can get a long name back out of the archive.
reader := NewReader(&buf)
hdr, err = reader.Next()
if err != nil {
t.Fatal(err)
}
if hdr.Name != longName {
t.Fatal("Couldn't recover long name")
}
}
...@@ -6,13 +6,11 @@ package zip ...@@ -6,13 +6,11 @@ package zip
import ( import (
"bufio" "bufio"
"compress/flate"
"encoding/binary" "encoding/binary"
"errors" "errors"
"hash" "hash"
"hash/crc32" "hash/crc32"
"io" "io"
"io/ioutil"
"os" "os"
) )
...@@ -116,6 +114,19 @@ func (rc *ReadCloser) Close() error { ...@@ -116,6 +114,19 @@ func (rc *ReadCloser) Close() error {
return rc.f.Close() return rc.f.Close()
} }
// DataOffset returns the offset of the file's possibly-compressed
// data, relative to the beginning of the zip file.
//
// Most callers should instead use Open, which transparently
// decompresses data and verifies checksums.
func (f *File) DataOffset() (offset int64, err error) {
bodyOffset, err := f.findBodyOffset()
if err != nil {
return
}
return f.headerOffset + bodyOffset, nil
}
// Open returns a ReadCloser that provides access to the File's contents. // Open returns a ReadCloser that provides access to the File's contents.
// Multiple files may be read concurrently. // Multiple files may be read concurrently.
func (f *File) Open() (rc io.ReadCloser, err error) { func (f *File) Open() (rc io.ReadCloser, err error) {
...@@ -125,15 +136,12 @@ func (f *File) Open() (rc io.ReadCloser, err error) { ...@@ -125,15 +136,12 @@ func (f *File) Open() (rc io.ReadCloser, err error) {
} }
size := int64(f.CompressedSize64) size := int64(f.CompressedSize64)
r := io.NewSectionReader(f.zipr, f.headerOffset+bodyOffset, size) r := io.NewSectionReader(f.zipr, f.headerOffset+bodyOffset, size)
switch f.Method { dcomp := decompressor(f.Method)
case Store: // (no compression) if dcomp == nil {
rc = ioutil.NopCloser(r)
case Deflate:
rc = flate.NewReader(r)
default:
err = ErrAlgorithm err = ErrAlgorithm
return return
} }
rc = dcomp(r)
var desr io.Reader var desr io.Reader
if f.hasDataDescriptor() { if f.hasDataDescriptor() {
desr = io.NewSectionReader(f.zipr, f.headerOffset+bodyOffset+size, dataDescriptorLen) desr = io.NewSectionReader(f.zipr, f.headerOffset+bodyOffset+size, dataDescriptorLen)
...@@ -184,9 +192,8 @@ func (r *checksumReader) Close() error { return r.rc.Close() } ...@@ -184,9 +192,8 @@ func (r *checksumReader) Close() error { return r.rc.Close() }
// findBodyOffset does the minimum work to verify the file has a header // findBodyOffset does the minimum work to verify the file has a header
// and returns the file body offset. // and returns the file body offset.
func (f *File) findBodyOffset() (int64, error) { func (f *File) findBodyOffset() (int64, error) {
r := io.NewSectionReader(f.zipr, f.headerOffset, f.zipsize-f.headerOffset)
var buf [fileHeaderLen]byte var buf [fileHeaderLen]byte
if _, err := io.ReadFull(r, buf[:]); err != nil { if _, err := f.zipr.ReadAt(buf[:], f.headerOffset); err != nil {
return 0, err return 0, err
} }
b := readBuf(buf[:]) b := readBuf(buf[:])
......
...@@ -276,6 +276,7 @@ func readTestZip(t *testing.T, zt ZipTest) { ...@@ -276,6 +276,7 @@ func readTestZip(t *testing.T, zt ZipTest) {
var rc *ReadCloser var rc *ReadCloser
rc, err = OpenReader(filepath.Join("testdata", zt.Name)) rc, err = OpenReader(filepath.Join("testdata", zt.Name))
if err == nil { if err == nil {
defer rc.Close()
z = &rc.Reader z = &rc.Reader
} }
} }
......
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zip
import (
"compress/flate"
"io"
"io/ioutil"
"sync"
)
// A Compressor returns a compressing writer, writing to the
// provided writer. On Close, any pending data should be flushed.
type Compressor func(io.Writer) (io.WriteCloser, error)
// Decompressor is a function that wraps a Reader with a decompressing Reader.
// The decompressed ReadCloser is returned to callers who open files from
// within the archive. These callers are responsible for closing this reader
// when they're finished reading.
type Decompressor func(io.Reader) io.ReadCloser
var (
mu sync.RWMutex // guards compressor and decompressor maps
compressors = map[uint16]Compressor{
Store: func(w io.Writer) (io.WriteCloser, error) { return &nopCloser{w}, nil },
Deflate: func(w io.Writer) (io.WriteCloser, error) { return flate.NewWriter(w, 5) },
}
decompressors = map[uint16]Decompressor{
Store: ioutil.NopCloser,
Deflate: flate.NewReader,
}
)
// RegisterDecompressor allows custom decompressors for a specified method ID.
func RegisterDecompressor(method uint16, d Decompressor) {
mu.Lock()
defer mu.Unlock()
if _, ok := decompressors[method]; ok {
panic("decompressor already registered")
}
decompressors[method] = d
}
// RegisterCompressor registers custom compressors for a specified method ID.
// The common methods Store and Deflate are built in.
func RegisterCompressor(method uint16, comp Compressor) {
mu.Lock()
defer mu.Unlock()
if _, ok := compressors[method]; ok {
panic("compressor already registered")
}
compressors[method] = comp
}
func compressor(method uint16) Compressor {
mu.RLock()
defer mu.RUnlock()
return compressors[method]
}
func decompressor(method uint16) Decompressor {
mu.RLock()
defer mu.RUnlock()
return decompressors[method]
}
...@@ -21,6 +21,7 @@ package zip ...@@ -21,6 +21,7 @@ package zip
import ( import (
"os" "os"
"path"
"time" "time"
) )
...@@ -99,7 +100,7 @@ type headerFileInfo struct { ...@@ -99,7 +100,7 @@ type headerFileInfo struct {
fh *FileHeader fh *FileHeader
} }
func (fi headerFileInfo) Name() string { return fi.fh.Name } func (fi headerFileInfo) Name() string { return path.Base(fi.fh.Name) }
func (fi headerFileInfo) Size() int64 { func (fi headerFileInfo) Size() int64 {
if fi.fh.UncompressedSize64 > 0 { if fi.fh.UncompressedSize64 > 0 {
return int64(fi.fh.UncompressedSize64) return int64(fi.fh.UncompressedSize64)
...@@ -113,6 +114,9 @@ func (fi headerFileInfo) Sys() interface{} { return fi.fh } ...@@ -113,6 +114,9 @@ func (fi headerFileInfo) Sys() interface{} { return fi.fh }
// FileInfoHeader creates a partially-populated FileHeader from an // FileInfoHeader creates a partially-populated FileHeader from an
// os.FileInfo. // os.FileInfo.
// Because os.FileInfo's Name method returns only the base name of
// the file it describes, it may be necessary to modify the Name field
// of the returned header to provide the full path name of the file.
func FileInfoHeader(fi os.FileInfo) (*FileHeader, error) { func FileInfoHeader(fi os.FileInfo) (*FileHeader, error) {
size := fi.Size() size := fi.Size()
fh := &FileHeader{ fh := &FileHeader{
......
...@@ -6,7 +6,6 @@ package zip ...@@ -6,7 +6,6 @@ package zip
import ( import (
"bufio" "bufio"
"compress/flate"
"encoding/binary" "encoding/binary"
"errors" "errors"
"hash" "hash"
...@@ -198,18 +197,15 @@ func (w *Writer) CreateHeader(fh *FileHeader) (io.Writer, error) { ...@@ -198,18 +197,15 @@ func (w *Writer) CreateHeader(fh *FileHeader) (io.Writer, error) {
compCount: &countWriter{w: w.cw}, compCount: &countWriter{w: w.cw},
crc32: crc32.NewIEEE(), crc32: crc32.NewIEEE(),
} }
switch fh.Method { comp := compressor(fh.Method)
case Store: if comp == nil {
fw.comp = nopCloser{fw.compCount} return nil, ErrAlgorithm
case Deflate: }
var err error var err error
fw.comp, err = flate.NewWriter(fw.compCount, 5) fw.comp, err = comp(fw.compCount)
if err != nil { if err != nil {
return nil, err return nil, err
} }
default:
return nil, ErrAlgorithm
}
fw.rawCount = &countWriter{w: fw.comp} fw.rawCount = &countWriter{w: fw.comp}
h := &header{ h := &header{
......
...@@ -9,22 +9,24 @@ package zip ...@@ -9,22 +9,24 @@ package zip
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"hash"
"io" "io"
"io/ioutil" "io/ioutil"
"sort"
"strings" "strings"
"testing" "testing"
"time" "time"
) )
func TestOver65kFiles(t *testing.T) { func TestOver65kFiles(t *testing.T) {
if testing.Short() {
t.Skip("slow test; skipping")
}
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
w := NewWriter(buf) w := NewWriter(buf)
const nFiles = (1 << 16) + 42 const nFiles = (1 << 16) + 42
for i := 0; i < nFiles; i++ { for i := 0; i < nFiles; i++ {
_, err := w.Create(fmt.Sprintf("%d.dat", i)) _, err := w.CreateHeader(&FileHeader{
Name: fmt.Sprintf("%d.dat", i),
Method: Store, // avoid Issue 6136 and Issue 6138
})
if err != nil { if err != nil {
t.Fatalf("creating file %d: %v", i, err) t.Fatalf("creating file %d: %v", i, err)
} }
...@@ -105,29 +107,156 @@ func TestFileHeaderRoundTrip64(t *testing.T) { ...@@ -105,29 +107,156 @@ func TestFileHeaderRoundTrip64(t *testing.T) {
testHeaderRoundTrip(fh, uint32max, fh.UncompressedSize64, t) testHeaderRoundTrip(fh, uint32max, fh.UncompressedSize64, t)
} }
type repeatedByte struct {
off int64
b byte
n int64
}
// rleBuffer is a run-length-encoded byte buffer.
// It's an io.Writer (like a bytes.Buffer) and also an io.ReaderAt,
// allowing random-access reads.
type rleBuffer struct {
buf []repeatedByte
}
func (r *rleBuffer) Size() int64 {
if len(r.buf) == 0 {
return 0
}
last := &r.buf[len(r.buf)-1]
return last.off + last.n
}
func (r *rleBuffer) Write(p []byte) (n int, err error) {
var rp *repeatedByte
if len(r.buf) > 0 {
rp = &r.buf[len(r.buf)-1]
// Fast path, if p is entirely the same byte repeated.
if lastByte := rp.b; len(p) > 0 && p[0] == lastByte {
all := true
for _, b := range p {
if b != lastByte {
all = false
break
}
}
if all {
rp.n += int64(len(p))
return len(p), nil
}
}
}
for _, b := range p {
if rp == nil || rp.b != b {
r.buf = append(r.buf, repeatedByte{r.Size(), b, 1})
rp = &r.buf[len(r.buf)-1]
} else {
rp.n++
}
}
return len(p), nil
}
func (r *rleBuffer) ReadAt(p []byte, off int64) (n int, err error) {
if len(p) == 0 {
return
}
skipParts := sort.Search(len(r.buf), func(i int) bool {
part := &r.buf[i]
return part.off+part.n > off
})
parts := r.buf[skipParts:]
if len(parts) > 0 {
skipBytes := off - parts[0].off
for len(parts) > 0 {
part := parts[0]
for i := skipBytes; i < part.n; i++ {
if n == len(p) {
return
}
p[n] = part.b
n++
}
parts = parts[1:]
skipBytes = 0
}
}
if n != len(p) {
err = io.ErrUnexpectedEOF
}
return
}
// Just testing the rleBuffer used in the Zip64 test above. Not used by the zip code.
func TestRLEBuffer(t *testing.T) {
b := new(rleBuffer)
var all []byte
writes := []string{"abcdeee", "eeeeeee", "eeeefghaaiii"}
for _, w := range writes {
b.Write([]byte(w))
all = append(all, w...)
}
if len(b.buf) != 10 {
t.Fatalf("len(b.buf) = %d; want 10", len(b.buf))
}
for i := 0; i < len(all); i++ {
for j := 0; j < len(all)-i; j++ {
buf := make([]byte, j)
n, err := b.ReadAt(buf, int64(i))
if err != nil || n != len(buf) {
t.Errorf("ReadAt(%d, %d) = %d, %v; want %d, nil", i, j, n, err, len(buf))
}
if !bytes.Equal(buf, all[i:i+j]) {
t.Errorf("ReadAt(%d, %d) = %q; want %q", i, j, buf, all[i:i+j])
}
}
}
}
// fakeHash32 is a dummy Hash32 that always returns 0.
type fakeHash32 struct {
hash.Hash32
}
func (fakeHash32) Write(p []byte) (int, error) { return len(p), nil }
func (fakeHash32) Sum32() uint32 { return 0 }
func TestZip64(t *testing.T) { func TestZip64(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("slow test; skipping") t.Skip("slow test; skipping")
} }
const size = 1 << 32 // before the "END\n" part
testZip64(t, size)
}
func testZip64(t testing.TB, size int64) {
const chunkSize = 1024
chunks := int(size / chunkSize)
// write 2^32 bytes plus "END\n" to a zip file // write 2^32 bytes plus "END\n" to a zip file
buf := new(bytes.Buffer) buf := new(rleBuffer)
w := NewWriter(buf) w := NewWriter(buf)
f, err := w.Create("huge.txt") f, err := w.CreateHeader(&FileHeader{
Name: "huge.txt",
Method: Store,
})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
chunk := make([]byte, 1024) f.(*fileWriter).crc32 = fakeHash32{}
chunk := make([]byte, chunkSize)
for i := range chunk { for i := range chunk {
chunk[i] = '.' chunk[i] = '.'
} }
chunk[len(chunk)-1] = '\n' for i := 0; i < chunks; i++ {
end := []byte("END\n")
for i := 0; i < (1<<32)/1024; i++ {
_, err := f.Write(chunk) _, err := f.Write(chunk)
if err != nil { if err != nil {
t.Fatal("write chunk:", err) t.Fatal("write chunk:", err)
} }
} }
end := []byte("END\n")
_, err = f.Write(end) _, err = f.Write(end)
if err != nil { if err != nil {
t.Fatal("write end:", err) t.Fatal("write end:", err)
...@@ -137,7 +266,7 @@ func TestZip64(t *testing.T) { ...@@ -137,7 +266,7 @@ func TestZip64(t *testing.T) {
} }
// read back zip file and check that we get to the end of it // read back zip file and check that we get to the end of it
r, err := NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len())) r, err := NewReader(buf, int64(buf.Size()))
if err != nil { if err != nil {
t.Fatal("reader:", err) t.Fatal("reader:", err)
} }
...@@ -146,7 +275,8 @@ func TestZip64(t *testing.T) { ...@@ -146,7 +275,8 @@ func TestZip64(t *testing.T) {
if err != nil { if err != nil {
t.Fatal("opening:", err) t.Fatal("opening:", err)
} }
for i := 0; i < (1<<32)/1024; i++ { rc.(*checksumReader).hash = fakeHash32{}
for i := 0; i < chunks; i++ {
_, err := io.ReadFull(rc, chunk) _, err := io.ReadFull(rc, chunk)
if err != nil { if err != nil {
t.Fatal("read:", err) t.Fatal("read:", err)
...@@ -163,11 +293,13 @@ func TestZip64(t *testing.T) { ...@@ -163,11 +293,13 @@ func TestZip64(t *testing.T) {
if err != nil { if err != nil {
t.Fatal("closing:", err) t.Fatal("closing:", err)
} }
if size == 1<<32 {
if got, want := f0.UncompressedSize, uint32(uint32max); got != want { if got, want := f0.UncompressedSize, uint32(uint32max); got != want {
t.Errorf("UncompressedSize %d, want %d", got, want) t.Errorf("UncompressedSize %d, want %d", got, want)
} }
}
if got, want := f0.UncompressedSize64, (1<<32)+uint64(len(end)); got != want { if got, want := f0.UncompressedSize64, uint64(size)+uint64(len(end)); got != want {
t.Errorf("UncompressedSize64 %d, want %d", got, want) t.Errorf("UncompressedSize64 %d, want %d", got, want)
} }
} }
...@@ -253,3 +385,11 @@ func TestZeroLengthHeader(t *testing.T) { ...@@ -253,3 +385,11 @@ func TestZeroLengthHeader(t *testing.T) {
} }
testValidHeader(&h, t) testValidHeader(&h, t)
} }
// Just benchmarking how fast the Zip64 test above is. Not related to
// our zip performance, since the test above disabled CRC32 and flate.
func BenchmarkZip64Test(b *testing.B) {
for i := 0; i < b.N; i++ {
testZip64(b, 1<<26)
}
}
...@@ -51,12 +51,9 @@ func NewReaderSize(rd io.Reader, size int) *Reader { ...@@ -51,12 +51,9 @@ func NewReaderSize(rd io.Reader, size int) *Reader {
if size < minReadBufferSize { if size < minReadBufferSize {
size = minReadBufferSize size = minReadBufferSize
} }
return &Reader{ r := new(Reader)
buf: make([]byte, size), r.reset(make([]byte, size), rd)
rd: rd, return r
lastByte: -1,
lastRuneSize: -1,
}
} }
// NewReader returns a new Reader whose buffer has the default size. // NewReader returns a new Reader whose buffer has the default size.
...@@ -64,6 +61,21 @@ func NewReader(rd io.Reader) *Reader { ...@@ -64,6 +61,21 @@ func NewReader(rd io.Reader) *Reader {
return NewReaderSize(rd, defaultBufSize) return NewReaderSize(rd, defaultBufSize)
} }
// Reset discards any buffered data, resets all state, and switches
// the buffered reader to read from r.
func (b *Reader) Reset(r io.Reader) {
b.reset(b.buf, r)
}
func (b *Reader) reset(buf []byte, r io.Reader) {
*b = Reader{
buf: buf,
rd: r,
lastByte: -1,
lastRuneSize: -1,
}
}
var errNegativeRead = errors.New("bufio: reader returned negative count from Read") var errNegativeRead = errors.New("bufio: reader returned negative count from Read")
// fill reads a new chunk into the buffer. // fill reads a new chunk into the buffer.
...@@ -234,7 +246,7 @@ func (b *Reader) Buffered() int { return b.w - b.r } ...@@ -234,7 +246,7 @@ func (b *Reader) Buffered() int { return b.w - b.r }
// ReadSlice reads until the first occurrence of delim in the input, // ReadSlice reads until the first occurrence of delim in the input,
// returning a slice pointing at the bytes in the buffer. // returning a slice pointing at the bytes in the buffer.
// The bytes stop being valid at the next read call. // The bytes stop being valid at the next read.
// If ReadSlice encounters an error before finding a delimiter, // If ReadSlice encounters an error before finding a delimiter,
// it returns all the data in the buffer and the error itself (often io.EOF). // it returns all the data in the buffer and the error itself (often io.EOF).
// ReadSlice fails with error ErrBufferFull if the buffer fills without a delim. // ReadSlice fails with error ErrBufferFull if the buffer fills without a delim.
...@@ -381,7 +393,8 @@ func (b *Reader) ReadBytes(delim byte) (line []byte, err error) { ...@@ -381,7 +393,8 @@ func (b *Reader) ReadBytes(delim byte) (line []byte, err error) {
// For simple uses, a Scanner may be more convenient. // For simple uses, a Scanner may be more convenient.
func (b *Reader) ReadString(delim byte) (line string, err error) { func (b *Reader) ReadString(delim byte) (line string, err error) {
bytes, err := b.ReadBytes(delim) bytes, err := b.ReadBytes(delim)
return string(bytes), err line = string(bytes)
return line, err
} }
// WriteTo implements io.WriterTo. // WriteTo implements io.WriterTo.
...@@ -424,6 +437,9 @@ func (b *Reader) writeBuf(w io.Writer) (int64, error) { ...@@ -424,6 +437,9 @@ func (b *Reader) writeBuf(w io.Writer) (int64, error) {
// Writer implements buffering for an io.Writer object. // Writer implements buffering for an io.Writer object.
// If an error occurs writing to a Writer, no more data will be // If an error occurs writing to a Writer, no more data will be
// accepted and all subsequent writes will return the error. // accepted and all subsequent writes will return the error.
// After all data has been written, the client should call the
// Flush method to guarantee all data has been forwarded to
// the underlying io.Writer.
type Writer struct { type Writer struct {
err error err error
buf []byte buf []byte
...@@ -434,28 +450,41 @@ type Writer struct { ...@@ -434,28 +450,41 @@ type Writer struct {
// NewWriterSize returns a new Writer whose buffer has at least the specified // NewWriterSize returns a new Writer whose buffer has at least the specified
// size. If the argument io.Writer is already a Writer with large enough // size. If the argument io.Writer is already a Writer with large enough
// size, it returns the underlying Writer. // size, it returns the underlying Writer.
func NewWriterSize(wr io.Writer, size int) *Writer { func NewWriterSize(w io.Writer, size int) *Writer {
// Is it already a Writer? // Is it already a Writer?
b, ok := wr.(*Writer) b, ok := w.(*Writer)
if ok && len(b.buf) >= size { if ok && len(b.buf) >= size {
return b return b
} }
if size <= 0 { if size <= 0 {
size = defaultBufSize size = defaultBufSize
} }
b = new(Writer) return &Writer{
b.buf = make([]byte, size) buf: make([]byte, size),
b.wr = wr wr: w,
return b }
} }
// NewWriter returns a new Writer whose buffer has the default size. // NewWriter returns a new Writer whose buffer has the default size.
func NewWriter(wr io.Writer) *Writer { func NewWriter(w io.Writer) *Writer {
return NewWriterSize(wr, defaultBufSize) return NewWriterSize(w, defaultBufSize)
}
// Reset discards any unflushed buffered data, clears any error, and
// resets b to write its output to w.
func (b *Writer) Reset(w io.Writer) {
b.err = nil
b.n = 0
b.wr = w
} }
// Flush writes any buffered data to the underlying io.Writer. // Flush writes any buffered data to the underlying io.Writer.
func (b *Writer) Flush() error { func (b *Writer) Flush() error {
err := b.flush()
return err
}
func (b *Writer) flush() error {
if b.err != nil { if b.err != nil {
return b.err return b.err
} }
...@@ -498,7 +527,7 @@ func (b *Writer) Write(p []byte) (nn int, err error) { ...@@ -498,7 +527,7 @@ func (b *Writer) Write(p []byte) (nn int, err error) {
} else { } else {
n = copy(b.buf[b.n:], p) n = copy(b.buf[b.n:], p)
b.n += n b.n += n
b.Flush() b.flush()
} }
nn += n nn += n
p = p[n:] p = p[n:]
...@@ -517,7 +546,7 @@ func (b *Writer) WriteByte(c byte) error { ...@@ -517,7 +546,7 @@ func (b *Writer) WriteByte(c byte) error {
if b.err != nil { if b.err != nil {
return b.err return b.err
} }
if b.Available() <= 0 && b.Flush() != nil { if b.Available() <= 0 && b.flush() != nil {
return b.err return b.err
} }
b.buf[b.n] = c b.buf[b.n] = c
...@@ -540,7 +569,7 @@ func (b *Writer) WriteRune(r rune) (size int, err error) { ...@@ -540,7 +569,7 @@ func (b *Writer) WriteRune(r rune) (size int, err error) {
} }
n := b.Available() n := b.Available()
if n < utf8.UTFMax { if n < utf8.UTFMax {
if b.Flush(); b.err != nil { if b.flush(); b.err != nil {
return 0, b.err return 0, b.err
} }
n = b.Available() n = b.Available()
...@@ -565,7 +594,7 @@ func (b *Writer) WriteString(s string) (int, error) { ...@@ -565,7 +594,7 @@ func (b *Writer) WriteString(s string) (int, error) {
b.n += n b.n += n
nn += n nn += n
s = s[n:] s = s[n:]
b.Flush() b.flush()
} }
if b.err != nil { if b.err != nil {
return nn, b.err return nn, b.err
...@@ -585,24 +614,29 @@ func (b *Writer) ReadFrom(r io.Reader) (n int64, err error) { ...@@ -585,24 +614,29 @@ func (b *Writer) ReadFrom(r io.Reader) (n int64, err error) {
} }
var m int var m int
for { for {
if b.Available() == 0 {
if err1 := b.flush(); err1 != nil {
return n, err1
}
}
m, err = r.Read(b.buf[b.n:]) m, err = r.Read(b.buf[b.n:])
if m == 0 { if m == 0 {
break break
} }
b.n += m b.n += m
n += int64(m) n += int64(m)
if b.Available() == 0 {
if err1 := b.Flush(); err1 != nil {
return n, err1
}
}
if err != nil { if err != nil {
break break
} }
} }
if err == io.EOF { if err == io.EOF {
// If we filled the buffer exactly, flush pre-emptively.
if b.Available() == 0 {
err = b.flush()
} else {
err = nil err = nil
} }
}
return n, err return n, err
} }
......
...@@ -847,6 +847,10 @@ func TestWriterReadFrom(t *testing.T) { ...@@ -847,6 +847,10 @@ func TestWriterReadFrom(t *testing.T) {
t.Errorf("ws[%d],rs[%d]: w.ReadFrom(r) = %d, %v, want %d, nil", wi, ri, n, err, len(input)) t.Errorf("ws[%d],rs[%d]: w.ReadFrom(r) = %d, %v, want %d, nil", wi, ri, n, err, len(input))
continue continue
} }
if err := w.Flush(); err != nil {
t.Errorf("Flush returned %v", err)
continue
}
if got, want := b.String(), string(input); got != want { if got, want := b.String(), string(input); got != want {
t.Errorf("ws[%d], rs[%d]:\ngot %q\nwant %q\n", wi, ri, got, want) t.Errorf("ws[%d], rs[%d]:\ngot %q\nwant %q\n", wi, ri, got, want)
} }
...@@ -1003,6 +1007,56 @@ func TestReaderClearError(t *testing.T) { ...@@ -1003,6 +1007,56 @@ func TestReaderClearError(t *testing.T) {
} }
} }
// Test for golang.org/issue/5947
func TestWriterReadFromWhileFull(t *testing.T) {
buf := new(bytes.Buffer)
w := NewWriterSize(buf, 10)
// Fill buffer exactly.
n, err := w.Write([]byte("0123456789"))
if n != 10 || err != nil {
t.Fatalf("Write returned (%v, %v), want (10, nil)", n, err)
}
// Use ReadFrom to read in some data.
n2, err := w.ReadFrom(strings.NewReader("abcdef"))
if n2 != 6 || err != nil {
t.Fatalf("ReadFrom returned (%v, %v), want (6, nil)", n, err)
}
}
func TestReaderReset(t *testing.T) {
r := NewReader(strings.NewReader("foo foo"))
buf := make([]byte, 3)
r.Read(buf)
if string(buf) != "foo" {
t.Errorf("buf = %q; want foo", buf)
}
r.Reset(strings.NewReader("bar bar"))
all, err := ioutil.ReadAll(r)
if err != nil {
t.Fatal(err)
}
if string(all) != "bar bar" {
t.Errorf("ReadAll = %q; want bar bar", all)
}
}
func TestWriterReset(t *testing.T) {
var buf1, buf2 bytes.Buffer
w := NewWriter(&buf1)
w.WriteString("foo")
w.Reset(&buf2) // and not flushed
w.WriteString("bar")
w.Flush()
if buf1.String() != "" {
t.Errorf("buf1 = %q; want empty", buf1.String())
}
if buf2.String() != "bar" {
t.Errorf("buf2 = %q; want bar", buf2.String())
}
}
// An onlyReader only implements io.Reader, no matter what other methods the underlying implementation may have. // An onlyReader only implements io.Reader, no matter what other methods the underlying implementation may have.
type onlyReader struct { type onlyReader struct {
r io.Reader r io.Reader
...@@ -1083,3 +1137,46 @@ func BenchmarkWriterCopyNoReadFrom(b *testing.B) { ...@@ -1083,3 +1137,46 @@ func BenchmarkWriterCopyNoReadFrom(b *testing.B) {
io.Copy(dst, src) io.Copy(dst, src)
} }
} }
func BenchmarkReaderEmpty(b *testing.B) {
b.ReportAllocs()
str := strings.Repeat("x", 16<<10)
for i := 0; i < b.N; i++ {
br := NewReader(strings.NewReader(str))
n, err := io.Copy(ioutil.Discard, br)
if err != nil {
b.Fatal(err)
}
if n != int64(len(str)) {
b.Fatal("wrong length")
}
}
}
func BenchmarkWriterEmpty(b *testing.B) {
b.ReportAllocs()
str := strings.Repeat("x", 1<<10)
bs := []byte(str)
for i := 0; i < b.N; i++ {
bw := NewWriter(ioutil.Discard)
bw.Flush()
bw.WriteByte('a')
bw.Flush()
bw.WriteRune('B')
bw.Flush()
bw.Write(bs)
bw.Flush()
bw.WriteString(str)
bw.Flush()
}
}
func BenchmarkWriterFlush(b *testing.B) {
b.ReportAllocs()
bw := NewWriter(ioutil.Discard)
str := strings.Repeat("x", 50)
for i := 0; i < b.N; i++ {
bw.WriteString(str)
bw.Flush()
}
}
...@@ -12,6 +12,14 @@ import ( ...@@ -12,6 +12,14 @@ import (
"strings" "strings"
) )
func ExampleWriter() {
w := bufio.NewWriter(os.Stdout)
fmt.Fprint(w, "Hello, ")
fmt.Fprint(w, "world!")
w.Flush() // Don't forget to flush!
// Output: Hello, world!
}
// The simplest use of a Scanner, to read standard input as a set of lines. // The simplest use of a Scanner, to read standard input as a set of lines.
func ExampleScanner_lines() { func ExampleScanner_lines() {
scanner := bufio.NewScanner(os.Stdin) scanner := bufio.NewScanner(os.Stdin)
......
...@@ -44,8 +44,8 @@ type Scanner struct { ...@@ -44,8 +44,8 @@ type Scanner struct {
// to give. The return values are the number of bytes to advance the input // to give. The return values are the number of bytes to advance the input
// and the next token to return to the user, plus an error, if any. If the // and the next token to return to the user, plus an error, if any. If the
// data does not yet hold a complete token, for instance if it has no newline // data does not yet hold a complete token, for instance if it has no newline
// while scanning lines, SplitFunc can return (0, nil) to signal the Scanner // while scanning lines, SplitFunc can return (0, nil, nil) to signal the
// to read more data into the slice and try again with a longer slice // Scanner to read more data into the slice and try again with a longer slice
// starting at the same point in the input. // starting at the same point in the input.
// //
// If the returned error is non-nil, scanning stops and the error // If the returned error is non-nil, scanning stops and the error
...@@ -287,7 +287,7 @@ func ScanLines(data []byte, atEOF bool) (advance int, token []byte, err error) { ...@@ -287,7 +287,7 @@ func ScanLines(data []byte, atEOF bool) (advance int, token []byte, err error) {
return 0, nil, nil return 0, nil, nil
} }
// isSpace returns whether the character is a Unicode white space character. // isSpace reports whether the character is a Unicode white space character.
// We avoid dependency on the unicode package, but check validity of the implementation // We avoid dependency on the unicode package, but check validity of the implementation
// in the tests. // in the tests.
func isSpace(r rune) bool { func isSpace(r rune) bool {
......
...@@ -236,6 +236,19 @@ func panic(v interface{}) ...@@ -236,6 +236,19 @@ func panic(v interface{})
// panicking. // panicking.
func recover() interface{} func recover() interface{}
// The print built-in function formats its arguments in an implementation-
// specific way and writes the result to standard error.
// Print is useful for bootstrapping and debugging; it is not guaranteed
// to stay in the language.
func print(args ...Type)
// The println built-in function formats its arguments in an implementation-
// specific way and writes the result to standard error.
// Spaces are always added between arguments and a newline is appended.
// Println is useful for bootstrapping and debugging; it is not guaranteed
// to stay in the language.
func println(args ...Type)
// The error built-in interface type is the conventional interface for // The error built-in interface type is the conventional interface for
// representing an error condition, with the nil value representing no error. // representing an error condition, with the nil value representing no error.
type error interface { type error interface {
......
...@@ -11,32 +11,6 @@ import ( ...@@ -11,32 +11,6 @@ import (
"unicode/utf8" "unicode/utf8"
) )
// Compare returns an integer comparing two byte slices lexicographically.
// The result will be 0 if a==b, -1 if a < b, and +1 if a > b.
// A nil argument is equivalent to an empty slice.
func Compare(a, b []byte) int {
m := len(a)
if m > len(b) {
m = len(b)
}
for i, ac := range a[0:m] {
bc := b[i]
switch {
case ac > bc:
return 1
case ac < bc:
return -1
}
}
switch {
case len(a) < len(b):
return -1
case len(a) > len(b):
return 1
}
return 0
}
func equalPortable(a, b []byte) bool { func equalPortable(a, b []byte) bool {
if len(a) != len(b) { if len(a) != len(b) {
return false return false
...@@ -103,7 +77,7 @@ func Count(s, sep []byte) int { ...@@ -103,7 +77,7 @@ func Count(s, sep []byte) int {
return count return count
} }
// Contains returns whether subslice is within b. // Contains reports whether subslice is within b.
func Contains(b, subslice []byte) bool { func Contains(b, subslice []byte) bool {
return Index(b, subslice) != -1 return Index(b, subslice) != -1
} }
...@@ -401,10 +375,7 @@ func Repeat(b []byte, count int) []byte { ...@@ -401,10 +375,7 @@ func Repeat(b []byte, count int) []byte {
nb := make([]byte, len(b)*count) nb := make([]byte, len(b)*count)
bp := 0 bp := 0
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
for j := 0; j < len(b); j++ { bp += copy(nb[bp:], b)
nb[bp] = b[j]
bp++
}
} }
return nb return nb
} }
......
...@@ -7,10 +7,18 @@ package bytes ...@@ -7,10 +7,18 @@ package bytes
//go:noescape //go:noescape
// IndexByte returns the index of the first instance of c in s, or -1 if c is not present in s. // IndexByte returns the index of the first instance of c in s, or -1 if c is not present in s.
func IndexByte(s []byte, c byte) int // asm_$GOARCH.s func IndexByte(s []byte, c byte) int // ../runtime/asm_$GOARCH.s
//go:noescape //go:noescape
// Equal returns a boolean reporting whether a == b. // Equal returns a boolean reporting whether a and b
// are the same length and contain the same bytes.
// A nil argument is equivalent to an empty slice. // A nil argument is equivalent to an empty slice.
func Equal(a, b []byte) bool // asm_arm.s or ../runtime/asm_{386,amd64}.s func Equal(a, b []byte) bool // ../runtime/asm_$GOARCH.s
//go:noescape
// Compare returns an integer comparing two byte slices lexicographically.
// The result will be 0 if a==b, -1 if a < b, and +1 if a > b.
// A nil argument is equivalent to an empty slice.
func Compare(a, b []byte) int // ../runtime/noasm_arm.goc or ../runtime/asm_{386,amd64}.s
...@@ -47,7 +47,7 @@ type BinOpTest struct { ...@@ -47,7 +47,7 @@ type BinOpTest struct {
i int i int
} }
var compareTests = []struct { var equalTests = []struct {
a, b []byte a, b []byte
i int i int
}{ }{
...@@ -73,12 +73,8 @@ var compareTests = []struct { ...@@ -73,12 +73,8 @@ var compareTests = []struct {
{nil, []byte("a"), -1}, {nil, []byte("a"), -1},
} }
func TestCompare(t *testing.T) { func TestEqual(t *testing.T) {
for _, tt := range compareTests { for _, tt := range compareTests {
cmp := Compare(tt.a, tt.b)
if cmp != tt.i {
t.Errorf(`Compare(%q, %q) = %v`, tt.a, tt.b, cmp)
}
eql := Equal(tt.a, tt.b) eql := Equal(tt.a, tt.b)
if eql != (tt.i == 0) { if eql != (tt.i == 0) {
t.Errorf(`Equal(%q, %q) = %v`, tt.a, tt.b, eql) t.Errorf(`Equal(%q, %q) = %v`, tt.a, tt.b, eql)
...@@ -90,7 +86,7 @@ func TestCompare(t *testing.T) { ...@@ -90,7 +86,7 @@ func TestCompare(t *testing.T) {
} }
} }
func TestEqual(t *testing.T) { func TestEqualExhaustive(t *testing.T) {
var size = 128 var size = 128
if testing.Short() { if testing.Short() {
size = 32 size = 32
...@@ -147,6 +143,7 @@ var indexTests = []BinOpTest{ ...@@ -147,6 +143,7 @@ var indexTests = []BinOpTest{
{"", "a", -1}, {"", "a", -1},
{"", "foo", -1}, {"", "foo", -1},
{"fo", "foo", -1}, {"fo", "foo", -1},
{"foo", "baz", -1},
{"foo", "foo", 0}, {"foo", "foo", 0},
{"oofofoofooo", "f", 2}, {"oofofoofooo", "f", 2},
{"oofofoofooo", "foo", 4}, {"oofofoofooo", "foo", 4},
...@@ -1086,6 +1083,24 @@ func TestTitle(t *testing.T) { ...@@ -1086,6 +1083,24 @@ func TestTitle(t *testing.T) {
} }
} }
var ToTitleTests = []TitleTest{
{"", ""},
{"a", "A"},
{" aaa aaa aaa ", " AAA AAA AAA "},
{" Aaa Aaa Aaa ", " AAA AAA AAA "},
{"123a456", "123A456"},
{"double-blind", "DOUBLE-BLIND"},
{"ÿøû", "ŸØÛ"},
}
func TestToTitle(t *testing.T) {
for _, tt := range ToTitleTests {
if s := string(ToTitle([]byte(tt.in))); s != tt.out {
t.Errorf("ToTitle(%q) = %q, want %q", tt.in, s, tt.out)
}
}
}
var EqualFoldTests = []struct { var EqualFoldTests = []struct {
s, t string s, t string
out bool out bool
...@@ -1114,6 +1129,37 @@ func TestEqualFold(t *testing.T) { ...@@ -1114,6 +1129,37 @@ func TestEqualFold(t *testing.T) {
} }
} }
func TestBufferGrowNegative(t *testing.T) {
defer func() {
if err := recover(); err == nil {
t.Fatal("Grow(-1) should have paniced")
}
}()
var b Buffer
b.Grow(-1)
}
func TestBufferTruncateNegative(t *testing.T) {
defer func() {
if err := recover(); err == nil {
t.Fatal("Truncate(-1) should have paniced")
}
}()
var b Buffer
b.Truncate(-1)
}
func TestBufferTruncateOutOfRange(t *testing.T) {
defer func() {
if err := recover(); err == nil {
t.Fatal("Truncate(20) should have paniced")
}
}()
var b Buffer
b.Write(make([]byte, 10))
b.Truncate(20)
}
var makeFieldsInput = func() []byte { var makeFieldsInput = func() []byte {
x := make([]byte, 1<<20) x := make([]byte, 1<<20)
// Input is ~10% space, ~10% 2-byte UTF-8, rest ASCII non-space. // Input is ~10% space, ~10% 2-byte UTF-8, rest ASCII non-space.
......
package bytes_test
import (
. "bytes"
"testing"
)
var compareTests = []struct {
a, b []byte
i int
}{
{[]byte(""), []byte(""), 0},
{[]byte("a"), []byte(""), 1},
{[]byte(""), []byte("a"), -1},
{[]byte("abc"), []byte("abc"), 0},
{[]byte("ab"), []byte("abc"), -1},
{[]byte("abc"), []byte("ab"), 1},
{[]byte("x"), []byte("ab"), 1},
{[]byte("ab"), []byte("x"), -1},
{[]byte("x"), []byte("a"), 1},
{[]byte("b"), []byte("x"), -1},
// test runtime·memeq's chunked implementation
{[]byte("abcdefgh"), []byte("abcdefgh"), 0},
{[]byte("abcdefghi"), []byte("abcdefghi"), 0},
{[]byte("abcdefghi"), []byte("abcdefghj"), -1},
// nil tests
{nil, nil, 0},
{[]byte(""), nil, 0},
{nil, []byte(""), 0},
{[]byte("a"), nil, 1},
{nil, []byte("a"), -1},
}
func TestCompare(t *testing.T) {
for _, tt := range compareTests {
cmp := Compare(tt.a, tt.b)
if cmp != tt.i {
t.Errorf(`Compare(%q, %q) = %v`, tt.a, tt.b, cmp)
}
}
}
func TestCompareIdenticalSlice(t *testing.T) {
var b = []byte("Hello Gophers!")
if Compare(b, b) != 0 {
t.Error("b != b")
}
if Compare(b, b[:1]) != 1 {
t.Error("b > b[:1] failed")
}
}
func TestCompareBytes(t *testing.T) {
n := 128
a := make([]byte, n+1)
b := make([]byte, n+1)
for len := 0; len < 128; len++ {
// randomish but deterministic data. No 0 or 255.
for i := 0; i < len; i++ {
a[i] = byte(1 + 31*i%254)
b[i] = byte(1 + 31*i%254)
}
// data past the end is different
for i := len; i <= n; i++ {
a[i] = 8
b[i] = 9
}
cmp := Compare(a[:len], b[:len])
if cmp != 0 {
t.Errorf(`CompareIdentical(%d) = %d`, len, cmp)
}
if len > 0 {
cmp = Compare(a[:len-1], b[:len])
if cmp != -1 {
t.Errorf(`CompareAshorter(%d) = %d`, len, cmp)
}
cmp = Compare(a[:len], b[:len-1])
if cmp != 1 {
t.Errorf(`CompareBshorter(%d) = %d`, len, cmp)
}
}
for k := 0; k < len; k++ {
b[k] = a[k] - 1
cmp = Compare(a[:len], b[:len])
if cmp != 1 {
t.Errorf(`CompareAbigger(%d,%d) = %d`, len, k, cmp)
}
b[k] = a[k] + 1
cmp = Compare(a[:len], b[:len])
if cmp != -1 {
t.Errorf(`CompareBbigger(%d,%d) = %d`, len, k, cmp)
}
b[k] = a[k]
}
}
}
func BenchmarkCompareBytesEqual(b *testing.B) {
b1 := []byte("Hello Gophers!")
b2 := []byte("Hello Gophers!")
for i := 0; i < b.N; i++ {
if Compare(b1, b2) != 0 {
b.Fatal("b1 != b2")
}
}
}
func BenchmarkCompareBytesToNil(b *testing.B) {
b1 := []byte("Hello Gophers!")
var b2 []byte
for i := 0; i < b.N; i++ {
if Compare(b1, b2) != 1 {
b.Fatal("b1 > b2 failed")
}
}
}
func BenchmarkCompareBytesEmpty(b *testing.B) {
b1 := []byte("")
b2 := b1
for i := 0; i < b.N; i++ {
if Compare(b1, b2) != 0 {
b.Fatal("b1 != b2")
}
}
}
func BenchmarkCompareBytesIdentical(b *testing.B) {
b1 := []byte("Hello Gophers!")
b2 := b1
for i := 0; i < b.N; i++ {
if Compare(b1, b2) != 0 {
b.Fatal("b1 != b2")
}
}
}
func BenchmarkCompareBytesSameLength(b *testing.B) {
b1 := []byte("Hello Gophers!")
b2 := []byte("Hello, Gophers")
for i := 0; i < b.N; i++ {
if Compare(b1, b2) != -1 {
b.Fatal("b1 < b2 failed")
}
}
}
func BenchmarkCompareBytesDifferentLength(b *testing.B) {
b1 := []byte("Hello Gophers!")
b2 := []byte("Hello, Gophers!")
for i := 0; i < b.N; i++ {
if Compare(b1, b2) != -1 {
b.Fatal("b1 < b2 failed")
}
}
}
func BenchmarkCompareBytesBigUnaligned(b *testing.B) {
b.StopTimer()
b1 := make([]byte, 0, 1<<20)
for len(b1) < 1<<20 {
b1 = append(b1, "Hello Gophers!"...)
}
b2 := append([]byte("hello"), b1...)
b.StartTimer()
for i := 0; i < b.N; i++ {
if Compare(b1, b2[len("hello"):]) != 0 {
b.Fatal("b1 != b2")
}
}
b.SetBytes(int64(len(b1)))
}
func BenchmarkCompareBytesBig(b *testing.B) {
b.StopTimer()
b1 := make([]byte, 0, 1<<20)
for len(b1) < 1<<20 {
b1 = append(b1, "Hello Gophers!"...)
}
b2 := append([]byte{}, b1...)
b.StartTimer()
for i := 0; i < b.N; i++ {
if Compare(b1, b2) != 0 {
b.Fatal("b1 != b2")
}
}
b.SetBytes(int64(len(b1)))
}
func BenchmarkCompareBytesBigIdentical(b *testing.B) {
b.StopTimer()
b1 := make([]byte, 0, 1<<20)
for len(b1) < 1<<20 {
b1 = append(b1, "Hello Gophers!"...)
}
b2 := b1
b.StartTimer()
for i := 0; i < b.N; i++ {
if Compare(b1, b2) != 0 {
b.Fatal("b1 != b2")
}
}
b.SetBytes(int64(len(b1)))
}
...@@ -41,3 +41,33 @@ Equal (struct __go_open_array a, struct __go_open_array b) ...@@ -41,3 +41,33 @@ Equal (struct __go_open_array a, struct __go_open_array b)
return 0; return 0;
return __builtin_memcmp (a.__values, b.__values, a.__count) == 0; return __builtin_memcmp (a.__values, b.__values, a.__count) == 0;
} }
intgo Compare (struct __go_open_array a, struct __go_open_array b)
__asm__ (GOSYM_PREFIX "bytes.Compare")
__attribute__ ((no_split_stack));
intgo
Compare (struct __go_open_array a, struct __go_open_array b)
{
intgo len;
len = a.__count;
if (len > b.__count)
len = b.__count;
if (len > 0)
{
intgo ret;
ret = __builtin_memcmp (a.__values, b.__values, len);
if (ret < 0)
return -1;
else if (ret > 0)
return 1;
}
if (a.__count < b.__count)
return -1;
else if (a.__count > b.__count)
return 1;
else
return 0;
}
...@@ -113,6 +113,41 @@ func TestReaderWriteTo(t *testing.T) { ...@@ -113,6 +113,41 @@ func TestReaderWriteTo(t *testing.T) {
} }
} }
func TestReaderLen(t *testing.T) {
const data = "hello world"
r := NewReader([]byte(data))
if got, want := r.Len(), 11; got != want {
t.Errorf("r.Len(): got %d, want %d", got, want)
}
if n, err := r.Read(make([]byte, 10)); err != nil || n != 10 {
t.Errorf("Read failed: read %d %v", n, err)
}
if got, want := r.Len(), 1; got != want {
t.Errorf("r.Len(): got %d, want %d", got, want)
}
if n, err := r.Read(make([]byte, 1)); err != nil || n != 1 {
t.Errorf("Read failed: read %d %v", n, err)
}
if got, want := r.Len(), 0; got != want {
t.Errorf("r.Len(): got %d, want %d", got, want)
}
}
func TestReaderDoubleUnreadRune(t *testing.T) {
buf := NewBuffer([]byte("groucho"))
if _, _, err := buf.ReadRune(); err != nil {
// should not happen
t.Fatal(err)
}
if err := buf.UnreadByte(); err != nil {
// should not happen
t.Fatal(err)
}
if err := buf.UnreadByte(); err == nil {
t.Fatal("UnreadByte: expected error, got nil")
}
}
// verify that copying from an empty reader always has the same results, // verify that copying from an empty reader always has the same results,
// regardless of the presence of a WriteTo method. // regardless of the presence of a WriteTo method.
func TestReaderCopyNothing(t *testing.T) { func TestReaderCopyNothing(t *testing.T) {
......
...@@ -77,6 +77,14 @@ func (br *bitReader) ReadBit() bool { ...@@ -77,6 +77,14 @@ func (br *bitReader) ReadBit() bool {
return n != 0 return n != 0
} }
func (br *bitReader) TryReadBit() (bit byte, ok bool) {
if br.bits > 0 {
br.bits--
return byte(br.n>>br.bits) & 1, true
}
return 0, false
}
func (br *bitReader) Err() error { func (br *bitReader) Err() error {
return br.err return br.err
} }
...@@ -23,6 +23,9 @@ func (s StructuralError) Error() string { ...@@ -23,6 +23,9 @@ func (s StructuralError) Error() string {
// A reader decompresses bzip2 compressed data. // A reader decompresses bzip2 compressed data.
type reader struct { type reader struct {
br bitReader br bitReader
fileCRC uint32
blockCRC uint32
wantBlockCRC uint32
setupDone bool // true if we have parsed the bzip2 header. setupDone bool // true if we have parsed the bzip2 header.
blockSize int // blockSize in bytes, i.e. 900 * 1024. blockSize int // blockSize in bytes, i.e. 900 * 1024.
eof bool eof bool
...@@ -50,13 +53,15 @@ const bzip2BlockMagic = 0x314159265359 ...@@ -50,13 +53,15 @@ const bzip2BlockMagic = 0x314159265359
const bzip2FinalMagic = 0x177245385090 const bzip2FinalMagic = 0x177245385090
// setup parses the bzip2 header. // setup parses the bzip2 header.
func (bz2 *reader) setup() error { func (bz2 *reader) setup(needMagic bool) error {
br := &bz2.br br := &bz2.br
if needMagic {
magic := br.ReadBits(16) magic := br.ReadBits(16)
if magic != bzip2FileMagic { if magic != bzip2FileMagic {
return StructuralError("bad magic value") return StructuralError("bad magic value")
} }
}
t := br.ReadBits(8) t := br.ReadBits(8)
if t != 'h' { if t != 'h' {
...@@ -68,8 +73,11 @@ func (bz2 *reader) setup() error { ...@@ -68,8 +73,11 @@ func (bz2 *reader) setup() error {
return StructuralError("invalid compression level") return StructuralError("invalid compression level")
} }
bz2.fileCRC = 0
bz2.blockSize = 100 * 1024 * (int(level) - '0') bz2.blockSize = 100 * 1024 * (int(level) - '0')
if bz2.blockSize > len(bz2.tt) {
bz2.tt = make([]uint32, bz2.blockSize) bz2.tt = make([]uint32, bz2.blockSize)
}
return nil return nil
} }
...@@ -79,7 +87,7 @@ func (bz2 *reader) Read(buf []byte) (n int, err error) { ...@@ -79,7 +87,7 @@ func (bz2 *reader) Read(buf []byte) (n int, err error) {
} }
if !bz2.setupDone { if !bz2.setupDone {
err = bz2.setup() err = bz2.setup(true)
brErr := bz2.br.Err() brErr := bz2.br.Err()
if brErr != nil { if brErr != nil {
err = brErr err = brErr
...@@ -98,14 +106,14 @@ func (bz2 *reader) Read(buf []byte) (n int, err error) { ...@@ -98,14 +106,14 @@ func (bz2 *reader) Read(buf []byte) (n int, err error) {
return return
} }
func (bz2 *reader) read(buf []byte) (n int, err error) { func (bz2 *reader) readFromBlock(buf []byte) int {
// bzip2 is a block based compressor, except that it has a run-length // bzip2 is a block based compressor, except that it has a run-length
// preprocessing step. The block based nature means that we can // preprocessing step. The block based nature means that we can
// preallocate fixed-size buffers and reuse them. However, the RLE // preallocate fixed-size buffers and reuse them. However, the RLE
// preprocessing would require allocating huge buffers to store the // preprocessing would require allocating huge buffers to store the
// maximum expansion. Thus we process blocks all at once, except for // maximum expansion. Thus we process blocks all at once, except for
// the RLE which we decompress as required. // the RLE which we decompress as required.
n := 0
for (bz2.repeats > 0 || bz2.preRLEUsed < len(bz2.preRLE)) && n < len(buf) { for (bz2.repeats > 0 || bz2.preRLEUsed < len(bz2.preRLE)) && n < len(buf) {
// We have RLE data pending. // We have RLE data pending.
...@@ -148,34 +156,87 @@ func (bz2 *reader) read(buf []byte) (n int, err error) { ...@@ -148,34 +156,87 @@ func (bz2 *reader) read(buf []byte) (n int, err error) {
n++ n++
} }
return n
}
func (bz2 *reader) read(buf []byte) (int, error) {
for {
n := bz2.readFromBlock(buf)
if n > 0 { if n > 0 {
return bz2.blockCRC = updateCRC(bz2.blockCRC, buf[:n])
return n, nil
} }
// No RLE data is pending so we need to read a block. // End of block. Check CRC.
if bz2.blockCRC != bz2.wantBlockCRC {
bz2.br.err = StructuralError("block checksum mismatch")
return 0, bz2.br.err
}
// Find next block.
br := &bz2.br br := &bz2.br
magic := br.ReadBits64(48) switch br.ReadBits64(48) {
if magic == bzip2FinalMagic { default:
br.ReadBits64(32) // ignored CRC
bz2.eof = true
return 0, io.EOF
} else if magic != bzip2BlockMagic {
return 0, StructuralError("bad magic value found") return 0, StructuralError("bad magic value found")
}
err = bz2.readBlock() case bzip2BlockMagic:
// Start of block.
err := bz2.readBlock()
if err != nil { if err != nil {
return 0, err return 0, err
} }
return bz2.read(buf) case bzip2FinalMagic:
// Check end-of-file CRC.
wantFileCRC := uint32(br.ReadBits64(32))
if br.err != nil {
return 0, br.err
}
if bz2.fileCRC != wantFileCRC {
br.err = StructuralError("file checksum mismatch")
return 0, br.err
}
// Skip ahead to byte boundary.
// Is there a file concatenated to this one?
// It would start with BZ.
if br.bits%8 != 0 {
br.ReadBits(br.bits % 8)
}
b, err := br.r.ReadByte()
if err == io.EOF {
br.err = io.EOF
bz2.eof = true
return 0, io.EOF
}
if err != nil {
br.err = err
return 0, err
}
z, err := br.r.ReadByte()
if err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
br.err = err
return 0, err
}
if b != 'B' || z != 'Z' {
return 0, StructuralError("bad magic value in continuation file")
}
if err := bz2.setup(false); err != nil {
return 0, err
}
}
}
} }
// readBlock reads a bzip2 block. The magic number should already have been consumed. // readBlock reads a bzip2 block. The magic number should already have been consumed.
func (bz2 *reader) readBlock() (err error) { func (bz2 *reader) readBlock() (err error) {
br := &bz2.br br := &bz2.br
br.ReadBits64(32) // skip checksum. TODO: check it if we can figure out what it is. bz2.wantBlockCRC = uint32(br.ReadBits64(32)) // skip checksum. TODO: check it if we can figure out what it is.
bz2.blockCRC = 0
bz2.fileCRC = (bz2.fileCRC<<1 | bz2.fileCRC>>31) ^ bz2.wantBlockCRC
randomized := br.ReadBits(1) randomized := br.ReadBits(1)
if randomized != 0 { if randomized != 0 {
return StructuralError("deprecated randomized files") return StructuralError("deprecated randomized files")
...@@ -316,6 +377,9 @@ func (bz2 *reader) readBlock() (err error) { ...@@ -316,6 +377,9 @@ func (bz2 *reader) readBlock() (err error) {
if repeat > 0 { if repeat > 0 {
// We have decoded a complete run-length so we need to // We have decoded a complete run-length so we need to
// replicate the last output symbol. // replicate the last output symbol.
if repeat > bz2.blockSize-bufIndex {
return StructuralError("repeats past end of block")
}
for i := 0; i < repeat; i++ { for i := 0; i < repeat; i++ {
b := byte(mtf.First()) b := byte(mtf.First())
bz2.tt[bufIndex] = uint32(b) bz2.tt[bufIndex] = uint32(b)
...@@ -339,6 +403,9 @@ func (bz2 *reader) readBlock() (err error) { ...@@ -339,6 +403,9 @@ func (bz2 *reader) readBlock() (err error) {
// doesn't need to be encoded and we have |v-1| in the next // doesn't need to be encoded and we have |v-1| in the next
// line. // line.
b := byte(mtf.Decode(int(v - 1))) b := byte(mtf.Decode(int(v - 1)))
if bufIndex >= bz2.blockSize {
return StructuralError("data exceeds block size")
}
bz2.tt[bufIndex] = uint32(b) bz2.tt[bufIndex] = uint32(b)
bz2.c[b]++ bz2.c[b]++
bufIndex++ bufIndex++
...@@ -385,3 +452,33 @@ func inverseBWT(tt []uint32, origPtr uint, c []uint) uint32 { ...@@ -385,3 +452,33 @@ func inverseBWT(tt []uint32, origPtr uint, c []uint) uint32 {
return tt[origPtr] >> 8 return tt[origPtr] >> 8
} }
// This is a standard CRC32 like in hash/crc32 except that all the shifts are reversed,
// causing the bits in the input to be processed in the reverse of the usual order.
var crctab [256]uint32
func init() {
const poly = 0x04C11DB7
for i := range crctab {
crc := uint32(i) << 24
for j := 0; j < 8; j++ {
if crc&0x80000000 != 0 {
crc = (crc << 1) ^ poly
} else {
crc <<= 1
}
}
crctab[i] = crc
}
}
// updateCRC updates the crc value to incorporate the data in b.
// The initial value is 0.
func updateCRC(val uint32, b []byte) uint32 {
crc := ^val
for _, v := range b {
crc = crctab[byte(crc>>24)^v] ^ (crc << 8)
}
return ^crc
}
...@@ -33,14 +33,17 @@ const invalidNodeValue = 0xffff ...@@ -33,14 +33,17 @@ const invalidNodeValue = 0xffff
// Decode reads bits from the given bitReader and navigates the tree until a // Decode reads bits from the given bitReader and navigates the tree until a
// symbol is found. // symbol is found.
func (t huffmanTree) Decode(br *bitReader) (v uint16) { func (t *huffmanTree) Decode(br *bitReader) (v uint16) {
nodeIndex := uint16(0) // node 0 is the root of the tree. nodeIndex := uint16(0) // node 0 is the root of the tree.
for { for {
node := &t.nodes[nodeIndex] node := &t.nodes[nodeIndex]
bit := br.ReadBit() bit, ok := br.TryReadBit()
if !ok && br.ReadBit() {
bit = 1
}
// bzip2 encodes left as a true bit. // bzip2 encodes left as a true bit.
if bit { if bit != 0 {
// left // left
if node.left == invalidNodeValue { if node.left == invalidNodeValue {
return node.leftValue return node.leftValue
......
...@@ -15,10 +15,11 @@ type moveToFrontDecoder struct { ...@@ -15,10 +15,11 @@ type moveToFrontDecoder struct {
// Rather than actually keep the list in memory, the symbols are stored // Rather than actually keep the list in memory, the symbols are stored
// as a circular, double linked list with the symbol indexed by head // as a circular, double linked list with the symbol indexed by head
// at the front of the list. // at the front of the list.
symbols []byte symbols [256]byte
next []uint8 next [256]uint8
prev []uint8 prev [256]uint8
head uint8 head uint8
len int
} }
// newMTFDecoder creates a move-to-front decoder with an explicit initial list // newMTFDecoder creates a move-to-front decoder with an explicit initial list
...@@ -28,12 +29,9 @@ func newMTFDecoder(symbols []byte) *moveToFrontDecoder { ...@@ -28,12 +29,9 @@ func newMTFDecoder(symbols []byte) *moveToFrontDecoder {
panic("too many symbols") panic("too many symbols")
} }
m := &moveToFrontDecoder{ m := new(moveToFrontDecoder)
symbols: symbols, copy(m.symbols[:], symbols)
next: make([]uint8, len(symbols)), m.len = len(symbols)
prev: make([]uint8, len(symbols)),
}
m.threadLinkedList() m.threadLinkedList()
return m return m
} }
...@@ -45,34 +43,29 @@ func newMTFDecoderWithRange(n int) *moveToFrontDecoder { ...@@ -45,34 +43,29 @@ func newMTFDecoderWithRange(n int) *moveToFrontDecoder {
panic("newMTFDecoderWithRange: cannot have > 256 symbols") panic("newMTFDecoderWithRange: cannot have > 256 symbols")
} }
m := &moveToFrontDecoder{ m := new(moveToFrontDecoder)
symbols: make([]uint8, n),
next: make([]uint8, n),
prev: make([]uint8, n),
}
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
m.symbols[i] = byte(i) m.symbols[byte(i)] = byte(i)
} }
m.len = n
m.threadLinkedList() m.threadLinkedList()
return m return m
} }
// threadLinkedList creates the initial linked-list pointers. // threadLinkedList creates the initial linked-list pointers.
func (m *moveToFrontDecoder) threadLinkedList() { func (m *moveToFrontDecoder) threadLinkedList() {
if len(m.symbols) == 0 { if m.len == 0 {
return return
} }
m.prev[0] = uint8(len(m.symbols) - 1) m.prev[0] = uint8(m.len - 1)
for i := 0; i < len(m.symbols)-1; i++ { for i := byte(0); int(i) < m.len-1; i++ {
m.next[i] = uint8(i + 1) m.next[i] = uint8(i + 1)
m.prev[i+1] = uint8(i) m.prev[i+1] = uint8(i)
} }
m.next[len(m.symbols)-1] = 0 m.next[m.len-1] = 0
} }
func (m *moveToFrontDecoder) Decode(n int) (b byte) { func (m *moveToFrontDecoder) Decode(n int) (b byte) {
......
...@@ -6,12 +6,27 @@ package flate ...@@ -6,12 +6,27 @@ package flate
// forwardCopy is like the built-in copy function except that it always goes // forwardCopy is like the built-in copy function except that it always goes
// forward from the start, even if the dst and src overlap. // forward from the start, even if the dst and src overlap.
func forwardCopy(dst, src []byte) int { // It is equivalent to:
if len(src) > len(dst) { // for i := 0; i < n; i++ {
src = src[:len(dst)] // mem[dst+i] = mem[src+i]
// }
func forwardCopy(mem []byte, dst, src, n int) {
if dst <= src {
copy(mem[dst:dst+n], mem[src:src+n])
return
} }
for i, x := range src { for {
dst[i] = x if dst >= src+n {
copy(mem[dst:dst+n], mem[src:src+n])
return
}
// There is some forward overlap. The destination
// will be filled with a repeated pattern of mem[src:src+k].
// We copy one instance of the pattern here, then repeat.
// Each time around this loop k will double.
k := dst - src
copy(mem[dst:dst+k], mem[src:src+k])
n -= k
dst += k
} }
return len(src)
} }
...@@ -30,10 +30,12 @@ func TestForwardCopy(t *testing.T) { ...@@ -30,10 +30,12 @@ func TestForwardCopy(t *testing.T) {
} }
for _, tc := range testCases { for _, tc := range testCases {
b := []byte("0123456789") b := []byte("0123456789")
dst := b[tc.dst0:tc.dst1] n := tc.dst1 - tc.dst0
src := b[tc.src0:tc.src1] if tc.src1-tc.src0 < n {
n := forwardCopy(dst, src) n = tc.src1 - tc.src0
got := string(dst[:n]) }
forwardCopy(b, tc.dst0, tc.src0, n)
got := string(b[tc.dst0 : tc.dst0+n])
if got != tc.want { if got != tc.want {
t.Errorf("dst=b[%d:%d], src=b[%d:%d]: got %q, want %q", t.Errorf("dst=b[%d:%d], src=b[%d:%d]: got %q, want %q",
tc.dst0, tc.dst1, tc.src0, tc.src1, got, tc.want) tc.dst0, tc.dst1, tc.src0, tc.src1, got, tc.want)
......
...@@ -416,6 +416,50 @@ func (d *compressor) init(w io.Writer, level int) (err error) { ...@@ -416,6 +416,50 @@ func (d *compressor) init(w io.Writer, level int) (err error) {
return nil return nil
} }
var zeroes [32]int
var bzeroes [256]byte
func (d *compressor) reset(w io.Writer) {
d.w.reset(w)
d.sync = false
d.err = nil
switch d.compressionLevel.chain {
case 0:
// level was NoCompression.
for i := range d.window {
d.window[i] = 0
}
d.windowEnd = 0
default:
d.chainHead = -1
for s := d.hashHead; len(s) > 0; {
n := copy(s, zeroes[:])
s = s[n:]
}
for s := d.hashPrev; len(s) > 0; s = s[len(zeroes):] {
copy(s, zeroes[:])
}
d.hashOffset = 1
d.index, d.windowEnd = 0, 0
for s := d.window; len(s) > 0; {
n := copy(s, bzeroes[:])
s = s[n:]
}
d.blockStart, d.byteAvailable = 0, false
d.tokens = d.tokens[:maxFlateBlockTokens+1]
for i := 0; i <= maxFlateBlockTokens; i++ {
d.tokens[i] = 0
}
d.tokens = d.tokens[:0]
d.length = minMatchLength - 1
d.offset = 0
d.hash = 0
d.maxInsertIndex = 0
}
}
func (d *compressor) close() error { func (d *compressor) close() error {
d.sync = true d.sync = true
d.step(d) d.step(d)
...@@ -439,7 +483,6 @@ func (d *compressor) close() error { ...@@ -439,7 +483,6 @@ func (d *compressor) close() error {
// If level is in the range [-1, 9] then the error returned will be nil. // If level is in the range [-1, 9] then the error returned will be nil.
// Otherwise the error returned will be non-nil. // Otherwise the error returned will be non-nil.
func NewWriter(w io.Writer, level int) (*Writer, error) { func NewWriter(w io.Writer, level int) (*Writer, error) {
const logWindowSize = logMaxOffsetSize
var dw Writer var dw Writer
if err := dw.d.init(w, level); err != nil { if err := dw.d.init(w, level); err != nil {
return nil, err return nil, err
...@@ -462,6 +505,7 @@ func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, error) { ...@@ -462,6 +505,7 @@ func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, error) {
zw.Write(dict) zw.Write(dict)
zw.Flush() zw.Flush()
dw.enabled = true dw.enabled = true
zw.dict = append(zw.dict, dict...) // duplicate dictionary for Reset method.
return zw, err return zw, err
} }
...@@ -481,6 +525,7 @@ func (w *dictWriter) Write(b []byte) (n int, err error) { ...@@ -481,6 +525,7 @@ func (w *dictWriter) Write(b []byte) (n int, err error) {
// form of that data to an underlying writer (see NewWriter). // form of that data to an underlying writer (see NewWriter).
type Writer struct { type Writer struct {
d compressor d compressor
dict []byte
} }
// Write writes data to w, which will eventually write the // Write writes data to w, which will eventually write the
...@@ -506,3 +551,21 @@ func (w *Writer) Flush() error { ...@@ -506,3 +551,21 @@ func (w *Writer) Flush() error {
func (w *Writer) Close() error { func (w *Writer) Close() error {
return w.d.close() return w.d.close()
} }
// Reset discards the writer's state and makes it equivalent to
// the result of NewWriter or NewWriterDict called with dst
// and w's level and dictionary.
func (w *Writer) Reset(dst io.Writer) {
if dw, ok := w.d.w.w.(*dictWriter); ok {
// w was created with NewWriterDict
dw.w = dst
w.d.reset(dw)
dw.enabled = false
w.Write(w.dict)
w.Flush()
dw.enabled = true
} else {
// w was created with NewWriter
w.d.reset(dst)
}
}
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"reflect"
"sync" "sync"
"testing" "testing"
) )
...@@ -424,3 +425,66 @@ func TestRegression2508(t *testing.T) { ...@@ -424,3 +425,66 @@ func TestRegression2508(t *testing.T) {
} }
w.Close() w.Close()
} }
func TestWriterReset(t *testing.T) {
for level := 0; level <= 9; level++ {
if testing.Short() && level > 1 {
break
}
w, err := NewWriter(ioutil.Discard, level)
if err != nil {
t.Fatalf("NewWriter: %v", err)
}
buf := []byte("hello world")
for i := 0; i < 1024; i++ {
w.Write(buf)
}
w.Reset(ioutil.Discard)
wref, err := NewWriter(ioutil.Discard, level)
if err != nil {
t.Fatalf("NewWriter: %v", err)
}
// DeepEqual doesn't compare functions.
w.d.fill, wref.d.fill = nil, nil
w.d.step, wref.d.step = nil, nil
if !reflect.DeepEqual(w, wref) {
t.Errorf("level %d Writer not reset after Reset", level)
}
}
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriter(w, NoCompression) })
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriter(w, DefaultCompression) })
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriter(w, BestCompression) })
dict := []byte("we are the world")
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriterDict(w, NoCompression, dict) })
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriterDict(w, DefaultCompression, dict) })
testResetOutput(t, func(w io.Writer) (*Writer, error) { return NewWriterDict(w, BestCompression, dict) })
}
func testResetOutput(t *testing.T, newWriter func(w io.Writer) (*Writer, error)) {
buf := new(bytes.Buffer)
w, err := newWriter(buf)
if err != nil {
t.Fatalf("NewWriter: %v", err)
}
b := []byte("hello world")
for i := 0; i < 1024; i++ {
w.Write(b)
}
w.Close()
out1 := buf.String()
buf2 := new(bytes.Buffer)
w.Reset(buf2)
for i := 0; i < 1024; i++ {
w.Write(b)
}
w.Close()
out2 := buf2.String()
if out1 != out2 {
t.Errorf("got %q, expected %q", out2, out1)
}
t.Logf("got %d bytes", len(out1))
}
...@@ -24,3 +24,39 @@ func TestUncompressedSource(t *testing.T) { ...@@ -24,3 +24,39 @@ func TestUncompressedSource(t *testing.T) {
t.Errorf("output[0] = %x, want 0x11", output[0]) t.Errorf("output[0] = %x, want 0x11", output[0])
} }
} }
// The following test should not panic.
func TestIssue5915(t *testing.T) {
bits := []int{4, 0, 0, 6, 4, 3, 2, 3, 3, 4, 4, 5, 0, 0, 0, 0, 5, 5, 6,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 8, 6, 0, 11, 0, 8, 0, 6, 6, 10, 8}
h := new(huffmanDecoder)
ok := h.init(bits)
if ok == true {
t.Fatalf("Given sequence of bits is bad, and should not succeed.")
}
}
// The following test should not panic.
func TestIssue5962(t *testing.T) {
bits := []int{4, 0, 0, 6, 4, 3, 2, 3, 3, 4, 4, 5, 0, 0, 0, 0,
5, 5, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 11}
h := new(huffmanDecoder)
ok := h.init(bits)
if ok == true {
t.Fatalf("Given sequence of bits is bad, and should not succeed.")
}
}
// The following test should not panic.
func TestIssue6255(t *testing.T) {
bits1 := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 11}
bits2 := []int{11, 13}
h := new(huffmanDecoder)
if !h.init(bits1) {
t.Fatalf("Given sequence of bits is good and should succeed.")
}
if h.init(bits2) {
t.Fatalf("Given sequence of bits is bad and should not succeed.")
}
}
...@@ -97,6 +97,31 @@ func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter { ...@@ -97,6 +97,31 @@ func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter {
} }
} }
func (w *huffmanBitWriter) reset(writer io.Writer) {
w.w = writer
w.bits, w.nbits, w.nbytes, w.err = 0, 0, 0, nil
w.bytes = [64]byte{}
for i := range w.codegen {
w.codegen[i] = 0
}
for _, s := range [...][]int32{w.literalFreq, w.offsetFreq, w.codegenFreq} {
for i := range s {
s[i] = 0
}
}
for _, enc := range [...]*huffmanEncoder{
w.literalEncoding,
w.offsetEncoding,
w.codegenEncoding} {
for i := range enc.code {
enc.code[i] = 0
}
for i := range enc.codeBits {
enc.codeBits[i] = 0
}
}
}
func (w *huffmanBitWriter) flushBits() { func (w *huffmanBitWriter) flushBits() {
if w.err != nil { if w.err != nil {
w.nbits = 0 w.nbits = 0
......
...@@ -19,23 +19,13 @@ type literalNode struct { ...@@ -19,23 +19,13 @@ type literalNode struct {
freq int32 freq int32
} }
type chain struct { // A levelInfo describes the state of the constructed tree for a given depth.
// The sum of the leaves in this tree
freq int32
// The number of literals to the left of this item at this level
leafCount int32
// The right child of this chain in the previous level.
up *chain
}
type levelInfo struct { type levelInfo struct {
// Our level. for better printing // Our level. for better printing
level int32 level int32
// The most recent chain generated for this level // The frequency of the last node at this level
lastChain *chain lastFreq int32
// The frequency of the next character to add to this level // The frequency of the next character to add to this level
nextCharFreq int32 nextCharFreq int32
...@@ -47,12 +37,6 @@ type levelInfo struct { ...@@ -47,12 +37,6 @@ type levelInfo struct {
// The number of chains remaining to generate for this level before moving // The number of chains remaining to generate for this level before moving
// up to the next level // up to the next level
needed int32 needed int32
// The levelInfo for level+1
up *levelInfo
// The levelInfo for level-1
down *levelInfo
} }
func maxNode() literalNode { return literalNode{math.MaxUint16, math.MaxInt32} } func maxNode() literalNode { return literalNode{math.MaxUint16, math.MaxInt32} }
...@@ -121,6 +105,8 @@ func (h *huffmanEncoder) bitLength(freq []int32) int64 { ...@@ -121,6 +105,8 @@ func (h *huffmanEncoder) bitLength(freq []int32) int64 {
return total return total
} }
const maxBitsLimit = 16
// Return the number of literals assigned to each bit size in the Huffman encoding // Return the number of literals assigned to each bit size in the Huffman encoding
// //
// This method is only called when list.length >= 3 // This method is only called when list.length >= 3
...@@ -131,9 +117,13 @@ func (h *huffmanEncoder) bitLength(freq []int32) int64 { ...@@ -131,9 +117,13 @@ func (h *huffmanEncoder) bitLength(freq []int32) int64 {
// frequency, and has as its last element a special element with frequency // frequency, and has as its last element a special element with frequency
// MaxInt32 // MaxInt32
// maxBits The maximum number of bits that should be used to encode any literal. // maxBits The maximum number of bits that should be used to encode any literal.
// Must be less than 16.
// return An integer array in which array[i] indicates the number of literals // return An integer array in which array[i] indicates the number of literals
// that should be encoded in i bits. // that should be encoded in i bits.
func (h *huffmanEncoder) bitCounts(list []literalNode, maxBits int32) []int32 { func (h *huffmanEncoder) bitCounts(list []literalNode, maxBits int32) []int32 {
if maxBits >= maxBitsLimit {
panic("flate: maxBits too large")
}
n := int32(len(list)) n := int32(len(list))
list = list[0 : n+1] list = list[0 : n+1]
list[n] = maxNode() list[n] = maxNode()
...@@ -148,53 +138,61 @@ func (h *huffmanEncoder) bitCounts(list []literalNode, maxBits int32) []int32 { ...@@ -148,53 +138,61 @@ func (h *huffmanEncoder) bitCounts(list []literalNode, maxBits int32) []int32 {
// A bogus "Level 0" whose sole purpose is so that // A bogus "Level 0" whose sole purpose is so that
// level1.prev.needed==0. This makes level1.nextPairFreq // level1.prev.needed==0. This makes level1.nextPairFreq
// be a legitimate value that never gets chosen. // be a legitimate value that never gets chosen.
top := &levelInfo{needed: 0} var levels [maxBitsLimit]levelInfo
chain2 := &chain{list[1].freq, 2, new(chain)} // leafCounts[i] counts the number of literals at the left
// of ancestors of the rightmost node at level i.
// leafCounts[i][j] is the number of literals at the left
// of the level j ancestor.
var leafCounts [maxBitsLimit][maxBitsLimit]int32
for level := int32(1); level <= maxBits; level++ { for level := int32(1); level <= maxBits; level++ {
// For every level, the first two items are the first two characters. // For every level, the first two items are the first two characters.
// We initialize the levels as if we had already figured this out. // We initialize the levels as if we had already figured this out.
top = &levelInfo{ levels[level] = levelInfo{
level: level, level: level,
lastChain: chain2, lastFreq: list[1].freq,
nextCharFreq: list[2].freq, nextCharFreq: list[2].freq,
nextPairFreq: list[0].freq + list[1].freq, nextPairFreq: list[0].freq + list[1].freq,
down: top,
} }
top.down.up = top leafCounts[level][level] = 2
if level == 1 { if level == 1 {
top.nextPairFreq = math.MaxInt32 levels[level].nextPairFreq = math.MaxInt32
} }
} }
// We need a total of 2*n - 2 items at top level and have already generated 2. // We need a total of 2*n - 2 items at top level and have already generated 2.
top.needed = 2*n - 4 levels[maxBits].needed = 2*n - 4
l := top level := maxBits
for { for {
l := &levels[level]
if l.nextPairFreq == math.MaxInt32 && l.nextCharFreq == math.MaxInt32 { if l.nextPairFreq == math.MaxInt32 && l.nextCharFreq == math.MaxInt32 {
// We've run out of both leafs and pairs. // We've run out of both leafs and pairs.
// End all calculations for this level. // End all calculations for this level.
// To m sure we never come back to this level or any lower level, // To make sure we never come back to this level or any lower level,
// set nextPairFreq impossibly large. // set nextPairFreq impossibly large.
l.lastChain = nil
l.needed = 0 l.needed = 0
l = l.up levels[level+1].nextPairFreq = math.MaxInt32
l.nextPairFreq = math.MaxInt32 level++
continue continue
} }
prevFreq := l.lastChain.freq prevFreq := l.lastFreq
if l.nextCharFreq < l.nextPairFreq { if l.nextCharFreq < l.nextPairFreq {
// The next item on this row is a leaf node. // The next item on this row is a leaf node.
n := l.lastChain.leafCount + 1 n := leafCounts[level][level] + 1
l.lastChain = &chain{l.nextCharFreq, n, l.lastChain.up} l.lastFreq = l.nextCharFreq
// Lower leafCounts are the same of the previous node.
leafCounts[level][level] = n
l.nextCharFreq = list[n].freq l.nextCharFreq = list[n].freq
} else { } else {
// The next item on this row is a pair from the previous row. // The next item on this row is a pair from the previous row.
// nextPairFreq isn't valid until we generate two // nextPairFreq isn't valid until we generate two
// more values in the level below // more values in the level below
l.lastChain = &chain{l.nextPairFreq, l.lastChain.leafCount, l.down.lastChain} l.lastFreq = l.nextPairFreq
l.down.needed = 2 // Take leaf counts from the lower level, except counts[level] remains the same.
copy(leafCounts[level][:level], leafCounts[level-1][:level])
levels[l.level-1].needed = 2
} }
if l.needed--; l.needed == 0 { if l.needed--; l.needed == 0 {
...@@ -202,33 +200,33 @@ func (h *huffmanEncoder) bitCounts(list []literalNode, maxBits int32) []int32 { ...@@ -202,33 +200,33 @@ func (h *huffmanEncoder) bitCounts(list []literalNode, maxBits int32) []int32 {
// Continue calculating one level up. Fill in nextPairFreq // Continue calculating one level up. Fill in nextPairFreq
// of that level with the sum of the two nodes we've just calculated on // of that level with the sum of the two nodes we've just calculated on
// this level. // this level.
up := l.up if l.level == maxBits {
if up == nil {
// All done! // All done!
break break
} }
up.nextPairFreq = prevFreq + l.lastChain.freq levels[l.level+1].nextPairFreq = prevFreq + l.lastFreq
l = up level++
} else { } else {
// If we stole from below, move down temporarily to replenish it. // If we stole from below, move down temporarily to replenish it.
for l.down.needed > 0 { for levels[level-1].needed > 0 {
l = l.down level--
} }
} }
} }
// Somethings is wrong if at the end, the top level is null or hasn't used // Somethings is wrong if at the end, the top level is null or hasn't used
// all of the leaves. // all of the leaves.
if top.lastChain.leafCount != n { if leafCounts[maxBits][maxBits] != n {
panic("top.lastChain.leafCount != n") panic("leafCounts[maxBits][maxBits] != n")
} }
bitCount := make([]int32, maxBits+1) bitCount := make([]int32, maxBits+1)
bits := 1 bits := 1
for chain := top.lastChain; chain.up != nil; chain = chain.up { counts := &leafCounts[maxBits]
for level := maxBits; level > 0; level-- {
// chain.leafCount gives the number of literals requiring at least "bits" // chain.leafCount gives the number of literals requiring at least "bits"
// bits to encode. // bits to encode.
bitCount[bits] = chain.leafCount - chain.up.leafCount bitCount[bits] = counts[level] - counts[level-1]
bits++ bits++
} }
return bitCount return bitCount
......
...@@ -91,6 +91,10 @@ type huffmanDecoder struct { ...@@ -91,6 +91,10 @@ type huffmanDecoder struct {
// Initialize Huffman decoding tables from array of code lengths. // Initialize Huffman decoding tables from array of code lengths.
func (h *huffmanDecoder) init(bits []int) bool { func (h *huffmanDecoder) init(bits []int) bool {
if h.min != 0 {
*h = huffmanDecoder{}
}
// Count number of codes of each length, // Count number of codes of each length,
// compute min and max length. // compute min and max length.
var count [maxCodeLen]int var count [maxCodeLen]int
...@@ -125,6 +129,9 @@ func (h *huffmanDecoder) init(bits []int) bool { ...@@ -125,6 +129,9 @@ func (h *huffmanDecoder) init(bits []int) bool {
if i == huffmanChunkBits+1 { if i == huffmanChunkBits+1 {
// create link tables // create link tables
link := code >> 1 link := code >> 1
if huffmanNumChunks < link {
return false
}
h.links = make([][]uint32, huffmanNumChunks-link) h.links = make([][]uint32, huffmanNumChunks-link)
for j := uint(link); j < huffmanNumChunks; j++ { for j := uint(link); j < huffmanNumChunks; j++ {
reverse := int(reverseByte[j>>8]) | int(reverseByte[j&0xff])<<8 reverse := int(reverseByte[j>>8]) | int(reverseByte[j&0xff])<<8
...@@ -154,7 +161,11 @@ func (h *huffmanDecoder) init(bits []int) bool { ...@@ -154,7 +161,11 @@ func (h *huffmanDecoder) init(bits []int) bool {
h.chunks[off] = chunk h.chunks[off] = chunk
} }
} else { } else {
linktab := h.links[h.chunks[reverse&(huffmanNumChunks-1)]>>huffmanValueShift] value := h.chunks[reverse&(huffmanNumChunks-1)] >> huffmanValueShift
if value >= uint32(len(h.links)) {
return false
}
linktab := h.links[value]
reverse >>= huffmanChunkBits reverse >>= huffmanChunkBits
for off := reverse; off < numLinks; off += 1 << uint(n-huffmanChunkBits) { for off := reverse; off < numLinks; off += 1 << uint(n-huffmanChunkBits) {
linktab[off] = chunk linktab[off] = chunk
...@@ -511,7 +522,7 @@ func (f *decompressor) copyHist() bool { ...@@ -511,7 +522,7 @@ func (f *decompressor) copyHist() bool {
if x := len(f.hist) - p; n > x { if x := len(f.hist) - p; n > x {
n = x n = x
} }
forwardCopy(f.hist[f.hp:f.hp+n], f.hist[p:p+n]) forwardCopy(f.hist[:], f.hp, p, n)
p += n p += n
f.hp += n f.hp += n
f.copyLen -= n f.copyLen -= n
...@@ -633,6 +644,10 @@ func (f *decompressor) huffSym(h *huffmanDecoder) (int, error) { ...@@ -633,6 +644,10 @@ func (f *decompressor) huffSym(h *huffmanDecoder) (int, error) {
if n > huffmanChunkBits { if n > huffmanChunkBits {
chunk = h.links[chunk>>huffmanValueShift][(f.b>>huffmanChunkBits)&h.linkMask] chunk = h.links[chunk>>huffmanValueShift][(f.b>>huffmanChunkBits)&h.linkMask]
n = uint(chunk & huffmanCountMask) n = uint(chunk & huffmanCountMask)
if n == 0 {
f.err = CorruptInputError(f.roffset)
return 0, f.err
}
} }
if n <= f.nb { if n <= f.nb {
f.b >>= n f.b >>= n
......
...@@ -37,6 +37,7 @@ var testfiles = []string{ ...@@ -37,6 +37,7 @@ var testfiles = []string{
} }
func benchmarkDecode(b *testing.B, testfile, level, n int) { func benchmarkDecode(b *testing.B, testfile, level, n int) {
b.ReportAllocs()
b.StopTimer() b.StopTimer()
b.SetBytes(int64(n)) b.SetBytes(int64(n))
buf0, err := ioutil.ReadFile(testfiles[testfile]) buf0, err := ioutil.ReadFile(testfiles[testfile])
...@@ -55,7 +56,7 @@ func benchmarkDecode(b *testing.B, testfile, level, n int) { ...@@ -55,7 +56,7 @@ func benchmarkDecode(b *testing.B, testfile, level, n int) {
if len(buf0) > n-i { if len(buf0) > n-i {
buf0 = buf0[:n-i] buf0 = buf0[:n-i]
} }
io.Copy(w, bytes.NewBuffer(buf0)) io.Copy(w, bytes.NewReader(buf0))
} }
w.Close() w.Close()
buf1 := compressed.Bytes() buf1 := compressed.Bytes()
...@@ -63,7 +64,7 @@ func benchmarkDecode(b *testing.B, testfile, level, n int) { ...@@ -63,7 +64,7 @@ func benchmarkDecode(b *testing.B, testfile, level, n int) {
runtime.GC() runtime.GC()
b.StartTimer() b.StartTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
io.Copy(ioutil.Discard, NewReader(bytes.NewBuffer(buf1))) io.Copy(ioutil.Discard, NewReader(bytes.NewReader(buf1)))
} }
} }
......
...@@ -7,7 +7,10 @@ package gzip ...@@ -7,7 +7,10 @@ package gzip
import ( import (
"bytes" "bytes"
"io" "io"
"io/ioutil"
"os"
"testing" "testing"
"time"
) )
type gunzipTest struct { type gunzipTest struct {
...@@ -302,3 +305,31 @@ func TestDecompressor(t *testing.T) { ...@@ -302,3 +305,31 @@ func TestDecompressor(t *testing.T) {
} }
} }
} }
func TestIssue6550(t *testing.T) {
f, err := os.Open("testdata/issue6550.gz")
if err != nil {
t.Fatal(err)
}
gzip, err := NewReader(f)
if err != nil {
t.Fatalf("NewReader(testdata/issue6550.gz): %v", err)
}
defer gzip.Close()
done := make(chan bool, 1)
go func() {
_, err := io.Copy(ioutil.Discard, gzip)
if err == nil {
t.Errorf("Copy succeeded")
} else {
t.Logf("Copy failed (correctly): %v", err)
}
done <- true
}()
select {
case <-time.After(1 * time.Second):
t.Errorf("Copy hung")
case <-done:
// ok
}
}
...@@ -28,6 +28,7 @@ type Writer struct { ...@@ -28,6 +28,7 @@ type Writer struct {
Header Header
w io.Writer w io.Writer
level int level int
wroteHeader bool
compressor *flate.Writer compressor *flate.Writer
digest hash.Hash32 digest hash.Hash32
size uint32 size uint32
...@@ -62,14 +63,39 @@ func NewWriterLevel(w io.Writer, level int) (*Writer, error) { ...@@ -62,14 +63,39 @@ func NewWriterLevel(w io.Writer, level int) (*Writer, error) {
if level < DefaultCompression || level > BestCompression { if level < DefaultCompression || level > BestCompression {
return nil, fmt.Errorf("gzip: invalid compression level: %d", level) return nil, fmt.Errorf("gzip: invalid compression level: %d", level)
} }
return &Writer{ z := new(Writer)
z.init(w, level)
return z, nil
}
func (z *Writer) init(w io.Writer, level int) {
digest := z.digest
if digest != nil {
digest.Reset()
} else {
digest = crc32.NewIEEE()
}
compressor := z.compressor
if compressor != nil {
compressor.Reset(w)
}
*z = Writer{
Header: Header{ Header: Header{
OS: 255, // unknown OS: 255, // unknown
}, },
w: w, w: w,
level: level, level: level,
digest: crc32.NewIEEE(), digest: digest,
}, nil compressor: compressor,
}
}
// Reset discards the Writer z's state and makes it equivalent to the
// result of its original state from NewWriter or NewWriterLevel, but
// writing to w instead. This permits reusing a Writer rather than
// allocating a new one.
func (z *Writer) Reset(w io.Writer) {
z.init(w, z.level)
} }
// GZIP (RFC 1952) is little-endian, unlike ZLIB (RFC 1950). // GZIP (RFC 1952) is little-endian, unlike ZLIB (RFC 1950).
...@@ -138,7 +164,8 @@ func (z *Writer) Write(p []byte) (int, error) { ...@@ -138,7 +164,8 @@ func (z *Writer) Write(p []byte) (int, error) {
} }
var n int var n int
// Write the GZIP header lazily. // Write the GZIP header lazily.
if z.compressor == nil { if !z.wroteHeader {
z.wroteHeader = true
z.buf[0] = gzipID1 z.buf[0] = gzipID1
z.buf[1] = gzipID2 z.buf[1] = gzipID2
z.buf[2] = gzipDeflate z.buf[2] = gzipDeflate
...@@ -183,8 +210,10 @@ func (z *Writer) Write(p []byte) (int, error) { ...@@ -183,8 +210,10 @@ func (z *Writer) Write(p []byte) (int, error) {
return n, z.err return n, z.err
} }
} }
if z.compressor == nil {
z.compressor, _ = flate.NewWriter(z.w, z.level) z.compressor, _ = flate.NewWriter(z.w, z.level)
} }
}
z.size += uint32(len(p)) z.size += uint32(len(p))
z.digest.Write(p) z.digest.Write(p)
n, z.err = z.compressor.Write(p) n, z.err = z.compressor.Write(p)
...@@ -206,8 +235,11 @@ func (z *Writer) Flush() error { ...@@ -206,8 +235,11 @@ func (z *Writer) Flush() error {
if z.closed { if z.closed {
return nil return nil
} }
if z.compressor == nil { if !z.wroteHeader {
z.Write(nil) z.Write(nil)
if z.err != nil {
return z.err
}
} }
z.err = z.compressor.Flush() z.err = z.compressor.Flush()
return z.err return z.err
...@@ -222,7 +254,7 @@ func (z *Writer) Close() error { ...@@ -222,7 +254,7 @@ func (z *Writer) Close() error {
return nil return nil
} }
z.closed = true z.closed = true
if z.compressor == nil { if !z.wroteHeader {
z.Write(nil) z.Write(nil)
if z.err != nil { if z.err != nil {
return z.err return z.err
......
...@@ -197,3 +197,35 @@ func TestWriterFlush(t *testing.T) { ...@@ -197,3 +197,35 @@ func TestWriterFlush(t *testing.T) {
t.Fatal("Flush didn't flush any data") t.Fatal("Flush didn't flush any data")
} }
} }
// Multiple gzip files concatenated form a valid gzip file.
func TestConcat(t *testing.T) {
var buf bytes.Buffer
w := NewWriter(&buf)
w.Write([]byte("hello "))
w.Close()
w = NewWriter(&buf)
w.Write([]byte("world\n"))
w.Close()
r, err := NewReader(&buf)
data, err := ioutil.ReadAll(r)
if string(data) != "hello world\n" || err != nil {
t.Fatalf("ReadAll = %q, %v, want %q, nil", data, err, "hello world")
}
}
func TestWriterReset(t *testing.T) {
buf := new(bytes.Buffer)
buf2 := new(bytes.Buffer)
z := NewWriter(buf)
msg := []byte("hello world")
z.Write(msg)
z.Close()
z.Reset(buf2)
z.Write(msg)
z.Close()
if buf.String() != buf2.String() {
t.Errorf("buf2 %q != original buf of %q", buf2.String(), buf.String())
}
}
...@@ -70,6 +70,23 @@ func NewWriterLevelDict(w io.Writer, level int, dict []byte) (*Writer, error) { ...@@ -70,6 +70,23 @@ func NewWriterLevelDict(w io.Writer, level int, dict []byte) (*Writer, error) {
}, nil }, nil
} }
// Reset clears the state of the Writer z such that it is equivalent to its
// initial state from NewWriterLevel or NewWriterLevelDict, but instead writing
// to w.
func (z *Writer) Reset(w io.Writer) {
z.w = w
// z.level and z.dict left unchanged.
if z.compressor != nil {
z.compressor.Reset(w)
}
if z.digest != nil {
z.digest.Reset()
}
z.err = nil
z.scratch = [4]byte{}
z.wroteHeader = false
}
// writeHeader writes the ZLIB header. // writeHeader writes the ZLIB header.
func (z *Writer) writeHeader() (err error) { func (z *Writer) writeHeader() (err error) {
z.wroteHeader = true z.wroteHeader = true
...@@ -111,11 +128,15 @@ func (z *Writer) writeHeader() (err error) { ...@@ -111,11 +128,15 @@ func (z *Writer) writeHeader() (err error) {
return err return err
} }
} }
if z.compressor == nil {
// Initialize deflater unless the Writer is being reused
// after a Reset call.
z.compressor, err = flate.NewWriterDict(z.w, z.level, z.dict) z.compressor, err = flate.NewWriterDict(z.w, z.level, z.dict)
if err != nil { if err != nil {
return err return err
} }
z.digest = adler32.New() z.digest = adler32.New()
}
return nil return nil
} }
......
...@@ -89,6 +89,56 @@ func testLevelDict(t *testing.T, fn string, b0 []byte, level int, d string) { ...@@ -89,6 +89,56 @@ func testLevelDict(t *testing.T, fn string, b0 []byte, level int, d string) {
} }
} }
func testFileLevelDictReset(t *testing.T, fn string, level int, dict []byte) {
var b0 []byte
var err error
if fn != "" {
b0, err = ioutil.ReadFile(fn)
if err != nil {
t.Errorf("%s (level=%d): %v", fn, level, err)
return
}
}
// Compress once.
buf := new(bytes.Buffer)
var zlibw *Writer
if dict == nil {
zlibw, err = NewWriterLevel(buf, level)
} else {
zlibw, err = NewWriterLevelDict(buf, level, dict)
}
if err == nil {
_, err = zlibw.Write(b0)
}
if err == nil {
err = zlibw.Close()
}
if err != nil {
t.Errorf("%s (level=%d): %v", fn, level, err)
return
}
out := buf.String()
// Reset and comprses again.
buf2 := new(bytes.Buffer)
zlibw.Reset(buf2)
_, err = zlibw.Write(b0)
if err == nil {
err = zlibw.Close()
}
if err != nil {
t.Errorf("%s (level=%d): %v", fn, level, err)
return
}
out2 := buf2.String()
if out2 != out {
t.Errorf("%s (level=%d): different output after reset (got %d bytes, expected %d",
fn, level, len(out2), len(out))
}
}
func TestWriter(t *testing.T) { func TestWriter(t *testing.T) {
for i, s := range data { for i, s := range data {
b := []byte(s) b := []byte(s)
...@@ -122,6 +172,21 @@ func TestWriterDict(t *testing.T) { ...@@ -122,6 +172,21 @@ func TestWriterDict(t *testing.T) {
} }
} }
func TestWriterReset(t *testing.T) {
const dictionary = "0123456789."
for _, fn := range filenames {
testFileLevelDictReset(t, fn, NoCompression, nil)
testFileLevelDictReset(t, fn, DefaultCompression, nil)
testFileLevelDictReset(t, fn, NoCompression, []byte(dictionary))
testFileLevelDictReset(t, fn, DefaultCompression, []byte(dictionary))
if !testing.Short() {
for level := BestSpeed; level <= BestCompression; level++ {
testFileLevelDictReset(t, fn, level, nil)
}
}
}
}
func TestWriterDictIsUsed(t *testing.T) { func TestWriterDictIsUsed(t *testing.T) {
var input = []byte("Lorem ipsum dolor sit amet, consectetur adipisicing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") var input = []byte("Lorem ipsum dolor sit amet, consectetur adipisicing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
var buf bytes.Buffer var buf bytes.Buffer
......
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
// heap.Interface. A heap is a tree with the property that each node is the // heap.Interface. A heap is a tree with the property that each node is the
// minimum-valued node in its subtree. // minimum-valued node in its subtree.
// //
// The minimum element in the tree is the root, at index 0.
//
// A heap is a common way to implement a priority queue. To build a priority // A heap is a common way to implement a priority queue. To build a priority
// queue, implement the Heap interface with the (negative) priority as the // queue, implement the Heap interface with the (negative) priority as the
// ordering for the Less method, so Push adds items while Pop removes the // ordering for the Less method, so Push adds items while Pop removes the
...@@ -54,7 +56,7 @@ func Push(h Interface, x interface{}) { ...@@ -54,7 +56,7 @@ func Push(h Interface, x interface{}) {
// Pop removes the minimum element (according to Less) from the heap // Pop removes the minimum element (according to Less) from the heap
// and returns it. The complexity is O(log(n)) where n = h.Len(). // and returns it. The complexity is O(log(n)) where n = h.Len().
// Same as Remove(h, 0). // It is equivalent to Remove(h, 0).
// //
func Pop(h Interface) interface{} { func Pop(h Interface) interface{} {
n := h.Len() - 1 n := h.Len() - 1
...@@ -76,6 +78,15 @@ func Remove(h Interface, i int) interface{} { ...@@ -76,6 +78,15 @@ func Remove(h Interface, i int) interface{} {
return h.Pop() return h.Pop()
} }
// Fix reestablishes the heap ordering after the element at index i has changed its value.
// Changing the value of the element at index i and then calling Fix is equivalent to,
// but less expensive than, calling Remove(h, i) followed by a Push of the new value.
// The complexity is O(log(n)) where n = h.Len().
func Fix(h Interface, i int) {
down(h, i, h.Len())
up(h, i)
}
func up(h Interface, j int) { func up(h Interface, j int) {
for { for {
i := (j - 1) / 2 // parent i := (j - 1) / 2 // parent
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
package heap package heap
import ( import (
"math/rand"
"testing" "testing"
) )
...@@ -182,3 +183,31 @@ func BenchmarkDup(b *testing.B) { ...@@ -182,3 +183,31 @@ func BenchmarkDup(b *testing.B) {
} }
} }
} }
func TestFix(t *testing.T) {
h := new(myHeap)
h.verify(t, 0)
for i := 200; i > 0; i -= 10 {
Push(h, i)
}
h.verify(t, 0)
if (*h)[0] != 10 {
t.Fatalf("Expected head to be 10, was %d", (*h)[0])
}
(*h)[0] = 210
Fix(h, 0)
h.verify(t, 0)
for i := 100; i > 0; i-- {
elem := rand.Intn(h.Len())
if i&1 == 0 {
(*h)[elem] *= 2
} else {
(*h)[elem] /= 2
}
Fix(h, elem)
h.verify(t, 0)
}
}
...@@ -29,7 +29,7 @@ type Element struct { ...@@ -29,7 +29,7 @@ type Element struct {
// Next returns the next list element or nil. // Next returns the next list element or nil.
func (e *Element) Next() *Element { func (e *Element) Next() *Element {
if p := e.next; p != &e.list.root { if p := e.next; e.list != nil && p != &e.list.root {
return p return p
} }
return nil return nil
...@@ -37,7 +37,7 @@ func (e *Element) Next() *Element { ...@@ -37,7 +37,7 @@ func (e *Element) Next() *Element {
// Prev returns the previous list element or nil. // Prev returns the previous list element or nil.
func (e *Element) Prev() *Element { func (e *Element) Prev() *Element {
if p := e.prev; p != &e.list.root { if p := e.prev; e.list != nil && p != &e.list.root {
return p return p
} }
return nil return nil
...@@ -62,6 +62,7 @@ func (l *List) Init() *List { ...@@ -62,6 +62,7 @@ func (l *List) Init() *List {
func New() *List { return new(List).Init() } func New() *List { return new(List).Init() }
// Len returns the number of elements of list l. // Len returns the number of elements of list l.
// The complexity is O(1).
func (l *List) Len() int { return l.len } func (l *List) Len() int { return l.len }
// Front returns the first element of list l or nil // Front returns the first element of list l or nil
...@@ -126,7 +127,7 @@ func (l *List) Remove(e *Element) interface{} { ...@@ -126,7 +127,7 @@ func (l *List) Remove(e *Element) interface{} {
return e.Value return e.Value
} }
// Pushfront inserts a new element e with value v at the front of list l and returns e. // PushFront inserts a new element e with value v at the front of list l and returns e.
func (l *List) PushFront(v interface{}) *Element { func (l *List) PushFront(v interface{}) *Element {
l.lazyInit() l.lazyInit()
return l.insertValue(v, &l.root) return l.insertValue(v, &l.root)
...@@ -178,6 +179,24 @@ func (l *List) MoveToBack(e *Element) { ...@@ -178,6 +179,24 @@ func (l *List) MoveToBack(e *Element) {
l.insert(l.remove(e), l.root.prev) l.insert(l.remove(e), l.root.prev)
} }
// MoveBefore moves element e to its new position before mark.
// If e is not an element of l, or e == mark, the list is not modified.
func (l *List) MoveBefore(e, mark *Element) {
if e.list != l || e == mark {
return
}
l.insert(l.remove(e), mark.prev)
}
// MoveAfter moves element e to its new position after mark.
// If e is not an element of l, or e == mark, the list is not modified.
func (l *List) MoveAfter(e, mark *Element) {
if e.list != l || e == mark {
return
}
l.insert(l.remove(e), mark)
}
// PushBackList inserts a copy of an other list at the back of list l. // PushBackList inserts a copy of an other list at the back of list l.
// The lists l and other may be the same. // The lists l and other may be the same.
func (l *List) PushBackList(other *List) { func (l *List) PushBackList(other *List) {
......
...@@ -233,3 +233,55 @@ func TestIssue4103(t *testing.T) { ...@@ -233,3 +233,55 @@ func TestIssue4103(t *testing.T) {
t.Errorf("l1.Len() = %d, want 3", n) t.Errorf("l1.Len() = %d, want 3", n)
} }
} }
func TestIssue6349(t *testing.T) {
l := New()
l.PushBack(1)
l.PushBack(2)
e := l.Front()
l.Remove(e)
if e.Value != 1 {
t.Errorf("e.value = %d, want 1", e.Value)
}
if e.Next() != nil {
t.Errorf("e.Next() != nil")
}
if e.Prev() != nil {
t.Errorf("e.Prev() != nil")
}
}
func TestMove(t *testing.T) {
l := New()
e1 := l.PushBack(1)
e2 := l.PushBack(2)
e3 := l.PushBack(3)
e4 := l.PushBack(4)
l.MoveAfter(e3, e3)
checkListPointers(t, l, []*Element{e1, e2, e3, e4})
l.MoveBefore(e2, e2)
checkListPointers(t, l, []*Element{e1, e2, e3, e4})
l.MoveAfter(e3, e2)
checkListPointers(t, l, []*Element{e1, e2, e3, e4})
l.MoveBefore(e2, e3)
checkListPointers(t, l, []*Element{e1, e2, e3, e4})
l.MoveBefore(e2, e4)
checkListPointers(t, l, []*Element{e1, e3, e2, e4})
e1, e2, e3, e4 = e1, e3, e2, e4
l.MoveBefore(e4, e1)
checkListPointers(t, l, []*Element{e4, e1, e2, e3})
e1, e2, e3, e4 = e4, e1, e2, e3
l.MoveAfter(e4, e1)
checkListPointers(t, l, []*Element{e1, e4, e2, e3})
e1, e2, e3, e4 = e1, e4, e2, e3
l.MoveAfter(e2, e3)
checkListPointers(t, l, []*Element{e1, e3, e2, e4})
e1, e2, e3, e4 = e1, e3, e2, e4
}
...@@ -61,6 +61,13 @@ func (x *cbcEncrypter) CryptBlocks(dst, src []byte) { ...@@ -61,6 +61,13 @@ func (x *cbcEncrypter) CryptBlocks(dst, src []byte) {
} }
} }
func (x *cbcEncrypter) SetIV(iv []byte) {
if len(iv) != len(x.iv) {
panic("cipher: incorrect length IV")
}
copy(x.iv, iv)
}
type cbcDecrypter cbc type cbcDecrypter cbc
// NewCBCDecrypter returns a BlockMode which decrypts in cipher block chaining // NewCBCDecrypter returns a BlockMode which decrypts in cipher block chaining
...@@ -94,3 +101,10 @@ func (x *cbcDecrypter) CryptBlocks(dst, src []byte) { ...@@ -94,3 +101,10 @@ func (x *cbcDecrypter) CryptBlocks(dst, src []byte) {
dst = dst[x.blockSize:] dst = dst[x.blockSize:]
} }
} }
func (x *cbcDecrypter) SetIV(iv []byte) {
if len(iv) != len(x.iv) {
panic("cipher: incorrect length IV")
}
copy(x.iv, iv)
}
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cipher_test
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"encoding/hex"
"testing"
)
// AES-GCM test vectors taken from gcmEncryptExtIV128.rsp from
// http://csrc.nist.gov/groups/STM/cavp/index.html.
var aesGCMTests = []struct {
key, nonce, plaintext, ad, result string
}{
{
"11754cd72aec309bf52f7687212e8957",
"3c819d9a9bed087615030b65",
"",
"",
"250327c674aaf477aef2675748cf6971",
},
{
"ca47248ac0b6f8372a97ac43508308ed",
"ffd2b598feabc9019262d2be",
"",
"",
"60d20404af527d248d893ae495707d1a",
},
{
"77be63708971c4e240d1cb79e8d77feb",
"e0e00f19fed7ba0136a797f3",
"",
"7a43ec1d9c0a5a78a0b16533a6213cab",
"209fcc8d3675ed938e9c7166709dd946",
},
{
"7680c5d3ca6154758e510f4d25b98820",
"f8f105f9c3df4965780321f8",
"",
"c94c410194c765e3dcc7964379758ed3",
"94dca8edfcf90bb74b153c8d48a17930",
},
{
"7fddb57453c241d03efbed3ac44e371c",
"ee283a3fc75575e33efd4887",
"d5de42b461646c255c87bd2962d3b9a2",
"",
"2ccda4a5415cb91e135c2a0f78c9b2fdb36d1df9b9d5e596f83e8b7f52971cb3",
},
{
"ab72c77b97cb5fe9a382d9fe81ffdbed",
"54cc7dc2c37ec006bcc6d1da",
"007c5e5b3e59df24a7c355584fc1518d",
"",
"0e1bde206a07a9c2c1b65300f8c649972b4401346697138c7a4891ee59867d0c",
},
{
"fe47fcce5fc32665d2ae399e4eec72ba",
"5adb9609dbaeb58cbd6e7275",
"7c0e88c88899a779228465074797cd4c2e1498d259b54390b85e3eef1c02df60e743f1b840382c4bccaf3bafb4ca8429bea063",
"88319d6e1d3ffa5f987199166c8a9b56c2aeba5a",
"98f4826f05a265e6dd2be82db241c0fbbbf9ffb1c173aa83964b7cf5393043736365253ddbc5db8778371495da76d269e5db3e291ef1982e4defedaa2249f898556b47",
},
{
"ec0c2ba17aa95cd6afffe949da9cc3a8",
"296bce5b50b7d66096d627ef",
"b85b3753535b825cbe5f632c0b843c741351f18aa484281aebec2f45bb9eea2d79d987b764b9611f6c0f8641843d5d58f3a242",
"f8d00f05d22bf68599bcdeb131292ad6e2df5d14",
"a7443d31c26bdf2a1c945e29ee4bd344a99cfaf3aa71f8b3f191f83c2adfc7a07162995506fde6309ffc19e716eddf1a828c5a890147971946b627c40016da1ecf3e77",
},
{
"2c1f21cf0f6fb3661943155c3e3d8492",
"23cb5ff362e22426984d1907",
"42f758836986954db44bf37c6ef5e4ac0adaf38f27252a1b82d02ea949c8a1a2dbc0d68b5615ba7c1220ff6510e259f06655d8",
"5d3624879d35e46849953e45a32a624d6a6c536ed9857c613b572b0333e701557a713e3f010ecdf9a6bd6c9e3e44b065208645aff4aabee611b391528514170084ccf587177f4488f33cfb5e979e42b6e1cfc0a60238982a7aec",
"81824f0e0d523db30d3da369fdc0d60894c7a0a20646dd015073ad2732bd989b14a222b6ad57af43e1895df9dca2a5344a62cc57a3ee28136e94c74838997ae9823f3a",
},
{
"d9f7d2411091f947b4d6f1e2d1f0fb2e",
"e1934f5db57cc983e6b180e7",
"73ed042327f70fe9c572a61545eda8b2a0c6e1d6c291ef19248e973aee6c312012f490c2c6f6166f4a59431e182663fcaea05a",
"0a8a18a7150e940c3d87b38e73baee9a5c049ee21795663e264b694a949822b639092d0e67015e86363583fcf0ca645af9f43375f05fdb4ce84f411dcbca73c2220dea03a20115d2e51398344b16bee1ed7c499b353d6c597af8",
"aaadbd5c92e9151ce3db7210b8714126b73e43436d242677afa50384f2149b831f1d573c7891c2a91fbc48db29967ec9542b2321b51ca862cb637cdd03b99a0f93b134",
},
{
"fe9bb47deb3a61e423c2231841cfd1fb",
"4d328eb776f500a2f7fb47aa",
"f1cc3818e421876bb6b8bbd6c9",
"",
"b88c5c1977b35b517b0aeae96743fd4727fe5cdb4b5b42818dea7ef8c9",
},
{
"6703df3701a7f54911ca72e24dca046a",
"12823ab601c350ea4bc2488c",
"793cd125b0b84a043e3ac67717",
"",
"b2051c80014f42f08735a7b0cd38e6bcd29962e5f2c13626b85a877101",
},
}
func TestAESGCM(t *testing.T) {
for i, test := range aesGCMTests {
key, _ := hex.DecodeString(test.key)
aes, err := aes.NewCipher(key)
if err != nil {
t.Fatal(err)
}
nonce, _ := hex.DecodeString(test.nonce)
plaintext, _ := hex.DecodeString(test.plaintext)
ad, _ := hex.DecodeString(test.ad)
aesgcm, err := cipher.NewGCM(aes)
if err != nil {
t.Fatal(err)
}
ct := aesgcm.Seal(nil, nonce, plaintext, ad)
if ctHex := hex.EncodeToString(ct); ctHex != test.result {
t.Errorf("#%d: got %s, want %s", i, ctHex, test.result)
continue
}
plaintext2, err := aesgcm.Open(nil, nonce, ct, ad)
if err != nil {
t.Errorf("#%d: Open failed", i)
continue
}
if !bytes.Equal(plaintext, plaintext2) {
t.Errorf("#%d: plaintext's don't match: got %x vs %x", i, plaintext2, plaintext)
continue
}
if len(ad) > 0 {
ad[0] ^= 0x80
if _, err := aesgcm.Open(nil, nonce, ct, ad); err == nil {
t.Errorf("#%d: Open was successful after altering additional data", i)
}
ad[0] ^= 0x80
}
nonce[0] ^= 0x80
if _, err := aesgcm.Open(nil, nonce, ct, ad); err == nil {
t.Errorf("#%d: Open was successful after altering nonce", i)
}
nonce[0] ^= 0x80
ct[0] ^= 0x80
if _, err := aesgcm.Open(nil, nonce, ct, ad); err == nil {
t.Errorf("#%d: Open was successful after altering ciphertext", i)
}
ct[0] ^= 0x80
}
}
func BenchmarkAESGCM(b *testing.B) {
buf := make([]byte, 1024)
b.SetBytes(int64(len(buf)))
var key [16]byte
var nonce [12]byte
aes, _ := aes.NewCipher(key[:])
aesgcm, _ := cipher.NewGCM(aes)
var out []byte
b.ResetTimer()
for i := 0; i < b.N; i++ {
out = aesgcm.Seal(out[:0], nonce[:], buf, nonce[:])
}
}
...@@ -25,6 +25,8 @@ func (r StreamReader) Read(dst []byte) (n int, err error) { ...@@ -25,6 +25,8 @@ func (r StreamReader) Read(dst []byte) (n int, err error) {
// StreamWriter wraps a Stream into an io.Writer. It calls XORKeyStream // StreamWriter wraps a Stream into an io.Writer. It calls XORKeyStream
// to process each slice of data which passes through. If any Write call // to process each slice of data which passes through. If any Write call
// returns short then the StreamWriter is out of sync and must be discarded. // returns short then the StreamWriter is out of sync and must be discarded.
// A StreamWriter has no internal buffering; Close does not need
// to be called to flush write data.
type StreamWriter struct { type StreamWriter struct {
S Stream S Stream
W io.Writer W io.Writer
...@@ -43,8 +45,11 @@ func (w StreamWriter) Write(src []byte) (n int, err error) { ...@@ -43,8 +45,11 @@ func (w StreamWriter) Write(src []byte) (n int, err error) {
return return
} }
// Close closes the underlying Writer and returns its Close return value, if the Writer
// is also an io.Closer. Otherwise it returns nil.
func (w StreamWriter) Close() error { func (w StreamWriter) Close() error {
// This saves us from either requiring a WriteCloser or having a if c, ok := w.W.(io.Closer); ok {
// StreamWriterCloser. return c.Close()
return w.W.(io.Closer).Close() }
return nil
} }
...@@ -7,6 +7,7 @@ package crypto ...@@ -7,6 +7,7 @@ package crypto
import ( import (
"hash" "hash"
"strconv"
) )
// Hash identifies a cryptographic hash function that is implemented in another // Hash identifies a cryptographic hash function that is implemented in another
...@@ -59,7 +60,7 @@ func (h Hash) New() hash.Hash { ...@@ -59,7 +60,7 @@ func (h Hash) New() hash.Hash {
return f() return f()
} }
} }
panic("crypto: requested hash function is unavailable") panic("crypto: requested hash function #" + strconv.Itoa(int(h)) + " is unavailable")
} }
// Available reports whether the given hash function is linked into the binary. // Available reports whether the given hash function is linked into the binary.
...@@ -77,5 +78,8 @@ func RegisterHash(h Hash, f func() hash.Hash) { ...@@ -77,5 +78,8 @@ func RegisterHash(h Hash, f func() hash.Hash) {
hashes[h] = f hashes[h] = f
} }
// PublicKey represents a public key using an unspecified algorithm.
type PublicKey interface{}
// PrivateKey represents a private key using an unspecified algorithm. // PrivateKey represents a private key using an unspecified algorithm.
type PrivateKey interface{} type PrivateKey interface{}
...@@ -10,7 +10,7 @@ import ( ...@@ -10,7 +10,7 @@ import (
func cryptBlock(subkeys []uint64, dst, src []byte, decrypt bool) { func cryptBlock(subkeys []uint64, dst, src []byte, decrypt bool) {
b := binary.BigEndian.Uint64(src) b := binary.BigEndian.Uint64(src)
b = permuteBlock(b, initialPermutation[:]) b = permuteInitialBlock(b)
left, right := uint32(b>>32), uint32(b) left, right := uint32(b>>32), uint32(b)
var subkey uint64 var subkey uint64
...@@ -25,7 +25,7 @@ func cryptBlock(subkeys []uint64, dst, src []byte, decrypt bool) { ...@@ -25,7 +25,7 @@ func cryptBlock(subkeys []uint64, dst, src []byte, decrypt bool) {
} }
// switch left & right and perform final permutation // switch left & right and perform final permutation
preOutput := (uint64(right) << 32) | uint64(left) preOutput := (uint64(right) << 32) | uint64(left)
binary.BigEndian.PutUint64(dst, permuteBlock(preOutput, finalPermutation[:])) binary.BigEndian.PutUint64(dst, permuteFinalBlock(preOutput))
} }
// Encrypt one block from src into dst, using the subkeys. // Encrypt one block from src into dst, using the subkeys.
...@@ -40,20 +40,24 @@ func decryptBlock(subkeys []uint64, dst, src []byte) { ...@@ -40,20 +40,24 @@ func decryptBlock(subkeys []uint64, dst, src []byte) {
// DES Feistel function // DES Feistel function
func feistel(right uint32, key uint64) (result uint32) { func feistel(right uint32, key uint64) (result uint32) {
sBoxLocations := key ^ permuteBlock(uint64(right), expansionFunction[:]) sBoxLocations := key ^ expandBlock(right)
var sBoxResult uint32 var sBoxResult uint32
for i := uint8(0); i < 8; i++ { for i := uint8(0); i < 8; i++ {
sBoxLocation := uint8(sBoxLocations>>42) & 0x3f sBoxLocation := uint8(sBoxLocations>>42) & 0x3f
sBoxLocations <<= 6 sBoxLocations <<= 6
// row determined by 1st and 6th bit // row determined by 1st and 6th bit
row := (sBoxLocation & 0x1) | ((sBoxLocation & 0x20) >> 4)
// column is middle four bits // column is middle four bits
row := (sBoxLocation & 0x1) | ((sBoxLocation & 0x20) >> 4)
column := (sBoxLocation >> 1) & 0xf column := (sBoxLocation >> 1) & 0xf
sBoxResult |= uint32(sBoxes[i][row][column]) << (4 * (7 - i)) sBoxResult ^= feistelBox[i][16*row+column]
} }
return uint32(permuteBlock(uint64(sBoxResult), permutationFunction[:])) return sBoxResult
} }
// feistelBox[s][16*i+j] contains the output of permutationFunction
// for sBoxes[s][i][j] << 4*(7-s)
var feistelBox [8][64]uint32
// general purpose function to perform DES block permutations // general purpose function to perform DES block permutations
func permuteBlock(src uint64, permutation []uint8) (block uint64) { func permuteBlock(src uint64, permutation []uint8) (block uint64) {
for position, n := range permutation { for position, n := range permutation {
...@@ -63,6 +67,127 @@ func permuteBlock(src uint64, permutation []uint8) (block uint64) { ...@@ -63,6 +67,127 @@ func permuteBlock(src uint64, permutation []uint8) (block uint64) {
return return
} }
func init() {
for s := range sBoxes {
for i := 0; i < 4; i++ {
for j := 0; j < 16; j++ {
f := uint64(sBoxes[s][i][j]) << (4 * (7 - uint(s)))
f = permuteBlock(uint64(f), permutationFunction[:])
feistelBox[s][16*i+j] = uint32(f)
}
}
}
}
// expandBlock expands an input block of 32 bits,
// producing an output block of 48 bits.
func expandBlock(src uint32) (block uint64) {
// rotate the 5 highest bits to the right.
src = (src << 5) | (src >> 27)
for i := 0; i < 8; i++ {
block <<= 6
// take the 6 bits on the right
block |= uint64(src) & (1<<6 - 1)
// advance by 4 bits.
src = (src << 4) | (src >> 28)
}
return
}
// permuteInitialBlock is equivalent to the permutation defined
// by initialPermutation.
func permuteInitialBlock(block uint64) uint64 {
// block = b7 b6 b5 b4 b3 b2 b1 b0 (8 bytes)
b1 := block >> 48
b2 := block << 48
block ^= b1 ^ b2 ^ b1<<48 ^ b2>>48
// block = b1 b0 b5 b4 b3 b2 b7 b6
b1 = block >> 32 & 0xff00ff
b2 = (block & 0xff00ff00)
block ^= b1<<32 ^ b2 ^ b1<<8 ^ b2<<24 // exchange b0 b4 with b3 b7
// block is now b1 b3 b5 b7 b0 b2 b4 b7, the permutation:
// ... 8
// ... 24
// ... 40
// ... 56
// 7 6 5 4 3 2 1 0
// 23 22 21 20 19 18 17 16
// ... 32
// ... 48
// exchange 4,5,6,7 with 32,33,34,35 etc.
b1 = block & 0x0f0f00000f0f0000
b2 = block & 0x0000f0f00000f0f0
block ^= b1 ^ b2 ^ b1>>12 ^ b2<<12
// block is the permutation:
//
// [+8] [+40]
//
// 7 6 5 4
// 23 22 21 20
// 3 2 1 0
// 19 18 17 16 [+32]
// exchange 0,1,4,5 with 18,19,22,23
b1 = block & 0x3300330033003300
b2 = block & 0x00cc00cc00cc00cc
block ^= b1 ^ b2 ^ b1>>6 ^ b2<<6
// block is the permutation:
// 15 14
// 13 12
// 11 10
// 9 8
// 7 6
// 5 4
// 3 2
// 1 0 [+16] [+32] [+64]
// exchange 0,2,4,6 with 9,11,13,15:
b1 = block & 0xaaaaaaaa55555555
block ^= b1 ^ b1>>33 ^ b1<<33
// block is the permutation:
// 6 14 22 30 38 46 54 62
// 4 12 20 28 36 44 52 60
// 2 10 18 26 34 42 50 58
// 0 8 16 24 32 40 48 56
// 7 15 23 31 39 47 55 63
// 5 13 21 29 37 45 53 61
// 3 11 19 27 35 43 51 59
// 1 9 17 25 33 41 49 57
return block
}
// permuteInitialBlock is equivalent to the permutation defined
// by finalPermutation.
func permuteFinalBlock(block uint64) uint64 {
// Perform the same bit exchanges as permuteInitialBlock
// but in reverse order.
b1 := block & 0xaaaaaaaa55555555
block ^= b1 ^ b1>>33 ^ b1<<33
b1 = block & 0x3300330033003300
b2 := block & 0x00cc00cc00cc00cc
block ^= b1 ^ b2 ^ b1>>6 ^ b2<<6
b1 = block & 0x0f0f00000f0f0000
b2 = block & 0x0000f0f00000f0f0
block ^= b1 ^ b2 ^ b1>>12 ^ b2<<12
b1 = block >> 32 & 0xff00ff
b2 = (block & 0xff00ff00)
block ^= b1<<32 ^ b2 ^ b1<<8 ^ b2<<24
b1 = block >> 48
b2 = block << 48
block ^= b1 ^ b2 ^ b1<<48 ^ b2>>48
return block
}
// creates 16 28-bit blocks rotated according // creates 16 28-bit blocks rotated according
// to the rotation schedule // to the rotation schedule
func ksRotate(in uint32) (out []uint32) { func ksRotate(in uint32) (out []uint32) {
......
...@@ -1504,20 +1504,63 @@ func TestSubstitutionTableKnownAnswerDecrypt(t *testing.T) { ...@@ -1504,20 +1504,63 @@ func TestSubstitutionTableKnownAnswerDecrypt(t *testing.T) {
} }
} }
func ExampleNewTripleDESCipher() { func TestInitialPermute(t *testing.T) {
// NewTripleDESCipher can also be used when EDE2 is required by for i := uint(0); i < 64; i++ {
// duplicating the first 8 bytes of the 16-byte key. bit := uint64(1) << i
ede2Key := []byte("example key 1234") got := permuteInitialBlock(bit)
want := uint64(1) << finalPermutation[63-i]
if got != want {
t.Errorf("permute(%x) = %x, want %x", bit, got, want)
}
}
}
func TestFinalPermute(t *testing.T) {
for i := uint(0); i < 64; i++ {
bit := uint64(1) << i
got := permuteFinalBlock(bit)
want := uint64(1) << initialPermutation[63-i]
if got != want {
t.Errorf("permute(%x) = %x, want %x", bit, got, want)
}
}
}
var tripleDESKey []byte func TestExpandBlock(t *testing.T) {
tripleDESKey = append(tripleDESKey, ede2Key[:16]...) for i := uint(0); i < 32; i++ {
tripleDESKey = append(tripleDESKey, ede2Key[:8]...) bit := uint32(1) << i
got := expandBlock(bit)
want := permuteBlock(uint64(bit), expansionFunction[:])
if got != want {
t.Errorf("expand(%x) = %x, want %x", bit, got, want)
}
}
}
_, err := NewTripleDESCipher(tripleDESKey) func BenchmarkEncrypt(b *testing.B) {
tt := encryptDESTests[0]
c, err := NewCipher(tt.key)
if err != nil { if err != nil {
panic(err) b.Fatal("NewCipher:", err)
} }
out := make([]byte, len(tt.in))
b.SetBytes(int64(len(out)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.Encrypt(out, tt.in)
}
}
// See crypto/cipher for how to use a cipher.Block for encryption and func BenchmarkDecrypt(b *testing.B) {
// decryption. tt := encryptDESTests[0]
c, err := NewCipher(tt.key)
if err != nil {
b.Fatal("NewCipher:", err)
}
out := make([]byte, len(tt.out))
b.SetBytes(int64(len(out)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.Decrypt(out, tt.out)
}
} }
...@@ -123,8 +123,8 @@ func Sign(rand io.Reader, priv *PrivateKey, hash []byte) (r, s *big.Int, err err ...@@ -123,8 +123,8 @@ func Sign(rand io.Reader, priv *PrivateKey, hash []byte) (r, s *big.Int, err err
return return
} }
// Verify verifies the signature in r, s of hash using the public key, pub. It // Verify verifies the signature in r, s of hash using the public key, pub. Its
// returns true iff the signature is valid. // return value records whether the signature is valid.
func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool { func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool {
// See [NSA] 3.4.2 // See [NSA] 3.4.2
c := pub.Curve c := pub.Curve
......
...@@ -322,7 +322,6 @@ func Unmarshal(curve Curve, data []byte) (x, y *big.Int) { ...@@ -322,7 +322,6 @@ func Unmarshal(curve Curve, data []byte) (x, y *big.Int) {
} }
var initonce sync.Once var initonce sync.Once
var p256 *CurveParams
var p384 *CurveParams var p384 *CurveParams
var p521 *CurveParams var p521 *CurveParams
...@@ -333,17 +332,6 @@ func initAll() { ...@@ -333,17 +332,6 @@ func initAll() {
initP521() initP521()
} }
func initP256() {
// See FIPS 186-3, section D.2.3
p256 = new(CurveParams)
p256.P, _ = new(big.Int).SetString("115792089210356248762697446949407573530086143415290314195533631308867097853951", 10)
p256.N, _ = new(big.Int).SetString("115792089210356248762697446949407573529996955224135760342422259061068512044369", 10)
p256.B, _ = new(big.Int).SetString("5ac635d8aa3a93e7b3ebbd55769886bc651d06b0cc53b0f63bce3c3e27d2604b", 16)
p256.Gx, _ = new(big.Int).SetString("6b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c296", 16)
p256.Gy, _ = new(big.Int).SetString("4fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f5", 16)
p256.BitSize = 256
}
func initP384() { func initP384() {
// See FIPS 186-3, section D.2.4 // See FIPS 186-3, section D.2.4
p384 = new(CurveParams) p384 = new(CurveParams)
......
...@@ -322,6 +322,52 @@ func TestGenericBaseMult(t *testing.T) { ...@@ -322,6 +322,52 @@ func TestGenericBaseMult(t *testing.T) {
} }
} }
func TestP256BaseMult(t *testing.T) {
p256 := P256()
p256Generic := p256.Params()
scalars := make([]*big.Int, 0, len(p224BaseMultTests)+1)
for _, e := range p224BaseMultTests {
k, _ := new(big.Int).SetString(e.k, 10)
scalars = append(scalars, k)
}
k := new(big.Int).SetInt64(1)
k.Lsh(k, 500)
scalars = append(scalars, k)
for i, k := range scalars {
x, y := p256.ScalarBaseMult(k.Bytes())
x2, y2 := p256Generic.ScalarBaseMult(k.Bytes())
if x.Cmp(x2) != 0 || y.Cmp(y2) != 0 {
t.Errorf("#%d: got (%x, %x), want (%x, %x)", i, x, y, x2, y2)
}
if testing.Short() && i > 5 {
break
}
}
}
func TestP256Mult(t *testing.T) {
p256 := P256()
p256Generic := p256.Params()
for i, e := range p224BaseMultTests {
x, _ := new(big.Int).SetString(e.x, 16)
y, _ := new(big.Int).SetString(e.y, 16)
k, _ := new(big.Int).SetString(e.k, 10)
xx, yy := p256.ScalarMult(x, y, k.Bytes())
xx2, yy2 := p256Generic.ScalarMult(x, y, k.Bytes())
if xx.Cmp(xx2) != 0 || yy.Cmp(yy2) != 0 {
t.Errorf("#%d: got (%x, %x), want (%x, %x)", i, xx, yy, xx2, yy2)
}
if testing.Short() && i > 5 {
break
}
}
}
func TestInfinity(t *testing.T) { func TestInfinity(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
...@@ -371,6 +417,17 @@ func BenchmarkBaseMult(b *testing.B) { ...@@ -371,6 +417,17 @@ func BenchmarkBaseMult(b *testing.B) {
} }
} }
func BenchmarkBaseMultP256(b *testing.B) {
b.ResetTimer()
p256 := P256()
e := p224BaseMultTests[25]
k, _ := new(big.Int).SetString(e.k, 10)
b.StartTimer()
for i := 0; i < b.N; i++ {
p256.ScalarBaseMult(k.Bytes())
}
}
func TestMarshal(t *testing.T) { func TestMarshal(t *testing.T) {
p224 := P224() p224 := P224()
_, x, y, err := GenerateKey(p224, rand.Reader) _, x, y, err := GenerateKey(p224, rand.Reader)
......
...@@ -164,7 +164,7 @@ var program = ` ...@@ -164,7 +164,7 @@ var program = `
// DO NOT EDIT. // DO NOT EDIT.
// Generate with: go run gen.go{{if .Full}} -full{{end}} | gofmt >md5block.go // Generate with: go run gen.go{{if .Full}} -full{{end}} | gofmt >md5block.go
// +build !amd64 // +build !amd64,!386,!arm
package md5 package md5
......
...@@ -88,7 +88,11 @@ func (d *digest) Write(p []byte) (nn int, err error) { ...@@ -88,7 +88,11 @@ func (d *digest) Write(p []byte) (nn int, err error) {
func (d0 *digest) Sum(in []byte) []byte { func (d0 *digest) Sum(in []byte) []byte {
// Make a copy of d0 so that caller can keep writing and summing. // Make a copy of d0 so that caller can keep writing and summing.
d := *d0 d := *d0
hash := d.checkSum()
return append(in, hash[:]...)
}
func (d *digest) checkSum() [Size]byte {
// Padding. Add a 1 bit and 0 bits until 56 bytes mod 64. // Padding. Add a 1 bit and 0 bits until 56 bytes mod 64.
len := d.len len := d.len
var tmp [64]byte var tmp [64]byte
...@@ -118,5 +122,13 @@ func (d0 *digest) Sum(in []byte) []byte { ...@@ -118,5 +122,13 @@ func (d0 *digest) Sum(in []byte) []byte {
digest[i*4+3] = byte(s >> 24) digest[i*4+3] = byte(s >> 24)
} }
return append(in, digest[:]...) return digest
}
// Sum returns the MD5 checksum of the data.
func Sum(data []byte) [Size]byte {
var d digest
d.Reset()
d.Write(data)
return d.checkSum()
} }
...@@ -53,6 +53,10 @@ var golden = []md5Test{ ...@@ -53,6 +53,10 @@ var golden = []md5Test{
func TestGolden(t *testing.T) { func TestGolden(t *testing.T) {
for i := 0; i < len(golden); i++ { for i := 0; i < len(golden); i++ {
g := golden[i] g := golden[i]
s := fmt.Sprintf("%x", Sum([]byte(g.in)))
if s != g.out {
t.Fatalf("Sum function: md5(%s) = %s want %s", g.in, s, g.out)
}
c := New() c := New()
buf := make([]byte, len(g.in)+4) buf := make([]byte, len(g.in)+4)
for j := 0; j < 3+4; j++ { for j := 0; j < 3+4; j++ {
...@@ -77,12 +81,28 @@ func TestGolden(t *testing.T) { ...@@ -77,12 +81,28 @@ func TestGolden(t *testing.T) {
} }
} }
func ExampleNew() { func TestLarge(t *testing.T) {
h := New() const N = 10000
io.WriteString(h, "The fog is getting thicker!") ok := "2bb571599a4180e1d542f76904adc3df" // md5sum of "0123456789" * 1000
io.WriteString(h, "And Leon's getting laaarger!") block := make([]byte, 10004)
fmt.Printf("%x", h.Sum(nil)) c := New()
// Output: e2c569be17396eca2a2e3c11578123ed for offset := 0; offset < 4; offset++ {
for i := 0; i < N; i++ {
block[offset+i] = '0' + byte(i%10)
}
for blockSize := 10; blockSize <= N; blockSize *= 10 {
blocks := N / blockSize
b := block[offset : offset+blockSize]
c.Reset()
for i := 0; i < blocks; i++ {
c.Write(b)
}
s := fmt.Sprintf("%x", c.Sum(nil))
if s != ok {
t.Fatalf("md5 TestLarge offset=%d, blockSize=%d = %s want %s", offset, blockSize, s, ok)
}
}
}
} }
var bench = New() var bench = New()
......
// DO NOT EDIT. // DO NOT EDIT.
// Generate with: go run gen.go -full | gofmt >md5block.go // Generate with: go run gen.go -full | gofmt >md5block.go
// +build !amd64,!386 // +build !amd64,!386,!arm
package md5 package md5
......
...@@ -2,8 +2,10 @@ ...@@ -2,8 +2,10 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build amd64 386 // +build amd64 386 arm
package md5 package md5
//go:noescape
func block(dig *digest, p []byte) func block(dig *digest, p []byte)
...@@ -14,5 +14,8 @@ import "io" ...@@ -14,5 +14,8 @@ import "io"
// On Windows systems, Reader uses the CryptGenRandom API. // On Windows systems, Reader uses the CryptGenRandom API.
var Reader io.Reader var Reader io.Reader
// Read is a helper function that calls Reader.Read. // Read is a helper function that calls Reader.Read using io.ReadFull.
func Read(b []byte) (n int, err error) { return Reader.Read(b) } // On return, n == len(b) if and only if err == nil.
func Read(b []byte) (n int, err error) {
return io.ReadFull(Reader, b)
}
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build darwin freebsd linux netbsd openbsd plan9 // +build darwin dragonfly freebsd linux netbsd openbsd plan9
// Unix cryptographically secure pseudorandom number // Unix cryptographically secure pseudorandom number
// generator. // generator.
......
...@@ -124,7 +124,11 @@ func decryptPKCS1v15(rand io.Reader, priv *PrivateKey, ciphertext []byte) (valid ...@@ -124,7 +124,11 @@ func decryptPKCS1v15(rand io.Reader, priv *PrivateKey, ciphertext []byte) (valid
lookingForIndex = subtle.ConstantTimeSelect(equals0, 0, lookingForIndex) lookingForIndex = subtle.ConstantTimeSelect(equals0, 0, lookingForIndex)
} }
valid = firstByteIsZero & secondByteIsTwo & (^lookingForIndex & 1) // The PS padding must be at least 8 bytes long, and it starts two
// bytes into em.
validPS := subtle.ConstantTimeLessOrEq(2+8, index)
valid = firstByteIsZero & secondByteIsTwo & (^lookingForIndex & 1) & validPS
msg = em[index+1:] msg = em[index+1:]
return return
} }
......
...@@ -197,6 +197,14 @@ func TestVerifyPKCS1v15(t *testing.T) { ...@@ -197,6 +197,14 @@ func TestVerifyPKCS1v15(t *testing.T) {
} }
} }
func TestOverlongMessagePKCS1v15(t *testing.T) {
ciphertext := decodeBase64("fjOVdirUzFoLlukv80dBllMLjXythIf22feqPrNo0YoIjzyzyoMFiLjAc/Y4krkeZ11XFThIrEvw\nkRiZcCq5ng==")
_, err := DecryptPKCS1v15(nil, rsaPrivateKey, ciphertext)
if err == nil {
t.Error("RSA decrypted a message that was too long.")
}
}
// In order to generate new test vectors you'll need the PEM form of this key: // In order to generate new test vectors you'll need the PEM form of this key:
// -----BEGIN RSA PRIVATE KEY----- // -----BEGIN RSA PRIVATE KEY-----
// MIIBOgIBAAJBALKZD0nEffqM1ACuak0bijtqE2QrI/KLADv7l3kK3ppMyCuLKoF0 // MIIBOgIBAAJBALKZD0nEffqM1ACuak0bijtqE2QrI/KLADv7l3kK3ppMyCuLKoF0
......
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package rsa
// This file implementes the PSS signature scheme [1].
//
// [1] http://www.rsa.com/rsalabs/pkcs/files/h11300-wp-pkcs-1v2-2-rsa-cryptography-standard.pdf
import (
"bytes"
"crypto"
"errors"
"hash"
"io"
"math/big"
)
func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byte, error) {
// See [1], section 9.1.1
hLen := hash.Size()
sLen := len(salt)
emLen := (emBits + 7) / 8
// 1. If the length of M is greater than the input limitation for the
// hash function (2^61 - 1 octets for SHA-1), output "message too
// long" and stop.
//
// 2. Let mHash = Hash(M), an octet string of length hLen.
if len(mHash) != hLen {
return nil, errors.New("crypto/rsa: input must be hashed message")
}
// 3. If emLen < hLen + sLen + 2, output "encoding error" and stop.
if emLen < hLen+sLen+2 {
return nil, errors.New("crypto/rsa: encoding error")
}
em := make([]byte, emLen)
db := em[:emLen-sLen-hLen-2+1+sLen]
h := em[emLen-sLen-hLen-2+1+sLen : emLen-1]
// 4. Generate a random octet string salt of length sLen; if sLen = 0,
// then salt is the empty string.
//
// 5. Let
// M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt;
//
// M' is an octet string of length 8 + hLen + sLen with eight
// initial zero octets.
//
// 6. Let H = Hash(M'), an octet string of length hLen.
var prefix [8]byte
hash.Write(prefix[:])
hash.Write(mHash)
hash.Write(salt)
h = hash.Sum(h[:0])
hash.Reset()
// 7. Generate an octet string PS consisting of emLen - sLen - hLen - 2
// zero octets. The length of PS may be 0.
//
// 8. Let DB = PS || 0x01 || salt; DB is an octet string of length
// emLen - hLen - 1.
db[emLen-sLen-hLen-2] = 0x01
copy(db[emLen-sLen-hLen-1:], salt)
// 9. Let dbMask = MGF(H, emLen - hLen - 1).
//
// 10. Let maskedDB = DB \xor dbMask.
mgf1XOR(db, hash, h)
// 11. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in
// maskedDB to zero.
db[0] &= (0xFF >> uint(8*emLen-emBits))
// 12. Let EM = maskedDB || H || 0xbc.
em[emLen-1] = 0xBC
// 13. Output EM.
return em, nil
}
func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
// 1. If the length of M is greater than the input limitation for the
// hash function (2^61 - 1 octets for SHA-1), output "inconsistent"
// and stop.
//
// 2. Let mHash = Hash(M), an octet string of length hLen.
hLen := hash.Size()
if hLen != len(mHash) {
return ErrVerification
}
// 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop.
emLen := (emBits + 7) / 8
if emLen < hLen+sLen+2 {
return ErrVerification
}
// 4. If the rightmost octet of EM does not have hexadecimal value
// 0xbc, output "inconsistent" and stop.
if em[len(em)-1] != 0xBC {
return ErrVerification
}
// 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and
// let H be the next hLen octets.
db := em[:emLen-hLen-1]
h := em[emLen-hLen-1 : len(em)-1]
// 6. If the leftmost 8 * emLen - emBits bits of the leftmost octet in
// maskedDB are not all equal to zero, output "inconsistent" and
// stop.
if em[0]&(0xFF<<uint(8-(8*emLen-emBits))) != 0 {
return ErrVerification
}
// 7. Let dbMask = MGF(H, emLen - hLen - 1).
//
// 8. Let DB = maskedDB \xor dbMask.
mgf1XOR(db, hash, h)
// 9. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB
// to zero.
db[0] &= (0xFF >> uint(8*emLen-emBits))
if sLen == PSSSaltLengthAuto {
FindSaltLength:
for sLen = emLen - (hLen + 2); sLen >= 0; sLen-- {
switch db[emLen-hLen-sLen-2] {
case 1:
break FindSaltLength
case 0:
continue
default:
return ErrVerification
}
}
if sLen < 0 {
return ErrVerification
}
} else {
// 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero
// or if the octet at position emLen - hLen - sLen - 1 (the leftmost
// position is "position 1") does not have hexadecimal value 0x01,
// output "inconsistent" and stop.
for _, e := range db[:emLen-hLen-sLen-2] {
if e != 0x00 {
return ErrVerification
}
}
if db[emLen-hLen-sLen-2] != 0x01 {
return ErrVerification
}
}
// 11. Let salt be the last sLen octets of DB.
salt := db[len(db)-sLen:]
// 12. Let
// M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ;
// M' is an octet string of length 8 + hLen + sLen with eight
// initial zero octets.
//
// 13. Let H' = Hash(M'), an octet string of length hLen.
var prefix [8]byte
hash.Write(prefix[:])
hash.Write(mHash)
hash.Write(salt)
h0 := hash.Sum(nil)
// 14. If H = H', output "consistent." Otherwise, output "inconsistent."
if !bytes.Equal(h0, h) {
return ErrVerification
}
return nil
}
// signPSSWithSalt calculates the signature of hashed using PSS [1] with specified salt.
// Note that hashed must be the result of hashing the input message using the
// given hash funcion. salt is a random sequence of bytes whose length will be
// later used to verify the signature.
func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) (s []byte, err error) {
nBits := priv.N.BitLen()
em, err := emsaPSSEncode(hashed, nBits-1, salt, hash.New())
if err != nil {
return
}
m := new(big.Int).SetBytes(em)
c, err := decrypt(rand, priv, m)
if err != nil {
return
}
s = make([]byte, (nBits+7)/8)
copyWithLeftPad(s, c.Bytes())
return
}
const (
// PSSSaltLengthAuto causes the salt in a PSS signature to be as large
// as possible when signing, and to be auto-detected when verifying.
PSSSaltLengthAuto = 0
// PSSSaltLengthEqualsHash causes the salt length to equal the length
// of the hash used in the signature.
PSSSaltLengthEqualsHash = -1
)
// PSSOptions contains options for creating and verifying PSS signatures.
type PSSOptions struct {
// SaltLength controls the length of the salt used in the PSS
// signature. It can either be a number of bytes, or one of the special
// PSSSaltLength constants.
SaltLength int
}
func (opts *PSSOptions) saltLength() int {
if opts == nil {
return PSSSaltLengthAuto
}
return opts.SaltLength
}
// SignPSS calculates the signature of hashed using RSASSA-PSS [1].
// Note that hashed must be the result of hashing the input message using the
// given hash funcion. The opts argument may be nil, in which case sensible
// defaults are used.
func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed []byte, opts *PSSOptions) (s []byte, err error) {
saltLength := opts.saltLength()
switch saltLength {
case PSSSaltLengthAuto:
saltLength = (priv.N.BitLen()+7)/8 - 2 - hash.Size()
case PSSSaltLengthEqualsHash:
saltLength = hash.Size()
}
salt := make([]byte, saltLength)
if _, err = io.ReadFull(rand, salt); err != nil {
return
}
return signPSSWithSalt(rand, priv, hash, hashed, salt)
}
// VerifyPSS verifies a PSS signature.
// hashed is the result of hashing the input message using the given hash
// function and sig is the signature. A valid signature is indicated by
// returning a nil error. The opts argument may be nil, in which case sensible
// defaults are used.
func VerifyPSS(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte, opts *PSSOptions) error {
return verifyPSS(pub, hash, hashed, sig, opts.saltLength())
}
// verifyPSS verifies a PSS signature with the given salt length.
func verifyPSS(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte, saltLen int) error {
nBits := pub.N.BitLen()
if len(sig) != (nBits+7)/8 {
return ErrVerification
}
s := new(big.Int).SetBytes(sig)
m := encrypt(new(big.Int), pub, s)
emBits := nBits - 1
emLen := (emBits + 7) / 8
if emLen < len(m.Bytes()) {
return ErrVerification
}
em := make([]byte, emLen)
copyWithLeftPad(em, m.Bytes())
if saltLen == PSSSaltLengthEqualsHash {
saltLen = hash.Size()
}
return emsaPSSVerify(hashed, em, emBits, saltLen, hash.New())
}
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment