Commit 94252f4b by Ian Lance Taylor

libgo: Update to weekly.2012-02-07.

From-SVN: r184034
parent cd636811
...@@ -35,14 +35,17 @@ func main() { ...@@ -35,14 +35,17 @@ func main() {
go sender(c, 100000) go sender(c, 100000)
receiver(c, dummy, 100000) receiver(c, dummy, 100000)
runtime.GC() runtime.GC()
runtime.MemStats.Alloc = 0 memstats := new(runtime.MemStats)
runtime.ReadMemStats(memstats)
alloc := memstats.Alloc
// second time shouldn't increase footprint by much // second time shouldn't increase footprint by much
go sender(c, 100000) go sender(c, 100000)
receiver(c, dummy, 100000) receiver(c, dummy, 100000)
runtime.GC() runtime.GC()
runtime.ReadMemStats(memstats)
if runtime.MemStats.Alloc > 1e5 { if memstats.Alloc-alloc > 1e5 {
println("BUG: too much memory for 100,000 selects:", runtime.MemStats.Alloc) println("BUG: too much memory for 100,000 selects:", memstats.Alloc-alloc)
} }
} }
...@@ -19,7 +19,9 @@ import ( ...@@ -19,7 +19,9 @@ import (
func main() { func main() {
const N = 10000 const N = 10000
st := runtime.MemStats st := new(runtime.MemStats)
memstats := new(runtime.MemStats)
runtime.ReadMemStats(st)
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
c := make(chan int, 10) c := make(chan int, 10)
_ = c _ = c
...@@ -33,8 +35,8 @@ func main() { ...@@ -33,8 +35,8 @@ func main() {
} }
} }
runtime.UpdateMemStats() runtime.ReadMemStats(memstats)
obj := runtime.MemStats.HeapObjects - st.HeapObjects obj := memstats.HeapObjects - st.HeapObjects
if obj > N/5 { if obj > N/5 {
fmt.Println("too many objects left:", obj) fmt.Println("too many objects left:", obj)
os.Exit(1) os.Exit(1)
......
...@@ -17,9 +17,10 @@ import ( ...@@ -17,9 +17,10 @@ import (
var chatty = flag.Bool("v", false, "chatty") var chatty = flag.Bool("v", false, "chatty")
func main() { func main() {
memstats := new(runtime.MemStats)
runtime.Free(runtime.Alloc(1)) runtime.Free(runtime.Alloc(1))
runtime.UpdateMemStats() runtime.ReadMemStats(memstats)
if *chatty { if *chatty {
fmt.Printf("%+v %v\n", runtime.MemStats, uint64(0)) fmt.Printf("%+v %v\n", memstats, uint64(0))
} }
} }
...@@ -21,8 +21,9 @@ var footprint uint64 ...@@ -21,8 +21,9 @@ var footprint uint64
var allocated uint64 var allocated uint64
func bigger() { func bigger() {
runtime.UpdateMemStats() memstats := new(runtime.MemStats)
if f := runtime.MemStats.Sys; footprint < f { runtime.ReadMemStats(memstats)
if f := memstats.Sys; footprint < f {
footprint = f footprint = f
if *chatty { if *chatty {
println("Footprint", footprint, " for ", allocated) println("Footprint", footprint, " for ", allocated)
......
...@@ -16,10 +16,12 @@ import ( ...@@ -16,10 +16,12 @@ import (
var chatty = flag.Bool("v", false, "chatty") var chatty = flag.Bool("v", false, "chatty")
var oldsys uint64 var oldsys uint64
var memstats runtime.MemStats
func bigger() { func bigger() {
runtime.UpdateMemStats() st := &memstats
if st := runtime.MemStats; oldsys < st.Sys { runtime.ReadMemStats(st)
if oldsys < st.Sys {
oldsys = st.Sys oldsys = st.Sys
if *chatty { if *chatty {
println(st.Sys, " system bytes for ", st.Alloc, " Go bytes") println(st.Sys, " system bytes for ", st.Alloc, " Go bytes")
...@@ -32,26 +34,26 @@ func bigger() { ...@@ -32,26 +34,26 @@ func bigger() {
} }
func main() { func main() {
runtime.GC() // clean up garbage from init runtime.GC() // clean up garbage from init
runtime.UpdateMemStats() // first call can do some allocations runtime.ReadMemStats(&memstats) // first call can do some allocations
runtime.MemProfileRate = 0 // disable profiler runtime.MemProfileRate = 0 // disable profiler
runtime.MemStats.Alloc = 0 // ignore stacks stacks := memstats.Alloc // ignore stacks
flag.Parse() flag.Parse()
for i := 0; i < 1<<7; i++ { for i := 0; i < 1<<7; i++ {
for j := 1; j <= 1<<22; j <<= 1 { for j := 1; j <= 1<<22; j <<= 1 {
if i == 0 && *chatty { if i == 0 && *chatty {
println("First alloc:", j) println("First alloc:", j)
} }
if a := runtime.MemStats.Alloc; a != 0 { if a := memstats.Alloc - stacks; a != 0 {
println("no allocations but stats report", a, "bytes allocated") println("no allocations but stats report", a, "bytes allocated")
panic("fail") panic("fail")
} }
b := runtime.Alloc(uintptr(j)) b := runtime.Alloc(uintptr(j))
runtime.UpdateMemStats() runtime.ReadMemStats(&memstats)
during := runtime.MemStats.Alloc during := memstats.Alloc - stacks
runtime.Free(b) runtime.Free(b)
runtime.UpdateMemStats() runtime.ReadMemStats(&memstats)
if a := runtime.MemStats.Alloc; a != 0 { if a := memstats.Alloc - stacks; a != 0 {
println("allocated ", j, ": wrong stats: during=", during, " after=", a, " (want 0)") println("allocated ", j, ": wrong stats: during=", during, " after=", a, " (want 0)")
panic("fail") panic("fail")
} }
......
...@@ -20,7 +20,7 @@ var reverse = flag.Bool("r", false, "reverse") ...@@ -20,7 +20,7 @@ var reverse = flag.Bool("r", false, "reverse")
var longtest = flag.Bool("l", false, "long test") var longtest = flag.Bool("l", false, "long test")
var b []*byte var b []*byte
var stats = &runtime.MemStats var stats = new(runtime.MemStats)
func OkAmount(size, n uintptr) bool { func OkAmount(size, n uintptr) bool {
if n < size { if n < size {
...@@ -42,7 +42,7 @@ func AllocAndFree(size, count int) { ...@@ -42,7 +42,7 @@ func AllocAndFree(size, count int) {
if *chatty { if *chatty {
fmt.Printf("size=%d count=%d ...\n", size, count) fmt.Printf("size=%d count=%d ...\n", size, count)
} }
runtime.UpdateMemStats() runtime.ReadMemStats(stats)
n1 := stats.Alloc n1 := stats.Alloc
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
b[i] = runtime.Alloc(uintptr(size)) b[i] = runtime.Alloc(uintptr(size))
...@@ -51,13 +51,13 @@ func AllocAndFree(size, count int) { ...@@ -51,13 +51,13 @@ func AllocAndFree(size, count int) {
println("lookup failed: got", base, n, "for", b[i]) println("lookup failed: got", base, n, "for", b[i])
panic("fail") panic("fail")
} }
runtime.UpdateMemStats() runtime.ReadMemStats(stats)
if stats.Sys > 1e9 { if stats.Sys > 1e9 {
println("too much memory allocated") println("too much memory allocated")
panic("fail") panic("fail")
} }
} }
runtime.UpdateMemStats() runtime.ReadMemStats(stats)
n2 := stats.Alloc n2 := stats.Alloc
if *chatty { if *chatty {
fmt.Printf("size=%d count=%d stats=%+v\n", size, count, *stats) fmt.Printf("size=%d count=%d stats=%+v\n", size, count, *stats)
...@@ -75,17 +75,17 @@ func AllocAndFree(size, count int) { ...@@ -75,17 +75,17 @@ func AllocAndFree(size, count int) {
panic("fail") panic("fail")
} }
runtime.Free(b[i]) runtime.Free(b[i])
runtime.UpdateMemStats() runtime.ReadMemStats(stats)
if stats.Alloc != uint64(alloc-n) { if stats.Alloc != uint64(alloc-n) {
println("free alloc got", stats.Alloc, "expected", alloc-n, "after free of", n) println("free alloc got", stats.Alloc, "expected", alloc-n, "after free of", n)
panic("fail") panic("fail")
} }
if runtime.MemStats.Sys > 1e9 { if stats.Sys > 1e9 {
println("too much memory allocated") println("too much memory allocated")
panic("fail") panic("fail")
} }
} }
runtime.UpdateMemStats() runtime.ReadMemStats(stats)
n4 := stats.Alloc n4 := stats.Alloc
if *chatty { if *chatty {
......
1107a7d3cb07 52ba9506bd99
The first line of this file holds the Mercurial revision number of the The first line of this file holds the Mercurial revision number of the
last merge done from the master library sources. last merge done from the master library sources.
...@@ -226,6 +226,7 @@ toolexeclibgoexp_DATA = \ ...@@ -226,6 +226,7 @@ toolexeclibgoexp_DATA = \
$(exp_inotify_gox) \ $(exp_inotify_gox) \
exp/norm.gox \ exp/norm.gox \
exp/proxy.gox \ exp/proxy.gox \
exp/signal.gox \
exp/terminal.gox \ exp/terminal.gox \
exp/types.gox \ exp/types.gox \
exp/utf8string.gox exp/utf8string.gox
...@@ -257,13 +258,11 @@ toolexeclibgohtml_DATA = \ ...@@ -257,13 +258,11 @@ toolexeclibgohtml_DATA = \
toolexeclibgoimagedir = $(toolexeclibgodir)/image toolexeclibgoimagedir = $(toolexeclibgodir)/image
toolexeclibgoimage_DATA = \ toolexeclibgoimage_DATA = \
image/bmp.gox \
image/color.gox \ image/color.gox \
image/draw.gox \ image/draw.gox \
image/gif.gox \ image/gif.gox \
image/jpeg.gox \ image/jpeg.gox \
image/png.gox \ image/png.gox
image/tiff.gox
toolexeclibgoindexdir = $(toolexeclibgodir)/index toolexeclibgoindexdir = $(toolexeclibgodir)/index
...@@ -327,8 +326,7 @@ toolexeclibgoosdir = $(toolexeclibgodir)/os ...@@ -327,8 +326,7 @@ toolexeclibgoosdir = $(toolexeclibgodir)/os
toolexeclibgoos_DATA = \ toolexeclibgoos_DATA = \
os/exec.gox \ os/exec.gox \
os/user.gox \ os/user.gox
os/signal.gox
toolexeclibgopathdir = $(toolexeclibgodir)/path toolexeclibgopathdir = $(toolexeclibgodir)/path
...@@ -949,7 +947,6 @@ go_crypto_cipher_files = \ ...@@ -949,7 +947,6 @@ go_crypto_cipher_files = \
go/crypto/cipher/cipher.go \ go/crypto/cipher/cipher.go \
go/crypto/cipher/ctr.go \ go/crypto/cipher/ctr.go \
go/crypto/cipher/io.go \ go/crypto/cipher/io.go \
go/crypto/cipher/ocfb.go \
go/crypto/cipher/ofb.go go/crypto/cipher/ofb.go
go_crypto_des_files = \ go_crypto_des_files = \
go/crypto/des/block.go \ go/crypto/des/block.go \
...@@ -1107,6 +1104,8 @@ go_exp_proxy_files = \ ...@@ -1107,6 +1104,8 @@ go_exp_proxy_files = \
go/exp/proxy/per_host.go \ go/exp/proxy/per_host.go \
go/exp/proxy/proxy.go \ go/exp/proxy/proxy.go \
go/exp/proxy/socks5.go go/exp/proxy/socks5.go
go_exp_signal_files = \
go/exp/signal/signal.go
go_exp_terminal_files = \ go_exp_terminal_files = \
go/exp/terminal/terminal.go \ go/exp/terminal/terminal.go \
go/exp/terminal/util.go go/exp/terminal/util.go
...@@ -1179,9 +1178,6 @@ go_html_template_files = \ ...@@ -1179,9 +1178,6 @@ go_html_template_files = \
go/html/template/transition.go \ go/html/template/transition.go \
go/html/template/url.go go/html/template/url.go
go_image_bmp_files = \
go/image/bmp/reader.go
go_image_color_files = \ go_image_color_files = \
go/image/color/color.go \ go/image/color/color.go \
go/image/color/ycbcr.go go/image/color/ycbcr.go
...@@ -1203,12 +1199,6 @@ go_image_png_files = \ ...@@ -1203,12 +1199,6 @@ go_image_png_files = \
go/image/png/reader.go \ go/image/png/reader.go \
go/image/png/writer.go go/image/png/writer.go
go_image_tiff_files = \
go/image/tiff/buffer.go \
go/image/tiff/compress.go \
go/image/tiff/consts.go \
go/image/tiff/reader.go
go_index_suffixarray_files = \ go_index_suffixarray_files = \
go/index/suffixarray/qsufsort.go \ go/index/suffixarray/qsufsort.go \
go/index/suffixarray/suffixarray.go go/index/suffixarray/suffixarray.go
...@@ -1317,9 +1307,6 @@ go_os_user_files = \ ...@@ -1317,9 +1307,6 @@ go_os_user_files = \
go/os/user/user.go \ go/os/user/user.go \
go/os/user/lookup_unix.go go/os/user/lookup_unix.go
go_os_signal_files = \
go/os/signal/signal.go
go_path_filepath_files = \ go_path_filepath_files = \
go/path/filepath/match.go \ go/path/filepath/match.go \
go/path/filepath/path.go \ go/path/filepath/path.go \
...@@ -1673,6 +1660,7 @@ libgo_go_objs = \ ...@@ -1673,6 +1660,7 @@ libgo_go_objs = \
exp/html.lo \ exp/html.lo \
exp/norm.lo \ exp/norm.lo \
exp/proxy.lo \ exp/proxy.lo \
exp/signal.lo \
exp/terminal.lo \ exp/terminal.lo \
exp/types.lo \ exp/types.lo \
exp/utf8string.lo \ exp/utf8string.lo \
...@@ -1693,13 +1681,11 @@ libgo_go_objs = \ ...@@ -1693,13 +1681,11 @@ libgo_go_objs = \
net/http/httptest.lo \ net/http/httptest.lo \
net/http/httputil.lo \ net/http/httputil.lo \
net/http/pprof.lo \ net/http/pprof.lo \
image/bmp.lo \
image/color.lo \ image/color.lo \
image/draw.lo \ image/draw.lo \
image/gif.lo \ image/gif.lo \
image/jpeg.lo \ image/jpeg.lo \
image/png.lo \ image/png.lo \
image/tiff.lo \
index/suffixarray.lo \ index/suffixarray.lo \
io/ioutil.lo \ io/ioutil.lo \
log/syslog.lo \ log/syslog.lo \
...@@ -1720,7 +1706,6 @@ libgo_go_objs = \ ...@@ -1720,7 +1706,6 @@ libgo_go_objs = \
old/template.lo \ old/template.lo \
$(os_lib_inotify_lo) \ $(os_lib_inotify_lo) \
os/user.lo \ os/user.lo \
os/signal.lo \
path/filepath.lo \ path/filepath.lo \
regexp/syntax.lo \ regexp/syntax.lo \
net/rpc/jsonrpc.lo \ net/rpc/jsonrpc.lo \
...@@ -2607,6 +2592,16 @@ exp/proxy/check: $(CHECK_DEPS) ...@@ -2607,6 +2592,16 @@ exp/proxy/check: $(CHECK_DEPS)
@$(CHECK) @$(CHECK)
.PHONY: exp/proxy/check .PHONY: exp/proxy/check
@go_include@ exp/signal.lo.dep
exp/signal.lo.dep: $(go_exp_signal_files)
$(BUILDDEPS)
exp/signal.lo: $(go_exp_signal_files)
$(BUILDPACKAGE)
exp/signal/check: $(CHECK_DEPS)
@$(MKDIR_P) exp/signal
@$(CHECK)
.PHONY: exp/signal/check
@go_include@ exp/terminal.lo.dep @go_include@ exp/terminal.lo.dep
exp/terminal.lo.dep: $(go_exp_terminal_files) exp/terminal.lo.dep: $(go_exp_terminal_files)
$(BUILDDEPS) $(BUILDDEPS)
...@@ -2776,16 +2771,6 @@ hash/fnv/check: $(CHECK_DEPS) ...@@ -2776,16 +2771,6 @@ hash/fnv/check: $(CHECK_DEPS)
@$(CHECK) @$(CHECK)
.PHONY: hash/fnv/check .PHONY: hash/fnv/check
@go_include@ image/bmp.lo.dep
image/bmp.lo.dep: $(go_image_bmp_files)
$(BUILDDEPS)
image/bmp.lo: $(go_image_bmp_files)
$(BUILDPACKAGE)
image/bmp/check: $(CHECK_DEPS)
@$(MKDIR_P) image/bmp
@$(CHECK)
.PHONY: image/bmp/check
@go_include@ image/color.lo.dep @go_include@ image/color.lo.dep
image/color.lo.dep: $(go_image_color_files) image/color.lo.dep: $(go_image_color_files)
$(BUILDDEPS) $(BUILDDEPS)
...@@ -2836,16 +2821,6 @@ image/png/check: $(CHECK_DEPS) ...@@ -2836,16 +2821,6 @@ image/png/check: $(CHECK_DEPS)
@$(CHECK) @$(CHECK)
.PHONY: image/png/check .PHONY: image/png/check
@go_include@ image/tiff.lo.dep
image/tiff.lo.dep: $(go_image_tiff_files)
$(BUILDDEPS)
image/tiff.lo: $(go_image_tiff_files)
$(BUILDPACKAGE)
image/tiff/check: $(CHECK_DEPS)
@$(MKDIR_P) image/tiff
@$(CHECK)
.PHONY: image/tiff/check
@go_include@ index/suffixarray.lo.dep @go_include@ index/suffixarray.lo.dep
index/suffixarray.lo.dep: $(go_index_suffixarray_files) index/suffixarray.lo.dep: $(go_index_suffixarray_files)
$(BUILDDEPS) $(BUILDDEPS)
...@@ -3088,16 +3063,6 @@ os/user/check: $(CHECK_DEPS) ...@@ -3088,16 +3063,6 @@ os/user/check: $(CHECK_DEPS)
@$(CHECK) @$(CHECK)
.PHONY: os/user/check .PHONY: os/user/check
@go_include@ os/signal.lo.dep
os/signal.lo.dep: $(go_os_signal_files)
$(BUILDDEPS)
os/signal.lo: $(go_os_signal_files)
$(BUILDPACKAGE)
os/signal/check: $(CHECK_DEPS)
@$(MKDIR_P) os/signal
@$(CHECK)
.PHONY: os/signal/check
@go_include@ path/filepath.lo.dep @go_include@ path/filepath.lo.dep
path/filepath.lo.dep: $(go_path_filepath_files) path/filepath.lo.dep: $(go_path_filepath_files)
$(BUILDDEPS) $(BUILDDEPS)
...@@ -3412,6 +3377,8 @@ exp/norm.gox: exp/norm.lo ...@@ -3412,6 +3377,8 @@ exp/norm.gox: exp/norm.lo
$(BUILDGOX) $(BUILDGOX)
exp/proxy.gox: exp/proxy.lo exp/proxy.gox: exp/proxy.lo
$(BUILDGOX) $(BUILDGOX)
exp/signal.gox: exp/signal.lo
$(BUILDGOX)
exp/terminal.gox: exp/terminal.lo exp/terminal.gox: exp/terminal.lo
$(BUILDGOX) $(BUILDGOX)
exp/types.gox: exp/types.lo exp/types.gox: exp/types.lo
...@@ -3446,8 +3413,6 @@ hash/crc64.gox: hash/crc64.lo ...@@ -3446,8 +3413,6 @@ hash/crc64.gox: hash/crc64.lo
hash/fnv.gox: hash/fnv.lo hash/fnv.gox: hash/fnv.lo
$(BUILDGOX) $(BUILDGOX)
image/bmp.gox: image/bmp.lo
$(BUILDGOX)
image/color.gox: image/color.lo image/color.gox: image/color.lo
$(BUILDGOX) $(BUILDGOX)
image/draw.gox: image/draw.lo image/draw.gox: image/draw.lo
...@@ -3458,8 +3423,6 @@ image/jpeg.gox: image/jpeg.lo ...@@ -3458,8 +3423,6 @@ image/jpeg.gox: image/jpeg.lo
$(BUILDGOX) $(BUILDGOX)
image/png.gox: image/png.lo image/png.gox: image/png.lo
$(BUILDGOX) $(BUILDGOX)
image/tiff.gox: image/tiff.lo
$(BUILDGOX)
index/suffixarray.gox: index/suffixarray.lo index/suffixarray.gox: index/suffixarray.lo
$(BUILDGOX) $(BUILDGOX)
...@@ -3518,8 +3481,6 @@ os/exec.gox: os/exec.lo ...@@ -3518,8 +3481,6 @@ os/exec.gox: os/exec.lo
$(BUILDGOX) $(BUILDGOX)
os/user.gox: os/user.lo os/user.gox: os/user.lo
$(BUILDGOX) $(BUILDGOX)
os/signal.gox: os/signal.lo
$(BUILDGOX)
path/filepath.gox: path/filepath.lo path/filepath.gox: path/filepath.lo
$(BUILDGOX) $(BUILDGOX)
...@@ -3637,6 +3598,7 @@ TEST_PACKAGES = \ ...@@ -3637,6 +3598,7 @@ TEST_PACKAGES = \
$(exp_inotify_check) \ $(exp_inotify_check) \
exp/norm/check \ exp/norm/check \
exp/proxy/check \ exp/proxy/check \
exp/signal/check \
exp/terminal/check \ exp/terminal/check \
exp/utf8string/check \ exp/utf8string/check \
html/template/check \ html/template/check \
...@@ -3656,7 +3618,6 @@ TEST_PACKAGES = \ ...@@ -3656,7 +3618,6 @@ TEST_PACKAGES = \
image/draw/check \ image/draw/check \
image/jpeg/check \ image/jpeg/check \
image/png/check \ image/png/check \
image/tiff/check \
index/suffixarray/check \ index/suffixarray/check \
io/ioutil/check \ io/ioutil/check \
log/syslog/check \ log/syslog/check \
...@@ -3679,7 +3640,6 @@ TEST_PACKAGES = \ ...@@ -3679,7 +3640,6 @@ TEST_PACKAGES = \
old/template/check \ old/template/check \
os/exec/check \ os/exec/check \
os/user/check \ os/user/check \
os/signal/check \
path/filepath/check \ path/filepath/check \
regexp/syntax/check \ regexp/syntax/check \
sync/atomic/check \ sync/atomic/check \
......
...@@ -154,26 +154,26 @@ am__DEPENDENCIES_2 = bufio/bufio.lo bytes/bytes.lo bytes/index.lo \ ...@@ -154,26 +154,26 @@ am__DEPENDENCIES_2 = bufio/bufio.lo bytes/bytes.lo bytes/index.lo \
encoding/base32.lo encoding/base64.lo encoding/binary.lo \ encoding/base32.lo encoding/base64.lo encoding/binary.lo \
encoding/csv.lo encoding/gob.lo encoding/hex.lo \ encoding/csv.lo encoding/gob.lo encoding/hex.lo \
encoding/json.lo encoding/pem.lo encoding/xml.lo exp/ebnf.lo \ encoding/json.lo encoding/pem.lo encoding/xml.lo exp/ebnf.lo \
exp/html.lo exp/norm.lo exp/proxy.lo exp/terminal.lo \ exp/html.lo exp/norm.lo exp/proxy.lo exp/signal.lo \
exp/types.lo exp/utf8string.lo html/template.lo go/ast.lo \ exp/terminal.lo exp/types.lo exp/utf8string.lo \
go/build.lo go/doc.lo go/parser.lo go/printer.lo go/scanner.lo \ html/template.lo go/ast.lo go/build.lo go/doc.lo go/parser.lo \
go/token.lo hash/adler32.lo hash/crc32.lo hash/crc64.lo \ go/printer.lo go/scanner.lo go/token.lo hash/adler32.lo \
hash/fnv.lo net/http/cgi.lo net/http/fcgi.lo \ hash/crc32.lo hash/crc64.lo hash/fnv.lo net/http/cgi.lo \
net/http/httptest.lo net/http/httputil.lo net/http/pprof.lo \ net/http/fcgi.lo net/http/httptest.lo net/http/httputil.lo \
image/bmp.lo image/color.lo image/draw.lo image/gif.lo \ net/http/pprof.lo image/color.lo image/draw.lo image/gif.lo \
image/jpeg.lo image/png.lo image/tiff.lo index/suffixarray.lo \ image/jpeg.lo image/png.lo index/suffixarray.lo io/ioutil.lo \
io/ioutil.lo log/syslog.lo log/syslog/syslog_c.lo math/big.lo \ log/syslog.lo log/syslog/syslog_c.lo math/big.lo math/cmplx.lo \
math/cmplx.lo math/rand.lo mime/mime.lo mime/multipart.lo \ math/rand.lo mime/mime.lo mime/multipart.lo net/http.lo \
net/http.lo net/mail.lo net/rpc.lo net/smtp.lo \ net/mail.lo net/rpc.lo net/smtp.lo net/textproto.lo net/url.lo \
net/textproto.lo net/url.lo old/netchan.lo old/regexp.lo \ old/netchan.lo old/regexp.lo old/template.lo \
old/template.lo $(am__DEPENDENCIES_1) os/user.lo os/signal.lo \ $(am__DEPENDENCIES_1) os/user.lo path/filepath.lo \
path/filepath.lo regexp/syntax.lo net/rpc/jsonrpc.lo \ regexp/syntax.lo net/rpc/jsonrpc.lo runtime/debug.lo \
runtime/debug.lo runtime/pprof.lo sync/atomic.lo \ runtime/pprof.lo sync/atomic.lo sync/atomic_c.lo \
sync/atomic_c.lo syscall/syscall.lo syscall/errno.lo \ syscall/syscall.lo syscall/errno.lo syscall/wait.lo \
syscall/wait.lo text/scanner.lo text/tabwriter.lo \ text/scanner.lo text/tabwriter.lo text/template.lo \
text/template.lo text/template/parse.lo testing/testing.lo \ text/template/parse.lo testing/testing.lo testing/iotest.lo \
testing/iotest.lo testing/quick.lo testing/script.lo \ testing/quick.lo testing/script.lo unicode/utf16.lo \
unicode/utf16.lo unicode/utf8.lo unicode/utf8.lo
libgo_la_DEPENDENCIES = $(am__DEPENDENCIES_2) $(am__DEPENDENCIES_1) \ libgo_la_DEPENDENCIES = $(am__DEPENDENCIES_2) $(am__DEPENDENCIES_1) \
$(am__DEPENDENCIES_1) $(am__DEPENDENCIES_1) \ $(am__DEPENDENCIES_1) $(am__DEPENDENCIES_1) \
$(am__DEPENDENCIES_1) $(am__DEPENDENCIES_1)
...@@ -679,6 +679,7 @@ toolexeclibgoexp_DATA = \ ...@@ -679,6 +679,7 @@ toolexeclibgoexp_DATA = \
$(exp_inotify_gox) \ $(exp_inotify_gox) \
exp/norm.gox \ exp/norm.gox \
exp/proxy.gox \ exp/proxy.gox \
exp/signal.gox \
exp/terminal.gox \ exp/terminal.gox \
exp/types.gox \ exp/types.gox \
exp/utf8string.gox exp/utf8string.gox
...@@ -706,13 +707,11 @@ toolexeclibgohtml_DATA = \ ...@@ -706,13 +707,11 @@ toolexeclibgohtml_DATA = \
toolexeclibgoimagedir = $(toolexeclibgodir)/image toolexeclibgoimagedir = $(toolexeclibgodir)/image
toolexeclibgoimage_DATA = \ toolexeclibgoimage_DATA = \
image/bmp.gox \
image/color.gox \ image/color.gox \
image/draw.gox \ image/draw.gox \
image/gif.gox \ image/gif.gox \
image/jpeg.gox \ image/jpeg.gox \
image/png.gox \ image/png.gox
image/tiff.gox
toolexeclibgoindexdir = $(toolexeclibgodir)/index toolexeclibgoindexdir = $(toolexeclibgodir)/index
toolexeclibgoindex_DATA = \ toolexeclibgoindex_DATA = \
...@@ -766,8 +765,7 @@ toolexeclibgoold_DATA = \ ...@@ -766,8 +765,7 @@ toolexeclibgoold_DATA = \
toolexeclibgoosdir = $(toolexeclibgodir)/os toolexeclibgoosdir = $(toolexeclibgodir)/os
toolexeclibgoos_DATA = \ toolexeclibgoos_DATA = \
os/exec.gox \ os/exec.gox \
os/user.gox \ os/user.gox
os/signal.gox
toolexeclibgopathdir = $(toolexeclibgodir)/path toolexeclibgopathdir = $(toolexeclibgodir)/path
toolexeclibgopath_DATA = \ toolexeclibgopath_DATA = \
...@@ -1253,7 +1251,6 @@ go_crypto_cipher_files = \ ...@@ -1253,7 +1251,6 @@ go_crypto_cipher_files = \
go/crypto/cipher/cipher.go \ go/crypto/cipher/cipher.go \
go/crypto/cipher/ctr.go \ go/crypto/cipher/ctr.go \
go/crypto/cipher/io.go \ go/crypto/cipher/io.go \
go/crypto/cipher/ocfb.go \
go/crypto/cipher/ofb.go go/crypto/cipher/ofb.go
go_crypto_des_files = \ go_crypto_des_files = \
...@@ -1445,6 +1442,9 @@ go_exp_proxy_files = \ ...@@ -1445,6 +1442,9 @@ go_exp_proxy_files = \
go/exp/proxy/proxy.go \ go/exp/proxy/proxy.go \
go/exp/proxy/socks5.go go/exp/proxy/socks5.go
go_exp_signal_files = \
go/exp/signal/signal.go
go_exp_terminal_files = \ go_exp_terminal_files = \
go/exp/terminal/terminal.go \ go/exp/terminal/terminal.go \
go/exp/terminal/util.go go/exp/terminal/util.go
...@@ -1528,9 +1528,6 @@ go_html_template_files = \ ...@@ -1528,9 +1528,6 @@ go_html_template_files = \
go/html/template/transition.go \ go/html/template/transition.go \
go/html/template/url.go go/html/template/url.go
go_image_bmp_files = \
go/image/bmp/reader.go
go_image_color_files = \ go_image_color_files = \
go/image/color/color.go \ go/image/color/color.go \
go/image/color/ycbcr.go go/image/color/ycbcr.go
...@@ -1552,12 +1549,6 @@ go_image_png_files = \ ...@@ -1552,12 +1549,6 @@ go_image_png_files = \
go/image/png/reader.go \ go/image/png/reader.go \
go/image/png/writer.go go/image/png/writer.go
go_image_tiff_files = \
go/image/tiff/buffer.go \
go/image/tiff/compress.go \
go/image/tiff/consts.go \
go/image/tiff/reader.go
go_index_suffixarray_files = \ go_index_suffixarray_files = \
go/index/suffixarray/qsufsort.go \ go/index/suffixarray/qsufsort.go \
go/index/suffixarray/suffixarray.go go/index/suffixarray/suffixarray.go
...@@ -1677,9 +1668,6 @@ go_os_user_files = \ ...@@ -1677,9 +1668,6 @@ go_os_user_files = \
go/os/user/user.go \ go/os/user/user.go \
go/os/user/lookup_unix.go go/os/user/lookup_unix.go
go_os_signal_files = \
go/os/signal/signal.go
go_path_filepath_files = \ go_path_filepath_files = \
go/path/filepath/match.go \ go/path/filepath/match.go \
go/path/filepath/path.go \ go/path/filepath/path.go \
...@@ -1924,6 +1912,7 @@ libgo_go_objs = \ ...@@ -1924,6 +1912,7 @@ libgo_go_objs = \
exp/html.lo \ exp/html.lo \
exp/norm.lo \ exp/norm.lo \
exp/proxy.lo \ exp/proxy.lo \
exp/signal.lo \
exp/terminal.lo \ exp/terminal.lo \
exp/types.lo \ exp/types.lo \
exp/utf8string.lo \ exp/utf8string.lo \
...@@ -1944,13 +1933,11 @@ libgo_go_objs = \ ...@@ -1944,13 +1933,11 @@ libgo_go_objs = \
net/http/httptest.lo \ net/http/httptest.lo \
net/http/httputil.lo \ net/http/httputil.lo \
net/http/pprof.lo \ net/http/pprof.lo \
image/bmp.lo \
image/color.lo \ image/color.lo \
image/draw.lo \ image/draw.lo \
image/gif.lo \ image/gif.lo \
image/jpeg.lo \ image/jpeg.lo \
image/png.lo \ image/png.lo \
image/tiff.lo \
index/suffixarray.lo \ index/suffixarray.lo \
io/ioutil.lo \ io/ioutil.lo \
log/syslog.lo \ log/syslog.lo \
...@@ -1971,7 +1958,6 @@ libgo_go_objs = \ ...@@ -1971,7 +1958,6 @@ libgo_go_objs = \
old/template.lo \ old/template.lo \
$(os_lib_inotify_lo) \ $(os_lib_inotify_lo) \
os/user.lo \ os/user.lo \
os/signal.lo \
path/filepath.lo \ path/filepath.lo \
regexp/syntax.lo \ regexp/syntax.lo \
net/rpc/jsonrpc.lo \ net/rpc/jsonrpc.lo \
...@@ -2175,6 +2161,7 @@ TEST_PACKAGES = \ ...@@ -2175,6 +2161,7 @@ TEST_PACKAGES = \
$(exp_inotify_check) \ $(exp_inotify_check) \
exp/norm/check \ exp/norm/check \
exp/proxy/check \ exp/proxy/check \
exp/signal/check \
exp/terminal/check \ exp/terminal/check \
exp/utf8string/check \ exp/utf8string/check \
html/template/check \ html/template/check \
...@@ -2194,7 +2181,6 @@ TEST_PACKAGES = \ ...@@ -2194,7 +2181,6 @@ TEST_PACKAGES = \
image/draw/check \ image/draw/check \
image/jpeg/check \ image/jpeg/check \
image/png/check \ image/png/check \
image/tiff/check \
index/suffixarray/check \ index/suffixarray/check \
io/ioutil/check \ io/ioutil/check \
log/syslog/check \ log/syslog/check \
...@@ -2217,7 +2203,6 @@ TEST_PACKAGES = \ ...@@ -2217,7 +2203,6 @@ TEST_PACKAGES = \
old/template/check \ old/template/check \
os/exec/check \ os/exec/check \
os/user/check \ os/user/check \
os/signal/check \
path/filepath/check \ path/filepath/check \
regexp/syntax/check \ regexp/syntax/check \
sync/atomic/check \ sync/atomic/check \
...@@ -5171,6 +5156,16 @@ exp/proxy/check: $(CHECK_DEPS) ...@@ -5171,6 +5156,16 @@ exp/proxy/check: $(CHECK_DEPS)
@$(CHECK) @$(CHECK)
.PHONY: exp/proxy/check .PHONY: exp/proxy/check
@go_include@ exp/signal.lo.dep
exp/signal.lo.dep: $(go_exp_signal_files)
$(BUILDDEPS)
exp/signal.lo: $(go_exp_signal_files)
$(BUILDPACKAGE)
exp/signal/check: $(CHECK_DEPS)
@$(MKDIR_P) exp/signal
@$(CHECK)
.PHONY: exp/signal/check
@go_include@ exp/terminal.lo.dep @go_include@ exp/terminal.lo.dep
exp/terminal.lo.dep: $(go_exp_terminal_files) exp/terminal.lo.dep: $(go_exp_terminal_files)
$(BUILDDEPS) $(BUILDDEPS)
...@@ -5340,16 +5335,6 @@ hash/fnv/check: $(CHECK_DEPS) ...@@ -5340,16 +5335,6 @@ hash/fnv/check: $(CHECK_DEPS)
@$(CHECK) @$(CHECK)
.PHONY: hash/fnv/check .PHONY: hash/fnv/check
@go_include@ image/bmp.lo.dep
image/bmp.lo.dep: $(go_image_bmp_files)
$(BUILDDEPS)
image/bmp.lo: $(go_image_bmp_files)
$(BUILDPACKAGE)
image/bmp/check: $(CHECK_DEPS)
@$(MKDIR_P) image/bmp
@$(CHECK)
.PHONY: image/bmp/check
@go_include@ image/color.lo.dep @go_include@ image/color.lo.dep
image/color.lo.dep: $(go_image_color_files) image/color.lo.dep: $(go_image_color_files)
$(BUILDDEPS) $(BUILDDEPS)
...@@ -5400,16 +5385,6 @@ image/png/check: $(CHECK_DEPS) ...@@ -5400,16 +5385,6 @@ image/png/check: $(CHECK_DEPS)
@$(CHECK) @$(CHECK)
.PHONY: image/png/check .PHONY: image/png/check
@go_include@ image/tiff.lo.dep
image/tiff.lo.dep: $(go_image_tiff_files)
$(BUILDDEPS)
image/tiff.lo: $(go_image_tiff_files)
$(BUILDPACKAGE)
image/tiff/check: $(CHECK_DEPS)
@$(MKDIR_P) image/tiff
@$(CHECK)
.PHONY: image/tiff/check
@go_include@ index/suffixarray.lo.dep @go_include@ index/suffixarray.lo.dep
index/suffixarray.lo.dep: $(go_index_suffixarray_files) index/suffixarray.lo.dep: $(go_index_suffixarray_files)
$(BUILDDEPS) $(BUILDDEPS)
...@@ -5652,16 +5627,6 @@ os/user/check: $(CHECK_DEPS) ...@@ -5652,16 +5627,6 @@ os/user/check: $(CHECK_DEPS)
@$(CHECK) @$(CHECK)
.PHONY: os/user/check .PHONY: os/user/check
@go_include@ os/signal.lo.dep
os/signal.lo.dep: $(go_os_signal_files)
$(BUILDDEPS)
os/signal.lo: $(go_os_signal_files)
$(BUILDPACKAGE)
os/signal/check: $(CHECK_DEPS)
@$(MKDIR_P) os/signal
@$(CHECK)
.PHONY: os/signal/check
@go_include@ path/filepath.lo.dep @go_include@ path/filepath.lo.dep
path/filepath.lo.dep: $(go_path_filepath_files) path/filepath.lo.dep: $(go_path_filepath_files)
$(BUILDDEPS) $(BUILDDEPS)
...@@ -5971,6 +5936,8 @@ exp/norm.gox: exp/norm.lo ...@@ -5971,6 +5936,8 @@ exp/norm.gox: exp/norm.lo
$(BUILDGOX) $(BUILDGOX)
exp/proxy.gox: exp/proxy.lo exp/proxy.gox: exp/proxy.lo
$(BUILDGOX) $(BUILDGOX)
exp/signal.gox: exp/signal.lo
$(BUILDGOX)
exp/terminal.gox: exp/terminal.lo exp/terminal.gox: exp/terminal.lo
$(BUILDGOX) $(BUILDGOX)
exp/types.gox: exp/types.lo exp/types.gox: exp/types.lo
...@@ -6005,8 +5972,6 @@ hash/crc64.gox: hash/crc64.lo ...@@ -6005,8 +5972,6 @@ hash/crc64.gox: hash/crc64.lo
hash/fnv.gox: hash/fnv.lo hash/fnv.gox: hash/fnv.lo
$(BUILDGOX) $(BUILDGOX)
image/bmp.gox: image/bmp.lo
$(BUILDGOX)
image/color.gox: image/color.lo image/color.gox: image/color.lo
$(BUILDGOX) $(BUILDGOX)
image/draw.gox: image/draw.lo image/draw.gox: image/draw.lo
...@@ -6017,8 +5982,6 @@ image/jpeg.gox: image/jpeg.lo ...@@ -6017,8 +5982,6 @@ image/jpeg.gox: image/jpeg.lo
$(BUILDGOX) $(BUILDGOX)
image/png.gox: image/png.lo image/png.gox: image/png.lo
$(BUILDGOX) $(BUILDGOX)
image/tiff.gox: image/tiff.lo
$(BUILDGOX)
index/suffixarray.gox: index/suffixarray.lo index/suffixarray.gox: index/suffixarray.lo
$(BUILDGOX) $(BUILDGOX)
...@@ -6077,8 +6040,6 @@ os/exec.gox: os/exec.lo ...@@ -6077,8 +6040,6 @@ os/exec.gox: os/exec.lo
$(BUILDGOX) $(BUILDGOX)
os/user.gox: os/user.lo os/user.gox: os/user.lo
$(BUILDGOX) $(BUILDGOX)
os/signal.gox: os/signal.lo
$(BUILDGOX)
path/filepath.gox: path/filepath.lo path/filepath.gox: path/filepath.lo
$(BUILDGOX) $(BUILDGOX)
......
...@@ -117,7 +117,7 @@ func (rc *ReadCloser) Close() error { ...@@ -117,7 +117,7 @@ func (rc *ReadCloser) Close() error {
} }
// Open returns a ReadCloser that provides access to the File's contents. // Open returns a ReadCloser that provides access to the File's contents.
// It is safe to Open and Read from files concurrently. // Multiple files may be read concurrently.
func (f *File) Open() (rc io.ReadCloser, err error) { func (f *File) Open() (rc io.ReadCloser, err error) {
bodyOffset, err := f.findBodyOffset() bodyOffset, err := f.findBodyOffset()
if err != nil { if err != nil {
......
...@@ -69,8 +69,23 @@ var tests = []ZipTest{ ...@@ -69,8 +69,23 @@ var tests = []ZipTest{
}, },
}, },
}, },
{Name: "readme.zip"}, {
{Name: "readme.notzip", Error: ErrFormat}, Name: "symlink.zip",
File: []ZipTestFile{
{
Name: "symlink",
Content: []byte("../target"),
Mode: 0777 | os.ModeSymlink,
},
},
},
{
Name: "readme.zip",
},
{
Name: "readme.notzip",
Error: ErrFormat,
},
{ {
Name: "dd.zip", Name: "dd.zip",
File: []ZipTestFile{ File: []ZipTestFile{
......
...@@ -57,8 +57,8 @@ type FileHeader struct { ...@@ -57,8 +57,8 @@ type FileHeader struct {
} }
// FileInfo returns an os.FileInfo for the FileHeader. // FileInfo returns an os.FileInfo for the FileHeader.
func (fh *FileHeader) FileInfo() os.FileInfo { func (h *FileHeader) FileInfo() os.FileInfo {
return headerFileInfo{fh} return headerFileInfo{h}
} }
// headerFileInfo implements os.FileInfo. // headerFileInfo implements os.FileInfo.
...@@ -71,6 +71,7 @@ func (fi headerFileInfo) Size() int64 { return int64(fi.fh.UncompressedSi ...@@ -71,6 +71,7 @@ func (fi headerFileInfo) Size() int64 { return int64(fi.fh.UncompressedSi
func (fi headerFileInfo) IsDir() bool { return fi.Mode().IsDir() } func (fi headerFileInfo) IsDir() bool { return fi.Mode().IsDir() }
func (fi headerFileInfo) ModTime() time.Time { return fi.fh.ModTime() } func (fi headerFileInfo) ModTime() time.Time { return fi.fh.ModTime() }
func (fi headerFileInfo) Mode() os.FileMode { return fi.fh.Mode() } func (fi headerFileInfo) Mode() os.FileMode { return fi.fh.Mode() }
func (fi headerFileInfo) Sys() interface{} { return fi.fh }
// FileInfoHeader creates a partially-populated FileHeader from an // FileInfoHeader creates a partially-populated FileHeader from an
// os.FileInfo. // os.FileInfo.
...@@ -151,13 +152,20 @@ func (h *FileHeader) SetModTime(t time.Time) { ...@@ -151,13 +152,20 @@ func (h *FileHeader) SetModTime(t time.Time) {
h.ModifiedDate, h.ModifiedTime = timeToMsDosTime(t) h.ModifiedDate, h.ModifiedTime = timeToMsDosTime(t)
} }
// traditional names for Unix constants
const ( const (
s_IFMT = 0xf000 // Unix constants. The specification doesn't mention them,
s_IFDIR = 0x4000 // but these seem to be the values agreed on by tools.
s_IFREG = 0x8000 s_IFMT = 0xf000
s_ISUID = 0x800 s_IFSOCK = 0xc000
s_ISGID = 0x400 s_IFLNK = 0xa000
s_IFREG = 0x8000
s_IFBLK = 0x6000
s_IFDIR = 0x4000
s_IFCHR = 0x2000
s_IFIFO = 0x1000
s_ISUID = 0x800
s_ISGID = 0x400
s_ISVTX = 0x200
msdosDir = 0x10 msdosDir = 0x10
msdosReadOnly = 0x01 msdosReadOnly = 0x01
...@@ -205,10 +213,23 @@ func msdosModeToFileMode(m uint32) (mode os.FileMode) { ...@@ -205,10 +213,23 @@ func msdosModeToFileMode(m uint32) (mode os.FileMode) {
func fileModeToUnixMode(mode os.FileMode) uint32 { func fileModeToUnixMode(mode os.FileMode) uint32 {
var m uint32 var m uint32
if mode&os.ModeDir != 0 { switch mode & os.ModeType {
m = s_IFDIR default:
} else {
m = s_IFREG m = s_IFREG
case os.ModeDir:
m = s_IFDIR
case os.ModeSymlink:
m = s_IFLNK
case os.ModeNamedPipe:
m = s_IFIFO
case os.ModeSocket:
m = s_IFSOCK
case os.ModeDevice:
if mode&os.ModeCharDevice != 0 {
m = s_IFCHR
} else {
m = s_IFBLK
}
} }
if mode&os.ModeSetuid != 0 { if mode&os.ModeSetuid != 0 {
m |= s_ISUID m |= s_ISUID
...@@ -216,13 +237,29 @@ func fileModeToUnixMode(mode os.FileMode) uint32 { ...@@ -216,13 +237,29 @@ func fileModeToUnixMode(mode os.FileMode) uint32 {
if mode&os.ModeSetgid != 0 { if mode&os.ModeSetgid != 0 {
m |= s_ISGID m |= s_ISGID
} }
if mode&os.ModeSticky != 0 {
m |= s_ISVTX
}
return m | uint32(mode&0777) return m | uint32(mode&0777)
} }
func unixModeToFileMode(m uint32) os.FileMode { func unixModeToFileMode(m uint32) os.FileMode {
var mode os.FileMode mode := os.FileMode(m & 0777)
if m&s_IFMT == s_IFDIR { switch m & s_IFMT {
case s_IFBLK:
mode |= os.ModeDevice
case s_IFCHR:
mode |= os.ModeDevice | os.ModeCharDevice
case s_IFDIR:
mode |= os.ModeDir mode |= os.ModeDir
case s_IFIFO:
mode |= os.ModeNamedPipe
case s_IFLNK:
mode |= os.ModeSymlink
case s_IFREG:
// nothing to do
case s_IFSOCK:
mode |= os.ModeSocket
} }
if m&s_ISGID != 0 { if m&s_ISGID != 0 {
mode |= os.ModeSetgid mode |= os.ModeSetgid
...@@ -230,5 +267,8 @@ func unixModeToFileMode(m uint32) os.FileMode { ...@@ -230,5 +267,8 @@ func unixModeToFileMode(m uint32) os.FileMode {
if m&s_ISUID != 0 { if m&s_ISUID != 0 {
mode |= os.ModeSetuid mode |= os.ModeSetuid
} }
return mode | os.FileMode(m&0777) if m&s_ISVTX != 0 {
mode |= os.ModeSticky
}
return mode
} }
...@@ -19,7 +19,7 @@ import ( ...@@ -19,7 +19,7 @@ import (
// Writer implements a zip file writer. // Writer implements a zip file writer.
type Writer struct { type Writer struct {
*countWriter countWriter
dir []*header dir []*header
last *fileWriter last *fileWriter
closed bool closed bool
...@@ -32,7 +32,7 @@ type header struct { ...@@ -32,7 +32,7 @@ type header struct {
// NewWriter returns a new Writer writing a zip file to w. // NewWriter returns a new Writer writing a zip file to w.
func NewWriter(w io.Writer) *Writer { func NewWriter(w io.Writer) *Writer {
return &Writer{countWriter: &countWriter{w: bufio.NewWriter(w)}} return &Writer{countWriter: countWriter{w: bufio.NewWriter(w)}}
} }
// Close finishes writing the zip file by writing the central directory. // Close finishes writing the zip file by writing the central directory.
......
...@@ -47,10 +47,10 @@ var writeTests = []WriteTest{ ...@@ -47,10 +47,10 @@ var writeTests = []WriteTest{
Mode: 0755 | os.ModeSetgid, Mode: 0755 | os.ModeSetgid,
}, },
{ {
Name: "setgid", Name: "symlink",
Data: []byte("setgid file"), Data: []byte("../link/target"),
Method: Deflate, Method: Deflate,
Mode: 0755 | os.ModeSetgid, Mode: 0755 | os.ModeSymlink,
}, },
} }
......
...@@ -85,4 +85,7 @@ func TestFileHeaderRoundTrip(t *testing.T) { ...@@ -85,4 +85,7 @@ func TestFileHeaderRoundTrip(t *testing.T) {
if !reflect.DeepEqual(fh, fh2) { if !reflect.DeepEqual(fh, fh2) {
t.Errorf("mismatch\n input=%#v\noutput=%#v\nerr=%v", fh, fh2, err) t.Errorf("mismatch\n input=%#v\noutput=%#v\nerr=%v", fh, fh2, err)
} }
if sysfh, ok := fi.Sys().(*FileHeader); !ok && sysfh != fh {
t.Errorf("Sys didn't return original *FileHeader")
}
} }
...@@ -9,8 +9,8 @@ package bufio ...@@ -9,8 +9,8 @@ package bufio
import ( import (
"bytes" "bytes"
"errors"
"io" "io"
"strconv"
"unicode/utf8" "unicode/utf8"
) )
...@@ -18,28 +18,14 @@ const ( ...@@ -18,28 +18,14 @@ const (
defaultBufSize = 4096 defaultBufSize = 4096
) )
// Errors introduced by this package.
type Error struct {
ErrorString string
}
func (err *Error) Error() string { return err.ErrorString }
var ( var (
ErrInvalidUnreadByte error = &Error{"bufio: invalid use of UnreadByte"} ErrInvalidUnreadByte = errors.New("bufio: invalid use of UnreadByte")
ErrInvalidUnreadRune error = &Error{"bufio: invalid use of UnreadRune"} ErrInvalidUnreadRune = errors.New("bufio: invalid use of UnreadRune")
ErrBufferFull error = &Error{"bufio: buffer full"} ErrBufferFull = errors.New("bufio: buffer full")
ErrNegativeCount error = &Error{"bufio: negative count"} ErrNegativeCount = errors.New("bufio: negative count")
errInternal error = &Error{"bufio: internal error"} errInternal = errors.New("bufio: internal error")
) )
// BufSizeError is the error representing an invalid buffer size.
type BufSizeError int
func (b BufSizeError) Error() string {
return "bufio: bad buffer size " + strconv.Itoa(int(b))
}
// Buffered input. // Buffered input.
// Reader implements buffering for an io.Reader object. // Reader implements buffering for an io.Reader object.
...@@ -54,35 +40,29 @@ type Reader struct { ...@@ -54,35 +40,29 @@ type Reader struct {
const minReadBufferSize = 16 const minReadBufferSize = 16
// NewReaderSize creates a new Reader whose buffer has the specified size, // NewReaderSize returns a new Reader whose buffer has at least the specified
// which must be at least 16 bytes. If the argument io.Reader is already a // size. If the argument io.Reader is already a Reader with large enough
// Reader with large enough size, it returns the underlying Reader. // size, it returns the underlying Reader.
// It returns the Reader and any error. func NewReaderSize(rd io.Reader, size int) *Reader {
func NewReaderSize(rd io.Reader, size int) (*Reader, error) {
if size < minReadBufferSize {
return nil, BufSizeError(size)
}
// Is it already a Reader? // Is it already a Reader?
b, ok := rd.(*Reader) b, ok := rd.(*Reader)
if ok && len(b.buf) >= size { if ok && len(b.buf) >= size {
return b, nil return b
}
if size < minReadBufferSize {
size = minReadBufferSize
}
return &Reader{
buf: make([]byte, size),
rd: rd,
lastByte: -1,
lastRuneSize: -1,
} }
b = new(Reader)
b.buf = make([]byte, size)
b.rd = rd
b.lastByte = -1
b.lastRuneSize = -1
return b, nil
} }
// NewReader returns a new Reader whose buffer has the default size. // NewReader returns a new Reader whose buffer has the default size.
func NewReader(rd io.Reader) *Reader { func NewReader(rd io.Reader) *Reader {
b, err := NewReaderSize(rd, defaultBufSize) return NewReaderSize(rd, defaultBufSize)
if err != nil {
// cannot happen - defaultBufSize is a valid size
panic(err)
}
return b
} }
// fill reads a new chunk into the buffer. // fill reads a new chunk into the buffer.
...@@ -208,7 +188,8 @@ func (b *Reader) UnreadByte() error { ...@@ -208,7 +188,8 @@ func (b *Reader) UnreadByte() error {
} }
// ReadRune reads a single UTF-8 encoded Unicode character and returns the // ReadRune reads a single UTF-8 encoded Unicode character and returns the
// rune and its size in bytes. // rune and its size in bytes. If the encoded rune is invalid, it consumes one byte
// and returns unicode.ReplacementChar (U+FFFD) with a size of 1.
func (b *Reader) ReadRune() (r rune, size int, err error) { func (b *Reader) ReadRune() (r rune, size int, err error) {
for b.r+utf8.UTFMax > b.w && !utf8.FullRune(b.buf[b.r:b.w]) && b.err == nil { for b.r+utf8.UTFMax > b.w && !utf8.FullRune(b.buf[b.r:b.w]) && b.err == nil {
b.fill() b.fill()
...@@ -392,6 +373,8 @@ func (b *Reader) ReadString(delim byte) (line string, err error) { ...@@ -392,6 +373,8 @@ func (b *Reader) ReadString(delim byte) (line string, err error) {
// buffered output // buffered output
// Writer implements buffering for an io.Writer object. // Writer implements buffering for an io.Writer object.
// If an error occurs writing to a Writer, no more data will be
// accepted and all subsequent writes will return the error.
type Writer struct { type Writer struct {
err error err error
buf []byte buf []byte
...@@ -399,33 +382,27 @@ type Writer struct { ...@@ -399,33 +382,27 @@ type Writer struct {
wr io.Writer wr io.Writer
} }
// NewWriterSize creates a new Writer whose buffer has the specified size, // NewWriterSize returns a new Writer whose buffer has at least the specified
// which must be greater than zero. If the argument io.Writer is already a // size. If the argument io.Writer is already a Writer with large enough
// Writer with large enough size, it returns the underlying Writer. // size, it returns the underlying Writer.
// It returns the Writer and any error. func NewWriterSize(wr io.Writer, size int) *Writer {
func NewWriterSize(wr io.Writer, size int) (*Writer, error) {
if size <= 0 {
return nil, BufSizeError(size)
}
// Is it already a Writer? // Is it already a Writer?
b, ok := wr.(*Writer) b, ok := wr.(*Writer)
if ok && len(b.buf) >= size { if ok && len(b.buf) >= size {
return b, nil return b
}
if size <= 0 {
size = defaultBufSize
} }
b = new(Writer) b = new(Writer)
b.buf = make([]byte, size) b.buf = make([]byte, size)
b.wr = wr b.wr = wr
return b, nil return b
} }
// NewWriter returns a new Writer whose buffer has the default size. // NewWriter returns a new Writer whose buffer has the default size.
func NewWriter(wr io.Writer) *Writer { func NewWriter(wr io.Writer) *Writer {
b, err := NewWriterSize(wr, defaultBufSize) return NewWriterSize(wr, defaultBufSize)
if err != nil {
// cannot happen - defaultBufSize is valid size
panic(err)
}
return b
} }
// Flush writes any buffered data to the underlying io.Writer. // Flush writes any buffered data to the underlying io.Writer.
......
...@@ -161,7 +161,7 @@ func TestReader(t *testing.T) { ...@@ -161,7 +161,7 @@ func TestReader(t *testing.T) {
bufreader := bufreaders[j] bufreader := bufreaders[j]
bufsize := bufsizes[k] bufsize := bufsizes[k]
read := readmaker.fn(bytes.NewBufferString(text)) read := readmaker.fn(bytes.NewBufferString(text))
buf, _ := NewReaderSize(read, bufsize) buf := NewReaderSize(read, bufsize)
s := bufreader.fn(buf) s := bufreader.fn(buf)
if s != text { if s != text {
t.Errorf("reader=%s fn=%s bufsize=%d want=%q got=%q", t.Errorf("reader=%s fn=%s bufsize=%d want=%q got=%q",
...@@ -379,18 +379,14 @@ func TestWriter(t *testing.T) { ...@@ -379,18 +379,14 @@ func TestWriter(t *testing.T) {
// and that the data is correct. // and that the data is correct.
w.Reset() w.Reset()
buf, e := NewWriterSize(w, bs) buf := NewWriterSize(w, bs)
context := fmt.Sprintf("nwrite=%d bufsize=%d", nwrite, bs) context := fmt.Sprintf("nwrite=%d bufsize=%d", nwrite, bs)
if e != nil {
t.Errorf("%s: NewWriterSize %d: %v", context, bs, e)
continue
}
n, e1 := buf.Write(data[0:nwrite]) n, e1 := buf.Write(data[0:nwrite])
if e1 != nil || n != nwrite { if e1 != nil || n != nwrite {
t.Errorf("%s: buf.Write %d = %d, %v", context, nwrite, n, e1) t.Errorf("%s: buf.Write %d = %d, %v", context, nwrite, n, e1)
continue continue
} }
if e = buf.Flush(); e != nil { if e := buf.Flush(); e != nil {
t.Errorf("%s: buf.Flush = %v", context, e) t.Errorf("%s: buf.Flush = %v", context, e)
} }
...@@ -447,23 +443,14 @@ func TestWriteErrors(t *testing.T) { ...@@ -447,23 +443,14 @@ func TestWriteErrors(t *testing.T) {
func TestNewReaderSizeIdempotent(t *testing.T) { func TestNewReaderSizeIdempotent(t *testing.T) {
const BufSize = 1000 const BufSize = 1000
b, err := NewReaderSize(bytes.NewBufferString("hello world"), BufSize) b := NewReaderSize(bytes.NewBufferString("hello world"), BufSize)
if err != nil {
t.Error("NewReaderSize create fail", err)
}
// Does it recognize itself? // Does it recognize itself?
b1, err2 := NewReaderSize(b, BufSize) b1 := NewReaderSize(b, BufSize)
if err2 != nil {
t.Error("NewReaderSize #2 create fail", err2)
}
if b1 != b { if b1 != b {
t.Error("NewReaderSize did not detect underlying Reader") t.Error("NewReaderSize did not detect underlying Reader")
} }
// Does it wrap if existing buffer is too small? // Does it wrap if existing buffer is too small?
b2, err3 := NewReaderSize(b, 2*BufSize) b2 := NewReaderSize(b, 2*BufSize)
if err3 != nil {
t.Error("NewReaderSize #3 create fail", err3)
}
if b2 == b { if b2 == b {
t.Error("NewReaderSize did not enlarge buffer") t.Error("NewReaderSize did not enlarge buffer")
} }
...@@ -471,23 +458,14 @@ func TestNewReaderSizeIdempotent(t *testing.T) { ...@@ -471,23 +458,14 @@ func TestNewReaderSizeIdempotent(t *testing.T) {
func TestNewWriterSizeIdempotent(t *testing.T) { func TestNewWriterSizeIdempotent(t *testing.T) {
const BufSize = 1000 const BufSize = 1000
b, err := NewWriterSize(new(bytes.Buffer), BufSize) b := NewWriterSize(new(bytes.Buffer), BufSize)
if err != nil {
t.Error("NewWriterSize create fail", err)
}
// Does it recognize itself? // Does it recognize itself?
b1, err2 := NewWriterSize(b, BufSize) b1 := NewWriterSize(b, BufSize)
if err2 != nil {
t.Error("NewWriterSize #2 create fail", err2)
}
if b1 != b { if b1 != b {
t.Error("NewWriterSize did not detect underlying Writer") t.Error("NewWriterSize did not detect underlying Writer")
} }
// Does it wrap if existing buffer is too small? // Does it wrap if existing buffer is too small?
b2, err3 := NewWriterSize(b, 2*BufSize) b2 := NewWriterSize(b, 2*BufSize)
if err3 != nil {
t.Error("NewWriterSize #3 create fail", err3)
}
if b2 == b { if b2 == b {
t.Error("NewWriterSize did not enlarge buffer") t.Error("NewWriterSize did not enlarge buffer")
} }
...@@ -496,10 +474,7 @@ func TestNewWriterSizeIdempotent(t *testing.T) { ...@@ -496,10 +474,7 @@ func TestNewWriterSizeIdempotent(t *testing.T) {
func TestWriteString(t *testing.T) { func TestWriteString(t *testing.T) {
const BufSize = 8 const BufSize = 8
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
b, err := NewWriterSize(buf, BufSize) b := NewWriterSize(buf, BufSize)
if err != nil {
t.Error("NewWriterSize create fail", err)
}
b.WriteString("0") // easy b.WriteString("0") // easy
b.WriteString("123456") // still easy b.WriteString("123456") // still easy
b.WriteString("7890") // easy after flush b.WriteString("7890") // easy after flush
...@@ -516,10 +491,7 @@ func TestWriteString(t *testing.T) { ...@@ -516,10 +491,7 @@ func TestWriteString(t *testing.T) {
func TestBufferFull(t *testing.T) { func TestBufferFull(t *testing.T) {
const longString = "And now, hello, world! It is the time for all good men to come to the aid of their party" const longString = "And now, hello, world! It is the time for all good men to come to the aid of their party"
buf, err := NewReaderSize(strings.NewReader(longString), minReadBufferSize) buf := NewReaderSize(strings.NewReader(longString), minReadBufferSize)
if err != nil {
t.Fatal("NewReaderSize:", err)
}
line, err := buf.ReadSlice('!') line, err := buf.ReadSlice('!')
if string(line) != "And now, hello, " || err != ErrBufferFull { if string(line) != "And now, hello, " || err != ErrBufferFull {
t.Errorf("first ReadSlice(,) = %q, %v", line, err) t.Errorf("first ReadSlice(,) = %q, %v", line, err)
...@@ -533,7 +505,7 @@ func TestBufferFull(t *testing.T) { ...@@ -533,7 +505,7 @@ func TestBufferFull(t *testing.T) {
func TestPeek(t *testing.T) { func TestPeek(t *testing.T) {
p := make([]byte, 10) p := make([]byte, 10)
// string is 16 (minReadBufferSize) long. // string is 16 (minReadBufferSize) long.
buf, _ := NewReaderSize(strings.NewReader("abcdefghijklmnop"), minReadBufferSize) buf := NewReaderSize(strings.NewReader("abcdefghijklmnop"), minReadBufferSize)
if s, err := buf.Peek(1); string(s) != "a" || err != nil { if s, err := buf.Peek(1); string(s) != "a" || err != nil {
t.Fatalf("want %q got %q, err=%v", "a", string(s), err) t.Fatalf("want %q got %q, err=%v", "a", string(s), err)
} }
...@@ -609,7 +581,7 @@ func testReadLine(t *testing.T, input []byte) { ...@@ -609,7 +581,7 @@ func testReadLine(t *testing.T, input []byte) {
for stride := 1; stride < 2; stride++ { for stride := 1; stride < 2; stride++ {
done := 0 done := 0
reader := testReader{input, stride} reader := testReader{input, stride}
l, _ := NewReaderSize(&reader, len(input)+1) l := NewReaderSize(&reader, len(input)+1)
for { for {
line, isPrefix, err := l.ReadLine() line, isPrefix, err := l.ReadLine()
if len(line) > 0 && err != nil { if len(line) > 0 && err != nil {
...@@ -646,7 +618,7 @@ func TestLineTooLong(t *testing.T) { ...@@ -646,7 +618,7 @@ func TestLineTooLong(t *testing.T) {
data = append(data, '0'+byte(i%10)) data = append(data, '0'+byte(i%10))
} }
buf := bytes.NewBuffer(data) buf := bytes.NewBuffer(data)
l, _ := NewReaderSize(buf, minReadBufferSize) l := NewReaderSize(buf, minReadBufferSize)
line, isPrefix, err := l.ReadLine() line, isPrefix, err := l.ReadLine()
if !isPrefix || !bytes.Equal(line, data[:minReadBufferSize]) || err != nil { if !isPrefix || !bytes.Equal(line, data[:minReadBufferSize]) || err != nil {
t.Errorf("bad result for first line: got %q want %q %v", line, data[:minReadBufferSize], err) t.Errorf("bad result for first line: got %q want %q %v", line, data[:minReadBufferSize], err)
...@@ -673,7 +645,7 @@ func TestReadAfterLines(t *testing.T) { ...@@ -673,7 +645,7 @@ func TestReadAfterLines(t *testing.T) {
inbuf := bytes.NewBuffer([]byte(line1 + "\n" + restData)) inbuf := bytes.NewBuffer([]byte(line1 + "\n" + restData))
outbuf := new(bytes.Buffer) outbuf := new(bytes.Buffer)
maxLineLength := len(line1) + len(restData)/2 maxLineLength := len(line1) + len(restData)/2
l, _ := NewReaderSize(inbuf, maxLineLength) l := NewReaderSize(inbuf, maxLineLength)
line, isPrefix, err := l.ReadLine() line, isPrefix, err := l.ReadLine()
if isPrefix || err != nil || string(line) != line1 { if isPrefix || err != nil || string(line) != line1 {
t.Errorf("bad result for first line: isPrefix=%v err=%v line=%q", isPrefix, err, string(line)) t.Errorf("bad result for first line: isPrefix=%v err=%v line=%q", isPrefix, err, string(line))
...@@ -688,7 +660,7 @@ func TestReadAfterLines(t *testing.T) { ...@@ -688,7 +660,7 @@ func TestReadAfterLines(t *testing.T) {
} }
func TestReadEmptyBuffer(t *testing.T) { func TestReadEmptyBuffer(t *testing.T) {
l, _ := NewReaderSize(bytes.NewBuffer(nil), minReadBufferSize) l := NewReaderSize(new(bytes.Buffer), minReadBufferSize)
line, isPrefix, err := l.ReadLine() line, isPrefix, err := l.ReadLine()
if err != io.EOF { if err != io.EOF {
t.Errorf("expected EOF from ReadLine, got '%s' %t %s", line, isPrefix, err) t.Errorf("expected EOF from ReadLine, got '%s' %t %s", line, isPrefix, err)
...@@ -696,7 +668,7 @@ func TestReadEmptyBuffer(t *testing.T) { ...@@ -696,7 +668,7 @@ func TestReadEmptyBuffer(t *testing.T) {
} }
func TestLinesAfterRead(t *testing.T) { func TestLinesAfterRead(t *testing.T) {
l, _ := NewReaderSize(bytes.NewBuffer([]byte("foo")), minReadBufferSize) l := NewReaderSize(bytes.NewBuffer([]byte("foo")), minReadBufferSize)
_, err := ioutil.ReadAll(l) _, err := ioutil.ReadAll(l)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
...@@ -752,10 +724,7 @@ func TestReadLineNewlines(t *testing.T) { ...@@ -752,10 +724,7 @@ func TestReadLineNewlines(t *testing.T) {
} }
func testReadLineNewlines(t *testing.T, input string, expect []readLineResult) { func testReadLineNewlines(t *testing.T, input string, expect []readLineResult) {
b, err := NewReaderSize(strings.NewReader(input), minReadBufferSize) b := NewReaderSize(strings.NewReader(input), minReadBufferSize)
if err != nil {
t.Fatal(err)
}
for i, e := range expect { for i, e := range expect {
line, isPrefix, err := b.ReadLine() line, isPrefix, err := b.ReadLine()
if bytes.Compare(line, e.line) != 0 { if bytes.Compare(line, e.line) != 0 {
......
...@@ -57,10 +57,13 @@ func (b *Buffer) String() string { ...@@ -57,10 +57,13 @@ func (b *Buffer) String() string {
func (b *Buffer) Len() int { return len(b.buf) - b.off } func (b *Buffer) Len() int { return len(b.buf) - b.off }
// Truncate discards all but the first n unread bytes from the buffer. // Truncate discards all but the first n unread bytes from the buffer.
// It is an error to call b.Truncate(n) with n > b.Len(). // It panics if n is negative or greater than the length of the buffer.
func (b *Buffer) Truncate(n int) { func (b *Buffer) Truncate(n int) {
b.lastRead = opInvalid b.lastRead = opInvalid
if n == 0 { switch {
case n < 0 || n > b.Len():
panic("bytes.Buffer: truncation out of range")
case n == 0:
// Reuse buffer space. // Reuse buffer space.
b.off = 0 b.off = 0
} }
...@@ -366,14 +369,15 @@ func (b *Buffer) ReadString(delim byte) (line string, err error) { ...@@ -366,14 +369,15 @@ func (b *Buffer) ReadString(delim byte) (line string, err error) {
// buf should have the desired capacity but a length of zero. // buf should have the desired capacity but a length of zero.
// //
// In most cases, new(Buffer) (or just declaring a Buffer variable) is // In most cases, new(Buffer) (or just declaring a Buffer variable) is
// preferable to NewBuffer. In particular, passing a non-empty buf to // sufficient to initialize a Buffer.
// NewBuffer and then writing to the Buffer will overwrite buf, not append to
// it.
func NewBuffer(buf []byte) *Buffer { return &Buffer{buf: buf} } func NewBuffer(buf []byte) *Buffer { return &Buffer{buf: buf} }
// NewBufferString creates and initializes a new Buffer using string s as its // NewBufferString creates and initializes a new Buffer using string s as its
// initial contents. It is intended to prepare a buffer to read an existing // initial contents. It is intended to prepare a buffer to read an existing
// string. See the warnings about NewBuffer; similar issues apply here. // string.
//
// In most cases, new(Buffer) (or just declaring a Buffer variable) is
// sufficient to initialize a Buffer.
func NewBufferString(s string) *Buffer { func NewBufferString(s string) *Buffer {
return &Buffer{buf: []byte(s)} return &Buffer{buf: []byte(s)}
} }
...@@ -102,7 +102,7 @@ func (d *compressor) fillDeflate(b []byte) int { ...@@ -102,7 +102,7 @@ func (d *compressor) fillDeflate(b []byte) int {
if d.blockStart >= windowSize { if d.blockStart >= windowSize {
d.blockStart -= windowSize d.blockStart -= windowSize
} else { } else {
d.blockStart = skipNever d.blockStart = math.MaxInt32
} }
d.hashOffset += windowSize d.hashOffset += windowSize
} }
......
...@@ -229,14 +229,14 @@ func testToFromWithLevel(t *testing.T, level int, input []byte, name string) err ...@@ -229,14 +229,14 @@ func testToFromWithLevel(t *testing.T, level int, input []byte, name string) err
} }
func testToFromWithLevelAndLimit(t *testing.T, level int, input []byte, name string, limit int) error { func testToFromWithLevelAndLimit(t *testing.T, level int, input []byte, name string, limit int) error {
buffer := bytes.NewBuffer(nil) var buffer bytes.Buffer
w := NewWriter(buffer, level) w := NewWriter(&buffer, level)
w.Write(input) w.Write(input)
w.Close() w.Close()
if limit > 0 && buffer.Len() > limit { if limit > 0 && buffer.Len() > limit {
t.Errorf("level: %d, len(compress(data)) = %d > limit = %d", level, buffer.Len(), limit) t.Errorf("level: %d, len(compress(data)) = %d > limit = %d", level, buffer.Len(), limit)
} }
r := NewReader(buffer) r := NewReader(&buffer)
out, err := ioutil.ReadAll(r) out, err := ioutil.ReadAll(r)
if err != nil { if err != nil {
t.Errorf("read: %s", err) t.Errorf("read: %s", err)
......
...@@ -121,61 +121,6 @@ func (h *huffmanEncoder) bitLength(freq []int32) int64 { ...@@ -121,61 +121,6 @@ func (h *huffmanEncoder) bitLength(freq []int32) int64 {
return total return total
} }
// Generate elements in the chain using an iterative algorithm.
func (h *huffmanEncoder) generateChains(top *levelInfo, list []literalNode) {
n := len(list)
list = list[0 : n+1]
list[n] = maxNode()
l := top
for {
if l.nextPairFreq == math.MaxInt32 && l.nextCharFreq == math.MaxInt32 {
// We've run out of both leafs and pairs.
// End all calculations for this level.
// To m sure we never come back to this level or any lower level,
// set nextPairFreq impossibly large.
l.lastChain = nil
l.needed = 0
l = l.up
l.nextPairFreq = math.MaxInt32
continue
}
prevFreq := l.lastChain.freq
if l.nextCharFreq < l.nextPairFreq {
// The next item on this row is a leaf node.
n := l.lastChain.leafCount + 1
l.lastChain = &chain{l.nextCharFreq, n, l.lastChain.up}
l.nextCharFreq = list[n].freq
} else {
// The next item on this row is a pair from the previous row.
// nextPairFreq isn't valid until we generate two
// more values in the level below
l.lastChain = &chain{l.nextPairFreq, l.lastChain.leafCount, l.down.lastChain}
l.down.needed = 2
}
if l.needed--; l.needed == 0 {
// We've done everything we need to do for this level.
// Continue calculating one level up. Fill in nextPairFreq
// of that level with the sum of the two nodes we've just calculated on
// this level.
up := l.up
if up == nil {
// All done!
return
}
up.nextPairFreq = prevFreq + l.lastChain.freq
l = up
} else {
// If we stole from below, move down temporarily to replenish it.
for l.down.needed > 0 {
l = l.down
}
}
}
}
// Return the number of literals assigned to each bit size in the Huffman encoding // Return the number of literals assigned to each bit size in the Huffman encoding
// //
// This method is only called when list.length >= 3 // This method is only called when list.length >= 3
......
...@@ -81,7 +81,7 @@ var lzwTests = []lzwTest{ ...@@ -81,7 +81,7 @@ var lzwTests = []lzwTest{
} }
func TestReader(t *testing.T) { func TestReader(t *testing.T) {
b := bytes.NewBuffer(nil) var b bytes.Buffer
for _, tt := range lzwTests { for _, tt := range lzwTests {
d := strings.Split(tt.desc, ";") d := strings.Split(tt.desc, ";")
var order Order var order Order
...@@ -97,7 +97,7 @@ func TestReader(t *testing.T) { ...@@ -97,7 +97,7 @@ func TestReader(t *testing.T) {
rc := NewReader(strings.NewReader(tt.compressed), order, litWidth) rc := NewReader(strings.NewReader(tt.compressed), order, litWidth)
defer rc.Close() defer rc.Close()
b.Reset() b.Reset()
n, err := io.Copy(b, rc) n, err := io.Copy(&b, rc)
if err != nil { if err != nil {
if err != tt.err { if err != tt.err {
t.Errorf("%s: io.Copy: %v want %v", tt.desc, err, tt.err) t.Errorf("%s: io.Copy: %v want %v", tt.desc, err, tt.err)
...@@ -116,7 +116,7 @@ func benchmarkDecoder(b *testing.B, n int) { ...@@ -116,7 +116,7 @@ func benchmarkDecoder(b *testing.B, n int) {
b.SetBytes(int64(n)) b.SetBytes(int64(n))
buf0, _ := ioutil.ReadFile("../testdata/e.txt") buf0, _ := ioutil.ReadFile("../testdata/e.txt")
buf0 = buf0[:10000] buf0 = buf0[:10000]
compressed := bytes.NewBuffer(nil) compressed := new(bytes.Buffer)
w := NewWriter(compressed, LSB, 8) w := NewWriter(compressed, LSB, 8)
for i := 0; i < n; i += len(buf0) { for i := 0; i < n; i += len(buf0) {
io.Copy(w, bytes.NewBuffer(buf0)) io.Copy(w, bytes.NewBuffer(buf0))
......
...@@ -124,8 +124,8 @@ func TestWriterDict(t *testing.T) { ...@@ -124,8 +124,8 @@ func TestWriterDict(t *testing.T) {
func TestWriterDictIsUsed(t *testing.T) { func TestWriterDictIsUsed(t *testing.T) {
var input = []byte("Lorem ipsum dolor sit amet, consectetur adipisicing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") var input = []byte("Lorem ipsum dolor sit amet, consectetur adipisicing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
buf := bytes.NewBuffer(nil) var buf bytes.Buffer
compressor, err := NewWriterDict(buf, BestCompression, input) compressor, err := NewWriterDict(&buf, BestCompression, input)
if err != nil { if err != nil {
t.Errorf("error in NewWriterDict: %s", err) t.Errorf("error in NewWriterDict: %s", err)
return return
......
...@@ -56,7 +56,7 @@ type cbcDecrypter cbc ...@@ -56,7 +56,7 @@ type cbcDecrypter cbc
// NewCBCDecrypter returns a BlockMode which decrypts in cipher block chaining // NewCBCDecrypter returns a BlockMode which decrypts in cipher block chaining
// mode, using the given Block. The length of iv must be the same as the // mode, using the given Block. The length of iv must be the same as the
// Block's block size as must match the iv used to encrypt the data. // Block's block size and must match the iv used to encrypt the data.
func NewCBCDecrypter(b Block, iv []byte) BlockMode { func NewCBCDecrypter(b Block, iv []byte) BlockMode {
return (*cbcDecrypter)(newCBC(b, iv)) return (*cbcDecrypter)(newCBC(b, iv))
} }
......
...@@ -9,7 +9,7 @@ import "io" ...@@ -9,7 +9,7 @@ import "io"
// The Stream* objects are so simple that all their members are public. Users // The Stream* objects are so simple that all their members are public. Users
// can create them themselves. // can create them themselves.
// StreamReader wraps a Stream into an io.Reader. It simply calls XORKeyStream // StreamReader wraps a Stream into an io.Reader. It calls XORKeyStream
// to process each slice of data which passes through. // to process each slice of data which passes through.
type StreamReader struct { type StreamReader struct {
S Stream S Stream
...@@ -22,7 +22,7 @@ func (r StreamReader) Read(dst []byte) (n int, err error) { ...@@ -22,7 +22,7 @@ func (r StreamReader) Read(dst []byte) (n int, err error) {
return return
} }
// StreamWriter wraps a Stream into an io.Writer. It simply calls XORKeyStream // StreamWriter wraps a Stream into an io.Writer. It calls XORKeyStream
// to process each slice of data which passes through. If any Write call // to process each slice of data which passes through. If any Write call
// returns short then the StreamWriter is out of sync and must be discarded. // returns short then the StreamWriter is out of sync and must be discarded.
type StreamWriter struct { type StreamWriter struct {
......
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// OpenPGP CFB Mode. http://tools.ietf.org/html/rfc4880#section-13.9
package cipher
type ocfbEncrypter struct {
b Block
fre []byte
outUsed int
}
// An OCFBResyncOption determines if the "resynchronization step" of OCFB is
// performed.
type OCFBResyncOption bool
const (
OCFBResync OCFBResyncOption = true
OCFBNoResync OCFBResyncOption = false
)
// NewOCFBEncrypter returns a Stream which encrypts data with OpenPGP's cipher
// feedback mode using the given Block, and an initial amount of ciphertext.
// randData must be random bytes and be the same length as the Block's block
// size. Resync determines if the "resynchronization step" from RFC 4880, 13.9
// step 7 is performed. Different parts of OpenPGP vary on this point.
func NewOCFBEncrypter(block Block, randData []byte, resync OCFBResyncOption) (Stream, []byte) {
blockSize := block.BlockSize()
if len(randData) != blockSize {
return nil, nil
}
x := &ocfbEncrypter{
b: block,
fre: make([]byte, blockSize),
outUsed: 0,
}
prefix := make([]byte, blockSize+2)
block.Encrypt(x.fre, x.fre)
for i := 0; i < blockSize; i++ {
prefix[i] = randData[i] ^ x.fre[i]
}
block.Encrypt(x.fre, prefix[:blockSize])
prefix[blockSize] = x.fre[0] ^ randData[blockSize-2]
prefix[blockSize+1] = x.fre[1] ^ randData[blockSize-1]
if resync {
block.Encrypt(x.fre, prefix[2:])
} else {
x.fre[0] = prefix[blockSize]
x.fre[1] = prefix[blockSize+1]
x.outUsed = 2
}
return x, prefix
}
func (x *ocfbEncrypter) XORKeyStream(dst, src []byte) {
for i := 0; i < len(src); i++ {
if x.outUsed == len(x.fre) {
x.b.Encrypt(x.fre, x.fre)
x.outUsed = 0
}
x.fre[x.outUsed] ^= src[i]
dst[i] = x.fre[x.outUsed]
x.outUsed++
}
}
type ocfbDecrypter struct {
b Block
fre []byte
outUsed int
}
// NewOCFBDecrypter returns a Stream which decrypts data with OpenPGP's cipher
// feedback mode using the given Block. Prefix must be the first blockSize + 2
// bytes of the ciphertext, where blockSize is the Block's block size. If an
// incorrect key is detected then nil is returned. On successful exit,
// blockSize+2 bytes of decrypted data are written into prefix. Resync
// determines if the "resynchronization step" from RFC 4880, 13.9 step 7 is
// performed. Different parts of OpenPGP vary on this point.
func NewOCFBDecrypter(block Block, prefix []byte, resync OCFBResyncOption) Stream {
blockSize := block.BlockSize()
if len(prefix) != blockSize+2 {
return nil
}
x := &ocfbDecrypter{
b: block,
fre: make([]byte, blockSize),
outUsed: 0,
}
prefixCopy := make([]byte, len(prefix))
copy(prefixCopy, prefix)
block.Encrypt(x.fre, x.fre)
for i := 0; i < blockSize; i++ {
prefixCopy[i] ^= x.fre[i]
}
block.Encrypt(x.fre, prefix[:blockSize])
prefixCopy[blockSize] ^= x.fre[0]
prefixCopy[blockSize+1] ^= x.fre[1]
if prefixCopy[blockSize-2] != prefixCopy[blockSize] ||
prefixCopy[blockSize-1] != prefixCopy[blockSize+1] {
return nil
}
if resync {
block.Encrypt(x.fre, prefix[2:])
} else {
x.fre[0] = prefix[blockSize]
x.fre[1] = prefix[blockSize+1]
x.outUsed = 2
}
copy(prefix, prefixCopy)
return x
}
func (x *ocfbDecrypter) XORKeyStream(dst, src []byte) {
for i := 0; i < len(src); i++ {
if x.outUsed == len(x.fre) {
x.b.Encrypt(x.fre, x.fre)
x.outUsed = 0
}
c := src[i]
dst[i] = x.fre[x.outUsed] ^ src[i]
x.fre[x.outUsed] = c
x.outUsed++
}
}
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cipher
import (
"bytes"
"crypto/aes"
"crypto/rand"
"testing"
)
func testOCFB(t *testing.T, resync OCFBResyncOption) {
block, err := aes.NewCipher(commonKey128)
if err != nil {
t.Error(err)
return
}
plaintext := []byte("this is the plaintext, which is long enough to span several blocks.")
randData := make([]byte, block.BlockSize())
rand.Reader.Read(randData)
ocfb, prefix := NewOCFBEncrypter(block, randData, resync)
ciphertext := make([]byte, len(plaintext))
ocfb.XORKeyStream(ciphertext, plaintext)
ocfbdec := NewOCFBDecrypter(block, prefix, resync)
if ocfbdec == nil {
t.Errorf("NewOCFBDecrypter failed (resync: %t)", resync)
return
}
plaintextCopy := make([]byte, len(plaintext))
ocfbdec.XORKeyStream(plaintextCopy, ciphertext)
if !bytes.Equal(plaintextCopy, plaintext) {
t.Errorf("got: %x, want: %x (resync: %t)", plaintextCopy, plaintext, resync)
}
}
func TestOCFB(t *testing.T) {
testOCFB(t, OCFBNoResync)
testOCFB(t, OCFBResync)
}
...@@ -14,15 +14,15 @@ import ( ...@@ -14,15 +14,15 @@ import (
type Hash uint type Hash uint
const ( const (
MD4 Hash = 1 + iota // in package crypto/md4 MD4 Hash = 1 + iota // import code.google.com/p/go.crypto/md4
MD5 // in package crypto/md5 MD5 // import crypto/md5
SHA1 // in package crypto/sha1 SHA1 // import crypto/sha1
SHA224 // in package crypto/sha256 SHA224 // import crypto/sha256
SHA256 // in package crypto/sha256 SHA256 // import crypto/sha256
SHA384 // in package crypto/sha512 SHA384 // import crypto/sha512
SHA512 // in package crypto/sha512 SHA512 // import crypto/sha512
MD5SHA1 // no implementation; MD5+SHA1 used for TLS RSA MD5SHA1 // no implementation; MD5+SHA1 used for TLS RSA
RIPEMD160 // in package crypto/ripemd160 RIPEMD160 // import code.google.com/p/go.crypto/ripemd160
maxHash maxHash
) )
...@@ -50,8 +50,8 @@ func (h Hash) Size() int { ...@@ -50,8 +50,8 @@ func (h Hash) Size() int {
var hashes = make([]func() hash.Hash, maxHash) var hashes = make([]func() hash.Hash, maxHash)
// New returns a new hash.Hash calculating the given hash function. If the // New returns a new hash.Hash calculating the given hash function. New panics
// hash function is not linked into the binary, New returns nil. // if the hash function is not linked into the binary.
func (h Hash) New() hash.Hash { func (h Hash) New() hash.Hash {
if h > 0 && h < maxHash { if h > 0 && h < maxHash {
f := hashes[h] f := hashes[h]
...@@ -59,7 +59,12 @@ func (h Hash) New() hash.Hash { ...@@ -59,7 +59,12 @@ func (h Hash) New() hash.Hash {
return f() return f()
} }
} }
return nil panic("crypto: requested hash function is unavailable")
}
// Available reports whether the given hash function is linked into the binary.
func (h Hash) Available() bool {
return h < maxHash && hashes[h] != nil
} }
// RegisterHash registers a function that returns a new instance of the given // RegisterHash registers a function that returns a new instance of the given
......
...@@ -34,13 +34,13 @@ func NewCipher(key []byte) (*Cipher, error) { ...@@ -34,13 +34,13 @@ func NewCipher(key []byte) (*Cipher, error) {
// BlockSize returns the DES block size, 8 bytes. // BlockSize returns the DES block size, 8 bytes.
func (c *Cipher) BlockSize() int { return BlockSize } func (c *Cipher) BlockSize() int { return BlockSize }
// Encrypts the 8-byte buffer src and stores the result in dst. // Encrypt encrypts the 8-byte buffer src and stores the result in dst.
// Note that for amounts of data larger than a block, // Note that for amounts of data larger than a block,
// it is not safe to just call Encrypt on successive blocks; // it is not safe to just call Encrypt on successive blocks;
// instead, use an encryption mode like CBC (see crypto/cipher/cbc.go). // instead, use an encryption mode like CBC (see crypto/cipher/cbc.go).
func (c *Cipher) Encrypt(dst, src []byte) { encryptBlock(c.subkeys[:], dst, src) } func (c *Cipher) Encrypt(dst, src []byte) { encryptBlock(c.subkeys[:], dst, src) }
// Decrypts the 8-byte buffer src and stores the result in dst. // Decrypt decrypts the 8-byte buffer src and stores the result in dst.
func (c *Cipher) Decrypt(dst, src []byte) { decryptBlock(c.subkeys[:], dst, src) } func (c *Cipher) Decrypt(dst, src []byte) { decryptBlock(c.subkeys[:], dst, src) }
// Reset zeros the key data, so that it will no longer // Reset zeros the key data, so that it will no longer
......
...@@ -102,7 +102,7 @@ GeneratePrimes: ...@@ -102,7 +102,7 @@ GeneratePrimes:
qBytes[0] |= 0x80 qBytes[0] |= 0x80
q.SetBytes(qBytes) q.SetBytes(qBytes)
if !big.ProbablyPrime(q, numMRTests) { if !q.ProbablyPrime(numMRTests) {
continue continue
} }
...@@ -123,7 +123,7 @@ GeneratePrimes: ...@@ -123,7 +123,7 @@ GeneratePrimes:
continue continue
} }
if !big.ProbablyPrime(p, numMRTests) { if !p.ProbablyPrime(numMRTests) {
continue continue
} }
......
...@@ -6,6 +6,7 @@ package elliptic ...@@ -6,6 +6,7 @@ package elliptic
import ( import (
"crypto/rand" "crypto/rand"
"encoding/hex"
"fmt" "fmt"
"math/big" "math/big"
"testing" "testing"
...@@ -350,3 +351,13 @@ func TestMarshal(t *testing.T) { ...@@ -350,3 +351,13 @@ func TestMarshal(t *testing.T) {
return return
} }
} }
func TestP224Overflow(t *testing.T) {
// This tests for a specific bug in the P224 implementation.
p224 := P224()
pointData, _ := hex.DecodeString("049B535B45FB0A2072398A6831834624C7E32CCFD5A4B933BCEAF77F1DD945E08BBE5178F5EDF5E733388F196D2A631D2E075BB16CBFEEA15B")
x, y := Unmarshal(p224, pointData)
if !p224.IsOnCurve(x, y) {
t.Error("P224 failed to validate a correct point")
}
}
...@@ -225,7 +225,7 @@ func p224ReduceLarge(out *p224FieldElement, in *p224LargeFieldElement) { ...@@ -225,7 +225,7 @@ func p224ReduceLarge(out *p224FieldElement, in *p224LargeFieldElement) {
in[i] += p224ZeroModP63[i] in[i] += p224ZeroModP63[i]
} }
// Elimintate the coefficients at 2**224 and greater. // Eliminate the coefficients at 2**224 and greater.
for i := 14; i >= 8; i-- { for i := 14; i >= 8; i-- {
in[i-8] -= in[i] in[i-8] -= in[i]
in[i-5] += (in[i] & 0xffff) << 12 in[i-5] += (in[i] & 0xffff) << 12
...@@ -288,7 +288,7 @@ func p224Reduce(a *p224FieldElement) { ...@@ -288,7 +288,7 @@ func p224Reduce(a *p224FieldElement) {
a[0] += mask & (1 << 28) a[0] += mask & (1 << 28)
} }
// p224Invert calcuates *out = in**-1 by computing in**(2**224 - 2**96 - 1), // p224Invert calculates *out = in**-1 by computing in**(2**224 - 2**96 - 1),
// i.e. Fermat's little theorem. // i.e. Fermat's little theorem.
func p224Invert(out, in *p224FieldElement) { func p224Invert(out, in *p224FieldElement) {
var f1, f2, f3, f4 p224FieldElement var f1, f2, f3, f4 p224FieldElement
...@@ -341,7 +341,7 @@ func p224Invert(out, in *p224FieldElement) { ...@@ -341,7 +341,7 @@ func p224Invert(out, in *p224FieldElement) {
// p224Contract converts a FieldElement to its unique, minimal form. // p224Contract converts a FieldElement to its unique, minimal form.
// //
// On entry, in[i] < 2**32 // On entry, in[i] < 2**29
// On exit, in[i] < 2**28 // On exit, in[i] < 2**28
func p224Contract(out, in *p224FieldElement) { func p224Contract(out, in *p224FieldElement) {
copy(out[:], in[:]) copy(out[:], in[:])
...@@ -365,6 +365,39 @@ func p224Contract(out, in *p224FieldElement) { ...@@ -365,6 +365,39 @@ func p224Contract(out, in *p224FieldElement) {
out[i+1] -= 1 & mask out[i+1] -= 1 & mask
} }
// We might have pushed out[3] over 2**28 so we perform another, partial,
// carry chain.
for i := 3; i < 7; i++ {
out[i+1] += out[i] >> 28
out[i] &= bottom28Bits
}
top = out[7] >> 28
out[7] &= bottom28Bits
// Eliminate top while maintaining the same value mod p.
out[0] -= top
out[3] += top << 12
// There are two cases to consider for out[3]:
// 1) The first time that we eliminated top, we didn't push out[3] over
// 2**28. In this case, the partial carry chain didn't change any values
// and top is zero.
// 2) We did push out[3] over 2**28 the first time that we eliminated top.
// The first value of top was in [0..16), therefore, prior to eliminating
// the first top, 0xfff1000 <= out[3] <= 0xfffffff. Therefore, after
// overflowing and being reduced by the second carry chain, out[3] <=
// 0xf000. Thus it cannot have overflowed when we eliminated top for the
// second time.
// Again, we may just have made out[0] negative, so do the same carry down.
// As before, if we made out[0] negative then we know that out[3] is
// sufficiently positive.
for i := 0; i < 3; i++ {
mask := uint32(int32(out[i]) >> 31)
out[i] += (1 << 28) & mask
out[i+1] -= 1 & mask
}
// Now we see if the value is >= p and, if so, subtract p. // Now we see if the value is >= p and, if so, subtract p.
// First we build a mask from the top four limbs, which must all be // First we build a mask from the top four limbs, which must all be
......
...@@ -39,7 +39,7 @@ func Prime(rand io.Reader, bits int) (p *big.Int, err error) { ...@@ -39,7 +39,7 @@ func Prime(rand io.Reader, bits int) (p *big.Int, err error) {
bytes[len(bytes)-1] |= 1 bytes[len(bytes)-1] |= 1
p.SetBytes(bytes) p.SetBytes(bytes)
if big.ProbablyPrime(p, 20) { if p.ProbablyPrime(20) {
return return
} }
} }
......
...@@ -65,7 +65,7 @@ func DecryptPKCS1v15(rand io.Reader, priv *PrivateKey, ciphertext []byte) (out [ ...@@ -65,7 +65,7 @@ func DecryptPKCS1v15(rand io.Reader, priv *PrivateKey, ciphertext []byte) (out [
// about the plaintext. // about the plaintext.
// See ``Chosen Ciphertext Attacks Against Protocols Based on the RSA // See ``Chosen Ciphertext Attacks Against Protocols Based on the RSA
// Encryption Standard PKCS #1'', Daniel Bleichenbacher, Advances in Cryptology // Encryption Standard PKCS #1'', Daniel Bleichenbacher, Advances in Cryptology
// (Crypto '98), // (Crypto '98).
func DecryptPKCS1v15SessionKey(rand io.Reader, priv *PrivateKey, ciphertext []byte, key []byte) (err error) { func DecryptPKCS1v15SessionKey(rand io.Reader, priv *PrivateKey, ciphertext []byte, key []byte) (err error) {
k := (priv.N.BitLen() + 7) / 8 k := (priv.N.BitLen() + 7) / 8
if k-(len(key)+3+8) < 0 { if k-(len(key)+3+8) < 0 {
......
...@@ -62,7 +62,7 @@ func (priv *PrivateKey) Validate() error { ...@@ -62,7 +62,7 @@ func (priv *PrivateKey) Validate() error {
// ProbablyPrime are deterministic, given the candidate number, it's // ProbablyPrime are deterministic, given the candidate number, it's
// easy for an attack to generate composites that pass this test. // easy for an attack to generate composites that pass this test.
for _, prime := range priv.Primes { for _, prime := range priv.Primes {
if !big.ProbablyPrime(prime, 20) { if !prime.ProbablyPrime(20) {
return errors.New("prime factor is composite") return errors.New("prime factor is composite")
} }
} }
...@@ -85,7 +85,7 @@ func (priv *PrivateKey) Validate() error { ...@@ -85,7 +85,7 @@ func (priv *PrivateKey) Validate() error {
gcd := new(big.Int) gcd := new(big.Int)
x := new(big.Int) x := new(big.Int)
y := new(big.Int) y := new(big.Int)
big.GcdInt(gcd, x, y, totient, e) gcd.GCD(x, y, totient, e)
if gcd.Cmp(bigOne) != 0 { if gcd.Cmp(bigOne) != 0 {
return errors.New("invalid public exponent E") return errors.New("invalid public exponent E")
} }
...@@ -156,7 +156,7 @@ NextSetOfPrimes: ...@@ -156,7 +156,7 @@ NextSetOfPrimes:
priv.D = new(big.Int) priv.D = new(big.Int)
y := new(big.Int) y := new(big.Int)
e := big.NewInt(int64(priv.E)) e := big.NewInt(int64(priv.E))
big.GcdInt(g, priv.D, y, e, totient) g.GCD(priv.D, y, e, totient)
if g.Cmp(bigOne) == 0 { if g.Cmp(bigOne) == 0 {
priv.D.Add(priv.D, totient) priv.D.Add(priv.D, totient)
...@@ -284,7 +284,7 @@ func modInverse(a, n *big.Int) (ia *big.Int, ok bool) { ...@@ -284,7 +284,7 @@ func modInverse(a, n *big.Int) (ia *big.Int, ok bool) {
g := new(big.Int) g := new(big.Int)
x := new(big.Int) x := new(big.Int)
y := new(big.Int) y := new(big.Int)
big.GcdInt(g, x, y, a, n) g.GCD(x, y, a, n)
if g.Cmp(bigOne) != 0 { if g.Cmp(bigOne) != 0 {
// In this case, a and n aren't coprime and we cannot calculate // In this case, a and n aren't coprime and we cannot calculate
// the inverse. This happens because the values of n are nearly // the inverse. This happens because the values of n are nearly
...@@ -412,7 +412,7 @@ func decrypt(random io.Reader, priv *PrivateKey, c *big.Int) (m *big.Int, err er ...@@ -412,7 +412,7 @@ func decrypt(random io.Reader, priv *PrivateKey, c *big.Int) (m *big.Int, err er
} }
// DecryptOAEP decrypts ciphertext using RSA-OAEP. // DecryptOAEP decrypts ciphertext using RSA-OAEP.
// If rand != nil, DecryptOAEP uses RSA blinding to avoid timing side-channel attacks. // If random != nil, DecryptOAEP uses RSA blinding to avoid timing side-channel attacks.
func DecryptOAEP(hash hash.Hash, random io.Reader, priv *PrivateKey, ciphertext []byte, label []byte) (msg []byte, err error) { func DecryptOAEP(hash hash.Hash, random io.Reader, priv *PrivateKey, ciphertext []byte, label []byte) (msg []byte, err error) {
k := (priv.N.BitLen() + 7) / 8 k := (priv.N.BitLen() + 7) / 8
if len(ciphertext) > k || if len(ciphertext) > k ||
......
...@@ -59,7 +59,8 @@ func (c *Conn) clientHandshake() error { ...@@ -59,7 +59,8 @@ func (c *Conn) clientHandshake() error {
finishedHash.Write(serverHello.marshal()) finishedHash.Write(serverHello.marshal())
vers, ok := mutualVersion(serverHello.vers) vers, ok := mutualVersion(serverHello.vers)
if !ok { if !ok || vers < versionTLS10 {
// TLS 1.0 is the minimum version supported as a client.
return c.sendAlert(alertProtocolVersion) return c.sendAlert(alertProtocolVersion)
} }
c.vers = vers c.vers = vers
......
...@@ -33,16 +33,16 @@ func Client(conn net.Conn, config *Config) *Conn { ...@@ -33,16 +33,16 @@ func Client(conn net.Conn, config *Config) *Conn {
return &Conn{conn: conn, config: config, isClient: true} return &Conn{conn: conn, config: config, isClient: true}
} }
// A Listener implements a network listener (net.Listener) for TLS connections. // A listener implements a network listener (net.Listener) for TLS connections.
type Listener struct { type listener struct {
listener net.Listener net.Listener
config *Config config *Config
} }
// Accept waits for and returns the next incoming TLS connection. // Accept waits for and returns the next incoming TLS connection.
// The returned connection c is a *tls.Conn. // The returned connection c is a *tls.Conn.
func (l *Listener) Accept() (c net.Conn, err error) { func (l *listener) Accept() (c net.Conn, err error) {
c, err = l.listener.Accept() c, err = l.Listener.Accept()
if err != nil { if err != nil {
return return
} }
...@@ -50,28 +50,22 @@ func (l *Listener) Accept() (c net.Conn, err error) { ...@@ -50,28 +50,22 @@ func (l *Listener) Accept() (c net.Conn, err error) {
return return
} }
// Close closes the listener.
func (l *Listener) Close() error { return l.listener.Close() }
// Addr returns the listener's network address.
func (l *Listener) Addr() net.Addr { return l.listener.Addr() }
// NewListener creates a Listener which accepts connections from an inner // NewListener creates a Listener which accepts connections from an inner
// Listener and wraps each connection with Server. // Listener and wraps each connection with Server.
// The configuration config must be non-nil and must have // The configuration config must be non-nil and must have
// at least one certificate. // at least one certificate.
func NewListener(listener net.Listener, config *Config) (l *Listener) { func NewListener(inner net.Listener, config *Config) net.Listener {
l = new(Listener) l := new(listener)
l.listener = listener l.Listener = inner
l.config = config l.config = config
return return l
} }
// Listen creates a TLS listener accepting connections on the // Listen creates a TLS listener accepting connections on the
// given network address using net.Listen. // given network address using net.Listen.
// The configuration config must be non-nil and must have // The configuration config must be non-nil and must have
// at least one certificate. // at least one certificate.
func Listen(network, laddr string, config *Config) (*Listener, error) { func Listen(network, laddr string, config *Config) (net.Listener, error) {
if config == nil || len(config.Certificates) == 0 { if config == nil || len(config.Certificates) == 0 {
return nil, errors.New("tls.Listen: no certificates in configuration") return nil, errors.New("tls.Listen: no certificates in configuration")
} }
......
...@@ -40,7 +40,7 @@ func ParsePKCS1PrivateKey(der []byte) (key *rsa.PrivateKey, err error) { ...@@ -40,7 +40,7 @@ func ParsePKCS1PrivateKey(der []byte) (key *rsa.PrivateKey, err error) {
var priv pkcs1PrivateKey var priv pkcs1PrivateKey
rest, err := asn1.Unmarshal(der, &priv) rest, err := asn1.Unmarshal(der, &priv)
if len(rest) > 0 { if len(rest) > 0 {
err = asn1.SyntaxError{"trailing data"} err = asn1.SyntaxError{Msg: "trailing data"}
return return
} }
if err != nil { if err != nil {
......
...@@ -23,6 +23,8 @@ type RDNSequence []RelativeDistinguishedNameSET ...@@ -23,6 +23,8 @@ type RDNSequence []RelativeDistinguishedNameSET
type RelativeDistinguishedNameSET []AttributeTypeAndValue type RelativeDistinguishedNameSET []AttributeTypeAndValue
// AttributeTypeAndValue mirrors the ASN.1 structure of the same name in
// http://tools.ietf.org/html/rfc5280#section-4.1.2.4
type AttributeTypeAndValue struct { type AttributeTypeAndValue struct {
Type asn1.ObjectIdentifier Type asn1.ObjectIdentifier
Value interface{} Value interface{}
......
...@@ -7,6 +7,7 @@ package x509 ...@@ -7,6 +7,7 @@ package x509
import ( import (
"strings" "strings"
"time" "time"
"unicode/utf8"
) )
type InvalidReason int type InvalidReason int
...@@ -225,17 +226,51 @@ func matchHostnames(pattern, host string) bool { ...@@ -225,17 +226,51 @@ func matchHostnames(pattern, host string) bool {
return true return true
} }
// toLowerCaseASCII returns a lower-case version of in. See RFC 6125 6.4.1. We use
// an explicitly ASCII function to avoid any sharp corners resulting from
// performing Unicode operations on DNS labels.
func toLowerCaseASCII(in string) string {
// If the string is already lower-case then there's nothing to do.
isAlreadyLowerCase := true
for _, c := range in {
if c == utf8.RuneError {
// If we get a UTF-8 error then there might be
// upper-case ASCII bytes in the invalid sequence.
isAlreadyLowerCase = false
break
}
if 'A' <= c && c <= 'Z' {
isAlreadyLowerCase = false
break
}
}
if isAlreadyLowerCase {
return in
}
out := []byte(in)
for i, c := range out {
if 'A' <= c && c <= 'Z' {
out[i] += 'a' - 'A'
}
}
return string(out)
}
// VerifyHostname returns nil if c is a valid certificate for the named host. // VerifyHostname returns nil if c is a valid certificate for the named host.
// Otherwise it returns an error describing the mismatch. // Otherwise it returns an error describing the mismatch.
func (c *Certificate) VerifyHostname(h string) error { func (c *Certificate) VerifyHostname(h string) error {
lowered := toLowerCaseASCII(h)
if len(c.DNSNames) > 0 { if len(c.DNSNames) > 0 {
for _, match := range c.DNSNames { for _, match := range c.DNSNames {
if matchHostnames(match, h) { if matchHostnames(toLowerCaseASCII(match), lowered) {
return nil return nil
} }
} }
// If Subject Alt Name is given, we ignore the common name. // If Subject Alt Name is given, we ignore the common name.
} else if matchHostnames(c.Subject.CommonName, h) { } else if matchHostnames(toLowerCaseASCII(c.Subject.CommonName), lowered) {
return nil return nil
} }
......
...@@ -42,6 +42,17 @@ var verifyTests = []verifyTest{ ...@@ -42,6 +42,17 @@ var verifyTests = []verifyTest{
intermediates: []string{thawteIntermediate}, intermediates: []string{thawteIntermediate},
roots: []string{verisignRoot}, roots: []string{verisignRoot},
currentTime: 1302726541, currentTime: 1302726541,
dnsName: "WwW.GooGLE.coM",
expectedChains: [][]string{
{"Google", "Thawte", "VeriSign"},
},
},
{
leaf: googleLeaf,
intermediates: []string{thawteIntermediate},
roots: []string{verisignRoot},
currentTime: 1302726541,
dnsName: "www.example.com", dnsName: "www.example.com",
errorCallback: expectHostnameError, errorCallback: expectHostnameError,
......
...@@ -592,7 +592,7 @@ func parseCertificate(in *certificate) (*Certificate, error) { ...@@ -592,7 +592,7 @@ func parseCertificate(in *certificate) (*Certificate, error) {
return nil, err return nil, err
} }
if !seq.IsCompound || seq.Tag != 16 || seq.Class != 0 { if !seq.IsCompound || seq.Tag != 16 || seq.Class != 0 {
return nil, asn1.StructuralError{"bad SAN sequence"} return nil, asn1.StructuralError{Msg: "bad SAN sequence"}
} }
parsedName := false parsedName := false
...@@ -744,7 +744,7 @@ func ParseCertificate(asn1Data []byte) (*Certificate, error) { ...@@ -744,7 +744,7 @@ func ParseCertificate(asn1Data []byte) (*Certificate, error) {
return nil, err return nil, err
} }
if len(rest) > 0 { if len(rest) > 0 {
return nil, asn1.SyntaxError{"trailing data"} return nil, asn1.SyntaxError{Msg: "trailing data"}
} }
return parseCertificate(&cert) return parseCertificate(&cert)
......
...@@ -49,6 +49,11 @@ func convertAssign(dest, src interface{}) error { ...@@ -49,6 +49,11 @@ func convertAssign(dest, src interface{}) error {
case *string: case *string:
*d = string(s) *d = string(s)
return nil return nil
case *interface{}:
bcopy := make([]byte, len(s))
copy(bcopy, s)
*d = bcopy
return nil
case *[]byte: case *[]byte:
*d = s *d = s
return nil return nil
...@@ -80,6 +85,9 @@ func convertAssign(dest, src interface{}) error { ...@@ -80,6 +85,9 @@ func convertAssign(dest, src interface{}) error {
*d = bv.(bool) *d = bv.(bool)
} }
return err return err
case *interface{}:
*d = src
return nil
} }
if scanner, ok := dest.(ScannerInto); ok { if scanner, ok := dest.(ScannerInto); ok {
......
...@@ -18,14 +18,15 @@ type conversionTest struct { ...@@ -18,14 +18,15 @@ type conversionTest struct {
s, d interface{} // source and destination s, d interface{} // source and destination
// following are used if they're non-zero // following are used if they're non-zero
wantint int64 wantint int64
wantuint uint64 wantuint uint64
wantstr string wantstr string
wantf32 float32 wantf32 float32
wantf64 float64 wantf64 float64
wanttime time.Time wanttime time.Time
wantbool bool // used if d is of type *bool wantbool bool // used if d is of type *bool
wanterr string wanterr string
wantiface interface{}
} }
// Target variables for scanning into. // Target variables for scanning into.
...@@ -41,6 +42,7 @@ var ( ...@@ -41,6 +42,7 @@ var (
scanf32 float32 scanf32 float32
scanf64 float64 scanf64 float64
scantime time.Time scantime time.Time
scaniface interface{}
) )
var conversionTests = []conversionTest{ var conversionTests = []conversionTest{
...@@ -95,6 +97,14 @@ var conversionTests = []conversionTest{ ...@@ -95,6 +97,14 @@ var conversionTests = []conversionTest{
{s: float64(1.5), d: &scanf32, wantf32: float32(1.5)}, {s: float64(1.5), d: &scanf32, wantf32: float32(1.5)},
{s: "1.5", d: &scanf32, wantf32: float32(1.5)}, {s: "1.5", d: &scanf32, wantf32: float32(1.5)},
{s: "1.5", d: &scanf64, wantf64: float64(1.5)}, {s: "1.5", d: &scanf64, wantf64: float64(1.5)},
// To interface{}
{s: float64(1.5), d: &scaniface, wantiface: float64(1.5)},
{s: int64(1), d: &scaniface, wantiface: int64(1)},
{s: "str", d: &scaniface, wantiface: "str"},
{s: []byte("byteslice"), d: &scaniface, wantiface: []byte("byteslice")},
{s: true, d: &scaniface, wantiface: true},
{s: nil, d: &scaniface},
} }
func intValue(intptr interface{}) int64 { func intValue(intptr interface{}) int64 {
...@@ -152,6 +162,18 @@ func TestConversions(t *testing.T) { ...@@ -152,6 +162,18 @@ func TestConversions(t *testing.T) {
if !ct.wanttime.IsZero() && !ct.wanttime.Equal(timeValue(ct.d)) { if !ct.wanttime.IsZero() && !ct.wanttime.Equal(timeValue(ct.d)) {
errf("want time %v, got %v", ct.wanttime, timeValue(ct.d)) errf("want time %v, got %v", ct.wanttime, timeValue(ct.d))
} }
if ifptr, ok := ct.d.(*interface{}); ok {
if !reflect.DeepEqual(ct.wantiface, scaniface) {
errf("want interface %#v, got %#v", ct.wantiface, scaniface)
continue
}
if srcBytes, ok := ct.s.([]byte); ok {
dstBytes := (*ifptr).([]byte)
if &dstBytes[0] == &srcBytes[0] {
errf("copy into interface{} didn't copy []byte data")
}
}
}
} }
} }
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
// Package driver defines interfaces to be implemented by database // Package driver defines interfaces to be implemented by database
// drivers as used by package sql. // drivers as used by package sql.
// //
// Code simply using databases should use package sql. // Most code should use package sql.
// //
// Drivers only need to be aware of a subset of Go's types. The sql package // Drivers only need to be aware of a subset of Go's types. The sql package
// will convert all types into one of the following: // will convert all types into one of the following:
......
...@@ -586,25 +586,25 @@ func converterForType(typ string) driver.ValueConverter { ...@@ -586,25 +586,25 @@ func converterForType(typ string) driver.ValueConverter {
case "bool": case "bool":
return driver.Bool return driver.Bool
case "nullbool": case "nullbool":
return driver.Null{driver.Bool} return driver.Null{Converter: driver.Bool}
case "int32": case "int32":
return driver.Int32 return driver.Int32
case "string": case "string":
return driver.NotNull{driver.String} return driver.NotNull{Converter: driver.String}
case "nullstring": case "nullstring":
return driver.Null{driver.String} return driver.Null{Converter: driver.String}
case "int64": case "int64":
// TODO(coopernurse): add type-specific converter // TODO(coopernurse): add type-specific converter
return driver.NotNull{driver.DefaultParameterConverter} return driver.NotNull{Converter: driver.DefaultParameterConverter}
case "nullint64": case "nullint64":
// TODO(coopernurse): add type-specific converter // TODO(coopernurse): add type-specific converter
return driver.Null{driver.DefaultParameterConverter} return driver.Null{Converter: driver.DefaultParameterConverter}
case "float64": case "float64":
// TODO(coopernurse): add type-specific converter // TODO(coopernurse): add type-specific converter
return driver.NotNull{driver.DefaultParameterConverter} return driver.NotNull{Converter: driver.DefaultParameterConverter}
case "nullfloat64": case "nullfloat64":
// TODO(coopernurse): add type-specific converter // TODO(coopernurse): add type-specific converter
return driver.Null{driver.DefaultParameterConverter} return driver.Null{Converter: driver.DefaultParameterConverter}
case "datetime": case "datetime":
return driver.DefaultParameterConverter return driver.DefaultParameterConverter
} }
......
...@@ -880,6 +880,10 @@ func (rs *Rows) Columns() ([]string, error) { ...@@ -880,6 +880,10 @@ func (rs *Rows) Columns() ([]string, error) {
// be modified and held indefinitely. The copy can be avoided by using // be modified and held indefinitely. The copy can be avoided by using
// an argument of type *RawBytes instead; see the documentation for // an argument of type *RawBytes instead; see the documentation for
// RawBytes for restrictions on its use. // RawBytes for restrictions on its use.
//
// If an argument has type *interface{}, Scan copies the value
// provided by the underlying driver without conversion. If the value
// is of type []byte, a copy is made and the caller owns the result.
func (rs *Rows) Scan(dest ...interface{}) error { func (rs *Rows) Scan(dest ...interface{}) error {
if rs.closed { if rs.closed {
return errors.New("sql: Rows closed") return errors.New("sql: Rows closed")
......
...@@ -24,7 +24,7 @@ type forkableWriter struct { ...@@ -24,7 +24,7 @@ type forkableWriter struct {
} }
func newForkableWriter() *forkableWriter { func newForkableWriter() *forkableWriter {
return &forkableWriter{bytes.NewBuffer(nil), nil, nil} return &forkableWriter{new(bytes.Buffer), nil, nil}
} }
func (f *forkableWriter) fork() (pre, post *forkableWriter) { func (f *forkableWriter) fork() (pre, post *forkableWriter) {
......
...@@ -125,6 +125,13 @@ func (enc *Encoding) Encode(dst, src []byte) { ...@@ -125,6 +125,13 @@ func (enc *Encoding) Encode(dst, src []byte) {
} }
} }
// EncodeToString returns the base32 encoding of src.
func (enc *Encoding) EncodeToString(src []byte) string {
buf := make([]byte, enc.EncodedLen(len(src)))
enc.Encode(buf, src)
return string(buf)
}
type encoder struct { type encoder struct {
err error err error
enc *Encoding enc *Encoding
...@@ -221,24 +228,32 @@ func (e CorruptInputError) Error() string { ...@@ -221,24 +228,32 @@ func (e CorruptInputError) Error() string {
// decode is like Decode but returns an additional 'end' value, which // decode is like Decode but returns an additional 'end' value, which
// indicates if end-of-message padding was encountered and thus any // indicates if end-of-message padding was encountered and thus any
// additional data is an error. decode also assumes len(src)%8==0, // additional data is an error.
// since it is meant for internal use.
func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
for i := 0; i < len(src)/8 && !end; i++ { osrc := src
for len(src) > 0 && !end {
// Decode quantum using the base32 alphabet // Decode quantum using the base32 alphabet
var dbuf [8]byte var dbuf [8]byte
dlen := 8 dlen := 8
// do the top bytes contain any data? // do the top bytes contain any data?
dbufloop: dbufloop:
for j := 0; j < 8; j++ { for j := 0; j < 8; {
in := src[i*8+j] if len(src) == 0 {
if in == '=' && j >= 2 && i == len(src)/8-1 { return n, false, CorruptInputError(len(osrc) - len(src) - j)
}
in := src[0]
src = src[1:]
if in == '\r' || in == '\n' {
// Ignore this character.
continue
}
if in == '=' && j >= 2 && len(src) < 8 {
// We've reached the end and there's // We've reached the end and there's
// padding, the rest should be padded // padding, the rest should be padded
for k := j; k < 8; k++ { for k := 0; k < 8-j-1; k++ {
if src[i*8+k] != '=' { if len(src) > k && src[k] != '=' {
return n, false, CorruptInputError(i*8 + j) return n, false, CorruptInputError(len(osrc) - len(src) + k - 1)
} }
} }
dlen = j dlen = j
...@@ -247,28 +262,30 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { ...@@ -247,28 +262,30 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
} }
dbuf[j] = enc.decodeMap[in] dbuf[j] = enc.decodeMap[in]
if dbuf[j] == 0xFF { if dbuf[j] == 0xFF {
return n, false, CorruptInputError(i*8 + j) return n, false, CorruptInputError(len(osrc) - len(src) - 1)
} }
j++
} }
// Pack 8x 5-bit source blocks into 5 byte destination // Pack 8x 5-bit source blocks into 5 byte destination
// quantum // quantum
switch dlen { switch dlen {
case 7, 8: case 7, 8:
dst[i*5+4] = dbuf[6]<<5 | dbuf[7] dst[4] = dbuf[6]<<5 | dbuf[7]
fallthrough fallthrough
case 6, 5: case 6, 5:
dst[i*5+3] = dbuf[4]<<7 | dbuf[5]<<2 | dbuf[6]>>3 dst[3] = dbuf[4]<<7 | dbuf[5]<<2 | dbuf[6]>>3
fallthrough fallthrough
case 4: case 4:
dst[i*5+2] = dbuf[3]<<4 | dbuf[4]>>1 dst[2] = dbuf[3]<<4 | dbuf[4]>>1
fallthrough fallthrough
case 3: case 3:
dst[i*5+1] = dbuf[1]<<6 | dbuf[2]<<1 | dbuf[3]>>4 dst[1] = dbuf[1]<<6 | dbuf[2]<<1 | dbuf[3]>>4
fallthrough fallthrough
case 2: case 2:
dst[i*5+0] = dbuf[0]<<3 | dbuf[1]>>2 dst[0] = dbuf[0]<<3 | dbuf[1]>>2
} }
dst = dst[5:]
switch dlen { switch dlen {
case 2: case 2:
n += 1 n += 1
...@@ -289,15 +306,19 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { ...@@ -289,15 +306,19 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
// DecodedLen(len(src)) bytes to dst and returns the number of bytes // DecodedLen(len(src)) bytes to dst and returns the number of bytes
// written. If src contains invalid base32 data, it will return the // written. If src contains invalid base32 data, it will return the
// number of bytes successfully written and CorruptInputError. // number of bytes successfully written and CorruptInputError.
// New line characters (\r and \n) are ignored.
func (enc *Encoding) Decode(dst, src []byte) (n int, err error) { func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
if len(src)%8 != 0 {
return 0, CorruptInputError(len(src) / 8 * 8)
}
n, _, err = enc.decode(dst, src) n, _, err = enc.decode(dst, src)
return return
} }
// DecodeString returns the bytes represented by the base32 string s.
func (enc *Encoding) DecodeString(s string) ([]byte, error) {
dbuf := make([]byte, enc.DecodedLen(len(s)))
n, err := enc.Decode(dbuf, []byte(s))
return dbuf[:n], err
}
type decoder struct { type decoder struct {
err error err error
enc *Encoding enc *Encoding
......
...@@ -51,9 +51,8 @@ func testEqual(t *testing.T, msg string, args ...interface{}) bool { ...@@ -51,9 +51,8 @@ func testEqual(t *testing.T, msg string, args ...interface{}) bool {
func TestEncode(t *testing.T) { func TestEncode(t *testing.T) {
for _, p := range pairs { for _, p := range pairs {
buf := make([]byte, StdEncoding.EncodedLen(len(p.decoded))) got := StdEncoding.EncodeToString([]byte(p.decoded))
StdEncoding.Encode(buf, []byte(p.decoded)) testEqual(t, "Encode(%q) = %q, want %q", p.decoded, got, p.encoded)
testEqual(t, "Encode(%q) = %q, want %q", p.decoded, string(buf), p.encoded)
} }
} }
...@@ -99,6 +98,10 @@ func TestDecode(t *testing.T) { ...@@ -99,6 +98,10 @@ func TestDecode(t *testing.T) {
testEqual(t, "Decode(%q) = %q, want %q", p.encoded, testEqual(t, "Decode(%q) = %q, want %q", p.encoded,
string(dbuf[0:count]), string(dbuf[0:count]),
p.decoded) p.decoded)
dbuf, err = StdEncoding.DecodeString(p.encoded)
testEqual(t, "DecodeString(%q) = error %v, want %v", p.encoded, err, error(nil))
testEqual(t, "DecodeString(%q) = %q, want %q", p.encoded, string(dbuf), p.decoded)
} }
} }
...@@ -191,3 +194,29 @@ func TestBig(t *testing.T) { ...@@ -191,3 +194,29 @@ func TestBig(t *testing.T) {
t.Errorf("Decode(Encode(%d-byte string)) failed at offset %d", n, i) t.Errorf("Decode(Encode(%d-byte string)) failed at offset %d", n, i)
} }
} }
func TestNewLineCharacters(t *testing.T) {
// Each of these should decode to the string "sure", without errors.
const expected = "sure"
examples := []string{
"ON2XEZI=",
"ON2XEZI=\r",
"ON2XEZI=\n",
"ON2XEZI=\r\n",
"ON2XEZ\r\nI=",
"ON2X\rEZ\nI=",
"ON2X\nEZ\rI=",
"ON2XEZ\nI=",
"ON2XEZI\n=",
}
for _, e := range examples {
buf, err := StdEncoding.DecodeString(e)
if err != nil {
t.Errorf("Decode(%q) failed: %v", e, err)
continue
}
if s := string(buf); s != expected {
t.Errorf("Decode(%q) = %q, want %q", e, s, expected)
}
}
}
...@@ -208,22 +208,30 @@ func (e CorruptInputError) Error() string { ...@@ -208,22 +208,30 @@ func (e CorruptInputError) Error() string {
// decode is like Decode but returns an additional 'end' value, which // decode is like Decode but returns an additional 'end' value, which
// indicates if end-of-message padding was encountered and thus any // indicates if end-of-message padding was encountered and thus any
// additional data is an error. decode also assumes len(src)%4==0, // additional data is an error.
// since it is meant for internal use.
func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
for i := 0; i < len(src)/4 && !end; i++ { osrc := src
for len(src) > 0 && !end {
// Decode quantum using the base64 alphabet // Decode quantum using the base64 alphabet
var dbuf [4]byte var dbuf [4]byte
dlen := 4 dlen := 4
dbufloop: dbufloop:
for j := 0; j < 4; j++ { for j := 0; j < 4; {
in := src[i*4+j] if len(src) == 0 {
if in == '=' && j >= 2 && i == len(src)/4-1 { return n, false, CorruptInputError(len(osrc) - len(src) - j)
}
in := src[0]
src = src[1:]
if in == '\r' || in == '\n' {
// Ignore this character.
continue
}
if in == '=' && j >= 2 && len(src) < 4 {
// We've reached the end and there's // We've reached the end and there's
// padding // padding
if src[i*4+3] != '=' { if len(src) > 0 && src[0] != '=' {
return n, false, CorruptInputError(i*4 + 2) return n, false, CorruptInputError(len(osrc) - len(src) - 1)
} }
dlen = j dlen = j
end = true end = true
...@@ -231,22 +239,24 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { ...@@ -231,22 +239,24 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
} }
dbuf[j] = enc.decodeMap[in] dbuf[j] = enc.decodeMap[in]
if dbuf[j] == 0xFF { if dbuf[j] == 0xFF {
return n, false, CorruptInputError(i*4 + j) return n, false, CorruptInputError(len(osrc) - len(src) - 1)
} }
j++
} }
// Pack 4x 6-bit source blocks into 3 byte destination // Pack 4x 6-bit source blocks into 3 byte destination
// quantum // quantum
switch dlen { switch dlen {
case 4: case 4:
dst[i*3+2] = dbuf[2]<<6 | dbuf[3] dst[2] = dbuf[2]<<6 | dbuf[3]
fallthrough fallthrough
case 3: case 3:
dst[i*3+1] = dbuf[1]<<4 | dbuf[2]>>2 dst[1] = dbuf[1]<<4 | dbuf[2]>>2
fallthrough fallthrough
case 2: case 2:
dst[i*3+0] = dbuf[0]<<2 | dbuf[1]>>4 dst[0] = dbuf[0]<<2 | dbuf[1]>>4
} }
dst = dst[3:]
n += dlen - 1 n += dlen - 1
} }
...@@ -257,11 +267,8 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { ...@@ -257,11 +267,8 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
// DecodedLen(len(src)) bytes to dst and returns the number of bytes // DecodedLen(len(src)) bytes to dst and returns the number of bytes
// written. If src contains invalid base64 data, it will return the // written. If src contains invalid base64 data, it will return the
// number of bytes successfully written and CorruptInputError. // number of bytes successfully written and CorruptInputError.
// New line characters (\r and \n) are ignored.
func (enc *Encoding) Decode(dst, src []byte) (n int, err error) { func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
if len(src)%4 != 0 {
return 0, CorruptInputError(len(src) / 4 * 4)
}
n, _, err = enc.decode(dst, src) n, _, err = enc.decode(dst, src)
return return
} }
......
...@@ -197,3 +197,29 @@ func TestBig(t *testing.T) { ...@@ -197,3 +197,29 @@ func TestBig(t *testing.T) {
t.Errorf("Decode(Encode(%d-byte string)) failed at offset %d", n, i) t.Errorf("Decode(Encode(%d-byte string)) failed at offset %d", n, i)
} }
} }
func TestNewLineCharacters(t *testing.T) {
// Each of these should decode to the string "sure", without errors.
const expected = "sure"
examples := []string{
"c3VyZQ==",
"c3VyZQ==\r",
"c3VyZQ==\n",
"c3VyZQ==\r\n",
"c3VyZ\r\nQ==",
"c3V\ryZ\nQ==",
"c3V\nyZ\rQ==",
"c3VyZ\nQ==",
"c3VyZQ\n==",
}
for _, e := range examples {
buf, err := StdEncoding.DecodeString(e)
if err != nil {
t.Errorf("Decode(%q) failed: %v", e, err)
continue
}
if s := string(buf); s != expected {
t.Errorf("Decode(%q) = %q, want %q", e, s, expected)
}
}
}
...@@ -163,7 +163,7 @@ func Read(r io.Reader, order ByteOrder, data interface{}) error { ...@@ -163,7 +163,7 @@ func Read(r io.Reader, order ByteOrder, data interface{}) error {
default: default:
return errors.New("binary.Read: invalid type " + d.Type().String()) return errors.New("binary.Read: invalid type " + d.Type().String())
} }
size := TotalSize(v) size := dataSize(v)
if size < 0 { if size < 0 {
return errors.New("binary.Read: invalid type " + v.Type().String()) return errors.New("binary.Read: invalid type " + v.Type().String())
} }
...@@ -242,7 +242,7 @@ func Write(w io.Writer, order ByteOrder, data interface{}) error { ...@@ -242,7 +242,7 @@ func Write(w io.Writer, order ByteOrder, data interface{}) error {
return err return err
} }
v := reflect.Indirect(reflect.ValueOf(data)) v := reflect.Indirect(reflect.ValueOf(data))
size := TotalSize(v) size := dataSize(v)
if size < 0 { if size < 0 {
return errors.New("binary.Write: invalid type " + v.Type().String()) return errors.New("binary.Write: invalid type " + v.Type().String())
} }
...@@ -253,7 +253,11 @@ func Write(w io.Writer, order ByteOrder, data interface{}) error { ...@@ -253,7 +253,11 @@ func Write(w io.Writer, order ByteOrder, data interface{}) error {
return err return err
} }
func TotalSize(v reflect.Value) int { // dataSize returns the number of bytes the actual data represented by v occupies in memory.
// For compound structures, it sums the sizes of the elements. Thus, for instance, for a slice
// it returns the length of the slice times the element size and does not count the memory
// occupied by the header.
func dataSize(v reflect.Value) int {
if v.Kind() == reflect.Slice { if v.Kind() == reflect.Slice {
elem := sizeof(v.Type().Elem()) elem := sizeof(v.Type().Elem())
if elem < 0 { if elem < 0 {
......
...@@ -187,7 +187,7 @@ func BenchmarkReadStruct(b *testing.B) { ...@@ -187,7 +187,7 @@ func BenchmarkReadStruct(b *testing.B) {
bsr := &byteSliceReader{} bsr := &byteSliceReader{}
var buf bytes.Buffer var buf bytes.Buffer
Write(&buf, BigEndian, &s) Write(&buf, BigEndian, &s)
n := TotalSize(reflect.ValueOf(s)) n := dataSize(reflect.ValueOf(s))
b.SetBytes(int64(n)) b.SetBytes(int64(n))
t := s t := s
b.ResetTimer() b.ResetTimer()
......
...@@ -156,6 +156,9 @@ func (r *Reader) Read() (record []string, err error) { ...@@ -156,6 +156,9 @@ func (r *Reader) Read() (record []string, err error) {
// ReadAll reads all the remaining records from r. // ReadAll reads all the remaining records from r.
// Each record is a slice of fields. // Each record is a slice of fields.
// A successful call returns err == nil, not err == EOF. Because ReadAll is
// defined to read until EOF, it does not treat end of file as an error to be
// reported.
func (r *Reader) ReadAll() (records [][]string, err error) { func (r *Reader) ReadAll() (records [][]string, err error) {
for { for {
record, err := r.Read() record, err := r.Read()
......
...@@ -8,9 +8,11 @@ import ( ...@@ -8,9 +8,11 @@ import (
"bytes" "bytes"
"errors" "errors"
"math" "math"
"math/rand"
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
"time"
"unsafe" "unsafe"
) )
...@@ -1407,3 +1409,60 @@ func TestDebugStruct(t *testing.T) { ...@@ -1407,3 +1409,60 @@ func TestDebugStruct(t *testing.T) {
} }
debugFunc(debugBuffer) debugFunc(debugBuffer)
} }
func encFuzzDec(rng *rand.Rand, in interface{}) error {
buf := new(bytes.Buffer)
enc := NewEncoder(buf)
if err := enc.Encode(&in); err != nil {
return err
}
b := buf.Bytes()
for i, bi := range b {
if rng.Intn(10) < 3 {
b[i] = bi + uint8(rng.Intn(256))
}
}
dec := NewDecoder(buf)
var e interface{}
if err := dec.Decode(&e); err != nil {
return err
}
return nil
}
// This does some "fuzz testing" by attempting to decode a sequence of random bytes.
func TestFuzz(t *testing.T) {
if testing.Short() {
return
}
// all possible inputs
input := []interface{}{
new(int),
new(float32),
new(float64),
new(complex128),
&ByteStruct{255},
&ArrayStruct{},
&StringStruct{"hello"},
&GobTest1{0, &StringStruct{"hello"}},
}
testFuzz(t, time.Now().UnixNano(), 100, input...)
}
func TestFuzzRegressions(t *testing.T) {
// An instance triggering a type name of length ~102 GB.
testFuzz(t, 1328492090837718000, 100, new(float32))
}
func testFuzz(t *testing.T, seed int64, n int, input ...interface{}) {
t.Logf("seed=%d n=%d\n", seed, n)
for _, e := range input {
rng := rand.New(rand.NewSource(seed))
for i := 0; i < n; i++ {
encFuzzDec(rng, e)
}
}
}
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Delete the next line to include this file in the gob package. // Delete the next line to include in the gob package.
// +build ignore // +build gob-debug
package gob package gob
......
...@@ -690,7 +690,11 @@ func (dec *Decoder) decodeInterface(ityp reflect.Type, state *decoderState, p ui ...@@ -690,7 +690,11 @@ func (dec *Decoder) decodeInterface(ityp reflect.Type, state *decoderState, p ui
// Create a writable interface reflect.Value. We need one even for the nil case. // Create a writable interface reflect.Value. We need one even for the nil case.
ivalue := allocValue(ityp) ivalue := allocValue(ityp)
// Read the name of the concrete type. // Read the name of the concrete type.
b := make([]byte, state.decodeUint()) nr := state.decodeUint()
if nr < 0 || nr > 1<<31 { // zero is permissible for anonymous types
errorf("invalid type name length %d", nr)
}
b := make([]byte, nr)
state.b.Read(b) state.b.Read(b)
name := string(b) name := string(b)
if name == "" { if name == "" {
......
...@@ -135,7 +135,7 @@ func (dec *Decoder) nextUint() uint64 { ...@@ -135,7 +135,7 @@ func (dec *Decoder) nextUint() uint64 {
// and returns the type id of the next value. It returns -1 at // and returns the type id of the next value. It returns -1 at
// EOF. Upon return, the remainder of dec.buf is the value to be // EOF. Upon return, the remainder of dec.buf is the value to be
// decoded. If this is an interface value, it can be ignored by // decoded. If this is an interface value, it can be ignored by
// simply resetting that buffer. // resetting that buffer.
func (dec *Decoder) decodeTypeSequence(isInterface bool) typeId { func (dec *Decoder) decodeTypeSequence(isInterface bool) typeId {
for dec.err == nil { for dec.err == nil {
if dec.buf.Len() == 0 { if dec.buf.Len() == 0 {
......
...@@ -70,7 +70,7 @@ operation will fail. ...@@ -70,7 +70,7 @@ operation will fail.
Structs, arrays and slices are also supported. Strings and arrays of bytes are Structs, arrays and slices are also supported. Strings and arrays of bytes are
supported with a special, efficient representation (see below). When a slice is supported with a special, efficient representation (see below). When a slice is
decoded, if the existing slice has capacity the slice will be extended in place; decoded, if the existing slice has capacity the slice will be extended in place;
if not, a new array is allocated. Regardless, the length of the resuling slice if not, a new array is allocated. Regardless, the length of the resulting slice
reports the number of elements decoded. reports the number of elements decoded.
Functions and channels cannot be sent in a gob. Attempting Functions and channels cannot be sent in a gob. Attempting
...@@ -162,7 +162,7 @@ description, constructed from these types: ...@@ -162,7 +162,7 @@ description, constructed from these types:
StructT *StructType StructT *StructType
MapT *MapType MapT *MapType
} }
type ArrayType struct { type arrayType struct {
CommonType CommonType
Elem typeId Elem typeId
Len int Len int
...@@ -171,19 +171,19 @@ description, constructed from these types: ...@@ -171,19 +171,19 @@ description, constructed from these types:
Name string // the name of the struct type Name string // the name of the struct type
Id int // the id of the type, repeated so it's inside the type Id int // the id of the type, repeated so it's inside the type
} }
type SliceType struct { type sliceType struct {
CommonType CommonType
Elem typeId Elem typeId
} }
type StructType struct { type structType struct {
CommonType CommonType
Field []*fieldType // the fields of the struct. Field []*fieldType // the fields of the struct.
} }
type FieldType struct { type fieldType struct {
Name string // the name of the field. Name string // the name of the field.
Id int // the type id of the field, which must be already defined Id int // the type id of the field, which must be already defined
} }
type MapType struct { type mapType struct {
CommonType CommonType
Key typeId Key typeId
Elem typeId Elem typeId
...@@ -308,15 +308,15 @@ reserved). ...@@ -308,15 +308,15 @@ reserved).
// Set the field number implicitly to -1; this is done at the beginning // Set the field number implicitly to -1; this is done at the beginning
// of every struct, including nested structs. // of every struct, including nested structs.
03 // Add 3 to field number; now 2 (wireType.structType; this is a struct). 03 // Add 3 to field number; now 2 (wireType.structType; this is a struct).
// structType starts with an embedded commonType, which appears // structType starts with an embedded CommonType, which appears
// as a regular structure here too. // as a regular structure here too.
01 // add 1 to field number (now 0); start of embedded commonType. 01 // add 1 to field number (now 0); start of embedded CommonType.
01 // add 1 to field number (now 0, the name of the type) 01 // add 1 to field number (now 0, the name of the type)
05 // string is (unsigned) 5 bytes long 05 // string is (unsigned) 5 bytes long
50 6f 69 6e 74 // wireType.structType.commonType.name = "Point" 50 6f 69 6e 74 // wireType.structType.CommonType.name = "Point"
01 // add 1 to field number (now 1, the id of the type) 01 // add 1 to field number (now 1, the id of the type)
ff 82 // wireType.structType.commonType._id = 65 ff 82 // wireType.structType.CommonType._id = 65
00 // end of embedded wiretype.structType.commonType struct 00 // end of embedded wiretype.structType.CommonType struct
01 // add 1 to field number (now 1, the field array in wireType.structType) 01 // add 1 to field number (now 1, the field array in wireType.structType)
02 // There are two fields in the type (len(structType.field)) 02 // There are two fields in the type (len(structType.field))
01 // Start of first field structure; add 1 to get field number 0: field[0].name 01 // Start of first field structure; add 1 to get field number 0: field[0].name
......
...@@ -570,8 +570,7 @@ func TestGobMapInterfaceEncode(t *testing.T) { ...@@ -570,8 +570,7 @@ func TestGobMapInterfaceEncode(t *testing.T) {
"bo": []bool{false}, "bo": []bool{false},
"st": []string{"s"}, "st": []string{"s"},
} }
buf := bytes.NewBuffer(nil) enc := NewEncoder(new(bytes.Buffer))
enc := NewEncoder(buf)
err := enc.Encode(m) err := enc.Encode(m)
if err != nil { if err != nil {
t.Errorf("encode map: %s", err) t.Errorf("encode map: %s", err)
...@@ -579,7 +578,7 @@ func TestGobMapInterfaceEncode(t *testing.T) { ...@@ -579,7 +578,7 @@ func TestGobMapInterfaceEncode(t *testing.T) {
} }
func TestSliceReusesMemory(t *testing.T) { func TestSliceReusesMemory(t *testing.T) {
buf := bytes.NewBuffer(nil) buf := new(bytes.Buffer)
// Bytes // Bytes
{ {
x := []byte("abcd") x := []byte("abcd")
......
...@@ -33,7 +33,11 @@ func error_(err error) { ...@@ -33,7 +33,11 @@ func error_(err error) {
// plain error. It overwrites the error return of the function that deferred its call. // plain error. It overwrites the error return of the function that deferred its call.
func catchError(err *error) { func catchError(err *error) {
if e := recover(); e != nil { if e := recover(); e != nil {
*err = e.(gobError).err // Will re-panic if not one of our errors, such as a runtime error. ge, ok := e.(gobError)
if !ok {
panic(e)
}
*err = ge.err
} }
return return
} }
...@@ -53,8 +53,9 @@ func TestCountEncodeMallocs(t *testing.T) { ...@@ -53,8 +53,9 @@ func TestCountEncodeMallocs(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
enc := NewEncoder(&buf) enc := NewEncoder(&buf)
bench := &Bench{7, 3.2, "now is the time", []byte("for all good men")} bench := &Bench{7, 3.2, "now is the time", []byte("for all good men")}
runtime.UpdateMemStats() memstats := new(runtime.MemStats)
mallocs := 0 - runtime.MemStats.Mallocs runtime.ReadMemStats(memstats)
mallocs := 0 - memstats.Mallocs
const count = 1000 const count = 1000
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
err := enc.Encode(bench) err := enc.Encode(bench)
...@@ -62,8 +63,8 @@ func TestCountEncodeMallocs(t *testing.T) { ...@@ -62,8 +63,8 @@ func TestCountEncodeMallocs(t *testing.T) {
t.Fatal("encode:", err) t.Fatal("encode:", err)
} }
} }
runtime.UpdateMemStats() runtime.ReadMemStats(memstats)
mallocs += runtime.MemStats.Mallocs mallocs += memstats.Mallocs
fmt.Printf("mallocs per encode of type Bench: %d\n", mallocs/count) fmt.Printf("mallocs per encode of type Bench: %d\n", mallocs/count)
} }
...@@ -79,8 +80,9 @@ func TestCountDecodeMallocs(t *testing.T) { ...@@ -79,8 +80,9 @@ func TestCountDecodeMallocs(t *testing.T) {
} }
} }
dec := NewDecoder(&buf) dec := NewDecoder(&buf)
runtime.UpdateMemStats() memstats := new(runtime.MemStats)
mallocs := 0 - runtime.MemStats.Mallocs runtime.ReadMemStats(memstats)
mallocs := 0 - memstats.Mallocs
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
*bench = Bench{} *bench = Bench{}
err := dec.Decode(&bench) err := dec.Decode(&bench)
...@@ -88,7 +90,7 @@ func TestCountDecodeMallocs(t *testing.T) { ...@@ -88,7 +90,7 @@ func TestCountDecodeMallocs(t *testing.T) {
t.Fatal("decode:", err) t.Fatal("decode:", err)
} }
} }
runtime.UpdateMemStats() runtime.ReadMemStats(memstats)
mallocs += runtime.MemStats.Mallocs mallocs += memstats.Mallocs
fmt.Printf("mallocs per decode of type Bench: %d\n", mallocs/count) fmt.Printf("mallocs per decode of type Bench: %d\n", mallocs/count)
} }
...@@ -180,7 +180,10 @@ func (t typeId) name() string { ...@@ -180,7 +180,10 @@ func (t typeId) name() string {
return t.gobType().name() return t.gobType().name()
} }
// Common elements of all types. // CommonType holds elements of all types.
// It is a historical artifact, kept for binary compatibility and exported
// only for the benefit of the package's encoding of type descriptors. It is
// not intended for direct use by clients.
type CommonType struct { type CommonType struct {
Name string Name string
Id typeId Id typeId
......
...@@ -7,8 +7,9 @@ package hex ...@@ -7,8 +7,9 @@ package hex
import ( import (
"bytes" "bytes"
"errors"
"fmt"
"io" "io"
"strconv"
) )
const hextable = "0123456789abcdef" const hextable = "0123456789abcdef"
...@@ -29,16 +30,14 @@ func Encode(dst, src []byte) int { ...@@ -29,16 +30,14 @@ func Encode(dst, src []byte) int {
return len(src) * 2 return len(src) * 2
} }
// OddLengthInputError results from decoding an odd length slice. // ErrLength results from decoding an odd length slice.
type OddLengthInputError struct{} var ErrLength = errors.New("encoding/hex: odd length hex string")
func (OddLengthInputError) Error() string { return "odd length hex string" } // InvalidByteError values describe errors resulting from an invalid byte in a hex string.
type InvalidByteError byte
// InvalidHexCharError results from finding an invalid character in a hex string. func (e InvalidByteError) Error() string {
type InvalidHexCharError byte return fmt.Sprintf("encoding/hex: invalid byte: %#U", rune(e))
func (e InvalidHexCharError) Error() string {
return "invalid hex char: " + strconv.Itoa(int(e))
} }
func DecodedLen(x int) int { return x / 2 } func DecodedLen(x int) int { return x / 2 }
...@@ -46,21 +45,20 @@ func DecodedLen(x int) int { return x / 2 } ...@@ -46,21 +45,20 @@ func DecodedLen(x int) int { return x / 2 }
// Decode decodes src into DecodedLen(len(src)) bytes, returning the actual // Decode decodes src into DecodedLen(len(src)) bytes, returning the actual
// number of bytes written to dst. // number of bytes written to dst.
// //
// If Decode encounters invalid input, it returns an OddLengthInputError or an // If Decode encounters invalid input, it returns an error describing the failure.
// InvalidHexCharError.
func Decode(dst, src []byte) (int, error) { func Decode(dst, src []byte) (int, error) {
if len(src)%2 == 1 { if len(src)%2 == 1 {
return 0, OddLengthInputError{} return 0, ErrLength
} }
for i := 0; i < len(src)/2; i++ { for i := 0; i < len(src)/2; i++ {
a, ok := fromHexChar(src[i*2]) a, ok := fromHexChar(src[i*2])
if !ok { if !ok {
return 0, InvalidHexCharError(src[i*2]) return 0, InvalidByteError(src[i*2])
} }
b, ok := fromHexChar(src[i*2+1]) b, ok := fromHexChar(src[i*2+1])
if !ok { if !ok {
return 0, InvalidHexCharError(src[i*2+1]) return 0, InvalidByteError(src[i*2+1])
} }
dst[i] = (a << 4) | b dst[i] = (a << 4) | b
} }
...@@ -103,8 +101,8 @@ func DecodeString(s string) ([]byte, error) { ...@@ -103,8 +101,8 @@ func DecodeString(s string) ([]byte, error) {
// Dump returns a string that contains a hex dump of the given data. The format // Dump returns a string that contains a hex dump of the given data. The format
// of the hex dump matches the output of `hexdump -C` on the command line. // of the hex dump matches the output of `hexdump -C` on the command line.
func Dump(data []byte) string { func Dump(data []byte) string {
buf := bytes.NewBuffer(nil) var buf bytes.Buffer
dumper := Dumper(buf) dumper := Dumper(&buf)
dumper.Write(data) dumper.Write(data)
dumper.Close() dumper.Close()
return string(buf.Bytes()) return string(buf.Bytes())
......
...@@ -9,141 +9,98 @@ import ( ...@@ -9,141 +9,98 @@ import (
"testing" "testing"
) )
type encodeTest struct { type encDecTest struct {
in, out []byte enc string
dec []byte
} }
var encodeTests = []encodeTest{ var encDecTests = []encDecTest{
{[]byte{}, []byte{}}, {"", []byte{}},
{[]byte{0x01}, []byte{'0', '1'}}, {"0001020304050607", []byte{0, 1, 2, 3, 4, 5, 6, 7}},
{[]byte{0xff}, []byte{'f', 'f'}}, {"08090a0b0c0d0e0f", []byte{8, 9, 10, 11, 12, 13, 14, 15}},
{[]byte{0xff, 00}, []byte{'f', 'f', '0', '0'}}, {"f0f1f2f3f4f5f6f7", []byte{0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7}},
{[]byte{0}, []byte{'0', '0'}}, {"f8f9fafbfcfdfeff", []byte{0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff}},
{[]byte{1}, []byte{'0', '1'}}, {"67", []byte{'g'}},
{[]byte{2}, []byte{'0', '2'}}, {"e3a1", []byte{0xe3, 0xa1}},
{[]byte{3}, []byte{'0', '3'}},
{[]byte{4}, []byte{'0', '4'}},
{[]byte{5}, []byte{'0', '5'}},
{[]byte{6}, []byte{'0', '6'}},
{[]byte{7}, []byte{'0', '7'}},
{[]byte{8}, []byte{'0', '8'}},
{[]byte{9}, []byte{'0', '9'}},
{[]byte{10}, []byte{'0', 'a'}},
{[]byte{11}, []byte{'0', 'b'}},
{[]byte{12}, []byte{'0', 'c'}},
{[]byte{13}, []byte{'0', 'd'}},
{[]byte{14}, []byte{'0', 'e'}},
{[]byte{15}, []byte{'0', 'f'}},
} }
func TestEncode(t *testing.T) { func TestEncode(t *testing.T) {
for i, test := range encodeTests { for i, test := range encDecTests {
dst := make([]byte, EncodedLen(len(test.in))) dst := make([]byte, EncodedLen(len(test.dec)))
n := Encode(dst, test.in) n := Encode(dst, test.dec)
if n != len(dst) { if n != len(dst) {
t.Errorf("#%d: bad return value: got: %d want: %d", i, n, len(dst)) t.Errorf("#%d: bad return value: got: %d want: %d", i, n, len(dst))
} }
if bytes.Compare(dst, test.out) != 0 { if string(dst) != test.enc {
t.Errorf("#%d: got: %#v want: %#v", i, dst, test.out) t.Errorf("#%d: got: %#v want: %#v", i, dst, test.enc)
} }
} }
} }
type decodeTest struct {
in, out []byte
ok bool
}
var decodeTests = []decodeTest{
{[]byte{}, []byte{}, true},
{[]byte{'0'}, []byte{}, false},
{[]byte{'0', 'g'}, []byte{}, false},
{[]byte{'0', '\x01'}, []byte{}, false},
{[]byte{'0', '0'}, []byte{0}, true},
{[]byte{'0', '1'}, []byte{1}, true},
{[]byte{'0', '2'}, []byte{2}, true},
{[]byte{'0', '3'}, []byte{3}, true},
{[]byte{'0', '4'}, []byte{4}, true},
{[]byte{'0', '5'}, []byte{5}, true},
{[]byte{'0', '6'}, []byte{6}, true},
{[]byte{'0', '7'}, []byte{7}, true},
{[]byte{'0', '8'}, []byte{8}, true},
{[]byte{'0', '9'}, []byte{9}, true},
{[]byte{'0', 'a'}, []byte{10}, true},
{[]byte{'0', 'b'}, []byte{11}, true},
{[]byte{'0', 'c'}, []byte{12}, true},
{[]byte{'0', 'd'}, []byte{13}, true},
{[]byte{'0', 'e'}, []byte{14}, true},
{[]byte{'0', 'f'}, []byte{15}, true},
{[]byte{'0', 'A'}, []byte{10}, true},
{[]byte{'0', 'B'}, []byte{11}, true},
{[]byte{'0', 'C'}, []byte{12}, true},
{[]byte{'0', 'D'}, []byte{13}, true},
{[]byte{'0', 'E'}, []byte{14}, true},
{[]byte{'0', 'F'}, []byte{15}, true},
}
func TestDecode(t *testing.T) { func TestDecode(t *testing.T) {
for i, test := range decodeTests { for i, test := range encDecTests {
dst := make([]byte, DecodedLen(len(test.in))) dst := make([]byte, DecodedLen(len(test.enc)))
n, err := Decode(dst, test.in) n, err := Decode(dst, []byte(test.enc))
if err == nil && n != len(dst) { if err != nil {
t.Errorf("#%d: bad return value: got:%d want:%d", i, n, len(dst)) t.Errorf("#%d: bad return value: got:%d want:%d", i, n, len(dst))
} } else if !bytes.Equal(dst, test.dec) {
if test.ok != (err == nil) { t.Errorf("#%d: got: %#v want: %#v", i, dst, test.dec)
t.Errorf("#%d: unexpected err value: %s", i, err)
}
if err == nil && bytes.Compare(dst, test.out) != 0 {
t.Errorf("#%d: got: %#v want: %#v", i, dst, test.out)
} }
} }
} }
type encodeStringTest struct { func TestEncodeToString(t *testing.T) {
in []byte for i, test := range encDecTests {
out string s := EncodeToString(test.dec)
} if s != test.enc {
t.Errorf("#%d got:%s want:%s", i, s, test.enc)
var encodeStringTests = []encodeStringTest{ }
{[]byte{}, ""}, }
{[]byte{0}, "00"},
{[]byte{0, 1}, "0001"},
{[]byte{0, 1, 255}, "0001ff"},
} }
func TestEncodeToString(t *testing.T) { func TestDecodeString(t *testing.T) {
for i, test := range encodeStringTests { for i, test := range encDecTests {
s := EncodeToString(test.in) dst, err := DecodeString(test.enc)
if s != test.out { if err != nil {
t.Errorf("#%d got:%s want:%s", i, s, test.out) t.Errorf("#%d: unexpected err value: %s", i, err)
continue
}
if bytes.Compare(dst, test.dec) != 0 {
t.Errorf("#%d: got: %#v want: #%v", i, dst, test.dec)
} }
} }
} }
type decodeStringTest struct { type errTest struct {
in string in string
out []byte err string
ok bool
} }
var decodeStringTests = []decodeStringTest{ var errTests = []errTest{
{"", []byte{}, true}, {"0", "encoding/hex: odd length hex string"},
{"0", []byte{}, false}, {"0g", "encoding/hex: invalid byte: U+0067 'g'"},
{"00", []byte{0}, true}, {"0\x01", "encoding/hex: invalid byte: U+0001"},
{"0\x01", []byte{}, false},
{"0g", []byte{}, false},
{"00ff00", []byte{0, 255, 0}, true},
{"0000ff", []byte{0, 0, 255}, true},
} }
func TestDecodeString(t *testing.T) { func TestInvalidErr(t *testing.T) {
for i, test := range decodeStringTests { for i, test := range errTests {
dst, err := DecodeString(test.in) dst := make([]byte, DecodedLen(len(test.in)))
if test.ok != (err == nil) { _, err := Decode(dst, []byte(test.in))
t.Errorf("#%d: unexpected err value: %s", i, err) if err == nil {
t.Errorf("#%d: expected error; got none")
} else if err.Error() != test.err {
t.Errorf("#%d: got: %v want: %v", i, err, test.err)
} }
if err == nil && bytes.Compare(dst, test.out) != 0 { }
t.Errorf("#%d: got: %#v want: #%v", i, dst, test.out) }
func TestInvalidStringErr(t *testing.T) {
for i, test := range errTests {
_, err := DecodeString(test.in)
if err == nil {
t.Errorf("#%d: expected error; got none")
} else if err.Error() != test.err {
t.Errorf("#%d: got: %v want: %v", i, err, test.err)
} }
} }
} }
...@@ -155,8 +112,8 @@ func TestDumper(t *testing.T) { ...@@ -155,8 +112,8 @@ func TestDumper(t *testing.T) {
} }
for stride := 1; stride < len(in); stride++ { for stride := 1; stride < len(in); stride++ {
out := bytes.NewBuffer(nil) var out bytes.Buffer
dumper := Dumper(out) dumper := Dumper(&out)
done := 0 done := 0
for done < len(in) { for done < len(in) {
todo := done + stride todo := done + stride
......
...@@ -598,3 +598,24 @@ var pallValueIndent = `{ ...@@ -598,3 +598,24 @@ var pallValueIndent = `{
}` }`
var pallValueCompact = strings.Map(noSpace, pallValueIndent) var pallValueCompact = strings.Map(noSpace, pallValueIndent)
func TestRefUnmarshal(t *testing.T) {
type S struct {
// Ref is defined in encode_test.go.
R0 Ref
R1 *Ref
}
want := S{
R0: 12,
R1: new(Ref),
}
*want.R1 = 12
var got S
if err := Unmarshal([]byte(`{"R0":"ref","R1":"ref"}`), &got); err != nil {
t.Fatalf("Unmarshal: %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %+v, want %+v", got, want)
}
}
...@@ -262,8 +262,18 @@ func (e *encodeState) reflectValueQuoted(v reflect.Value, quoted bool) { ...@@ -262,8 +262,18 @@ func (e *encodeState) reflectValueQuoted(v reflect.Value, quoted bool) {
return return
} }
if j, ok := v.Interface().(Marshaler); ok && (v.Kind() != reflect.Ptr || !v.IsNil()) { m, ok := v.Interface().(Marshaler)
b, err := j.MarshalJSON() if !ok {
// T doesn't match the interface. Check against *T too.
if v.Kind() != reflect.Ptr && v.CanAddr() {
m, ok = v.Addr().Interface().(Marshaler)
if ok {
v = v.Addr()
}
}
}
if ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
b, err := m.MarshalJSON()
if err == nil { if err == nil {
// copy JSON into buffer, checking validity. // copy JSON into buffer, checking validity.
err = Compact(&e.Buffer, b) err = Compact(&e.Buffer, b)
......
...@@ -126,3 +126,44 @@ func TestUnsupportedValues(t *testing.T) { ...@@ -126,3 +126,44 @@ func TestUnsupportedValues(t *testing.T) {
} }
} }
} }
// Ref has Marshaler and Unmarshaler methods with pointer receiver.
type Ref int
func (*Ref) MarshalJSON() ([]byte, error) {
return []byte(`"ref"`), nil
}
func (r *Ref) UnmarshalJSON([]byte) error {
*r = 12
return nil
}
// Val has Marshaler methods with value receiver.
type Val int
func (Val) MarshalJSON() ([]byte, error) {
return []byte(`"val"`), nil
}
func TestRefValMarshal(t *testing.T) {
var s = struct {
R0 Ref
R1 *Ref
V0 Val
V1 *Val
}{
R0: 12,
R1: new(Ref),
V0: 13,
V1: new(Val),
}
const want = `{"R0":"ref","R1":"ref","V0":"val","V1":"val"}`
b, err := Marshal(&s)
if err != nil {
t.Fatalf("Marshal: %v", err)
}
if got := string(b); got != want {
t.Errorf("got %q, want %q", got, want)
}
}
...@@ -185,18 +185,9 @@ func isSpace(c rune) bool { ...@@ -185,18 +185,9 @@ func isSpace(c rune) bool {
return c == ' ' || c == '\t' || c == '\r' || c == '\n' return c == ' ' || c == '\t' || c == '\r' || c == '\n'
} }
// NOTE(rsc): The various instances of
//
// if c <= ' ' && (c == ' ' || c == '\t' || c == '\r' || c == '\n')
//
// below should all be if c <= ' ' && isSpace(c), but inlining
// the checks makes a significant difference (>10%) in tight loops
// such as nextValue. These should be rewritten with the clearer
// function call once 6g knows to inline the call.
// stateBeginValueOrEmpty is the state after reading `[`. // stateBeginValueOrEmpty is the state after reading `[`.
func stateBeginValueOrEmpty(s *scanner, c int) int { func stateBeginValueOrEmpty(s *scanner, c int) int {
if c <= ' ' && (c == ' ' || c == '\t' || c == '\r' || c == '\n') { if c <= ' ' && isSpace(rune(c)) {
return scanSkipSpace return scanSkipSpace
} }
if c == ']' { if c == ']' {
...@@ -207,7 +198,7 @@ func stateBeginValueOrEmpty(s *scanner, c int) int { ...@@ -207,7 +198,7 @@ func stateBeginValueOrEmpty(s *scanner, c int) int {
// stateBeginValue is the state at the beginning of the input. // stateBeginValue is the state at the beginning of the input.
func stateBeginValue(s *scanner, c int) int { func stateBeginValue(s *scanner, c int) int {
if c <= ' ' && (c == ' ' || c == '\t' || c == '\r' || c == '\n') { if c <= ' ' && isSpace(rune(c)) {
return scanSkipSpace return scanSkipSpace
} }
switch c { switch c {
...@@ -247,7 +238,7 @@ func stateBeginValue(s *scanner, c int) int { ...@@ -247,7 +238,7 @@ func stateBeginValue(s *scanner, c int) int {
// stateBeginStringOrEmpty is the state after reading `{`. // stateBeginStringOrEmpty is the state after reading `{`.
func stateBeginStringOrEmpty(s *scanner, c int) int { func stateBeginStringOrEmpty(s *scanner, c int) int {
if c <= ' ' && (c == ' ' || c == '\t' || c == '\r' || c == '\n') { if c <= ' ' && isSpace(rune(c)) {
return scanSkipSpace return scanSkipSpace
} }
if c == '}' { if c == '}' {
...@@ -260,7 +251,7 @@ func stateBeginStringOrEmpty(s *scanner, c int) int { ...@@ -260,7 +251,7 @@ func stateBeginStringOrEmpty(s *scanner, c int) int {
// stateBeginString is the state after reading `{"key": value,`. // stateBeginString is the state after reading `{"key": value,`.
func stateBeginString(s *scanner, c int) int { func stateBeginString(s *scanner, c int) int {
if c <= ' ' && (c == ' ' || c == '\t' || c == '\r' || c == '\n') { if c <= ' ' && isSpace(rune(c)) {
return scanSkipSpace return scanSkipSpace
} }
if c == '"' { if c == '"' {
...@@ -280,7 +271,7 @@ func stateEndValue(s *scanner, c int) int { ...@@ -280,7 +271,7 @@ func stateEndValue(s *scanner, c int) int {
s.endTop = true s.endTop = true
return stateEndTop(s, c) return stateEndTop(s, c)
} }
if c <= ' ' && (c == ' ' || c == '\t' || c == '\r' || c == '\n') { if c <= ' ' && isSpace(rune(c)) {
s.step = stateEndValue s.step = stateEndValue
return scanSkipSpace return scanSkipSpace
} }
......
...@@ -251,7 +251,7 @@ func Encode(out io.Writer, b *Block) (err error) { ...@@ -251,7 +251,7 @@ func Encode(out io.Writer, b *Block) (err error) {
} }
func EncodeToMemory(b *Block) []byte { func EncodeToMemory(b *Block) []byte {
buf := bytes.NewBuffer(nil) var buf bytes.Buffer
Encode(buf, b) Encode(&buf, b)
return buf.Bytes() return buf.Bytes()
} }
...@@ -73,7 +73,7 @@ var lineBreakerTests = []lineBreakerTest{ ...@@ -73,7 +73,7 @@ var lineBreakerTests = []lineBreakerTest{
func TestLineBreaker(t *testing.T) { func TestLineBreaker(t *testing.T) {
for i, test := range lineBreakerTests { for i, test := range lineBreakerTests {
buf := bytes.NewBuffer(nil) buf := new(bytes.Buffer)
var breaker lineBreaker var breaker lineBreaker
breaker.out = buf breaker.out = buf
_, err := breaker.Write([]byte(test.in)) _, err := breaker.Write([]byte(test.in))
...@@ -93,7 +93,7 @@ func TestLineBreaker(t *testing.T) { ...@@ -93,7 +93,7 @@ func TestLineBreaker(t *testing.T) {
} }
for i, test := range lineBreakerTests { for i, test := range lineBreakerTests {
buf := bytes.NewBuffer(nil) buf := new(bytes.Buffer)
var breaker lineBreaker var breaker lineBreaker
breaker.out = buf breaker.out = buf
......
...@@ -532,6 +532,11 @@ var marshalTests = []struct { ...@@ -532,6 +532,11 @@ var marshalTests = []struct {
Value: &NameInField{Name{Space: "ns", Local: "foo"}}, Value: &NameInField{Name{Space: "ns", Local: "foo"}},
ExpectXML: `<NameInField><foo xmlns="ns"></foo></NameInField>`, ExpectXML: `<NameInField><foo xmlns="ns"></foo></NameInField>`,
}, },
{
Value: &NameInField{Name{Space: "ns", Local: "foo"}},
ExpectXML: `<NameInField><foo xmlns="ns"><ignore></ignore></foo></NameInField>`,
UnmarshalOnly: true,
},
// Marshaling zero xml.Name uses the tag or field name. // Marshaling zero xml.Name uses the tag or field name.
{ {
......
...@@ -265,12 +265,13 @@ func (p *Decoder) unmarshal(val reflect.Value, start *StartElement) error { ...@@ -265,12 +265,13 @@ func (p *Decoder) unmarshal(val reflect.Value, start *StartElement) error {
saveData = v saveData = v
case reflect.Struct: case reflect.Struct:
sv = v typ := v.Type()
typ := sv.Type()
if typ == nameType { if typ == nameType {
v.Set(reflect.ValueOf(start.Name)) v.Set(reflect.ValueOf(start.Name))
break break
} }
sv = v
tinfo, err = getTypeInfo(typ) tinfo, err = getTypeInfo(typ)
if err != nil { if err != nil {
return err return err
...@@ -541,19 +542,21 @@ Loop: ...@@ -541,19 +542,21 @@ Loop:
panic("unreachable") panic("unreachable")
} }
// Have already read a start element. // Skip reads tokens until it has consumed the end element
// Read tokens until we find the end element. // matching the most recent start element already consumed.
// Token is taking care of making sure the // It recurs if it encounters a start element, so it can be used to
// end element matches the start element we saw. // skip nested structures.
func (p *Decoder) Skip() error { // It returns nil if it finds an end element matching the start
// element; otherwise it returns an error describing the problem.
func (d *Decoder) Skip() error {
for { for {
tok, err := p.Token() tok, err := d.Token()
if err != nil { if err != nil {
return err return err
} }
switch tok.(type) { switch tok.(type) {
case StartElement: case StartElement:
if err := p.Skip(); err != nil { if err := d.Skip(); err != nil {
return err return err
} }
case EndElement: case EndElement:
......
...@@ -193,7 +193,7 @@ func structFieldInfo(typ reflect.Type, f *reflect.StructField) (*fieldInfo, erro ...@@ -193,7 +193,7 @@ func structFieldInfo(typ reflect.Type, f *reflect.StructField) (*fieldInfo, erro
// If the field type has an XMLName field, the names must match // If the field type has an XMLName field, the names must match
// so that the behavior of both marshalling and unmarshalling // so that the behavior of both marshalling and unmarshalling
// is straighforward and unambiguous. // is straightforward and unambiguous.
if finfo.flags&fElement != 0 { if finfo.flags&fElement != 0 {
ftyp := f.Type ftyp := f.Type
xmlname := lookupXMLName(ftyp) xmlname := lookupXMLName(ftyp)
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
/* /*
Ebnflint verifies that EBNF productions are consistent and gramatically correct. Ebnflint verifies that EBNF productions are consistent and grammatically correct.
It reads them from an HTML document such as the Go specification. It reads them from an HTML document such as the Go specification.
Grammar productions are grouped in boxes demarcated by the HTML elements Grammar productions are grouped in boxes demarcated by the HTML elements
...@@ -13,7 +13,7 @@ Grammar productions are grouped in boxes demarcated by the HTML elements ...@@ -13,7 +13,7 @@ Grammar productions are grouped in boxes demarcated by the HTML elements
Usage: Usage:
ebnflint [--start production] [file] go tool ebnflint [--start production] [file]
The --start flag specifies the name of the start production for The --start flag specifies the name of the start production for
the grammar; it defaults to "Start". the grammar; it defaults to "Start".
......
...@@ -21,7 +21,7 @@ var fset = token.NewFileSet() ...@@ -21,7 +21,7 @@ var fset = token.NewFileSet()
var start = flag.String("start", "Start", "name of start production") var start = flag.String("start", "Start", "name of start production")
func usage() { func usage() {
fmt.Fprintf(os.Stderr, "usage: ebnflint [flags] [filename]\n") fmt.Fprintf(os.Stderr, "usage: go tool ebnflint [flags] [filename]\n")
flag.PrintDefaults() flag.PrintDefaults()
os.Exit(1) os.Exit(1)
} }
......
...@@ -233,8 +233,8 @@ func EscapeString(s string) string { ...@@ -233,8 +233,8 @@ func EscapeString(s string) string {
if strings.IndexAny(s, escapedChars) == -1 { if strings.IndexAny(s, escapedChars) == -1 {
return s return s
} }
buf := bytes.NewBuffer(nil) var buf bytes.Buffer
escape(buf, s) escape(&buf, s)
return buf.String() return buf.String()
} }
......
...@@ -159,9 +159,9 @@ func dump(n *Node) (string, error) { ...@@ -159,9 +159,9 @@ func dump(n *Node) (string, error) {
if n == nil || len(n.Child) == 0 { if n == nil || len(n.Child) == 0 {
return "", nil return "", nil
} }
b := bytes.NewBuffer(nil) var b bytes.Buffer
for _, child := range n.Child { for _, child := range n.Child {
if err := dumpLevel(b, child, 0); err != nil { if err := dumpLevel(&b, child, 0); err != nil {
return "", err return "", err
} }
} }
......
...@@ -77,8 +77,7 @@ func (t Token) tagString() string { ...@@ -77,8 +77,7 @@ func (t Token) tagString() string {
if len(t.Attr) == 0 { if len(t.Attr) == 0 {
return t.Data return t.Data
} }
buf := bytes.NewBuffer(nil) buf := bytes.NewBufferString(t.Data)
buf.WriteString(t.Data)
for _, a := range t.Attr { for _, a := range t.Attr {
buf.WriteByte(' ') buf.WriteByte(' ')
buf.WriteString(a.Key) buf.WriteString(a.Key)
......
...@@ -555,8 +555,8 @@ func TestUnescapeEscape(t *testing.T) { ...@@ -555,8 +555,8 @@ func TestUnescapeEscape(t *testing.T) {
func TestBufAPI(t *testing.T) { func TestBufAPI(t *testing.T) {
s := "0<a>1</a>2<b>3<a>4<a>5</a>6</b>7</a>8<a/>9" s := "0<a>1</a>2<b>3<a>4<a>5</a>6</b>7</a>8<a/>9"
z := NewTokenizer(bytes.NewBuffer([]byte(s))) z := NewTokenizer(bytes.NewBufferString(s))
result := bytes.NewBuffer(nil) var result bytes.Buffer
depth := 0 depth := 0
loop: loop:
for { for {
......
...@@ -107,7 +107,11 @@ func (w *Watcher) AddWatch(path string, flags uint32) error { ...@@ -107,7 +107,11 @@ func (w *Watcher) AddWatch(path string, flags uint32) error {
} }
wd, err := syscall.InotifyAddWatch(w.fd, path, flags) wd, err := syscall.InotifyAddWatch(w.fd, path, flags)
if err != nil { if err != nil {
return &os.PathError{"inotify_add_watch", path, err} return &os.PathError{
Op: "inotify_add_watch",
Path: path,
Err: err,
}
} }
if !found { if !found {
......
...@@ -98,10 +98,10 @@ func (rb *reorderBuffer) insertOrdered(info runeInfo) bool { ...@@ -98,10 +98,10 @@ func (rb *reorderBuffer) insertOrdered(info runeInfo) bool {
func (rb *reorderBuffer) insert(src input, i int, info runeInfo) bool { func (rb *reorderBuffer) insert(src input, i int, info runeInfo) bool {
if info.size == 3 { if info.size == 3 {
if rune := src.hangul(i); rune != 0 { if rune := src.hangul(i); rune != 0 {
return rb.decomposeHangul(uint32(rune)) return rb.decomposeHangul(rune)
} }
} }
if info.flags.hasDecomposition() { if info.hasDecomposition() {
dcomp := rb.f.decompose(src, i) dcomp := rb.f.decompose(src, i)
rb.tmpBytes = inputBytes(dcomp) rb.tmpBytes = inputBytes(dcomp)
for i := 0; i < len(dcomp); { for i := 0; i < len(dcomp); {
...@@ -126,26 +126,26 @@ func (rb *reorderBuffer) insert(src input, i int, info runeInfo) bool { ...@@ -126,26 +126,26 @@ func (rb *reorderBuffer) insert(src input, i int, info runeInfo) bool {
} }
// appendRune inserts a rune at the end of the buffer. It is used for Hangul. // appendRune inserts a rune at the end of the buffer. It is used for Hangul.
func (rb *reorderBuffer) appendRune(r uint32) { func (rb *reorderBuffer) appendRune(r rune) {
bn := rb.nbyte bn := rb.nbyte
sz := utf8.EncodeRune(rb.byte[bn:], rune(r)) sz := utf8.EncodeRune(rb.byte[bn:], rune(r))
rb.nbyte += utf8.UTFMax rb.nbyte += utf8.UTFMax
rb.rune[rb.nrune] = runeInfo{bn, uint8(sz), 0, 0} rb.rune[rb.nrune] = runeInfo{pos: bn, size: uint8(sz)}
rb.nrune++ rb.nrune++
} }
// assignRune sets a rune at position pos. It is used for Hangul and recomposition. // assignRune sets a rune at position pos. It is used for Hangul and recomposition.
func (rb *reorderBuffer) assignRune(pos int, r uint32) { func (rb *reorderBuffer) assignRune(pos int, r rune) {
bn := rb.rune[pos].pos bn := rb.rune[pos].pos
sz := utf8.EncodeRune(rb.byte[bn:], rune(r)) sz := utf8.EncodeRune(rb.byte[bn:], rune(r))
rb.rune[pos] = runeInfo{bn, uint8(sz), 0, 0} rb.rune[pos] = runeInfo{pos: bn, size: uint8(sz)}
} }
// runeAt returns the rune at position n. It is used for Hangul and recomposition. // runeAt returns the rune at position n. It is used for Hangul and recomposition.
func (rb *reorderBuffer) runeAt(n int) uint32 { func (rb *reorderBuffer) runeAt(n int) rune {
inf := rb.rune[n] inf := rb.rune[n]
r, _ := utf8.DecodeRune(rb.byte[inf.pos : inf.pos+inf.size]) r, _ := utf8.DecodeRune(rb.byte[inf.pos : inf.pos+inf.size])
return uint32(r) return r
} }
// bytesAt returns the UTF-8 encoding of the rune at position n. // bytesAt returns the UTF-8 encoding of the rune at position n.
...@@ -237,7 +237,7 @@ func isHangulWithoutJamoT(b []byte) bool { ...@@ -237,7 +237,7 @@ func isHangulWithoutJamoT(b []byte) bool {
// decomposeHangul algorithmically decomposes a Hangul rune into // decomposeHangul algorithmically decomposes a Hangul rune into
// its Jamo components. // its Jamo components.
// See http://unicode.org/reports/tr15/#Hangul for details on decomposing Hangul. // See http://unicode.org/reports/tr15/#Hangul for details on decomposing Hangul.
func (rb *reorderBuffer) decomposeHangul(r uint32) bool { func (rb *reorderBuffer) decomposeHangul(r rune) bool {
b := rb.rune[:] b := rb.rune[:]
n := rb.nrune n := rb.nrune
if n+3 > len(b) { if n+3 > len(b) {
...@@ -319,7 +319,7 @@ func (rb *reorderBuffer) compose() { ...@@ -319,7 +319,7 @@ func (rb *reorderBuffer) compose() {
// get the info for the combined character. This is more // get the info for the combined character. This is more
// expensive than using the filter. Using combinesBackward() // expensive than using the filter. Using combinesBackward()
// is safe. // is safe.
if ii.flags.combinesBackward() { if ii.combinesBackward() {
cccB := b[k-1].ccc cccB := b[k-1].ccc
cccC := ii.ccc cccC := ii.ccc
blocked := false // b[i] blocked by starter or greater or equal CCC? blocked := false // b[i] blocked by starter or greater or equal CCC?
......
...@@ -14,7 +14,6 @@ type runeInfo struct { ...@@ -14,7 +14,6 @@ type runeInfo struct {
} }
// functions dispatchable per form // functions dispatchable per form
type boundaryFunc func(f *formInfo, info runeInfo) bool
type lookupFunc func(b input, i int) runeInfo type lookupFunc func(b input, i int) runeInfo
type decompFunc func(b input, i int) []byte type decompFunc func(b input, i int) []byte
...@@ -24,10 +23,8 @@ type formInfo struct { ...@@ -24,10 +23,8 @@ type formInfo struct {
composing, compatibility bool // form type composing, compatibility bool // form type
decompose decompFunc decompose decompFunc
info lookupFunc info lookupFunc
boundaryBefore boundaryFunc
boundaryAfter boundaryFunc
} }
var formTable []*formInfo var formTable []*formInfo
...@@ -49,27 +46,17 @@ func init() { ...@@ -49,27 +46,17 @@ func init() {
} }
if Form(i) == NFC || Form(i) == NFKC { if Form(i) == NFC || Form(i) == NFKC {
f.composing = true f.composing = true
f.boundaryBefore = compBoundaryBefore
f.boundaryAfter = compBoundaryAfter
} else {
f.boundaryBefore = decompBoundary
f.boundaryAfter = decompBoundary
} }
} }
} }
func decompBoundary(f *formInfo, info runeInfo) bool { // We do not distinguish between boundaries for NFC, NFD, etc. to avoid
if info.ccc == 0 && info.flags.isYesD() { // Implies isHangul(b) == true // unexpected behavior for the user. For example, in NFD, there is a boundary
return true // after 'a'. However, a might combine with modifiers, so from the application's
} // perspective it is not a good boundary. We will therefore always use the
// We assume that the CCC of the first character in a decomposition // boundaries for the combining variants.
// is always non-zero if different from info.ccc and that we can return func (i runeInfo) boundaryBefore() bool {
// false at this point. This is verified by maketables. if i.ccc == 0 && !i.combinesBackward() {
return false
}
func compBoundaryBefore(f *formInfo, info runeInfo) bool {
if info.ccc == 0 && !info.flags.combinesBackward() {
return true return true
} }
// We assume that the CCC of the first character in a decomposition // We assume that the CCC of the first character in a decomposition
...@@ -78,15 +65,13 @@ func compBoundaryBefore(f *formInfo, info runeInfo) bool { ...@@ -78,15 +65,13 @@ func compBoundaryBefore(f *formInfo, info runeInfo) bool {
return false return false
} }
func compBoundaryAfter(f *formInfo, info runeInfo) bool { func (i runeInfo) boundaryAfter() bool {
// This misses values where the last char in a decomposition is a return i.isInert()
// boundary such as Hangul with JamoT.
return info.isInert()
} }
// We pack quick check data in 4 bits: // We pack quick check data in 4 bits:
// 0: NFD_QC Yes (0) or No (1). No also means there is a decomposition. // 0: NFD_QC Yes (0) or No (1). No also means there is a decomposition.
// 1..2: NFC_QC Yes(00), No (01), or Maybe (11) // 1..2: NFC_QC Yes(00), No (10), or Maybe (11)
// 3: Combines forward (0 == false, 1 == true) // 3: Combines forward (0 == false, 1 == true)
// //
// When all 4 bits are zero, the character is inert, meaning it is never // When all 4 bits are zero, the character is inert, meaning it is never
...@@ -95,15 +80,12 @@ func compBoundaryAfter(f *formInfo, info runeInfo) bool { ...@@ -95,15 +80,12 @@ func compBoundaryAfter(f *formInfo, info runeInfo) bool {
// We pack the bits for both NFC/D and NFKC/D in one byte. // We pack the bits for both NFC/D and NFKC/D in one byte.
type qcInfo uint8 type qcInfo uint8
func (i qcInfo) isYesC() bool { return i&0x2 == 0 } func (i runeInfo) isYesC() bool { return i.flags&0x4 == 0 }
func (i qcInfo) isNoC() bool { return i&0x6 == 0x2 } func (i runeInfo) isYesD() bool { return i.flags&0x1 == 0 }
func (i qcInfo) isMaybe() bool { return i&0x4 != 0 }
func (i qcInfo) isYesD() bool { return i&0x1 == 0 }
func (i qcInfo) isNoD() bool { return i&0x1 != 0 }
func (i qcInfo) combinesForward() bool { return i&0x8 != 0 } func (i runeInfo) combinesForward() bool { return i.flags&0x8 != 0 }
func (i qcInfo) combinesBackward() bool { return i&0x4 != 0 } // == isMaybe func (i runeInfo) combinesBackward() bool { return i.flags&0x2 != 0 } // == isMaybe
func (i qcInfo) hasDecomposition() bool { return i&0x1 != 0 } // == isNoD func (i runeInfo) hasDecomposition() bool { return i.flags&0x1 != 0 } // == isNoD
func (r runeInfo) isInert() bool { func (r runeInfo) isInert() bool {
return r.flags&0xf == 0 && r.ccc == 0 return r.flags&0xf == 0 && r.ccc == 0
...@@ -111,7 +93,7 @@ func (r runeInfo) isInert() bool { ...@@ -111,7 +93,7 @@ func (r runeInfo) isInert() bool {
// Wrappers for tables.go // Wrappers for tables.go
// The 16-bit value of the decompostion tries is an index into a byte // The 16-bit value of the decomposition tries is an index into a byte
// array of UTF-8 decomposition sequences. The first byte is the number // array of UTF-8 decomposition sequences. The first byte is the number
// of bytes in the decomposition (excluding this length byte). The actual // of bytes in the decomposition (excluding this length byte). The actual
// sequence starts at the offset+1. // sequence starts at the offset+1.
...@@ -137,7 +119,7 @@ func decomposeNFKC(s input, i int) []byte { ...@@ -137,7 +119,7 @@ func decomposeNFKC(s input, i int) []byte {
// Note that the recomposition map for NFC and NFKC are identical. // Note that the recomposition map for NFC and NFKC are identical.
// combine returns the combined rune or 0 if it doesn't exist. // combine returns the combined rune or 0 if it doesn't exist.
func combine(a, b uint32) uint32 { func combine(a, b rune) rune {
key := uint32(uint16(a))<<16 + uint32(uint16(b)) key := uint32(uint16(a))<<16 + uint32(uint16(b))
return recompMap[key] return recompMap[key]
} }
...@@ -148,10 +130,10 @@ func combine(a, b uint32) uint32 { ...@@ -148,10 +130,10 @@ func combine(a, b uint32) uint32 {
// 12..15 qcInfo for NFKC/NFKD // 12..15 qcInfo for NFKC/NFKD
func lookupInfoNFC(b input, i int) runeInfo { func lookupInfoNFC(b input, i int) runeInfo {
v, sz := b.charinfo(i) v, sz := b.charinfo(i)
return runeInfo{0, uint8(sz), uint8(v), qcInfo(v >> 8)} return runeInfo{size: uint8(sz), ccc: uint8(v), flags: qcInfo(v >> 8)}
} }
func lookupInfoNFKC(b input, i int) runeInfo { func lookupInfoNFKC(b input, i int) runeInfo {
v, sz := b.charinfo(i) v, sz := b.charinfo(i)
return runeInfo{0, uint8(sz), uint8(v), qcInfo(v >> 12)} return runeInfo{size: uint8(sz), ccc: uint8(v), flags: qcInfo(v >> 12)}
} }
...@@ -14,7 +14,7 @@ type input interface { ...@@ -14,7 +14,7 @@ type input interface {
charinfo(p int) (uint16, int) charinfo(p int) (uint16, int)
decomposeNFC(p int) uint16 decomposeNFC(p int) uint16
decomposeNFKC(p int) uint16 decomposeNFKC(p int) uint16
hangul(p int) uint32 hangul(p int) rune
} }
type inputString string type inputString string
...@@ -54,12 +54,12 @@ func (s inputString) decomposeNFKC(p int) uint16 { ...@@ -54,12 +54,12 @@ func (s inputString) decomposeNFKC(p int) uint16 {
return nfkcDecompTrie.lookupStringUnsafe(string(s[p:])) return nfkcDecompTrie.lookupStringUnsafe(string(s[p:]))
} }
func (s inputString) hangul(p int) uint32 { func (s inputString) hangul(p int) rune {
if !isHangulString(string(s[p:])) { if !isHangulString(string(s[p:])) {
return 0 return 0
} }
rune, _ := utf8.DecodeRuneInString(string(s[p:])) rune, _ := utf8.DecodeRuneInString(string(s[p:]))
return uint32(rune) return rune
} }
type inputBytes []byte type inputBytes []byte
...@@ -96,10 +96,10 @@ func (s inputBytes) decomposeNFKC(p int) uint16 { ...@@ -96,10 +96,10 @@ func (s inputBytes) decomposeNFKC(p int) uint16 {
return nfkcDecompTrie.lookupUnsafe(s[p:]) return nfkcDecompTrie.lookupUnsafe(s[p:])
} }
func (s inputBytes) hangul(p int) uint32 { func (s inputBytes) hangul(p int) rune {
if !isHangul(s[p:]) { if !isHangul(s[p:]) {
return 0 return 0
} }
rune, _ := utf8.DecodeRune(s[p:]) rune, _ := utf8.DecodeRune(s[p:])
return uint32(rune) return rune
} }
...@@ -562,7 +562,7 @@ func makeEntry(f *FormInfo) uint16 { ...@@ -562,7 +562,7 @@ func makeEntry(f *FormInfo) uint16 {
switch f.quickCheck[MComposed] { switch f.quickCheck[MComposed] {
case QCYes: case QCYes:
case QCNo: case QCNo:
e |= 0x2 e |= 0x4
case QCMaybe: case QCMaybe:
e |= 0x6 e |= 0x6
default: default:
...@@ -718,7 +718,7 @@ func makeTables() { ...@@ -718,7 +718,7 @@ func makeTables() {
sz := nrentries * 8 sz := nrentries * 8
size += sz size += sz
fmt.Printf("// recompMap: %d bytes (entries only)\n", sz) fmt.Printf("// recompMap: %d bytes (entries only)\n", sz)
fmt.Println("var recompMap = map[uint32]uint32{") fmt.Println("var recompMap = map[uint32]rune{")
for i, c := range chars { for i, c := range chars {
f := c.forms[FCanonical] f := c.forms[FCanonical]
d := f.decomp d := f.decomp
......
...@@ -188,11 +188,11 @@ func doAppend(rb *reorderBuffer, out []byte, p int) []byte { ...@@ -188,11 +188,11 @@ func doAppend(rb *reorderBuffer, out []byte, p int) []byte {
var info runeInfo var info runeInfo
if p < n { if p < n {
info = fd.info(src, p) info = fd.info(src, p)
if p == 0 && !fd.boundaryBefore(fd, info) { if p == 0 && !info.boundaryBefore() {
out = decomposeToLastBoundary(rb, out) out = decomposeToLastBoundary(rb, out)
} }
} }
if info.size == 0 || fd.boundaryBefore(fd, info) { if info.size == 0 || info.boundaryBefore() {
if fd.composing { if fd.composing {
rb.compose() rb.compose()
} }
...@@ -257,11 +257,11 @@ func quickSpan(rb *reorderBuffer, i int) int { ...@@ -257,11 +257,11 @@ func quickSpan(rb *reorderBuffer, i int) int {
} }
cc := info.ccc cc := info.ccc
if rb.f.composing { if rb.f.composing {
if !info.flags.isYesC() { if !info.isYesC() {
break break
} }
} else { } else {
if !info.flags.isYesD() { if !info.isYesD() {
break break
} }
} }
...@@ -316,13 +316,13 @@ func firstBoundary(rb *reorderBuffer) int { ...@@ -316,13 +316,13 @@ func firstBoundary(rb *reorderBuffer) int {
} }
fd := &rb.f fd := &rb.f
info := fd.info(src, i) info := fd.info(src, i)
for n := 0; info.size != 0 && !fd.boundaryBefore(fd, info); { for n := 0; info.size != 0 && !info.boundaryBefore(); {
i += int(info.size) i += int(info.size)
if n++; n >= maxCombiningChars { if n++; n >= maxCombiningChars {
return i return i
} }
if i >= nsrc { if i >= nsrc {
if !fd.boundaryAfter(fd, info) { if !info.boundaryAfter() {
return -1 return -1
} }
return nsrc return nsrc
...@@ -368,11 +368,11 @@ func lastBoundary(fd *formInfo, b []byte) int { ...@@ -368,11 +368,11 @@ func lastBoundary(fd *formInfo, b []byte) int {
if p+int(info.size) != i { // trailing non-starter bytes: illegal UTF-8 if p+int(info.size) != i { // trailing non-starter bytes: illegal UTF-8
return i return i
} }
if fd.boundaryAfter(fd, info) { if info.boundaryAfter() {
return i return i
} }
i = p i = p
for n := 0; i >= 0 && !fd.boundaryBefore(fd, info); { for n := 0; i >= 0 && !info.boundaryBefore(); {
info, p = lastRuneStart(fd, b[:i]) info, p = lastRuneStart(fd, b[:i])
if n++; n >= maxCombiningChars { if n++; n >= maxCombiningChars {
return len(b) return len(b)
...@@ -404,7 +404,7 @@ func decomposeSegment(rb *reorderBuffer, sp int) int { ...@@ -404,7 +404,7 @@ func decomposeSegment(rb *reorderBuffer, sp int) int {
break break
} }
info = rb.f.info(rb.src, sp) info = rb.f.info(rb.src, sp)
bound := rb.f.boundaryBefore(&rb.f, info) bound := info.boundaryBefore()
if bound || info.size == 0 { if bound || info.size == 0 {
break break
} }
...@@ -419,7 +419,7 @@ func lastRuneStart(fd *formInfo, buf []byte) (runeInfo, int) { ...@@ -419,7 +419,7 @@ func lastRuneStart(fd *formInfo, buf []byte) (runeInfo, int) {
for ; p >= 0 && !utf8.RuneStart(buf[p]); p-- { for ; p >= 0 && !utf8.RuneStart(buf[p]); p-- {
} }
if p < 0 { if p < 0 {
return runeInfo{0, 0, 0, 0}, -1 return runeInfo{}, -1
} }
return fd.info(inputBytes(buf), p), p return fd.info(inputBytes(buf), p), p
} }
...@@ -433,7 +433,7 @@ func decomposeToLastBoundary(rb *reorderBuffer, buf []byte) []byte { ...@@ -433,7 +433,7 @@ func decomposeToLastBoundary(rb *reorderBuffer, buf []byte) []byte {
// illegal trailing continuation bytes // illegal trailing continuation bytes
return buf return buf
} }
if rb.f.boundaryAfter(fd, info) { if info.boundaryAfter() {
return buf return buf
} }
var add [maxBackRunes]runeInfo // stores runeInfo in reverse order var add [maxBackRunes]runeInfo // stores runeInfo in reverse order
...@@ -441,13 +441,13 @@ func decomposeToLastBoundary(rb *reorderBuffer, buf []byte) []byte { ...@@ -441,13 +441,13 @@ func decomposeToLastBoundary(rb *reorderBuffer, buf []byte) []byte {
padd := 1 padd := 1
n := 1 n := 1
p := len(buf) - int(info.size) p := len(buf) - int(info.size)
for ; p >= 0 && !rb.f.boundaryBefore(fd, info); p -= int(info.size) { for ; p >= 0 && !info.boundaryBefore(); p -= int(info.size) {
info, i = lastRuneStart(fd, buf[:p]) info, i = lastRuneStart(fd, buf[:p])
if int(info.size) != p-i { if int(info.size) != p-i {
break break
} }
// Check that decomposition doesn't result in overflow. // Check that decomposition doesn't result in overflow.
if info.flags.hasDecomposition() { if info.hasDecomposition() {
dcomp := rb.f.decompose(inputBytes(buf), p-int(info.size)) dcomp := rb.f.decompose(inputBytes(buf), p-int(info.size))
for i := 0; i < len(dcomp); { for i := 0; i < len(dcomp); {
inf := rb.f.info(inputBytes(dcomp), i) inf := rb.f.info(inputBytes(dcomp), i)
......
...@@ -495,11 +495,11 @@ func TestAppend(t *testing.T) { ...@@ -495,11 +495,11 @@ func TestAppend(t *testing.T) {
runAppendTests(t, "TestString", NFKC, stringF, appendTests) runAppendTests(t, "TestString", NFKC, stringF, appendTests)
} }
func doFormBenchmark(b *testing.B, f Form, s string) { func doFormBenchmark(b *testing.B, inf, f Form, s string) {
b.StopTimer() b.StopTimer()
in := []byte(s) in := inf.Bytes([]byte(s))
buf := make([]byte, 2*len(in)) buf := make([]byte, 2*len(in))
b.SetBytes(int64(len(s))) b.SetBytes(int64(len(in)))
b.StartTimer() b.StartTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
buf = f.Append(buf[0:0], in...) buf = f.Append(buf[0:0], in...)
...@@ -510,16 +510,43 @@ func doFormBenchmark(b *testing.B, f Form, s string) { ...@@ -510,16 +510,43 @@ func doFormBenchmark(b *testing.B, f Form, s string) {
var ascii = strings.Repeat("There is nothing to change here! ", 500) var ascii = strings.Repeat("There is nothing to change here! ", 500)
func BenchmarkNormalizeAsciiNFC(b *testing.B) { func BenchmarkNormalizeAsciiNFC(b *testing.B) {
doFormBenchmark(b, NFC, ascii) doFormBenchmark(b, NFC, NFC, ascii)
} }
func BenchmarkNormalizeAsciiNFD(b *testing.B) { func BenchmarkNormalizeAsciiNFD(b *testing.B) {
doFormBenchmark(b, NFD, ascii) doFormBenchmark(b, NFC, NFD, ascii)
} }
func BenchmarkNormalizeAsciiNFKC(b *testing.B) { func BenchmarkNormalizeAsciiNFKC(b *testing.B) {
doFormBenchmark(b, NFKC, ascii) doFormBenchmark(b, NFC, NFKC, ascii)
} }
func BenchmarkNormalizeAsciiNFKD(b *testing.B) { func BenchmarkNormalizeAsciiNFKD(b *testing.B) {
doFormBenchmark(b, NFKD, ascii) doFormBenchmark(b, NFC, NFKD, ascii)
}
func BenchmarkNormalizeNFC2NFC(b *testing.B) {
doFormBenchmark(b, NFC, NFC, txt_all)
}
func BenchmarkNormalizeNFC2NFD(b *testing.B) {
doFormBenchmark(b, NFC, NFD, txt_all)
}
func BenchmarkNormalizeNFD2NFC(b *testing.B) {
doFormBenchmark(b, NFD, NFC, txt_all)
}
func BenchmarkNormalizeNFD2NFD(b *testing.B) {
doFormBenchmark(b, NFD, NFD, txt_all)
}
// Hangul is often special-cased, so we test it separately.
func BenchmarkNormalizeHangulNFC2NFC(b *testing.B) {
doFormBenchmark(b, NFC, NFC, txt_kr)
}
func BenchmarkNormalizeHangulNFC2NFD(b *testing.B) {
doFormBenchmark(b, NFC, NFD, txt_kr)
}
func BenchmarkNormalizeHangulNFD2NFC(b *testing.B) {
doFormBenchmark(b, NFD, NFC, txt_kr)
}
func BenchmarkNormalizeHangulNFD2NFD(b *testing.B) {
doFormBenchmark(b, NFD, NFD, txt_kr)
} }
func doTextBenchmark(b *testing.B, s string) { func doTextBenchmark(b *testing.B, s string) {
...@@ -657,3 +684,6 @@ const txt_cn = `您可以自由: 复制、发行、展览、表演、放映、 ...@@ -657,3 +684,6 @@ const txt_cn = `您可以自由: 复制、发行、展览、表演、放映、
署名 — 您必须按照作者或者许可人指定的方式对作品进行署名。 署名 — 您必须按照作者或者许可人指定的方式对作品进行署名。
相同方式共享 — 如果您改变、转换本作品或者以本作品为基础进行创作, 相同方式共享 — 如果您改变、转换本作品或者以本作品为基础进行创作,
您只能采用与本协议相同的许可协议发布基于本作品的演绎作品。` 您只能采用与本协议相同的许可协议发布基于本作品的演绎作品。`
const txt_cjk = txt_cn + txt_jp + txt_kr
const txt_all = txt_vn + twoByteUtf8 + threeByteUtf8 + txt_cjk
...@@ -12,11 +12,13 @@ import ( ...@@ -12,11 +12,13 @@ import (
"testing" "testing"
) )
const sighup = os.UnixSignal(syscall.SIGHUP)
func TestSignal(t *testing.T) { func TestSignal(t *testing.T) {
// Send this process a SIGHUP. // Send this process a SIGHUP.
syscall.Syscall(syscall.SYS_KILL, uintptr(syscall.Getpid()), syscall.SIGHUP, 0) syscall.Syscall(syscall.SYS_KILL, uintptr(syscall.Getpid()), syscall.SIGHUP, 0)
if sig := (<-Incoming).(os.UnixSignal); sig != os.SIGHUP { if sig := (<-Incoming).(os.UnixSignal); sig != sighup {
t.Errorf("signal was %v, want %v", sig, os.SIGHUP) t.Errorf("signal was %v, want %v", sig, sighup)
} }
} }
...@@ -31,7 +31,7 @@ func init() { ...@@ -31,7 +31,7 @@ func init() {
gcPath = gcName gcPath = gcName
return return
} }
gcPath, _ = exec.LookPath(gcName) gcPath = filepath.Join(runtime.GOROOT(), "/bin/tool/", gcName)
} }
func compile(t *testing.T, dirname, filename string) { func compile(t *testing.T, dirname, filename string) {
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
// //
// The package is sometimes only imported for the side effect of // The package is sometimes only imported for the side effect of
// registering its HTTP handler and the above variables. To use it // registering its HTTP handler and the above variables. To use it
// this way, simply link this package into your program: // this way, link this package into your program:
// import _ "expvar" // import _ "expvar"
// //
package expvar package expvar
...@@ -83,7 +83,7 @@ func (v *Float) Set(value float64) { ...@@ -83,7 +83,7 @@ func (v *Float) Set(value float64) {
// Map is a string-to-Var map variable that satisfies the Var interface. // Map is a string-to-Var map variable that satisfies the Var interface.
type Map struct { type Map struct {
m map[string]Var m map[string]Var
mu sync.Mutex mu sync.RWMutex
} }
// KeyValue represents a single entry in a Map. // KeyValue represents a single entry in a Map.
...@@ -93,8 +93,8 @@ type KeyValue struct { ...@@ -93,8 +93,8 @@ type KeyValue struct {
} }
func (v *Map) String() string { func (v *Map) String() string {
v.mu.Lock() v.mu.RLock()
defer v.mu.Unlock() defer v.mu.RUnlock()
b := new(bytes.Buffer) b := new(bytes.Buffer)
fmt.Fprintf(b, "{") fmt.Fprintf(b, "{")
first := true first := true
...@@ -115,8 +115,8 @@ func (v *Map) Init() *Map { ...@@ -115,8 +115,8 @@ func (v *Map) Init() *Map {
} }
func (v *Map) Get(key string) Var { func (v *Map) Get(key string) Var {
v.mu.Lock() v.mu.RLock()
defer v.mu.Unlock() defer v.mu.RUnlock()
return v.m[key] return v.m[key]
} }
...@@ -127,12 +127,17 @@ func (v *Map) Set(key string, av Var) { ...@@ -127,12 +127,17 @@ func (v *Map) Set(key string, av Var) {
} }
func (v *Map) Add(key string, delta int64) { func (v *Map) Add(key string, delta int64) {
v.mu.Lock() v.mu.RLock()
defer v.mu.Unlock()
av, ok := v.m[key] av, ok := v.m[key]
v.mu.RUnlock()
if !ok { if !ok {
av = new(Int) // check again under the write lock
v.m[key] = av v.mu.Lock()
if _, ok = v.m[key]; !ok {
av = new(Int)
v.m[key] = av
}
v.mu.Unlock()
} }
// Add to Int; ignore otherwise. // Add to Int; ignore otherwise.
...@@ -143,12 +148,17 @@ func (v *Map) Add(key string, delta int64) { ...@@ -143,12 +148,17 @@ func (v *Map) Add(key string, delta int64) {
// AddFloat adds delta to the *Float value stored under the given map key. // AddFloat adds delta to the *Float value stored under the given map key.
func (v *Map) AddFloat(key string, delta float64) { func (v *Map) AddFloat(key string, delta float64) {
v.mu.Lock() v.mu.RLock()
defer v.mu.Unlock()
av, ok := v.m[key] av, ok := v.m[key]
v.mu.RUnlock()
if !ok { if !ok {
av = new(Float) // check again under the write lock
v.m[key] = av v.mu.Lock()
if _, ok = v.m[key]; !ok {
av = new(Float)
v.m[key] = av
}
v.mu.Unlock()
} }
// Add to Float; ignore otherwise. // Add to Float; ignore otherwise.
...@@ -157,18 +167,15 @@ func (v *Map) AddFloat(key string, delta float64) { ...@@ -157,18 +167,15 @@ func (v *Map) AddFloat(key string, delta float64) {
} }
} }
// TODO(rsc): Make sure map access in separate thread is safe. // Do calls f for each entry in the map.
func (v *Map) iterate(c chan<- KeyValue) { // The map is locked during the iteration,
// but existing entries may be concurrently updated.
func (v *Map) Do(f func(KeyValue)) {
v.mu.RLock()
defer v.mu.RUnlock()
for k, v := range v.m { for k, v := range v.m {
c <- KeyValue{k, v} f(KeyValue{k, v})
} }
close(c)
}
func (v *Map) Iter() <-chan KeyValue {
c := make(chan KeyValue)
go v.iterate(c)
return c
} }
// String is a string variable, and satisfies the Var interface. // String is a string variable, and satisfies the Var interface.
...@@ -190,8 +197,10 @@ func (f Func) String() string { ...@@ -190,8 +197,10 @@ func (f Func) String() string {
} }
// All published variables. // All published variables.
var vars map[string]Var = make(map[string]Var) var (
var mutex sync.Mutex mutex sync.RWMutex
vars map[string]Var = make(map[string]Var)
)
// Publish declares a named exported variable. This should be called from a // Publish declares a named exported variable. This should be called from a
// package's init function when it creates its Vars. If the name is already // package's init function when it creates its Vars. If the name is already
...@@ -207,17 +216,11 @@ func Publish(name string, v Var) { ...@@ -207,17 +216,11 @@ func Publish(name string, v Var) {
// Get retrieves a named exported variable. // Get retrieves a named exported variable.
func Get(name string) Var { func Get(name string) Var {
mutex.RLock()
defer mutex.RUnlock()
return vars[name] return vars[name]
} }
// RemoveAll removes all exported variables.
// This is for tests; don't call this on a real server.
func RemoveAll() {
mutex.Lock()
defer mutex.Unlock()
vars = make(map[string]Var)
}
// Convenience functions for creating new exported variables. // Convenience functions for creating new exported variables.
func NewInt(name string) *Int { func NewInt(name string) *Int {
...@@ -244,31 +247,28 @@ func NewString(name string) *String { ...@@ -244,31 +247,28 @@ func NewString(name string) *String {
return v return v
} }
// TODO(rsc): Make sure map access in separate thread is safe. // Do calls f for each exported variable.
func iterate(c chan<- KeyValue) { // The global variable map is locked during the iteration,
// but existing entries may be concurrently updated.
func Do(f func(KeyValue)) {
mutex.RLock()
defer mutex.RUnlock()
for k, v := range vars { for k, v := range vars {
c <- KeyValue{k, v} f(KeyValue{k, v})
} }
close(c)
}
func Iter() <-chan KeyValue {
c := make(chan KeyValue)
go iterate(c)
return c
} }
func expvarHandler(w http.ResponseWriter, r *http.Request) { func expvarHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Header().Set("Content-Type", "application/json; charset=utf-8")
fmt.Fprintf(w, "{\n") fmt.Fprintf(w, "{\n")
first := true first := true
for name, value := range vars { Do(func(kv KeyValue) {
if !first { if !first {
fmt.Fprintf(w, ",\n") fmt.Fprintf(w, ",\n")
} }
first = false first = false
fmt.Fprintf(w, "%q: %s", name, value) fmt.Fprintf(w, "%q: %s", kv.Key, kv.Value)
} })
fmt.Fprintf(w, "\n}\n") fmt.Fprintf(w, "\n}\n")
} }
...@@ -277,11 +277,13 @@ func cmdline() interface{} { ...@@ -277,11 +277,13 @@ func cmdline() interface{} {
} }
func memstats() interface{} { func memstats() interface{} {
return runtime.MemStats stats := new(runtime.MemStats)
runtime.ReadMemStats(stats)
return *stats
} }
func init() { func init() {
http.Handle("/debug/vars", http.HandlerFunc(expvarHandler)) http.HandleFunc("/debug/vars", expvarHandler)
Publish("cmdline", Func(cmdline)) Publish("cmdline", Func(cmdline))
Publish("memstats", Func(memstats)) Publish("memstats", Func(memstats))
} }
...@@ -9,6 +9,14 @@ import ( ...@@ -9,6 +9,14 @@ import (
"testing" "testing"
) )
// RemoveAll removes all exported variables.
// This is for tests only.
func RemoveAll() {
mutex.Lock()
defer mutex.Unlock()
vars = make(map[string]Var)
}
func TestInt(t *testing.T) { func TestInt(t *testing.T) {
reqs := NewInt("requests") reqs := NewInt("requests")
if reqs.i != 0 { if reqs.i != 0 {
......
...@@ -49,6 +49,7 @@ ...@@ -49,6 +49,7 @@
Integer flags accept 1234, 0664, 0x1234 and may be negative. Integer flags accept 1234, 0664, 0x1234 and may be negative.
Boolean flags may be 1, 0, t, f, true, false, TRUE, FALSE, True, False. Boolean flags may be 1, 0, t, f, true, false, TRUE, FALSE, True, False.
Duration flags accept any input valid for time.ParseDuration.
The default set of command-line flags is controlled by The default set of command-line flags is controlled by
top-level functions. The FlagSet type allows one to define top-level functions. The FlagSet type allows one to define
...@@ -62,6 +63,7 @@ package flag ...@@ -62,6 +63,7 @@ package flag
import ( import (
"errors" "errors"
"fmt" "fmt"
"io"
"os" "os"
"sort" "sort"
"strconv" "strconv"
...@@ -228,6 +230,7 @@ type FlagSet struct { ...@@ -228,6 +230,7 @@ type FlagSet struct {
args []string // arguments after flags args []string // arguments after flags
exitOnError bool // does the program exit if there's an error? exitOnError bool // does the program exit if there's an error?
errorHandling ErrorHandling errorHandling ErrorHandling
output io.Writer // nil means stderr; use out() accessor
} }
// A Flag represents the state of a flag. // A Flag represents the state of a flag.
...@@ -254,6 +257,19 @@ func sortFlags(flags map[string]*Flag) []*Flag { ...@@ -254,6 +257,19 @@ func sortFlags(flags map[string]*Flag) []*Flag {
return result return result
} }
func (f *FlagSet) out() io.Writer {
if f.output == nil {
return os.Stderr
}
return f.output
}
// SetOutput sets the destination for usage and error messages.
// If output is nil, os.Stderr is used.
func (f *FlagSet) SetOutput(output io.Writer) {
f.output = output
}
// VisitAll visits the flags in lexicographical order, calling fn for each. // VisitAll visits the flags in lexicographical order, calling fn for each.
// It visits all flags, even those not set. // It visits all flags, even those not set.
func (f *FlagSet) VisitAll(fn func(*Flag)) { func (f *FlagSet) VisitAll(fn func(*Flag)) {
...@@ -315,15 +331,16 @@ func Set(name, value string) error { ...@@ -315,15 +331,16 @@ func Set(name, value string) error {
return commandLine.Set(name, value) return commandLine.Set(name, value)
} }
// PrintDefaults prints to standard error the default values of all defined flags in the set. // PrintDefaults prints, to standard error unless configured
// otherwise, the default values of all defined flags in the set.
func (f *FlagSet) PrintDefaults() { func (f *FlagSet) PrintDefaults() {
f.VisitAll(func(f *Flag) { f.VisitAll(func(flag *Flag) {
format := " -%s=%s: %s\n" format := " -%s=%s: %s\n"
if _, ok := f.Value.(*stringValue); ok { if _, ok := flag.Value.(*stringValue); ok {
// put quotes on the value // put quotes on the value
format = " -%s=%q: %s\n" format = " -%s=%q: %s\n"
} }
fmt.Fprintf(os.Stderr, format, f.Name, f.DefValue, f.Usage) fmt.Fprintf(f.out(), format, flag.Name, flag.DefValue, flag.Usage)
}) })
} }
...@@ -334,7 +351,7 @@ func PrintDefaults() { ...@@ -334,7 +351,7 @@ func PrintDefaults() {
// defaultUsage is the default function to print a usage message. // defaultUsage is the default function to print a usage message.
func defaultUsage(f *FlagSet) { func defaultUsage(f *FlagSet) {
fmt.Fprintf(os.Stderr, "Usage of %s:\n", f.name) fmt.Fprintf(f.out(), "Usage of %s:\n", f.name)
f.PrintDefaults() f.PrintDefaults()
} }
...@@ -601,7 +618,7 @@ func (f *FlagSet) Var(value Value, name string, usage string) { ...@@ -601,7 +618,7 @@ func (f *FlagSet) Var(value Value, name string, usage string) {
flag := &Flag{name, usage, value, value.String()} flag := &Flag{name, usage, value, value.String()}
_, alreadythere := f.formal[name] _, alreadythere := f.formal[name]
if alreadythere { if alreadythere {
fmt.Fprintf(os.Stderr, "%s flag redefined: %s\n", f.name, name) fmt.Fprintf(f.out(), "%s flag redefined: %s\n", f.name, name)
panic("flag redefinition") // Happens only if flags are declared with identical names panic("flag redefinition") // Happens only if flags are declared with identical names
} }
if f.formal == nil { if f.formal == nil {
...@@ -624,7 +641,7 @@ func Var(value Value, name string, usage string) { ...@@ -624,7 +641,7 @@ func Var(value Value, name string, usage string) {
// returns the error. // returns the error.
func (f *FlagSet) failf(format string, a ...interface{}) error { func (f *FlagSet) failf(format string, a ...interface{}) error {
err := fmt.Errorf(format, a...) err := fmt.Errorf(format, a...)
fmt.Fprintln(os.Stderr, err) fmt.Fprintln(f.out(), err)
f.usage() f.usage()
return err return err
} }
......
...@@ -5,10 +5,12 @@ ...@@ -5,10 +5,12 @@
package flag_test package flag_test
import ( import (
"bytes"
. "flag" . "flag"
"fmt" "fmt"
"os" "os"
"sort" "sort"
"strings"
"testing" "testing"
"time" "time"
) )
...@@ -206,6 +208,17 @@ func TestUserDefined(t *testing.T) { ...@@ -206,6 +208,17 @@ func TestUserDefined(t *testing.T) {
} }
} }
func TestSetOutput(t *testing.T) {
var flags FlagSet
var buf bytes.Buffer
flags.SetOutput(&buf)
flags.Init("test", ContinueOnError)
flags.Parse([]string{"-unknown"})
if out := buf.String(); !strings.Contains(out, "-unknown") {
t.Logf("expected output mentioning unknown; got %q", out)
}
}
// This tests that one can reset the flags. This still works but not well, and is // This tests that one can reset the flags. This still works but not well, and is
// superseded by FlagSet. // superseded by FlagSet.
func TestChangingArgs(t *testing.T) { func TestChangingArgs(t *testing.T) {
......
...@@ -443,6 +443,14 @@ var fmttests = []struct { ...@@ -443,6 +443,14 @@ var fmttests = []struct {
{"%s", nil, "%!s(<nil>)"}, {"%s", nil, "%!s(<nil>)"},
{"%T", nil, "<nil>"}, {"%T", nil, "<nil>"},
{"%-1", 100, "%!(NOVERB)%!(EXTRA int=100)"}, {"%-1", 100, "%!(NOVERB)%!(EXTRA int=100)"},
// The "<nil>" show up because maps are printed by
// first obtaining a list of keys and then looking up
// each key. Since NaNs can be map keys but cannot
// be fetched directly, the lookup fails and returns a
// zero reflect.Value, which formats as <nil>.
// This test is just to check that it shows the two NaNs at all.
{"%v", map[float64]int{math.NaN(): 1, math.NaN(): 2}, "map[NaN:<nil> NaN:<nil>]"},
} }
func TestSprintf(t *testing.T) { func TestSprintf(t *testing.T) {
...@@ -532,13 +540,14 @@ var _ bytes.Buffer ...@@ -532,13 +540,14 @@ var _ bytes.Buffer
func TestCountMallocs(t *testing.T) { func TestCountMallocs(t *testing.T) {
for _, mt := range mallocTest { for _, mt := range mallocTest {
const N = 100 const N = 100
runtime.UpdateMemStats() memstats := new(runtime.MemStats)
mallocs := 0 - runtime.MemStats.Mallocs runtime.ReadMemStats(memstats)
mallocs := 0 - memstats.Mallocs
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
mt.fn() mt.fn()
} }
runtime.UpdateMemStats() runtime.ReadMemStats(memstats)
mallocs += runtime.MemStats.Mallocs mallocs += memstats.Mallocs
if mallocs/N > uint64(mt.count) { if mallocs/N > uint64(mt.count) {
t.Errorf("%s: expected %d mallocs, got %d", mt.desc, mt.count, mallocs/N) t.Errorf("%s: expected %d mallocs, got %d", mt.desc, mt.count, mallocs/N)
} }
......
...@@ -366,6 +366,7 @@ func newScanState(r io.Reader, nlIsSpace, nlIsEnd bool) (s *ss, old ssave) { ...@@ -366,6 +366,7 @@ func newScanState(r io.Reader, nlIsSpace, nlIsEnd bool) (s *ss, old ssave) {
s.fieldLimit = hugeWid s.fieldLimit = hugeWid
s.maxWid = hugeWid s.maxWid = hugeWid
s.validSave = true s.validSave = true
s.count = 0
return return
} }
......
...@@ -71,6 +71,8 @@ func TestBuild(t *testing.T) { ...@@ -71,6 +71,8 @@ func TestBuild(t *testing.T) {
t.Errorf("ScanDir(%#q): %v", tt.dir, err) t.Errorf("ScanDir(%#q): %v", tt.dir, err)
continue continue
} }
// Don't bother testing import positions.
tt.info.ImportPos, tt.info.TestImportPos = info.ImportPos, info.TestImportPos
if !reflect.DeepEqual(info, tt.info) { if !reflect.DeepEqual(info, tt.info) {
t.Errorf("ScanDir(%#q) = %#v, want %#v\n", tt.dir, info, tt.info) t.Errorf("ScanDir(%#q) = %#v, want %#v\n", tt.dir, info, tt.info)
continue continue
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment